SkalskiP commited on
Commit
39840e5
1 Parent(s): f590b07

EfficientSAM support added

Browse files
Files changed (4) hide show
  1. .gitattributes +2 -0
  2. app.py +38 -8
  3. utils/__init__.py +0 -0
  4. utils/efficient_sam.py +47 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ efficient_sam_s_cpu.jit filter=lfs diff=lfs merge=lfs -text
37
+ efficient_sam_s_gpu.jit filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,33 +1,63 @@
1
  from typing import List
2
 
 
3
  import gradio as gr
4
  import numpy as np
5
  import supervision as sv
6
  from inference.models import YOLOWorld
7
 
 
 
8
  MARKDOWN = """
9
- # YOLO-World 🌎
 
 
10
 
11
  Powered by Roboflow [Inference](https://github.com/roboflow/inference) and [Supervision](https://github.com/roboflow/supervision).
12
  """
13
 
14
- MODEL = YOLOWorld(model_id="yolo_world/l")
 
 
 
15
  BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
16
- LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK)
 
17
 
18
 
19
  def process_categories(categories: str) -> List[str]:
20
  return [category.strip() for category in categories.split(',')]
21
 
22
 
23
- def process_image(input_image: np.ndarray, categories: str) -> np.ndarray:
 
 
 
 
 
 
 
24
  categories = process_categories(categories)
25
- MODEL.set_classes(categories)
26
- results = MODEL.infer(input_image, confidence=0.003)
27
- detections = sv.Detections.from_inference(results).with_nms(0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  output_image = input_image.copy()
 
29
  output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
30
- output_image = LABEL_ANNOTATOR.annotate(output_image, detections)
31
  return output_image
32
 
33
 
 
1
  from typing import List
2
 
3
+ import torch
4
  import gradio as gr
5
  import numpy as np
6
  import supervision as sv
7
  from inference.models import YOLOWorld
8
 
9
+ from utils.efficient_sam import load, inference_with_box
10
+
11
  MARKDOWN = """
12
+ # YOLO-World 🔥 [with Efficient-SAM]
13
+
14
+ This is a demo of zero-shot instance segmentation using [YOLO-World](https://github.com/AILab-CVC/YOLO-World) and [Efficient-SAM](https://github.com/yformer/EfficientSAM).
15
 
16
  Powered by Roboflow [Inference](https://github.com/roboflow/inference) and [Supervision](https://github.com/roboflow/supervision).
17
  """
18
 
19
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ EFFICIENT_SAM_MODEL = load(device=DEVICE)
21
+ YOLO_WORLD_MODEL = YOLOWorld(model_id="yolo_world/l")
22
+
23
  BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
24
+ MASK_ANNOTATOR = sv.MaskAnnotator()
25
+ LABEL_ANNOTATOR = sv.LabelAnnotator()
26
 
27
 
28
  def process_categories(categories: str) -> List[str]:
29
  return [category.strip() for category in categories.split(',')]
30
 
31
 
32
+ def process_image(
33
+ input_image: np.ndarray,
34
+ categories: str,
35
+ confidence_threshold: float = 0.003,
36
+ iou_threshold: float = 0.5,
37
+ with_segmentation: bool = True,
38
+ with_confidence: bool = True
39
+ ) -> np.ndarray:
40
  categories = process_categories(categories)
41
+ YOLO_WORLD_MODEL.set_classes(categories)
42
+ results = YOLO_WORLD_MODEL.infer(input_image, confidence=confidence_threshold)
43
+ detections = sv.Detections.from_inference(results).with_nms(iou_threshold)
44
+ if with_segmentation:
45
+ masks = []
46
+ for [x_min, y_min, x_max, y_max] in detections.xyxy:
47
+ box = np.array([[x_min, y_min], [x_max, y_max]])
48
+ mask = inference_with_box(input_image, box, EFFICIENT_SAM_MODEL, DEVICE)
49
+ masks.append(mask)
50
+ detections.mask = np.array(masks)
51
+
52
+ labels = [
53
+ f"{categories[class_id]}: {confidence:.2f}" if with_confidence else f"{categories[class_id]}"
54
+ for class_id, confidence in
55
+ zip(detections.class_id, detections.confidence)
56
+ ]
57
  output_image = input_image.copy()
58
+ output_image = MASK_ANNOTATOR.annotate(output_image, detections)
59
  output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
60
+ output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
61
  return output_image
62
 
63
 
utils/__init__.py ADDED
File without changes
utils/efficient_sam.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchvision.transforms import ToTensor
4
+
5
+ GPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_gpu.jit"
6
+ CPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_cpu.jit"
7
+
8
+
9
+ def load(device: torch.device) -> torch.jit.ScriptModule:
10
+ if device.type == "cuda":
11
+ model = torch.jit.load(GPU_EFFICIENT_SAM_CHECKPOINT)
12
+ else:
13
+ model = torch.jit.load(CPU_EFFICIENT_SAM_CHECKPOINT)
14
+ model.eval()
15
+ return model
16
+
17
+
18
+ def inference_with_box(
19
+ image: np.ndarray,
20
+ box: np.ndarray,
21
+ model: torch.jit.ScriptModule,
22
+ device: torch.device
23
+ ) -> np.ndarray:
24
+ bbox = torch.reshape(torch.tensor(box), [1, 1, 2, 2])
25
+ bbox_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2])
26
+ img_tensor = ToTensor()(image)
27
+
28
+ predicted_logits, predicted_iou = model(
29
+ img_tensor[None, ...].to(device),
30
+ bbox.to(device),
31
+ bbox_labels.to(device),
32
+ )
33
+ predicted_logits = predicted_logits.cpu()
34
+ all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
35
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
36
+
37
+ max_predicted_iou = -1
38
+ selected_mask_using_predicted_iou = None
39
+ for m in range(all_masks.shape[0]):
40
+ curr_predicted_iou = predicted_iou[m]
41
+ if (
42
+ curr_predicted_iou > max_predicted_iou
43
+ or selected_mask_using_predicted_iou is None
44
+ ):
45
+ max_predicted_iou = curr_predicted_iou
46
+ selected_mask_using_predicted_iou = all_masks[m]
47
+ return selected_mask_using_predicted_iou