pg56714 commited on
Commit
65f5e56
1 Parent(s): a375a27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -35
app.py CHANGED
@@ -4,9 +4,14 @@ import gradio as gr
4
  import numpy as np
5
  import supervision as sv
6
  import torch
7
- from transformers import SamModel, SamProcessor
 
8
 
9
- from utils.efficient_sam import load, inference_with_box
 
 
 
 
10
 
11
  from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
12
  from efficientvit.sam_model_zoo import create_sam_model
@@ -30,10 +35,10 @@ PROMPT_COLOR = sv.Color.from_hex("#D3D3D3")
30
  MASK_COLOR = sv.Color.from_hex("#FF0000")
31
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
 
33
- SAM_MODEL = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
34
- SAM_PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-huge")
35
 
36
- EFFICIENT_SAM_MODEL = load(device=DEVICE)
37
 
38
  MASK_ANNOTATOR = sv.MaskAnnotator(color=MASK_COLOR, color_lookup=sv.ColorLookup.INDEX)
39
 
@@ -52,9 +57,11 @@ def annotate_image_with_box_prompt_result(
52
  ) -> np.ndarray:
53
  h, w, _ = image.shape
54
  bgr_image = image[:, :, ::-1]
 
55
  annotated_bgr_image = MASK_ANNOTATOR.annotate(
56
- scene=bgr_image, detections=detections
57
  )
 
58
  annotated_bgr_image = sv.draw_rectangle(
59
  scene=annotated_bgr_image,
60
  rect=sv.Rect(
@@ -66,35 +73,78 @@ def annotate_image_with_box_prompt_result(
66
  color=PROMPT_COLOR,
67
  thickness=sv.calculate_optimal_line_thickness(resolution_wh=(w, h)),
68
  )
 
69
  return annotated_bgr_image[:, :, ::-1]
70
 
71
 
72
  def efficientvit_sam_box_inference(
73
  image: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int
74
  ) -> np.ndarray:
 
 
75
  box = np.array([[x_min, y_min, x_max, y_max]])
76
  EFFICIENTVITSAM.set_image(image)
77
  mask = EFFICIENTVITSAM.predict(box=box, multimask_output=False)
78
  mask = mask[0]
79
  detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
80
- return annotate_image_with_box_prompt_result(
81
  image=image,
82
  detections=detections,
83
- x_min=x_min,
84
- y_min=y_min,
85
  x_max=x_max,
 
86
  y_max=y_max,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
 
90
  def efficient_sam_box_inference(
91
  image: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int
92
  ) -> np.ndarray:
 
 
93
  box = np.array([[x_min, y_min], [x_max, y_max]])
94
  mask = inference_with_box(image, box, EFFICIENT_SAM_MODEL, DEVICE)
95
  mask = mask[np.newaxis, ...]
96
  detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
97
- return annotate_image_with_box_prompt_result(
 
98
  image=image,
99
  detections=detections,
100
  x_max=x_max,
@@ -102,11 +152,18 @@ def efficient_sam_box_inference(
102
  y_max=y_max,
103
  y_min=y_min,
104
  )
 
 
 
 
 
105
 
106
 
107
  # def sam_box_inference(
108
  # image: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int
109
  # ) -> np.ndarray:
 
 
110
  # input_boxes = [[[x_min, y_min, x_max, y_max]]]
111
  # inputs = SAM_PROCESSOR(
112
  # Image.fromarray(image), input_boxes=[input_boxes], return_tensors="pt"
@@ -122,7 +179,8 @@ def efficient_sam_box_inference(
122
  # )[0][0][0].numpy()
123
  # mask = mask[np.newaxis, ...]
124
  # detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
125
- # return annotate_image_with_box_prompt_result(
 
126
  # image=image,
127
  # detections=detections,
128
  # x_max=x_max,
@@ -130,6 +188,11 @@ def efficient_sam_box_inference(
130
  # y_max=y_max,
131
  # y_min=y_min,
132
  # )
 
 
 
 
 
133
 
134
 
135
  def box_inference(
@@ -159,31 +222,30 @@ box_inputs = [box_input_image, x_min_number, y_min_number, x_max_number, y_max_n
159
 
160
  with gr.Blocks() as demo:
161
  gr.Markdown(MARKDOWN)
162
- with gr.Tab(label="Box prompt"):
163
- with gr.Row():
164
- with gr.Column():
165
- box_input_image.render()
166
- with gr.Accordion(label="Box", open=False):
167
- with gr.Row():
168
- x_min_number.render()
169
- y_min_number.render()
170
- x_max_number.render()
171
- y_max_number.render()
172
- efficientvit_sam_box_output_image = gr.Image(label="EfficientVit-SAM")
173
- efficient_sam_box_output_image = gr.Image(label="EfficientSAM")
174
- # sam_box_output_image = gr.Image(label="SAM")
175
- with gr.Row():
176
- submit_box_inference_button = gr.Button("Submit")
177
- gr.Examples(
178
- # fn=box_inference,
179
- examples=BOX_EXAMPLES,
180
- inputs=box_inputs,
181
- outputs=[
182
- efficientvit_sam_box_output_image,
183
- efficient_sam_box_output_image,
184
- # sam_box_output_image,
185
- ],
186
  )
 
 
 
 
 
 
 
 
 
 
187
 
188
  submit_box_inference_button.click(
189
  efficientvit_sam_box_inference,
 
4
  import numpy as np
5
  import supervision as sv
6
  import torch
7
+ import time
8
+ from PIL import Image
9
 
10
+ from torchvision.transforms import ToTensor
11
+
12
+ # from transformers import SamModel, SamProcessor
13
+
14
+ from efficient_sam.build_efficient_sam import build_efficient_sam_vits
15
 
16
  from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
17
  from efficientvit.sam_model_zoo import create_sam_model
 
35
  MASK_COLOR = sv.Color.from_hex("#FF0000")
36
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
 
38
+ # SAM_MODEL = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE).eval()
39
+ # SAM_PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-huge")
40
 
41
+ EFFICIENT_SAM_MODEL = build_efficient_sam_vits().to(DEVICE).eval()
42
 
43
  MASK_ANNOTATOR = sv.MaskAnnotator(color=MASK_COLOR, color_lookup=sv.ColorLookup.INDEX)
44
 
 
57
  ) -> np.ndarray:
58
  h, w, _ = image.shape
59
  bgr_image = image[:, :, ::-1]
60
+
61
  annotated_bgr_image = MASK_ANNOTATOR.annotate(
62
+ scene=bgr_image.copy(), detections=detections
63
  )
64
+
65
  annotated_bgr_image = sv.draw_rectangle(
66
  scene=annotated_bgr_image,
67
  rect=sv.Rect(
 
73
  color=PROMPT_COLOR,
74
  thickness=sv.calculate_optimal_line_thickness(resolution_wh=(w, h)),
75
  )
76
+
77
  return annotated_bgr_image[:, :, ::-1]
78
 
79
 
80
  def efficientvit_sam_box_inference(
81
  image: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int
82
  ) -> np.ndarray:
83
+ t1 = time.time()
84
+
85
  box = np.array([[x_min, y_min, x_max, y_max]])
86
  EFFICIENTVITSAM.set_image(image)
87
  mask = EFFICIENTVITSAM.predict(box=box, multimask_output=False)
88
  mask = mask[0]
89
  detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
90
+ result = annotate_image_with_box_prompt_result(
91
  image=image,
92
  detections=detections,
 
 
93
  x_max=x_max,
94
+ x_min=x_min,
95
  y_max=y_max,
96
+ y_min=y_min,
97
+ )
98
+ t2 = time.time()
99
+
100
+ print(f"timecost: {t2-t1}")
101
+
102
+ return result
103
+
104
+
105
+ def inference_with_box(
106
+ image: np.ndarray,
107
+ box: np.ndarray,
108
+ model: torch.jit.ScriptModule,
109
+ device: torch.device,
110
+ ) -> np.ndarray:
111
+ bbox = torch.reshape(torch.tensor(box), [1, 1, 2, 2])
112
+ bbox_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2])
113
+ img_tensor = ToTensor()(image)
114
+
115
+ predicted_logits, predicted_iou = model(
116
+ img_tensor[None, ...].to(device),
117
+ bbox.to(device),
118
+ bbox_labels.to(device),
119
  )
120
+ predicted_logits = predicted_logits.cpu()
121
+ all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
122
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
123
+
124
+ max_predicted_iou = -1
125
+ selected_mask_using_predicted_iou = None
126
+ for m in range(all_masks.shape[0]):
127
+ curr_predicted_iou = predicted_iou[m]
128
+ if (
129
+ curr_predicted_iou > max_predicted_iou
130
+ or selected_mask_using_predicted_iou is None
131
+ ):
132
+ max_predicted_iou = curr_predicted_iou
133
+ selected_mask_using_predicted_iou = all_masks[m]
134
+ return selected_mask_using_predicted_iou
135
 
136
 
137
  def efficient_sam_box_inference(
138
  image: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int
139
  ) -> np.ndarray:
140
+ t1 = time.time()
141
+
142
  box = np.array([[x_min, y_min], [x_max, y_max]])
143
  mask = inference_with_box(image, box, EFFICIENT_SAM_MODEL, DEVICE)
144
  mask = mask[np.newaxis, ...]
145
  detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
146
+
147
+ result = annotate_image_with_box_prompt_result(
148
  image=image,
149
  detections=detections,
150
  x_max=x_max,
 
152
  y_max=y_max,
153
  y_min=y_min,
154
  )
155
+ t2 = time.time()
156
+
157
+ print(f"timecost: {t2-t1}")
158
+
159
+ return result
160
 
161
 
162
  # def sam_box_inference(
163
  # image: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int
164
  # ) -> np.ndarray:
165
+ # t1 = time.time()
166
+
167
  # input_boxes = [[[x_min, y_min, x_max, y_max]]]
168
  # inputs = SAM_PROCESSOR(
169
  # Image.fromarray(image), input_boxes=[input_boxes], return_tensors="pt"
 
179
  # )[0][0][0].numpy()
180
  # mask = mask[np.newaxis, ...]
181
  # detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
182
+
183
+ # result = annotate_image_with_box_prompt_result(
184
  # image=image,
185
  # detections=detections,
186
  # x_max=x_max,
 
188
  # y_max=y_max,
189
  # y_min=y_min,
190
  # )
191
+ # t2 = time.time()
192
+
193
+ # print(f"timecost: {t2-t1}")
194
+
195
+ # return result
196
 
197
 
198
  def box_inference(
 
222
 
223
  with gr.Blocks() as demo:
224
  gr.Markdown(MARKDOWN)
225
+ with gr.Row():
226
+ box_input_image.render()
227
+ efficientvit_sam_box_output_image = gr.Image(label="EfficientVit-SAM")
228
+ efficient_sam_box_output_image = gr.Image(label="EfficientSAM")
229
+ # sam_box_output_image = gr.Image(label="SAM")
230
+
231
+ with gr.Row():
232
+ x_min_number.render()
233
+ y_min_number.render()
234
+ x_max_number.render()
235
+ y_max_number.render()
236
+ submit_box_inference_button = gr.Button(
237
+ value="Submit", scale=1, variant="primary"
 
 
 
 
 
 
 
 
 
 
 
238
  )
239
+ gr.Examples(
240
+ # fn=box_inference,
241
+ examples=BOX_EXAMPLES,
242
+ inputs=box_inputs,
243
+ outputs=[
244
+ efficientvit_sam_box_output_image,
245
+ efficient_sam_box_output_image,
246
+ # sam_box_output_image,
247
+ ],
248
+ )
249
 
250
  submit_box_inference_button.click(
251
  efficientvit_sam_box_inference,