xuan2k commited on
Commit
4979800
1 Parent(s): e5efca7

update demo

Browse files
Files changed (8) hide show
  1. .gitignore +2 -1
  2. .log/log.txt +5 -5
  3. SegFormer +1 -1
  4. mask.png +0 -0
  5. output.png +0 -0
  6. streamlit_test.py +3 -0
  7. test.png +0 -0
  8. test.py +168 -242
.gitignore CHANGED
@@ -2,4 +2,5 @@ __pycache__
2
  *.pyc
3
  checkpoints/
4
  I2SB/
5
- *.pth
 
 
2
  *.pyc
3
  checkpoints/
4
  I2SB/
5
+ *.pth
6
+ SegFormer/
.log/log.txt CHANGED
@@ -1,6 +1,6 @@
1
- [19:02:29] INFO (0:00:00) Loaded options from opt_pkl_path=PosixPath('I2SB/results/inpaint-freeform2030/options.pkl')!
2
  INFO (0:00:00) [Diffusion] Built I2SB diffusion: steps=1000!
3
- [19:02:33] INFO (0:00:03) [Net] Initialized network from ckpt_pkl='I2SB/data/256x256_diffusion_uncond_fixedsigma.pkl'! Size=552807171!
4
- [19:02:44] INFO (0:00:14) [Net] Loaded pretrained adm ckpt_pt='I2SB/data/256x256_diffusion_uncond_fixedsigma.pt'!
5
- [19:02:49] INFO (0:00:19) [Net] Loaded network ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
6
- [19:02:50] INFO (0:00:20) [Ema] Loaded ema ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
 
1
+ [19:58:55] INFO (0:00:00) Loaded options from opt_pkl_path=PosixPath('I2SB/results/inpaint-freeform2030/options.pkl')!
2
  INFO (0:00:00) [Diffusion] Built I2SB diffusion: steps=1000!
3
+ [19:58:58] INFO (0:00:03) [Net] Initialized network from ckpt_pkl='I2SB/data/256x256_diffusion_uncond_fixedsigma.pkl'! Size=552807171!
4
+ [19:59:02] INFO (0:00:07) [Net] Loaded pretrained adm ckpt_pt='I2SB/data/256x256_diffusion_uncond_fixedsigma.pt'!
5
+ [19:59:06] INFO (0:00:11) [Net] Loaded network ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
6
+ [19:59:08] INFO (0:00:13) [Ema] Loaded ema ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
SegFormer CHANGED
@@ -1 +1 @@
1
- Subproject commit 64ab11278eb30b8e2d8ea1d10a777fc5b1563948
 
1
+ Subproject commit ccc3dd500c4091a583b4b2749e35da501e670aca
mask.png ADDED
output.png CHANGED
streamlit_test.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.write("Hello")
test.png CHANGED
test.py CHANGED
@@ -40,6 +40,7 @@ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases
40
  import sys
41
 
42
  sys.path.insert(0, "/home/ubuntu/Thesis-Demo/I2SB")
 
43
 
44
  import numpy as np
45
  import torch
@@ -62,6 +63,18 @@ from I2SB.i2sb import Runner, ckpt_util, download_ckpt
62
  from I2SB.logger import Logger
63
  from I2SB.sample import *
64
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  import cv2
@@ -89,13 +102,6 @@ if os.environ.get('IS_MY_DEBUG') is not None:
89
  inpainting_enable = False
90
  kosmos_enable = False
91
 
92
- if lama_cleaner_enable:
93
- try:
94
- from lama_cleaner.model_manager import ModelManager
95
- from lama_cleaner.schema import Config as lama_Config
96
- except Exception as e:
97
- lama_cleaner_enable = False
98
-
99
  # segment anything
100
  from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
101
 
@@ -191,13 +197,16 @@ def get_point(img, sel_pix, evt: gr.SelectData):
191
 
192
 
193
  def undo_button(orig_img, sel_pix):
194
- temp = orig_img.copy()
195
- temp = np.array(temp, dtype=np.uint8)
196
- if len(sel_pix) != 0:
197
- sel_pix.pop()
198
- for point in sel_pix:
199
- cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
200
- return Image.fromarray(temp).convert("RGB")
 
 
 
201
 
202
  def clear_button(orig_img):
203
 
@@ -256,10 +265,22 @@ def load_i2sb_model():
256
  runner.ema = ExponentialMovingAverage(
257
  runner.net.parameters(), decay=0.99) # re-init ema with fp16 weight
258
 
 
259
  print("Loading time:", (time.time()-s)*1e3, "ms.")
260
  i2sb_model = runner
261
  return runner
262
 
 
 
 
 
 
 
 
 
 
 
 
263
  def plot_boxes_to_image(image_pil, tgt):
264
  H, W = tgt["size"]
265
  boxes = tgt["boxes"]
@@ -326,42 +347,6 @@ def load_image(image_path):
326
  return image_pil, image
327
 
328
 
329
-
330
- def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
331
- caption = caption.lower()
332
- caption = caption.strip()
333
- if not caption.endswith("."):
334
- caption = caption + "."
335
- model = model.to(device)
336
- image = image.to(device)
337
- with torch.no_grad():
338
- outputs = model(image[None], captions=[caption])
339
- logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
340
- boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
341
- logits.shape[0]
342
-
343
- # filter output
344
- logits_filt = logits.clone()
345
- boxes_filt = boxes.clone()
346
- filt_mask = logits_filt.max(dim=1)[0] > box_threshold
347
- logits_filt = logits_filt[filt_mask] # num_filt, 256
348
- boxes_filt = boxes_filt[filt_mask] # num_filt, 4
349
- logits_filt.shape[0]
350
-
351
- # get phrase
352
- tokenlizer = model.tokenizer
353
- tokenized = tokenlizer(caption)
354
- # build pred
355
- pred_phrases = []
356
- for logit, box in zip(logits_filt, boxes_filt):
357
- pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
358
- if with_logits:
359
- pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
360
- else:
361
- pred_phrases.append(pred_phrase)
362
-
363
- return boxes_filt, pred_phrases
364
-
365
  def show_mask(mask, ax, random_color=False):
366
  if random_color:
367
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
@@ -447,99 +432,45 @@ def load_sd_model(device):
447
  )
448
  sd_model = sd_model.to(device)
449
 
450
- def forward_i2sb(img, mask):
451
- print(np.unique(img),mask.shape)
 
 
452
  mask = np.where(mask > 0, 1, 0)
 
 
 
 
 
 
 
 
453
  img_tensor = i2sb_transforms(img).to(
454
  i2sb_opt.device).unsqueeze(0)
455
 
456
  mask_tensor = torch.from_numpy(np.resize(np.array(mask), (256,256))).to(
457
  i2sb_opt.device).unsqueeze(0).unsqueeze(0)
458
- print("POST PROCESSING\t", torch.unique(img_tensor))
459
- # corrupt_tensor = img_tensor * (1. - mask_tensor) + mask_tensor
 
 
 
460
  f = time.time()
461
  xs, _ = i2sb_model.ddpm_sampling(
462
  ckpt_opt, img_tensor, mask=mask_tensor, cond=None, clip_denoise=i2sb_opt.clip_denoise, nfe=nfe, verbose=i2sb_opt.n_gpu_per_node == 1)
463
  recon_img = xs[:, 0, ...].to(i2sb_opt.device)
464
- tu.save_image((recon_img+1)/2, "output.png")
 
465
  print(recon_img.shape)
466
- return transforms.ToPILImage()(((recon_img+1)/2)[0])
467
 
468
- def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
469
- try:
470
- logger.info(f'_______lama_cleaner_process_______1____')
471
- ori_image = image
472
- if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
473
- # rotate image
474
- logger.info(f'_______lama_cleaner_process_______2____')
475
- ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
476
- logger.info(f'_______lama_cleaner_process_______3____')
477
- image = ori_image
478
-
479
- logger.info(f'_______lama_cleaner_process_______4____')
480
- original_shape = ori_image.shape
481
- logger.info(f'_______lama_cleaner_process_______5____')
482
- interpolation = cv2.INTER_CUBIC
483
-
484
- size_limit = cleaner_size_limit
485
- if size_limit == -1:
486
- logger.info(f'_______lama_cleaner_process_______6____')
487
- size_limit = max(image.shape)
488
- else:
489
- logger.info(f'_______lama_cleaner_process_______7____')
490
- size_limit = int(size_limit)
491
-
492
- logger.info(f'_______lama_cleaner_process_______8____')
493
- config = lama_Config(
494
- ldm_steps=25,
495
- ldm_sampler='plms',
496
- zits_wireframe=True,
497
- hd_strategy='Original',
498
- hd_strategy_crop_margin=196,
499
- hd_strategy_crop_trigger_size=1280,
500
- hd_strategy_resize_limit=2048,
501
- prompt='',
502
- use_croper=False,
503
- croper_x=0,
504
- croper_y=0,
505
- croper_height=512,
506
- croper_width=512,
507
- sd_mask_blur=5,
508
- sd_strength=0.75,
509
- sd_steps=50,
510
- sd_guidance_scale=7.5,
511
- sd_sampler='ddim',
512
- sd_seed=42,
513
- cv2_flag='INPAINT_NS',
514
- cv2_radius=5,
515
- )
516
-
517
- logger.info(f'_______lama_cleaner_process_______9____')
518
- if config.sd_seed == -1:
519
- config.sd_seed = random.randint(1, 999999999)
520
-
521
- # logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
522
- logger.info(f'_______lama_cleaner_process_______10____')
523
- image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
524
- # logger.info(f"Resized image shape_1_: {image.shape}")
525
-
526
- # logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
527
- logger.info(f'_______lama_cleaner_process_______11____')
528
- mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
529
- # logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
530
-
531
- logger.info(f'_______lama_cleaner_process_______12____')
532
- res_np_img = lama_cleaner_model(image, mask, config)
533
- logger.info(f'_______lama_cleaner_process_______13____')
534
- torch.cuda.empty_cache()
535
 
536
- logger.info(f'_______lama_cleaner_process_______14____')
537
- image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
538
- logger.info(f'_______lama_cleaner_process_______15____')
539
- except Exception as e:
540
- logger.info(f'lama_cleaner_process[Error]:' + str(e))
541
- image = None
542
- return image
543
 
544
  # visualization
545
  def draw_selected_mask(mask, draw):
@@ -632,27 +563,15 @@ def get_time_cost(run_task_time, time_cost_str):
632
  return run_task_time, time_cost_str
633
 
634
  def run_anything_task(input_image, input_points, origin_image, task_type,
635
- mask_source_radio, cleaner_size_limit=1080):
636
 
637
  run_task_time = 0
638
  time_cost_str = ''
639
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
640
  print("HERE................", task_type)
641
- if (task_type == 'Kosmos-2'):
642
- global kosmos_model, kosmos_processor
643
- if isinstance(input_image, dict):
644
- image_pil, image = load_image(input_image['image'].convert("RGB"))
645
- input_img = input_image['image']
646
- else:
647
- image_pil, image = load_image(input_image.convert("RGB"))
648
- input_img = input_image
649
-
650
- kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil, kosmos_model, kosmos_processor)
651
- run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
652
- return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
653
-
654
  if input_image is None:
655
- return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
656
 
657
  file_temp = int(time.time())
658
  logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/[{mask_source_radio}]_1_')
@@ -682,92 +601,119 @@ def run_anything_task(input_image, input_points, origin_image, task_type,
682
  groundingdino_device = 'cpu'
683
 
684
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
685
- if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment):
686
- image = np.array(input_img)
687
- if sam_predictor:
688
- sam_predictor.set_image(image)
689
-
690
- if sam_predictor:
691
- logger.info(f"Forward with: {input_points}")
692
- masks, _, _, _ = sam_predictor.predict(
693
- point_coords = np.array(input_points),
694
- point_labels = np.array([1 for _ in range(len(input_points))]),
695
- # boxes = transformed_boxes,
696
- multimask_output = False,
697
- )
698
- # masks: [9, 1, 512, 512]
699
- assert sam_checkpoint, 'sam_checkpoint is not found!'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
  else:
701
- run_mode = "rectangle"
702
-
703
- # draw output image
704
- plt.figure(figsize=(10, 10))
705
- plt.imshow(origin_image)
706
- for mask in masks:
707
- show_mask(mask, plt.gca(), random_color=True)
708
- # for box, label in zip(boxes_filt, pred_phrases):
709
- # show_box(box.cpu().numpy(), plt.gca(), label)
710
- plt.axis('off')
711
- image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
712
- plt.savefig(image_path, bbox_inches="tight")
713
- plt.clf()
714
- plt.close('all')
715
- segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
716
- os.remove(image_path)
717
  output_images.append(Image.fromarray(segment_image_result))
718
- run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
719
 
720
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
721
  if task_type == 'detection' or task_type == 'segment':
722
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
723
- return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
724
- elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove':
725
- if mask_source_radio == mask_source_segment:
726
- task_type = 'remove'
727
 
728
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
729
- if mask_source_radio == mask_source_draw:
730
- input_mask_pil = input_image['mask']
731
- input_mask = np.array(input_mask_pil.convert("L"))
732
- mask_pil = input_mask_pil
733
- mask = input_mask
 
 
 
 
 
 
 
 
 
 
734
  else:
735
- masks_ori = copy.deepcopy(masks)
736
- masks = torch.where(masks > 0, True, False)
737
- mask = masks[0][0].cpu().numpy()
738
- mask_pil = Image.fromarray(mask)
 
 
 
 
 
 
 
739
  output_images.append(mask_pil.convert("RGB"))
740
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
741
 
742
- if task_type in ['inpainting', 'outpainting']:
743
  # image_inpainting = sd_model(prompt = "", image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
744
- input_img.save("test.png")
745
- image_inpainting = forward_i2sb(input_img, mask)
746
-
747
- print("RESULT\t", np.array(image_inpainting))
 
 
 
 
748
  else:
749
  # remove from mask
750
  aasds = 1
751
 
752
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
753
- image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
754
  if image_inpainting is None:
755
  logger.info(f'run_anything_task_failed_')
756
- return None, None, None, None, None, None, None
757
 
758
  # output_images.append(image_inpainting)
759
  # run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
760
 
761
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
762
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
 
763
  output_images.append(image_inpainting)
764
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
765
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
766
- return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
767
  else:
768
  logger.info(f"task_type:{task_type} error!")
769
  logger.info(f'run_anything_task_[{file_temp}]_9_9_')
770
- return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
771
 
772
  def change_radio_display(task_type, mask_source_radio, orig_img):
773
  mask_source_radio_visible = False
@@ -789,20 +735,19 @@ def change_radio_display(task_type, mask_source_radio, orig_img):
789
  mask_source_radio_visible = True
790
  if task_type == "relate anything":
791
  num_relation_visible = True
792
- if task_type == "segment":
793
- ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
794
- elif task_type == "inpainting":
795
  ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
 
 
796
 
797
  return (gr.Radio.update(visible=mask_source_radio_visible),
798
  gr.Slider.update(visible=num_relation_visible),
799
  gr.Gallery.update(visible=image_gallery_visible),
800
- gr.Radio.update(visible=kosmos_input_visible),
801
- gr.Image.update(visible=kosmos_output_visible),
802
- gr.HighlightedText.update(visible=kosmos_text_output_visible),
803
  ret, [],
804
- gr.Button("Undo point", visible = task_type == "segment"),
805
- gr.Button("Clear point", visible = task_type == "segment"),)
806
 
807
  def get_model_device(module):
808
  try:
@@ -832,10 +777,11 @@ def main_gradio(args):
832
  with gr.Row():
833
  with gr.Column():
834
  selected_points = gr.State([])
835
- original_image = gr.State()
836
  task_types = ["segment"]
837
  if inpainting_enable:
838
  task_types.append("inpainting")
 
839
 
840
 
841
  input_image = gr.Image(elem_id="image_upload", type='pil', label="Upload", height=512)
@@ -854,7 +800,7 @@ def main_gradio(args):
854
  with gr.Row():
855
  with gr.Column():
856
 
857
- undo_point_button = gr.Button("Undo point")
858
  undo_point_button.click(
859
  fn= undo_button,
860
  inputs=[original_image, selected_points],
@@ -863,7 +809,7 @@ def main_gradio(args):
863
 
864
  with gr.Column():
865
 
866
- clear_point_button = gr.Button("Clear point")
867
  clear_point_button.click(
868
  fn= clear_button,
869
  inputs=[original_image],
@@ -876,10 +822,15 @@ def main_gradio(args):
876
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
877
  value=mask_source_draw, label="Mask from",
878
  visible=False)
 
 
 
 
 
 
 
879
  num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
880
 
881
- kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
882
-
883
  run_button = gr.Button(label="Run", visible=True)
884
  # with gr.Accordion("Advanced options", open=False) as advanced_options:
885
  # box_threshold = gr.Slider(
@@ -900,47 +851,21 @@ def main_gradio(args):
900
 
901
  with gr.Column():
902
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
903
- ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
904
  time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
905
 
906
- kosmos_output = gr.Image(type="pil", label="result images", visible=False)
907
- kosmos_text_output = gr.HighlightedText(
908
- label="Generated Description",
909
- combine_adjacent=False,
910
- show_legend=True,
911
- visible=False,
912
- ).style(color_map=color_map)
913
- # record which text span (label) is selected
914
- selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
915
-
916
- # record the current `entities`
917
- entity_output = gr.Textbox(visible=False)
918
-
919
- # get the current selected span label
920
- def get_text_span_label(evt: gr.SelectData):
921
- if evt.value[-1] is None:
922
- return -1
923
- return int(evt.value[-1])
924
- # and set this information to `selected`
925
- kosmos_text_output.select(get_text_span_label, None, selected)
926
 
927
- # update output image when we change the span (enity) selection
928
- def update_output_image(img_input, image_output, entities, idx):
929
- entities = ast.literal_eval(entities)
930
- updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
931
- return updated_image
932
- selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
933
 
934
  run_button.click(fn=run_anything_task, inputs=[
935
  input_image, selected_points, original_image, task_type,
936
- mask_source_radio],
937
- outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
938
 
939
  mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
940
  outputs=[mask_source_radio, num_relation])
941
  task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
942
  outputs=[mask_source_radio, num_relation,
943
- image_gallery, kosmos_input, kosmos_output, kosmos_text_output, input_image, selected_points, undo_point_button, clear_point_button
944
  ])
945
 
946
  # DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
@@ -985,6 +910,7 @@ if __name__ == "__main__":
985
 
986
  if sam_enable:
987
  load_sam_model(device)
 
988
 
989
  if inpainting_enable:
990
  load_sd_model(device)
 
40
  import sys
41
 
42
  sys.path.insert(0, "/home/ubuntu/Thesis-Demo/I2SB")
43
+ sys.path.insert(0, "/home/ubuntu/Thesis-Demo/SegFormer")
44
 
45
  import numpy as np
46
  import torch
 
63
  from I2SB.logger import Logger
64
  from I2SB.sample import *
65
 
66
+ from pathlib import Path
67
+
68
+ inpaint_checkpoint = Path("/home/ubuntu/Thesis-Demo/I2SB/results")
69
+
70
+ if not inpaint_checkpoint.exists():
71
+ os.system("pip install transformers==4.32.0")
72
+
73
+ # SegFormer
74
+ from PIL import Image
75
+
76
+ from SegFormer.mmseg.apis import inference_segmentor, init_segmentor, visualize_result_pyplot
77
+ from SegFormer.mmseg.core.evaluation import get_palette
78
 
79
 
80
  import cv2
 
102
  inpainting_enable = False
103
  kosmos_enable = False
104
 
 
 
 
 
 
 
 
105
  # segment anything
106
  from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
107
 
 
197
 
198
 
199
  def undo_button(orig_img, sel_pix):
200
+ if orig_img:
201
+ temp = orig_img.copy()
202
+ temp = np.array(temp, dtype=np.uint8)
203
+ if len(sel_pix) != 0:
204
+ sel_pix.pop()
205
+ for point in sel_pix:
206
+ cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
207
+ return Image.fromarray(temp).convert("RGB")
208
+ return orig_img
209
+
210
 
211
  def clear_button(orig_img):
212
 
 
265
  runner.ema = ExponentialMovingAverage(
266
  runner.net.parameters(), decay=0.99) # re-init ema with fp16 weight
267
 
268
+ logger.info(f"I2SB Loading time:\t {(time.time()-s)*1e3} ms.")
269
  print("Loading time:", (time.time()-s)*1e3, "ms.")
270
  i2sb_model = runner
271
  return runner
272
 
273
+ def load_segformer(device):
274
+ global segformer_model
275
+ s = time.time()
276
+ config = "SegFormer/local_configs/segformer/B3/segformer.b3.256x256.wtm.160k.py"
277
+ checkpoint = "SegFormer/work_dirs/segformer.b3.256x256.wtm.160k/iter_160000.pth"
278
+ model = init_segmentor(config, checkpoint, device=device)
279
+
280
+ logger.info(f"SegFormer Loading time:\t {(time.time()-s)*1e3} ms.")
281
+ segformer_model = model
282
+ return model
283
+
284
  def plot_boxes_to_image(image_pil, tgt):
285
  H, W = tgt["size"]
286
  boxes = tgt["boxes"]
 
347
  return image_pil, image
348
 
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  def show_mask(mask, ax, random_color=False):
351
  if random_color:
352
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
 
432
  )
433
  sd_model = sd_model.to(device)
434
 
435
+ def forward_i2sb(img, mask, dilation_mask_extend):
436
+
437
+
438
+ print(np.unique(mask),mask.shape)
439
  mask = np.where(mask > 0, 1, 0)
440
+ print(np.unique(mask),mask.shape)
441
+ mask = mask.astype(np.uint8)
442
+ if dilation_mask_extend.isdigit():
443
+
444
+ kernel_size = int(dilation_mask_extend)
445
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (int(kernel_size), int(kernel_size)))
446
+ mask = cv2.dilate(mask, kernel, iterations = 1)
447
+
448
  img_tensor = i2sb_transforms(img).to(
449
  i2sb_opt.device).unsqueeze(0)
450
 
451
  mask_tensor = torch.from_numpy(np.resize(np.array(mask), (256,256))).to(
452
  i2sb_opt.device).unsqueeze(0).unsqueeze(0)
453
+ # print("POST PROCESSING\t", torch.unique(img_tensor))
454
+ corrupt_tensor = img_tensor * (1. - mask_tensor) + mask_tensor
455
+ print("DOUBLE CHECK:\t", corrupt_tensor.shape)
456
+ print("DOUBLE CHECK:\t", img_tensor.shape)
457
+ print("DOUBLE CHECK:\t", mask_tensor.shape)
458
  f = time.time()
459
  xs, _ = i2sb_model.ddpm_sampling(
460
  ckpt_opt, img_tensor, mask=mask_tensor, cond=None, clip_denoise=i2sb_opt.clip_denoise, nfe=nfe, verbose=i2sb_opt.n_gpu_per_node == 1)
461
  recon_img = xs[:, 0, ...].to(i2sb_opt.device)
462
+ # tu.save_image((recon_img+1)/2, "output.png")
463
+ # tu.save_image((corrupt_tensor+1)/2, "output.png")
464
  print(recon_img.shape)
465
+ return transforms.ToPILImage()(((recon_img+1)/2)[0]), transforms.ToPILImage()(((corrupt_tensor+1)/2)[0])
466
 
467
+ def forward_segformer(img):
468
+ img_np = np.array(img)
469
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
+ result = inference_segmentor(segformer_model, img_np)
472
+
473
+ return np.asarray(result[0], dtype=np.uint8)
 
 
 
 
474
 
475
  # visualization
476
  def draw_selected_mask(mask, draw):
 
563
  return run_task_time, time_cost_str
564
 
565
  def run_anything_task(input_image, input_points, origin_image, task_type,
566
+ mask_source_radio, segmentation_radio, dilation_mask_extend):
567
 
568
  run_task_time = 0
569
  time_cost_str = ''
570
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
571
  print("HERE................", task_type)
572
+
 
 
 
 
 
 
 
 
 
 
 
 
573
  if input_image is None:
574
+ return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
575
 
576
  file_temp = int(time.time())
577
  logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/[{mask_source_radio}]_1_')
 
601
  groundingdino_device = 'cpu'
602
 
603
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
604
+ if task_type == 'segment' or task_type == 'pipeline':
605
+ image = np.array(origin_image)
606
+ if segmentation_radio == "SAM":
607
+ if sam_predictor:
608
+ sam_predictor.set_image(image)
609
+
610
+ if sam_predictor:
611
+ logger.info(f"Forward with: {input_points}")
612
+ masks, _, _, _ = sam_predictor.predict(
613
+ point_coords = np.array(input_points),
614
+ point_labels = np.array([1 for _ in range(len(input_points))]),
615
+ # boxes = transformed_boxes,
616
+ multimask_output = False,
617
+ )
618
+ # masks: [9, 1, 512, 512]
619
+ assert sam_checkpoint, 'sam_checkpoint is not found!'
620
+ else:
621
+ run_mode = "rectangle"
622
+
623
+ # draw output image
624
+ plt.figure(figsize=(10, 10))
625
+ plt.imshow(origin_image)
626
+ for mask in masks:
627
+ show_mask(mask, plt.gca(), random_color=True)
628
+ # for box, label in zip(boxes_filt, pred_phrases):
629
+ # show_box(box.cpu().numpy(), plt.gca(), label)
630
+ plt.axis('off')
631
+ image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
632
+ plt.savefig(image_path, bbox_inches="tight")
633
+ plt.clf()
634
+ plt.close('all')
635
+ segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
636
+ os.remove(image_path)
637
+
638
  else:
639
+ masks = forward_segformer(image)
640
+
641
+ segment_image_result = visualize_result_pyplot(segformer_model, image, masks, get_palette("wtm"), dilation=dilation_mask_extend)# if task_type == "pipeline" else None)
642
+
 
 
 
 
 
 
 
 
 
 
 
 
643
  output_images.append(Image.fromarray(segment_image_result))
644
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
645
+
646
 
647
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
648
  if task_type == 'detection' or task_type == 'segment':
649
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
650
+ return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
651
+ elif task_type in ['inpainting', 'outpainting'] or task_type == 'pipeline':
 
 
652
 
653
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
654
+ if task_type == "pipeline":
655
+ if segmentation_radio == "SAM":
656
+ masks_ori = copy.deepcopy(masks)
657
+ print(masks.shape)
658
+ # masks = torch.where(masks > 0, True, False)
659
+ mask = masks[0]
660
+ mask_pil = Image.fromarray(mask)
661
+ mask = np.where(mask == True, 1, 0)
662
+ else:
663
+ mask = masks
664
+ save_mask = copy.deepcopy(mask)
665
+ save_mask = np.where(mask > 0, 255, 0).astype(np.uint8)
666
+ print((save_mask.dtype))
667
+ mask_pil = Image.fromarray(save_mask)
668
+
669
  else:
670
+ if mask_source_radio == mask_source_draw:
671
+ input_mask_pil = input_image['mask']
672
+ input_mask = np.array(input_mask_pil.convert("L"))
673
+ mask_pil = input_mask_pil
674
+ mask = input_mask
675
+ else:
676
+ pass
677
+ # masks_ori = copy.deepcopy(masks)
678
+ # masks = torch.where(masks > 0, True, False)
679
+ # mask = masks[0][0].cpu().numpy()
680
+ # mask_pil = Image.fromarray(mask)
681
  output_images.append(mask_pil.convert("RGB"))
682
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
683
 
684
+ if task_type in ['inpainting', 'pipeline']:
685
  # image_inpainting = sd_model(prompt = "", image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
686
+ # input_img.save("test.png")
687
+ w, h = input_img.size
688
+ input_img = input_img.resize((256,256))
689
+ image_inpainting, corrupted = forward_i2sb(input_img, mask, dilation_mask_extend)
690
+ input_img = input_img.resize((w,h))
691
+ corrupted = corrupted.resize((w,h))
692
+ image_inpainting = image_inpainting.resize((w,h))
693
+ # print("RESULT\t", np.array(image_inpainting))
694
  else:
695
  # remove from mask
696
  aasds = 1
697
 
698
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
 
699
  if image_inpainting is None:
700
  logger.info(f'run_anything_task_failed_')
701
+ return None, None, None, None
702
 
703
  # output_images.append(image_inpainting)
704
  # run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
705
 
706
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
707
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
708
+ output_images.append(corrupted)
709
  output_images.append(image_inpainting)
710
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
711
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
712
+ return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
713
  else:
714
  logger.info(f"task_type:{task_type} error!")
715
  logger.info(f'run_anything_task_[{file_temp}]_9_9_')
716
+ return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
717
 
718
  def change_radio_display(task_type, mask_source_radio, orig_img):
719
  mask_source_radio_visible = False
 
735
  mask_source_radio_visible = True
736
  if task_type == "relate anything":
737
  num_relation_visible = True
738
+ if task_type == "inpainting":
 
 
739
  ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
740
+ elif task_type in ["segment", "pipeline"]:
741
+ ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
742
 
743
  return (gr.Radio.update(visible=mask_source_radio_visible),
744
  gr.Slider.update(visible=num_relation_visible),
745
  gr.Gallery.update(visible=image_gallery_visible),
746
+ gr.Radio(["SegFormer", "SAM"], value="SAM", label="Segementation Model", visible= task_type != "inpainting"),
747
+ gr.Textbox(label="Dilation kernel size", value='7', visible= task_type == "pipeline"),
 
748
  ret, [],
749
+ gr.Button("Undo point", visible = task_type != "inpainting"),
750
+ gr.Button("Clear point", visible = task_type != "inpainting"),)
751
 
752
  def get_model_device(module):
753
  try:
 
777
  with gr.Row():
778
  with gr.Column():
779
  selected_points = gr.State([])
780
+ original_image = gr.State(None)
781
  task_types = ["segment"]
782
  if inpainting_enable:
783
  task_types.append("inpainting")
784
+ task_types.append("pipeline")
785
 
786
 
787
  input_image = gr.Image(elem_id="image_upload", type='pil', label="Upload", height=512)
 
800
  with gr.Row():
801
  with gr.Column():
802
 
803
+ undo_point_button = gr.Button("Undo point", visible= True if original_image is not None else False)
804
  undo_point_button.click(
805
  fn= undo_button,
806
  inputs=[original_image, selected_points],
 
809
 
810
  with gr.Column():
811
 
812
+ clear_point_button = gr.Button("Clear point", visible= True if original_image is not None else False)
813
  clear_point_button.click(
814
  fn= clear_button,
815
  inputs=[original_image],
 
822
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
823
  value=mask_source_draw, label="Mask from",
824
  visible=False)
825
+
826
+ segmentation_radio = gr.Radio(["SegFormer", "SAM"],
827
+ value="SAM", label="Segementation Model",
828
+ visible=True)
829
+
830
+ dilation_mask_extend = gr.Textbox(label="Dilation kernel size", value='5', visible=False)
831
+
832
  num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
833
 
 
 
834
  run_button = gr.Button(label="Run", visible=True)
835
  # with gr.Accordion("Advanced options", open=False) as advanced_options:
836
  # box_threshold = gr.Slider(
 
851
 
852
  with gr.Column():
853
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
854
+ ).style(preview=True, columns=[5], object_fit="scale-down", height=512)
855
  time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857
 
 
 
 
 
 
 
858
 
859
  run_button.click(fn=run_anything_task, inputs=[
860
  input_image, selected_points, original_image, task_type,
861
+ mask_source_radio, segmentation_radio, dilation_mask_extend],
862
+ outputs=[image_gallery, image_gallery, time_cost, time_cost], show_progress=True, queue=True)
863
 
864
  mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
865
  outputs=[mask_source_radio, num_relation])
866
  task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
867
  outputs=[mask_source_radio, num_relation,
868
+ image_gallery, segmentation_radio, dilation_mask_extend, input_image, selected_points, undo_point_button, clear_point_button
869
  ])
870
 
871
  # DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
 
910
 
911
  if sam_enable:
912
  load_sam_model(device)
913
+ load_segformer(device)
914
 
915
  if inpainting_enable:
916
  load_sd_model(device)