SkalskiP commited on
Commit
f0c408b
1 Parent(s): f7e104c

Box prompt working

Browse files
Files changed (2) hide show
  1. app.py +55 -12
  2. utils/efficient_sam.py +47 -0
app.py CHANGED
@@ -1,29 +1,40 @@
1
- import time
 
2
  import gradio as gr
3
  import numpy as np
4
  import supervision as sv
5
- from PIL import Image
6
  import torch
 
7
  from transformers import SamModel, SamProcessor
8
- from typing import Tuple
9
 
 
10
 
11
  MARKDOWN = """
12
  # EfficientSAM sv. SAM
 
 
 
 
13
  """
14
 
15
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  SAM_MODEL = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
17
  SAM_PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-huge")
 
18
  MASK_ANNOTATOR = sv.MaskAnnotator(
19
  color=sv.Color.red(),
20
  color_lookup=sv.ColorLookup.INDEX)
 
 
 
21
 
22
 
23
  def annotate_image(image: np.ndarray, detections: sv.Detections) -> np.ndarray:
24
  bgr_image = image[:, :, ::-1]
25
  annotated_bgr_image = MASK_ANNOTATOR.annotate(
26
  scene=bgr_image, detections=detections)
 
 
27
  return annotated_bgr_image[:, :, ::-1]
28
 
29
 
@@ -34,8 +45,11 @@ def efficient_sam_inference(
34
  x_max: int,
35
  y_max: int
36
  ) -> np.ndarray:
37
- time.sleep(0.2)
38
- return image
 
 
 
39
 
40
 
41
  def sam_inference(
@@ -78,6 +92,10 @@ def inference(
78
  )
79
 
80
 
 
 
 
 
81
  with gr.Blocks() as demo:
82
  gr.Markdown(MARKDOWN)
83
  with gr.Tab(label="Box prompt"):
@@ -90,8 +108,8 @@ with gr.Blocks() as demo:
90
  y_min_number = gr.Number(label="y_min")
91
  x_max_number = gr.Number(label="x_max")
92
  y_max_number = gr.Number(label="y_max")
93
- efficient_sam_output_image = gr.Image()
94
- sam_output_image = gr.Image()
95
  with gr.Row():
96
  submit_button = gr.Button("Submit")
97
 
@@ -99,11 +117,32 @@ with gr.Blocks() as demo:
99
  fn=inference,
100
  examples=[
101
  [
102
- 'https://media.roboflow.com/notebooks/examples/dog.jpeg',
103
  69,
104
- 247,
105
- 624,
106
- 930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  ]
108
  ],
109
  inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
@@ -115,11 +154,15 @@ with gr.Blocks() as demo:
115
  inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
116
  outputs=efficient_sam_output_image
117
  )
118
-
119
  submit_button.click(
120
  sam_inference,
121
  inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
122
  outputs=sam_output_image
123
  )
 
 
 
 
 
124
 
125
  demo.launch(debug=False, show_error=True)
 
1
+ from typing import Tuple
2
+
3
  import gradio as gr
4
  import numpy as np
5
  import supervision as sv
 
6
  import torch
7
+ from PIL import Image
8
  from transformers import SamModel, SamProcessor
 
9
 
10
+ from utils.efficient_sam import load, inference_with_box
11
 
12
  MARKDOWN = """
13
  # EfficientSAM sv. SAM
14
+
15
+ This is a demo for comparing the performance of
16
+ [EfficientSAM](https://arxiv.org/abs/2312.00863) and
17
+ [SAM](https://arxiv.org/abs/2304.02643).
18
  """
19
 
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  SAM_MODEL = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
22
  SAM_PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-huge")
23
+ EFFICIENT_SAM_MODEL = load(device=DEVICE)
24
  MASK_ANNOTATOR = sv.MaskAnnotator(
25
  color=sv.Color.red(),
26
  color_lookup=sv.ColorLookup.INDEX)
27
+ BOX_ANNOTATOR = sv.BoundingBoxAnnotator(
28
+ color=sv.Color.red(),
29
+ color_lookup=sv.ColorLookup.INDEX)
30
 
31
 
32
  def annotate_image(image: np.ndarray, detections: sv.Detections) -> np.ndarray:
33
  bgr_image = image[:, :, ::-1]
34
  annotated_bgr_image = MASK_ANNOTATOR.annotate(
35
  scene=bgr_image, detections=detections)
36
+ annotated_bgr_image = BOX_ANNOTATOR.annotate(
37
+ scene=annotated_bgr_image, detections=detections)
38
  return annotated_bgr_image[:, :, ::-1]
39
 
40
 
 
45
  x_max: int,
46
  y_max: int
47
  ) -> np.ndarray:
48
+ box = np.array([[x_min, y_min], [x_max, y_max]])
49
+ mask = inference_with_box(image, box, EFFICIENT_SAM_MODEL, DEVICE)
50
+ mask = mask[np.newaxis, ...]
51
+ detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
52
+ return annotate_image(image=image, detections=detections)
53
 
54
 
55
  def sam_inference(
 
92
  )
93
 
94
 
95
+ def clear(image: np.ndarray) -> Tuple[None, None]:
96
+ return (None, None)
97
+
98
+
99
  with gr.Blocks() as demo:
100
  gr.Markdown(MARKDOWN)
101
  with gr.Tab(label="Box prompt"):
 
108
  y_min_number = gr.Number(label="y_min")
109
  x_max_number = gr.Number(label="x_max")
110
  y_max_number = gr.Number(label="y_max")
111
+ efficient_sam_output_image = gr.Image(label="EfficientSAM")
112
+ sam_output_image = gr.Image(label="SAM")
113
  with gr.Row():
114
  submit_button = gr.Button("Submit")
115
 
 
117
  fn=inference,
118
  examples=[
119
  [
120
+ 'https://media.roboflow.com/efficient-sam/beagle.jpeg',
121
  69,
122
+ 26,
123
+ 625,
124
+ 704
125
+ ],
126
+ [
127
+ 'https://media.roboflow.com/efficient-sam/corgi.jpg',
128
+ 801,
129
+ 510,
130
+ 1782,
131
+ 993
132
+ ],
133
+ [
134
+ 'https://media.roboflow.com/efficient-sam/horses.jpg',
135
+ 814,
136
+ 696,
137
+ 1523,
138
+ 1183
139
+ ],
140
+ [
141
+ 'https://media.roboflow.com/efficient-sam/bears.jpg',
142
+ 653,
143
+ 874,
144
+ 1173,
145
+ 1229
146
  ]
147
  ],
148
  inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
 
154
  inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
155
  outputs=efficient_sam_output_image
156
  )
 
157
  submit_button.click(
158
  sam_inference,
159
  inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
160
  outputs=sam_output_image
161
  )
162
+ input_image.change(
163
+ clear,
164
+ inputs=input_image,
165
+ outputs=[efficient_sam_output_image, sam_output_image]
166
+ )
167
 
168
  demo.launch(debug=False, show_error=True)
utils/efficient_sam.py CHANGED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchvision.transforms import ToTensor
4
+
5
+ GPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_gpu.jit"
6
+ CPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_cpu.jit"
7
+
8
+
9
+ def load(device: torch.device) -> torch.jit.ScriptModule:
10
+ if device.type == "cuda":
11
+ model = torch.jit.load(GPU_EFFICIENT_SAM_CHECKPOINT)
12
+ else:
13
+ model = torch.jit.load(CPU_EFFICIENT_SAM_CHECKPOINT)
14
+ model.eval()
15
+ return model
16
+
17
+
18
+ def inference_with_box(
19
+ image: np.ndarray,
20
+ box: np.ndarray,
21
+ model: torch.jit.ScriptModule,
22
+ device: torch.device
23
+ ) -> np.ndarray:
24
+ bbox = torch.reshape(torch.tensor(box), [1, 1, 2, 2])
25
+ bbox_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2])
26
+ img_tensor = ToTensor()(image)
27
+
28
+ predicted_logits, predicted_iou = model(
29
+ img_tensor[None, ...].to(device),
30
+ bbox.to(device),
31
+ bbox_labels.to(device),
32
+ )
33
+ predicted_logits = predicted_logits.cpu()
34
+ all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
35
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
36
+
37
+ max_predicted_iou = -1
38
+ selected_mask_using_predicted_iou = None
39
+ for m in range(all_masks.shape[0]):
40
+ curr_predicted_iou = predicted_iou[m]
41
+ if (
42
+ curr_predicted_iou > max_predicted_iou
43
+ or selected_mask_using_predicted_iou is None
44
+ ):
45
+ max_predicted_iou = curr_predicted_iou
46
+ selected_mask_using_predicted_iou = all_masks[m]
47
+ return selected_mask_using_predicted_iou