liuyizhang commited on
Commit
7a7f9d8
1 Parent(s): 5d0da89
app.py CHANGED
@@ -44,7 +44,7 @@ from lama_cleaner.model_manager import ModelManager
44
  from lama_cleaner.schema import Config
45
 
46
  # segment anything
47
- from segment_anything import build_sam, SamPredictor
48
 
49
  # diffusers
50
  import PIL
@@ -238,6 +238,7 @@ groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
238
  # initialize SAM
239
  logger.info(f"initialize SAM model...")
240
  sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
 
241
 
242
  # initialize stable-diffusion-inpainting
243
  logger.info(f"initialize stable-diffusion-inpainting...")
@@ -319,11 +320,168 @@ def lama_cleaner_process(image, mask):
319
  image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
320
  return image
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  mask_source_draw = "draw a mask on input image"
323
  mask_source_segment = "type what to detect below"
324
 
325
- def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
326
- iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
 
 
 
327
  text_prompt = text_prompt.strip()
328
  if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
329
  if text_prompt == '':
@@ -333,7 +491,7 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
333
  return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂')
334
 
335
  file_temp = int(time.time())
336
- logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
337
 
338
  # load image
339
  input_mask_pil = input_image['mask']
@@ -364,7 +522,7 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
364
  groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
365
  )
366
  if boxes_filt.size(0) == 0:
367
- logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_[{text_prompt}]_1_[No objects detected, please try others.]_')
368
  return [], gr.Gallery.update(label='No objects detected, please try others.😂😂😂😂')
369
  boxes_filt_ori = copy.deepcopy(boxes_filt)
370
 
@@ -380,7 +538,7 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
380
  os.remove(image_path)
381
  output_images.append(detection_image_result)
382
 
383
- logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_2_')
384
  if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
385
  image = np.array(input_image['image'])
386
  sam_predictor.set_image(image)
@@ -416,15 +574,15 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
416
  os.remove(image_path)
417
  output_images.append(segment_image_result)
418
 
419
- logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_3_')
420
  if task_type == 'detection' or task_type == 'segment':
421
- logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_9_')
422
  return output_images, gr.Gallery.update(label='result images')
423
  elif task_type == 'inpainting' or task_type == 'remove':
424
  if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
425
  task_type = 'remove'
426
 
427
- logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_4_')
428
  if mask_source_radio == mask_source_draw:
429
  mask_pil = input_mask_pil
430
  mask = input_mask
@@ -437,6 +595,8 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
437
  mask_pil = Image.fromarray(mask)
438
 
439
  image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
 
 
440
  mask_pil.convert("RGB").save(image_path)
441
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
442
  os.remove(image_path)
@@ -480,6 +640,8 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
480
  mask_pil = mix_masks(mask_imgs)
481
 
482
  image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
 
 
483
  mask_pil.convert("RGB").save(image_path)
484
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
485
  os.remove(image_path)
@@ -492,25 +654,35 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
492
  image_inpainting.save(image_path)
493
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
494
  os.remove(image_path)
495
- logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_9_')
496
  output_images.append(image_result)
497
  return output_images, gr.Gallery.update(label='result images')
498
  else:
499
  logger.info(f"task_type:{task_type} error!")
500
- logger.info(f'run_grounded_sam_[{file_temp}]_9_9_')
501
  return output_images, gr.Gallery.update(label='result images')
502
 
503
- def change_radio_display(task_type, mask_source_radio):
504
  text_prompt_visible = True
505
  inpaint_prompt_visible = False
506
  mask_source_radio_visible = False
 
 
 
507
  if task_type == "inpainting":
508
  inpaint_prompt_visible = True
509
  if task_type == "inpainting" or task_type == "remove":
510
  mask_source_radio_visible = True
511
  if mask_source_radio == mask_source_draw:
512
  text_prompt_visible = False
513
- return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible)
 
 
 
 
 
 
 
514
 
515
  if __name__ == "__main__":
516
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
@@ -525,15 +697,16 @@ if __name__ == "__main__":
525
  with gr.Row():
526
  with gr.Column():
527
  input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
528
- task_type = gr.Radio(["detection", "segment", "inpainting", "remove"], value="detection",
529
- label='Task type',interactive=True, visible=True)
530
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
531
  value=mask_source_segment, label="Mask from",
532
- interactive=True, visible=False)
533
  text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
534
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
 
535
  run_button = gr.Button(label="Run")
536
- with gr.Accordion("Advanced options", open=False):
537
  box_threshold = gr.Slider(
538
  label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
539
  )
@@ -551,14 +724,16 @@ if __name__ == "__main__":
551
  remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
552
 
553
  with gr.Column():
554
- gallery = gr.Gallery(
555
- label="result images", show_label=True, elem_id="gallery"
556
- ).style(grid=[2], full_width=True, full_height=True)
557
-
558
- run_button.click(fn=run_grounded_sam, inputs=[
559
- input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend], outputs=[gallery, gallery])
560
- task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio])
561
- mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio])
 
 
562
 
563
  DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). Thanks for their excellent work.'
564
  DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
44
  from lama_cleaner.schema import Config
45
 
46
  # segment anything
47
+ from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
48
 
49
  # diffusers
50
  import PIL
238
  # initialize SAM
239
  logger.info(f"initialize SAM model...")
240
  sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
241
+ sam_mask_generator = SamAutomaticMaskGenerator(sam_predictor)
242
 
243
  # initialize stable-diffusion-inpainting
244
  logger.info(f"initialize stable-diffusion-inpainting...")
320
  image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
321
  return image
322
 
323
+ # relate anything
324
+ from ram_utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, show_mask
325
+ from ram_train_eval import RamModel,RamPredictor
326
+ from mmengine.config import Config
327
+ input_size = 512
328
+ hidden_size = 256
329
+ num_classes = 56
330
+
331
+ # load ram model
332
+ model_path = "./checkpoints/ram_epoch12.pth"
333
+ config = dict(
334
+ model=dict(
335
+ pretrained_model_name_or_path='bert-base-uncased',
336
+ load_pretrained_weights=False,
337
+ num_transformer_layer=2,
338
+ input_feature_size=256,
339
+ output_feature_size=768,
340
+ cls_feature_size=512,
341
+ num_relation_classes=56,
342
+ pred_type='attention',
343
+ loss_type='multi_label_ce',
344
+ ),
345
+ load_from=model_path,
346
+ )
347
+ config = Config(config)
348
+
349
+ class Predictor(RamPredictor, device='cpu'):
350
+ def __init__(self,config):
351
+ self.config = config
352
+ self.device = torch.device(device)
353
+ self._build_model()
354
+
355
+ def _build_model(self):
356
+ self.model = RamModel(**self.config.model).to(self.device)
357
+ if self.config.load_from is not None:
358
+ self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
359
+ self.model.train()
360
+ ram_model = Predictor(config, device)
361
+
362
+ # visualization
363
+ def draw_selected_mask(mask, draw):
364
+ color = (255, 0, 0, 153)
365
+ nonzero_coords = np.transpose(np.nonzero(mask))
366
+ for coord in nonzero_coords:
367
+ draw.point(coord[::-1], fill=color)
368
+
369
+ def draw_object_mask(mask, draw):
370
+ color = (0, 0, 255, 153)
371
+ nonzero_coords = np.transpose(np.nonzero(mask))
372
+ for coord in nonzero_coords:
373
+ draw.point(coord[::-1], fill=color)
374
+
375
+ def create_title_image(word1, word2, word3, width, font_path='./assets/OpenSans-Bold.ttf'):
376
+ # Define the colors to use for each word
377
+ color_red = (255, 0, 0)
378
+ color_black = (0, 0, 0)
379
+ color_blue = (0, 0, 255)
380
+
381
+ # Define the initial font size and spacing between words
382
+ font_size = 40
383
+
384
+ # Create a new image with the specified width and white background
385
+ image = Image.new('RGB', (width, 60), (255, 255, 255))
386
+
387
+ # Load the specified font
388
+ font = ImageFont.truetype(font_path, font_size)
389
+
390
+ # Keep increasing the font size until all words fit within the desired width
391
+ while True:
392
+ # Create a draw object for the image
393
+ draw = ImageDraw.Draw(image)
394
+
395
+ word_spacing = font_size / 2
396
+ # Draw each word in the appropriate color
397
+ x_offset = word_spacing
398
+ draw.text((x_offset, 0), word1, color_red, font=font)
399
+ x_offset += font.getsize(word1)[0] + word_spacing
400
+ draw.text((x_offset, 0), word2, color_black, font=font)
401
+ x_offset += font.getsize(word2)[0] + word_spacing
402
+ draw.text((x_offset, 0), word3, color_blue, font=font)
403
+
404
+ word_sizes = [font.getsize(word) for word in [word1, word2, word3]]
405
+ total_width = sum([size[0] for size in word_sizes]) + word_spacing * 3
406
+
407
+ # Stop increasing font size if the image is within the desired width
408
+ if total_width <= width:
409
+ break
410
+
411
+ # Increase font size and reset the draw object
412
+ font_size -= 1
413
+ image = Image.new('RGB', (width, 50), (255, 255, 255))
414
+ font = ImageFont.truetype(font_path, font_size)
415
+ draw = None
416
+
417
+ return image
418
+
419
+ def concatenate_images_vertical(image1, image2):
420
+ # Get the dimensions of the two images
421
+ width1, height1 = image1.size
422
+ width2, height2 = image2.size
423
+
424
+ # Create a new image with the combined height and the maximum width
425
+ new_image = Image.new('RGBA', (max(width1, width2), height1 + height2))
426
+
427
+ # Paste the first image at the top of the new image
428
+ new_image.paste(image1, (0, 0))
429
+
430
+ # Paste the second image below the first image
431
+ new_image.paste(image2, (0, height1))
432
+
433
+ return new_image
434
+
435
+ def relate_anything(input_image, k):
436
+ w, h = input_image.size
437
+ max_edge = 1500
438
+ if w > max_edge or h > max_edge:
439
+ ratio = max(w, h) / max_edge
440
+ new_size = (int(w / ratio), int(h / ratio))
441
+ input_image.thumbnail(new_size)
442
+
443
+ # load image
444
+ pil_image = input_image.convert('RGBA')
445
+ image = np.array(input_image)
446
+ sam_masks = sam_mask_generator.generate(image)
447
+ filtered_masks = sort_and_deduplicate(sam_masks)
448
+
449
+ feat_list = []
450
+ for fm in filtered_masks:
451
+ feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
452
+ feat_list.append(feat)
453
+ feat = torch.cat(feat_list, dim=1).to(device)
454
+ matrix_output, rel_triplets = ram_model.predict(feat)
455
+
456
+ pil_image_list = []
457
+ for i, rel in enumerate(rel_triplets[:k]):
458
+ s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
459
+ relation = relation_classes[r]
460
+
461
+ mask_image = Image.new('RGBA', pil_image.size, color=(0, 0, 0, 0))
462
+ mask_draw = ImageDraw.Draw(mask_image)
463
+
464
+ draw_selected_mask(filtered_masks[s]['segmentation'], mask_draw)
465
+ draw_object_mask(filtered_masks[o]['segmentation'], mask_draw)
466
+
467
+ current_pil_image = pil_image.copy()
468
+ current_pil_image.alpha_composite(mask_image)
469
+
470
+ title_image = create_title_image('Red', relation, 'Blue', current_pil_image.size[0])
471
+ concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
472
+ pil_image_list.append(concate_pil_image)
473
+
474
+ yield pil_image_list
475
+
476
+
477
  mask_source_draw = "draw a mask on input image"
478
  mask_source_segment = "type what to detect below"
479
 
480
+ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
481
+ iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation):
482
+ if task_type == "relate anything":
483
+ return relate_anything(input_image['image'], num_relation)
484
+
485
  text_prompt = text_prompt.strip()
486
  if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
487
  if text_prompt == '':
491
  return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂')
492
 
493
  file_temp = int(time.time())
494
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
495
 
496
  # load image
497
  input_mask_pil = input_image['mask']
522
  groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
523
  )
524
  if boxes_filt.size(0) == 0:
525
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1_[No objects detected, please try others.]_')
526
  return [], gr.Gallery.update(label='No objects detected, please try others.😂😂😂😂')
527
  boxes_filt_ori = copy.deepcopy(boxes_filt)
528
 
538
  os.remove(image_path)
539
  output_images.append(detection_image_result)
540
 
541
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
542
  if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
543
  image = np.array(input_image['image'])
544
  sam_predictor.set_image(image)
574
  os.remove(image_path)
575
  output_images.append(segment_image_result)
576
 
577
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
578
  if task_type == 'detection' or task_type == 'segment':
579
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
580
  return output_images, gr.Gallery.update(label='result images')
581
  elif task_type == 'inpainting' or task_type == 'remove':
582
  if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
583
  task_type = 'remove'
584
 
585
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
586
  if mask_source_radio == mask_source_draw:
587
  mask_pil = input_mask_pil
588
  mask = input_mask
595
  mask_pil = Image.fromarray(mask)
596
 
597
  image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
598
+ # if reverse_mask:
599
+ # mask_pil = mask_pil.point(lambda _: 255-_)
600
  mask_pil.convert("RGB").save(image_path)
601
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
602
  os.remove(image_path)
640
  mask_pil = mix_masks(mask_imgs)
641
 
642
  image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
643
+ # if reverse_mask:
644
+ # mask_pil = mask_pil.point(lambda _: 255-_)
645
  mask_pil.convert("RGB").save(image_path)
646
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
647
  os.remove(image_path)
654
  image_inpainting.save(image_path)
655
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
656
  os.remove(image_path)
657
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
658
  output_images.append(image_result)
659
  return output_images, gr.Gallery.update(label='result images')
660
  else:
661
  logger.info(f"task_type:{task_type} error!")
662
+ logger.info(f'run_anything_task_[{file_temp}]_9_9_')
663
  return output_images, gr.Gallery.update(label='result images')
664
 
665
+ def change_radio_display(task_type, mask_source_radio, num_relation): #, gsa_gallery, ram_gallery):
666
  text_prompt_visible = True
667
  inpaint_prompt_visible = False
668
  mask_source_radio_visible = False
669
+ num_relation_visible = False
670
+ # gsa_gallery_visible = True
671
+ # ram_gallery_visible = False
672
  if task_type == "inpainting":
673
  inpaint_prompt_visible = True
674
  if task_type == "inpainting" or task_type == "remove":
675
  mask_source_radio_visible = True
676
  if mask_source_radio == mask_source_draw:
677
  text_prompt_visible = False
678
+ if task_type == "relate anything":
679
+ text_prompt_visible = False
680
+ num_relation_visible = True
681
+ # gsa_gallery_visible = False
682
+ # ram_gallery_visible = True
683
+ return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible),
684
+ gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
685
+ # gr.Gallery.update(visible=gas_gallery_visible), gr.Gallery.update(visible=ram_gallery_visible)
686
 
687
  if __name__ == "__main__":
688
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
697
  with gr.Row():
698
  with gr.Column():
699
  input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
700
+ task_type = gr.Radio(["detection", "segment", "inpainting", "remove", "relate anything"], value="detection",
701
+ label='Task type', visible=True)
702
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
703
  value=mask_source_segment, label="Mask from",
704
+ visible=False)
705
  text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
706
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
707
+ num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
708
  run_button = gr.Button(label="Run")
709
+ with gr.Accordion("Advanced options", open=False) as advanced_options:
710
  box_threshold = gr.Slider(
711
  label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
712
  )
724
  remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
725
 
726
  with gr.Column():
727
+ # gsa_gallery = gr.Gallery(
728
+ # label="result images", show_label=True, elem_id="gsa_gallery"
729
+ # ).style(grid=[2], full_width=True, full_height=True)
730
+ gallery = gr.Gallery(label="Your Result", show_label=True, elem_id="gallery").style(preview=True, columns=5, object_fit="scale-down")
731
+
732
+
733
+ run_button.click(fn=run_anything_task, inputs=[
734
+ input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation], outputs=[gsa_gallery, gsa_gallery])
735
+ task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
736
+ mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
737
 
738
  DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). Thanks for their excellent work.'
739
  DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
assets/OpenSans-Bold.ttf ADDED
Binary file (225 kB). View file
checkpoints/ram_epoch12.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:036ddbb89e3376b61cb548c8cac3007c3ab7236fb6ac82207d4ccf4039654297
3
+ size 333991817
ram_train_eval.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from datetime import timedelta
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from mmengine.config import Config
10
+ from mmengine.utils import ProgressBar
11
+ from transformers import AutoConfig, AutoModel
12
+
13
+ class RamDataset(torch.utils.data.Dataset):
14
+ def __init__(self, data_path, is_train=True, num_relation_classes=56):
15
+ super().__init__()
16
+ self.num_relation_classes = num_relation_classes
17
+ data = np.load(data_path, allow_pickle=True)
18
+ self.samples = data["arr_0"]
19
+ sample_num = self.samples.size
20
+ self.sample_idx_list = []
21
+ for idx in range(sample_num):
22
+ if self.samples[idx]["is_train"] == is_train:
23
+ self.sample_idx_list.append(idx)
24
+
25
+ def __getitem__(self, idx):
26
+ sample = self.samples[self.sample_idx_list[idx]]
27
+ object_num = sample["feat"].shape[0]
28
+ embedding = torch.from_numpy(sample["feat"])
29
+ gt_rels = sample["relations"]
30
+ rel_target = self._get_target(object_num, gt_rels)
31
+ return embedding, rel_target, gt_rels
32
+
33
+ def __len__(self):
34
+ return len(self.sample_idx_list)
35
+
36
+ def _get_target(self, object_num, gt_rels):
37
+ rel_target = torch.zeros([self.num_relation_classes, object_num, object_num])
38
+ for ii, jj, cls_relationship in gt_rels:
39
+ rel_target[cls_relationship, ii, jj] = 1
40
+ return rel_target
41
+
42
+
43
+ class RamModel(nn.Module):
44
+ def __init__(
45
+ self,
46
+ pretrained_model_name_or_path,
47
+ load_pretrained_weights=True,
48
+ num_transformer_layer=2,
49
+ input_feature_size=256,
50
+ output_feature_size=768,
51
+ cls_feature_size=512,
52
+ num_relation_classes=56,
53
+ pred_type="attention",
54
+ loss_type="bce",
55
+ ):
56
+ super().__init__()
57
+ # 0. config
58
+ self.cls_feature_size = cls_feature_size
59
+ self.num_relation_classes = num_relation_classes
60
+ self.pred_type = pred_type
61
+ self.loss_type = loss_type
62
+
63
+ # 1. fc input and output
64
+ self.fc_input = nn.Sequential(
65
+ nn.Linear(input_feature_size, output_feature_size),
66
+ nn.LayerNorm(output_feature_size),
67
+ )
68
+ self.fc_output = nn.Sequential(
69
+ nn.Linear(output_feature_size, output_feature_size),
70
+ nn.LayerNorm(output_feature_size),
71
+ )
72
+ # 2. transformer model
73
+ if load_pretrained_weights:
74
+ self.model = AutoModel.from_pretrained(pretrained_model_name_or_path)
75
+ else:
76
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
77
+ self.model = AutoModel.from_config(config)
78
+ if num_transformer_layer != "all" and isinstance(num_transformer_layer, int):
79
+ self.model.encoder.layer = self.model.encoder.layer[:num_transformer_layer]
80
+ # 3. predict head
81
+ self.cls_sub = nn.Linear(output_feature_size, cls_feature_size * num_relation_classes)
82
+ self.cls_obj = nn.Linear(output_feature_size, cls_feature_size * num_relation_classes)
83
+ # 4. loss
84
+ if self.loss_type == "bce":
85
+ self.bce_loss = nn.BCEWithLogitsLoss()
86
+ elif self.loss_type == "multi_label_ce":
87
+ print("Use Multi Label Cross Entropy Loss.")
88
+
89
+ def forward(self, embeds, attention_mask=None):
90
+ """
91
+ embeds: (batch_size, token_num, feature_size)
92
+ attention_mask: (batch_size, token_num)
93
+ """
94
+ # 1. fc input
95
+ embeds = self.fc_input(embeds)
96
+ # 2. transformer model
97
+ position_ids = torch.ones([1, embeds.shape[1]]).to(embeds.device).to(torch.long)
98
+ outputs = self.model.forward(inputs_embeds=embeds, attention_mask=attention_mask, position_ids=position_ids)
99
+ embeds = outputs["last_hidden_state"]
100
+ # 3. fc output
101
+ embeds = self.fc_output(embeds)
102
+ # 4. predict head
103
+ batch_size, token_num, feature_size = embeds.shape
104
+ sub_embeds = self.cls_sub(embeds).reshape([batch_size, token_num, self.num_relation_classes, self.cls_feature_size]).permute([0, 2, 1, 3])
105
+ obj_embeds = self.cls_obj(embeds).reshape([batch_size, token_num, self.num_relation_classes, self.cls_feature_size]).permute([0, 2, 1, 3])
106
+ if self.pred_type == "attention":
107
+ cls_pred = sub_embeds @ torch.transpose(obj_embeds, 2, 3) / self.cls_feature_size**0.5 # noqa
108
+ elif self.pred_type == "einsum":
109
+ cls_pred = torch.einsum("nrsc,nroc->nrso", sub_embeds, obj_embeds)
110
+ return cls_pred
111
+
112
+ def loss(self, pred, target, attention_mask):
113
+ loss_dict = dict()
114
+ batch_size, relation_num, _, _ = pred.shape
115
+
116
+ mask = torch.zeros_like(pred).to(pred.device)
117
+ for idx in range(batch_size):
118
+ n = torch.sum(attention_mask[idx]).to(torch.int)
119
+ mask[idx, :, :n, :n] = 1
120
+ pred = pred * mask - 9999 * (1 - mask)
121
+
122
+ if self.loss_type == "bce":
123
+ loss = self.bce_loss(pred, target)
124
+ elif self.loss_type == "multi_label_ce":
125
+ input_tensor = torch.permute(pred, (1, 0, 2, 3))
126
+ target_tensor = torch.permute(target, (1, 0, 2, 3))
127
+ input_tensor = pred.reshape([relation_num, -1])
128
+ target_tensor = target.reshape([relation_num, -1])
129
+ loss = self.multilabel_categorical_crossentropy(target_tensor, input_tensor)
130
+ weight = loss / loss.max()
131
+ loss = loss * weight
132
+ loss = loss.mean()
133
+ loss_dict["loss"] = loss
134
+
135
+ # running metric
136
+ recall_20 = get_recall_N(pred, target, object_num=20)
137
+ loss_dict["recall@20"] = recall_20
138
+ return loss_dict
139
+
140
+ def multilabel_categorical_crossentropy(self, y_true, y_pred):
141
+ """
142
+ https://kexue.fm/archives/7359
143
+ """
144
+ y_pred = (1 - 2 * y_true) * y_pred
145
+ y_pred_neg = y_pred - y_true * 9999
146
+ y_pred_pos = y_pred - (1 - y_true) * 9999
147
+ zeros = torch.zeros_like(y_pred[..., :1])
148
+ y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
149
+ y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
150
+ neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
151
+ pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
152
+ return neg_loss + pos_loss
153
+
154
+
155
+ def get_recall_N(y_pred, y_true, object_num=20):
156
+ """
157
+ y_pred: [batch_size, 56, object_num, object_num]
158
+ y_true: [batch_size, 56, object_num, object_num]
159
+ """
160
+
161
+ device = y_pred.device
162
+ recall_list = []
163
+
164
+ for idx in range(len(y_true)):
165
+ sample_y_true = []
166
+ sample_y_pred = []
167
+
168
+ # find topk
169
+ _, topk_indices = torch.topk(
170
+ y_true[idx : idx + 1].reshape(
171
+ [
172
+ -1,
173
+ ]
174
+ ),
175
+ k=object_num,
176
+ )
177
+ for index in topk_indices:
178
+ pred_cls = index // (y_true.shape[2] ** 2)
179
+ index_subject_object = index % (y_true.shape[2] ** 2)
180
+ pred_subject = index_subject_object // y_true.shape[2]
181
+ pred_object = index_subject_object % y_true.shape[2]
182
+ if y_true[idx, pred_cls, pred_subject, pred_object] == 0:
183
+ continue
184
+ sample_y_true.append([pred_subject, pred_object, pred_cls])
185
+
186
+ # find topk
187
+ _, topk_indices = torch.topk(
188
+ y_pred[idx : idx + 1].reshape(
189
+ [
190
+ -1,
191
+ ]
192
+ ),
193
+ k=object_num,
194
+ )
195
+ for index in topk_indices:
196
+ pred_cls = index // (y_pred.shape[2] ** 2)
197
+ index_subject_object = index % (y_pred.shape[2] ** 2)
198
+ pred_subject = index_subject_object // y_pred.shape[2]
199
+ pred_object = index_subject_object % y_pred.shape[2]
200
+ sample_y_pred.append([pred_subject, pred_object, pred_cls])
201
+
202
+ recall = len([x for x in sample_y_pred if x in sample_y_true]) / (len(sample_y_true) + 1e-8)
203
+ recall_list.append(recall)
204
+
205
+ recall = torch.tensor(recall_list).to(device).mean() * 100
206
+ return recall
207
+
208
+
209
+ class RamTrainer(object):
210
+ def __init__(self, config):
211
+ self.config = config
212
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
213
+ self._build_dataset()
214
+ self._build_dataloader()
215
+ self._build_model()
216
+ self._build_optimizer()
217
+ self._build_lr_scheduler()
218
+
219
+ def _build_dataset(self):
220
+ self.dataset = RamDataset(**self.config.dataset)
221
+
222
+ def _build_dataloader(self):
223
+ self.dataloader = torch.utils.data.DataLoader(
224
+ self.dataset,
225
+ batch_size=self.config.dataloader.batch_size,
226
+ shuffle=True if self.config.dataset.is_train else False,
227
+ )
228
+
229
+ def _build_model(self):
230
+ self.model = RamModel(**self.config.model).to(self.device)
231
+ if self.config.load_from is not None:
232
+ self.model.load_state_dict(torch.load(self.config.load_from))
233
+ self.model.train()
234
+
235
+ def _build_optimizer(self):
236
+ self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.optim.lr, weight_decay=self.config.optim.weight_decay, eps=self.config.optim.eps, betas=self.config.optim.betas)
237
+
238
+ def _build_lr_scheduler(self):
239
+ self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.config.optim.lr_scheduler.step, gamma=self.config.optim.lr_scheduler.gamma)
240
+
241
+ def train(self):
242
+ t_start = time.time()
243
+ running_avg_loss = 0
244
+ for epoch_idx in range(self.config.num_epoch):
245
+ for batch_idx, batch_data in enumerate(self.dataloader):
246
+ batch_embeds = batch_data[0].to(torch.float32).to(self.device)
247
+ batch_target = batch_data[1].to(torch.float32).to(self.device)
248
+ attention_mask = batch_embeds.new_ones((batch_embeds.shape[0], batch_embeds.shape[1]))
249
+ batch_pred = self.model.forward(batch_embeds, attention_mask)
250
+ loss_dict = self.model.loss(batch_pred, batch_target, attention_mask)
251
+ loss = loss_dict["loss"]
252
+ recall_20 = loss_dict["recall@20"]
253
+ self.optimizer.zero_grad()
254
+ loss.backward()
255
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.optim.max_norm, self.config.optim.norm_type)
256
+ self.optimizer.step()
257
+ running_avg_loss += loss.item()
258
+
259
+ if batch_idx % 100 == 0:
260
+ t_current = time.time()
261
+ num_finished_step = epoch_idx * self.config.num_epoch * len(self.dataloader) + batch_idx + 1
262
+ num_to_do_step = (self.config.num_epoch - epoch_idx - 1) * len(self.dataloader) + (len(self.dataloader) - batch_idx - 1)
263
+ avg_speed = num_finished_step / (t_current - t_start)
264
+ eta = num_to_do_step / avg_speed
265
+ print(
266
+ "ETA={:0>8}, Epoch={}, Batch={}/{}, LR={}, Loss={:.4f}, RunningAvgLoss={:.4f}, Recall@20={:.2f}%".format(
267
+ str(timedelta(seconds=int(eta))), epoch_idx + 1, batch_idx, len(self.dataloader), self.lr_scheduler.get_last_lr()[0], loss.item(), running_avg_loss / num_finished_step, recall_20.item()
268
+ )
269
+ )
270
+ self.lr_scheduler.step()
271
+ if not os.path.exists(self.config.output_dir):
272
+ os.makedirs(self.config.output_dir)
273
+ save_path = os.path.join(self.config.output_dir, "epoch_{}.pth".format(epoch_idx + 1))
274
+ print("Save epoch={} checkpoint to {}".format(epoch_idx + 1, save_path))
275
+ torch.save(self.model.state_dict(), save_path)
276
+ return save_path
277
+
278
+
279
+ class RamPredictor(object):
280
+ def __init__(self, config):
281
+ self.config = config
282
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
283
+ self._build_dataset()
284
+ self._build_dataloader()
285
+ self._build_model()
286
+
287
+ def _build_dataset(self):
288
+ self.dataset = RamDataset(**self.config.dataset)
289
+
290
+ def _build_dataloader(self):
291
+ self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=self.config.dataloader.batch_size, shuffle=False)
292
+
293
+ def _build_model(self):
294
+ self.model = RamModel(**self.config.model).to(self.device)
295
+ if self.config.load_from is not None:
296
+ self.model.load_state_dict(torch.load(self.config.load_from))
297
+ self.model.eval()
298
+
299
+ def predict(self, batch_embeds, pred_keep_num=100):
300
+ """
301
+ Parameters
302
+ ----------
303
+ batch_embeds: (batch_size=1, token_num, feature_size)
304
+ pred_keep_num: int
305
+ Returns
306
+ -------
307
+ batch_pred: (batch_size, relation_num, object_num, object_num)
308
+ pred_rels: [[sub_id, obj_id, rel_id], ...]
309
+ """
310
+ if not isinstance(batch_embeds, torch.Tensor):
311
+ batch_embeds = torch.asarray(batch_embeds)
312
+ batch_embeds = batch_embeds.to(torch.float32).to(self.device)
313
+ attention_mask = batch_embeds.new_ones((batch_embeds.shape[0], batch_embeds.shape[1]))
314
+ batch_pred = self.model.forward(batch_embeds, attention_mask)
315
+ for idx_i in range(batch_pred.shape[2]):
316
+ batch_pred[:, :, idx_i, idx_i] = -9999
317
+ batch_pred = batch_pred.sigmoid()
318
+
319
+ pred_rels = []
320
+ _, topk_indices = torch.topk(
321
+ batch_pred.reshape(
322
+ [
323
+ -1,
324
+ ]
325
+ ),
326
+ k=pred_keep_num,
327
+ )
328
+
329
+ # subject, object, relation
330
+ for index in topk_indices:
331
+ pred_relation = index // (batch_pred.shape[2] ** 2)
332
+ index_subject_object = index % (batch_pred.shape[2] ** 2)
333
+ pred_subject = index_subject_object // batch_pred.shape[2]
334
+ pred_object = index_subject_object % batch_pred.shape[2]
335
+ pred = [pred_subject.item(), pred_object.item(), pred_relation.item()]
336
+ pred_rels.append(pred)
337
+ return batch_pred, pred_rels
338
+
339
+ def eval(self):
340
+ sum_recall_20 = 0.0
341
+ sum_recall_50 = 0.0
342
+ sum_recall_100 = 0.0
343
+ prog_bar = ProgressBar(len(self.dataloader))
344
+ for batch_idx, batch_data in enumerate(self.dataloader):
345
+ batch_embeds = batch_data[0]
346
+ batch_target = batch_data[1]
347
+ gt_rels = batch_data[2]
348
+ batch_pred, pred_rels = self.predict(batch_embeds)
349
+ this_recall_20 = get_recall_N(batch_pred, batch_target, object_num=20)
350
+ this_recall_50 = get_recall_N(batch_pred, batch_target, object_num=50)
351
+ this_recall_100 = get_recall_N(batch_pred, batch_target, object_num=100)
352
+ sum_recall_20 += this_recall_20.item()
353
+ sum_recall_50 += this_recall_50.item()
354
+ sum_recall_100 += this_recall_100.item()
355
+ prog_bar.update()
356
+ recall_20 = sum_recall_20 / len(self.dataloader)
357
+ recall_50 = sum_recall_50 / len(self.dataloader)
358
+ recall_100 = sum_recall_100 / len(self.dataloader)
359
+ metric = {
360
+ "recall_20": recall_20,
361
+ "recall_50": recall_50,
362
+ "recall_100": recall_100,
363
+ }
364
+ return metric
365
+
366
+
367
+ if __name__ == "__main__":
368
+ # Config
369
+ config = dict(
370
+ dataset=dict(
371
+ data_path="./data/feat_0420.npz",
372
+ is_train=True,
373
+ num_relation_classes=56,
374
+ ),
375
+ dataloader=dict(
376
+ batch_size=4,
377
+ ),
378
+ model=dict(
379
+ pretrained_model_name_or_path="bert-base-uncased",
380
+ load_pretrained_weights=True,
381
+ num_transformer_layer=2,
382
+ input_feature_size=256,
383
+ output_feature_size=768,
384
+ cls_feature_size=512,
385
+ num_relation_classes=56,
386
+ pred_type="attention",
387
+ loss_type="multi_label_ce",
388
+ ),
389
+ optim=dict(
390
+ lr=1e-4,
391
+ weight_decay=0.05,
392
+ eps=1e-8,
393
+ betas=(0.9, 0.999),
394
+ max_norm=0.01,
395
+ norm_type=2,
396
+ lr_scheduler=dict(
397
+ step=[6, 10],
398
+ gamma=0.1,
399
+ ),
400
+ ),
401
+ num_epoch=12,
402
+ output_dir="./work_dirs",
403
+ load_from=None,
404
+ )
405
+
406
+ # Train
407
+ config = Config(config)
408
+ trainer = RamTrainer(config)
409
+ last_model_path = trainer.train()
410
+
411
+ # Test/Eval
412
+ config.dataset.is_train = False
413
+ config.load_from = last_model_path
414
+ predictor = RamPredictor(config)
415
+ metric = predictor.eval()
416
+ print(metric)
ram_utils.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class MLP(nn.Module):
9
+ def __init__(self, input_size, hidden_size, num_classes, dropout_prob=0.1):
10
+ super(MLP, self).__init__()
11
+ self.fc1 = nn.Linear(input_size, hidden_size)
12
+ self.relu = nn.ReLU()
13
+ self.dropout = nn.Dropout(dropout_prob)
14
+ self.fc2 = nn.Linear(hidden_size, num_classes)
15
+
16
+ def forward(self, x):
17
+ out = self.fc1(x)
18
+ out = self.relu(out)
19
+ out = self.dropout(out)
20
+ out = self.fc2(out)
21
+ return out
22
+
23
+
24
+ def show_anns(anns, color_code='auto'):
25
+ if len(anns) == 0:
26
+ return
27
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
28
+ ax = plt.gca()
29
+ ax.set_autoscale_on(False)
30
+ polygons = []
31
+ color = []
32
+ for ann in sorted_anns:
33
+ m = ann['segmentation']
34
+ img = np.ones((m.shape[0], m.shape[1], 3))
35
+ color_mask = np.random.random((1, 3)).tolist()[0]
36
+ if color_code == 'auto':
37
+ for i in range(3):
38
+ img[:,:,i] = color_mask[i]
39
+ elif color_code == 'red':
40
+ for i in range(3):
41
+ img[:,:,0] = 1
42
+ img[:,:,1] = 0
43
+ img[:,:,2] = 0
44
+ else:
45
+ for i in range(3):
46
+ img[:,:,0] = 0
47
+ img[:,:,1] = 0
48
+ img[:,:,2] = 1
49
+ return np.dstack((img, m*0.35))
50
+
51
+
52
+ def show_points(coords, labels, ax, marker_size=375):
53
+ pos_points = coords[labels==1]
54
+ neg_points = coords[labels==0]
55
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*',
56
+ s=marker_size, edgecolor='white', linewidth=1.25)
57
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*',
58
+ s=marker_size, edgecolor='white', linewidth=1.25)
59
+
60
+ def show_mask(m):
61
+ img = np.ones((m.shape[0], m.shape[1], 3))
62
+ color_mask = np.random.random((1, 3)).tolist()[0]
63
+ for i in range(3):
64
+ img[:,:,0] = 1
65
+ img[:,:,1] = 0
66
+ img[:,:,2] = 0
67
+
68
+ return np.dstack((img, m*0.35))
69
+
70
+
71
+ def iou(mask1, mask2):
72
+ intersection = np.logical_and(mask1, mask2)
73
+ union = np.logical_or(mask1, mask2)
74
+ iou_score = np.sum(intersection) / np.sum(union)
75
+ return iou_score
76
+
77
+
78
+ def sort_and_deduplicate(sam_masks, iou_threshold=0.8):
79
+ # Sort the sam_masks list based on the area value
80
+ sorted_masks = sorted(sam_masks, key=lambda x: x['area'], reverse=True)
81
+
82
+ # Deduplicate masks based on the given iou_threshold
83
+ filtered_masks = []
84
+ for mask in sorted_masks:
85
+ duplicate = False
86
+ for filtered_mask in filtered_masks:
87
+ if iou(mask['segmentation'], filtered_mask['segmentation']) > iou_threshold:
88
+ duplicate = True
89
+ break
90
+
91
+ if not duplicate:
92
+ filtered_masks.append(mask)
93
+
94
+ return filtered_masks
95
+
96
+
97
+ relation_classes = ['over',
98
+ 'in front of',
99
+ 'beside',
100
+ 'on',
101
+ 'in',
102
+ 'attached to',
103
+ 'hanging from',
104
+ 'on back of',
105
+ 'falling off',
106
+ 'going down',
107
+ 'painted on',
108
+ 'walking on',
109
+ 'running on',
110
+ 'crossing',
111
+ 'standing on',
112
+ 'lying on',
113
+ 'sitting on',
114
+ 'flying over',
115
+ 'jumping over',
116
+ 'jumping from',
117
+ 'wearing',
118
+ 'holding',
119
+ 'carrying',
120
+ 'looking at',
121
+ 'guiding',
122
+ 'kissing',
123
+ 'eating',
124
+ 'drinking',
125
+ 'feeding',
126
+ 'biting',
127
+ 'catching',
128
+ 'picking',
129
+ 'playing with',
130
+ 'chasing',
131
+ 'climbing',
132
+ 'cleaning',
133
+ 'playing',
134
+ 'touching',
135
+ 'pushing',
136
+ 'pulling',
137
+ 'opening',
138
+ 'cooking',
139
+ 'talking to',
140
+ 'throwing',
141
+ 'slicing',
142
+ 'driving',
143
+ 'riding',
144
+ 'parked on',
145
+ 'driving on',
146
+ 'about to hit',
147
+ 'kicking',
148
+ 'swinging',
149
+ 'entering',
150
+ 'exiting',
151
+ 'enclosing',
152
+ 'leaning on',]
requirements.txt CHANGED
@@ -22,11 +22,6 @@ yapf
22
  numba
23
  segment_anything
24
 
25
- # ftfy
26
- # uuid
27
- # psutil
28
- # facexlib
29
  lama-cleaner==0.25.0
30
- # tensorflow
31
- # easydict
32
-
22
  numba
23
  segment_anything
24
 
 
 
 
 
25
  lama-cleaner==0.25.0
26
+ openmim==0.1.5
27
+ mmcv==2.0.0