chongzhou commited on
Commit
65665c1
1 Parent(s): b88e069

fix the resizing issue during concurrent visiting

Browse files
app.py CHANGED
@@ -109,6 +109,8 @@ def reset(session_state):
109
  session_state['ori_image'] = None
110
  session_state['image_with_prompt'] = None
111
  session_state['feature'] = None
 
 
112
  return None, None, None, session_state
113
 
114
 
@@ -119,6 +121,8 @@ def reset_all(session_state):
119
  session_state['ori_image'] = None
120
  session_state['image_with_prompt'] = None
121
  session_state['feature'] = None
 
 
122
  return None, None, None, None, None, None, session_state
123
 
124
 
@@ -149,7 +153,7 @@ def on_image_upload(
149
  session_state['image_with_prompt'] = copy.deepcopy(image)
150
  print("Image changed")
151
  nd_image = np.array(image)
152
- session_state['feature'] = predictor.set_image(nd_image)
153
 
154
  return image, None, None, session_state
155
 
@@ -190,12 +194,16 @@ def segment_with_points(
190
  fill=point_color,
191
  )
192
  image = session_state['image_with_prompt']
 
 
193
 
194
  if ENABLE_ONNX:
195
  coord_np = np.array(session_state['coord_list'])[None]
196
  label_np = np.array(session_state['label_list'])[None]
197
  masks, scores, _ = predictor.predict(
198
  features=session_state['feature'],
 
 
199
  point_coords=coord_np,
200
  point_labels=label_np,
201
  )
@@ -206,6 +214,8 @@ def segment_with_points(
206
  label_np = np.array(session_state['label_list'])
207
  masks, scores, logits = predictor.predict(
208
  features=session_state['feature'],
 
 
209
  point_coords=coord_np,
210
  point_labels=label_np,
211
  num_multimask_outputs=4,
@@ -233,7 +243,7 @@ def segment_with_points(
233
  binary_mask = np.where(annotations[0] > 0.5, 255, 0).astype(np.uint8)
234
  mask = Image.fromarray(binary_mask)
235
  binary_mask = np.expand_dims(binary_mask, axis=2)
236
- crop = Image.fromarray(np.concatenate((session_state['ori_image'], binary_mask), axis=2), "RGBA")
237
  return seg, mask, crop, session_state
238
 
239
 
@@ -282,6 +292,8 @@ def segment_with_box(
282
  point_labels = np.array([2, 3])[None]
283
  masks, _, _ = predictor.predict(
284
  features=session_state['feature'],
 
 
285
  point_coords=point_coords,
286
  point_labels=point_labels,
287
  )
@@ -289,6 +301,8 @@ def segment_with_box(
289
  else:
290
  masks, scores, _ = predictor.predict(
291
  features=session_state['feature'],
 
 
292
  box=box_np,
293
  num_multimask_outputs=1,
294
  )
 
109
  session_state['ori_image'] = None
110
  session_state['image_with_prompt'] = None
111
  session_state['feature'] = None
112
+ session_state['input_size'] = None
113
+ session_state['original_size'] = None
114
  return None, None, None, session_state
115
 
116
 
 
121
  session_state['ori_image'] = None
122
  session_state['image_with_prompt'] = None
123
  session_state['feature'] = None
124
+ session_state['input_size'] = None
125
+ session_state['original_size'] = None
126
  return None, None, None, None, None, None, session_state
127
 
128
 
 
153
  session_state['image_with_prompt'] = copy.deepcopy(image)
154
  print("Image changed")
155
  nd_image = np.array(image)
156
+ session_state['feature'], session_state['input_size'], session_state['original_size'] = predictor.set_image(nd_image)
157
 
158
  return image, None, None, session_state
159
 
 
194
  fill=point_color,
195
  )
196
  image = session_state['image_with_prompt']
197
+ print(f"image: {image.size}")
198
+ nd_image = np.array(session_state['ori_image'])
199
 
200
  if ENABLE_ONNX:
201
  coord_np = np.array(session_state['coord_list'])[None]
202
  label_np = np.array(session_state['label_list'])[None]
203
  masks, scores, _ = predictor.predict(
204
  features=session_state['feature'],
205
+ input_size=session_state['input_size'],
206
+ original_size=session_state['original_size'],
207
  point_coords=coord_np,
208
  point_labels=label_np,
209
  )
 
214
  label_np = np.array(session_state['label_list'])
215
  masks, scores, logits = predictor.predict(
216
  features=session_state['feature'],
217
+ input_size=session_state['input_size'],
218
+ original_size=session_state['original_size'],
219
  point_coords=coord_np,
220
  point_labels=label_np,
221
  num_multimask_outputs=4,
 
243
  binary_mask = np.where(annotations[0] > 0.5, 255, 0).astype(np.uint8)
244
  mask = Image.fromarray(binary_mask)
245
  binary_mask = np.expand_dims(binary_mask, axis=2)
246
+ crop = Image.fromarray(np.concatenate((nd_image, binary_mask), axis=2), "RGBA")
247
  return seg, mask, crop, session_state
248
 
249
 
 
292
  point_labels = np.array([2, 3])[None]
293
  masks, _, _ = predictor.predict(
294
  features=session_state['feature'],
295
+ input_size=session_state['input_size'],
296
+ original_size=session_state['original_size'],
297
  point_coords=point_coords,
298
  point_labels=point_labels,
299
  )
 
301
  else:
302
  masks, scores, _ = predictor.predict(
303
  features=session_state['feature'],
304
+ input_size=session_state['input_size'],
305
+ original_size=session_state['original_size'],
306
  box=box_np,
307
  num_multimask_outputs=1,
308
  )
segment_anything/onnx/predictor_onnx.py CHANGED
@@ -53,34 +53,30 @@ class SamPredictorONNX:
53
  input_image = self.transform.apply_image(image)
54
  input_image = input_image.transpose(2, 0, 1)[None, :, :, :]
55
  self.reset_image()
56
- self.original_size = image.shape[:2]
57
- self.input_size = tuple(input_image.shape[-2:])
58
  input_image = self.preprocess(input_image).astype(np.float32)
59
  outputs = self.encoder.run(None, {'image': input_image})
60
- self.features = outputs[0]
61
- self.is_image_set = True
62
 
63
- return self.features
64
 
65
  def predict(
66
  self,
67
- features: np.ndarray = None,
 
 
68
  point_coords: Optional[np.ndarray] = None,
69
  point_labels: Optional[np.ndarray] = None,
70
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
71
- if features is None and not self.is_image_set:
72
- raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
73
- if features is None:
74
- features = self.features
75
-
76
- point_coords = self.transform.apply_coords(point_coords, self.original_size)
77
  outputs = self.decoder.run(None, {
78
  'image_embeddings': features,
79
  'point_coords': point_coords.astype(np.float32),
80
  'point_labels': point_labels.astype(np.float32)
81
  })
82
  scores, low_res_masks = outputs[0], outputs[1]
83
- masks = self.postprocess_masks(low_res_masks)
84
  masks = masks > self.mask_threshold
85
 
86
  return masks, scores, low_res_masks
@@ -102,10 +98,10 @@ class SamPredictorONNX:
102
  x = np.pad(x, ((0, 0), (0, 0), (0, padh), (0, padw)), mode='constant', constant_values=0)
103
  return x
104
 
105
- def postprocess_masks(self, mask: np.ndarray):
106
  mask = mask.squeeze(0).transpose(1, 2, 0)
107
  mask = cv2.resize(mask, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)
108
- mask = mask[:self.input_size[0], :self.input_size[1], :]
109
- mask = cv2.resize(mask, (self.original_size[1], self.original_size[0]), interpolation=cv2.INTER_LINEAR)
110
  mask = mask.transpose(2, 0, 1)[None, :, :, :]
111
  return mask
 
53
  input_image = self.transform.apply_image(image)
54
  input_image = input_image.transpose(2, 0, 1)[None, :, :, :]
55
  self.reset_image()
56
+ original_size = image.shape[:2]
57
+ input_size = tuple(input_image.shape[-2:])
58
  input_image = self.preprocess(input_image).astype(np.float32)
59
  outputs = self.encoder.run(None, {'image': input_image})
60
+ features = outputs[0]
 
61
 
62
+ return features, input_size, original_size
63
 
64
  def predict(
65
  self,
66
+ features: np.ndarray,
67
+ input_size: Tuple[int, int],
68
+ original_size: Tuple[int, int],
69
  point_coords: Optional[np.ndarray] = None,
70
  point_labels: Optional[np.ndarray] = None,
71
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
72
+ point_coords = self.transform.apply_coords(point_coords, original_size)
 
 
 
 
 
73
  outputs = self.decoder.run(None, {
74
  'image_embeddings': features,
75
  'point_coords': point_coords.astype(np.float32),
76
  'point_labels': point_labels.astype(np.float32)
77
  })
78
  scores, low_res_masks = outputs[0], outputs[1]
79
+ masks = self.postprocess_masks(low_res_masks, input_size, original_size)
80
  masks = masks > self.mask_threshold
81
 
82
  return masks, scores, low_res_masks
 
98
  x = np.pad(x, ((0, 0), (0, 0), (0, padh), (0, padw)), mode='constant', constant_values=0)
99
  return x
100
 
101
+ def postprocess_masks(self, mask: np.ndarray, input_size: Tuple[int, int], original_size: Tuple[int, int]) -> np.ndarray:
102
  mask = mask.squeeze(0).transpose(1, 2, 0)
103
  mask = cv2.resize(mask, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)
104
+ mask = mask[:input_size[0], :input_size[1], :]
105
+ mask = cv2.resize(mask, (original_size[1], original_size[0]), interpolation=cv2.INTER_LINEAR)
106
  mask = mask.transpose(2, 0, 1)[None, :, :, :]
107
  return mask
segment_anything/predictor.py CHANGED
@@ -59,13 +59,15 @@ class SamPredictor:
59
  input_image_torch = torch.as_tensor(input_image, device=self.device)
60
  input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
61
 
62
- return self.set_torch_image(input_image_torch, image.shape[:2])
 
 
 
63
 
64
  @torch.no_grad()
65
  def set_torch_image(
66
  self,
67
  transformed_image: torch.Tensor,
68
- original_image_size: Tuple[int, ...],
69
  ) -> torch.Tensor:
70
  """
71
  Calculates the image embeddings for the provided image, allowing
@@ -75,8 +77,6 @@ class SamPredictor:
75
  Arguments:
76
  transformed_image (torch.Tensor): The input image, with shape
77
  1x3xHxW, which has been transformed with ResizeLongestSide.
78
- original_image_size (tuple(int, int)): The size of the image
79
- before transformation, in (H, W) format.
80
  """
81
  assert (
82
  len(transformed_image.shape) == 4
@@ -85,24 +85,23 @@ class SamPredictor:
85
  ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
86
  self.reset_image()
87
 
88
- self.original_size = original_image_size
89
- self.input_size = tuple(transformed_image.shape[-2:])
90
  input_image = self.model.preprocess(transformed_image)
91
- self.features = self.model.image_encoder(input_image)
92
- self.is_image_set = True
93
 
94
- return self.features
95
 
96
  def predict(
97
  self,
98
- features: torch.Tensor = None,
 
 
99
  point_coords: Optional[np.ndarray] = None,
100
  point_labels: Optional[np.ndarray] = None,
101
  box: Optional[np.ndarray] = None,
102
  mask_input: Optional[np.ndarray] = None,
103
  num_multimask_outputs: int = 3,
104
  return_logits: bool = False,
105
- use_stability_score: bool = False
106
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
107
  """
108
  Predict masks for the given input prompts, using the currently set image.
@@ -134,24 +133,18 @@ class SamPredictor:
134
  of masks and H=W=256. These low resolution logits can be passed to
135
  a subsequent iteration as mask input.
136
  """
137
- if features is None and not self.is_image_set:
138
- raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
139
-
140
- if features is None:
141
- features = self.features
142
-
143
  # Transform input prompts
144
  coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
145
  if point_coords is not None:
146
  assert (
147
  point_labels is not None
148
  ), "point_labels must be supplied if point_coords is supplied."
149
- point_coords = self.transform.apply_coords(point_coords, self.original_size)
150
  coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
151
  labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
152
  coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
153
  if box is not None:
154
- box = self.transform.apply_boxes(box, self.original_size)
155
  box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
156
  box_torch = box_torch[None, :]
157
  if mask_input is not None:
@@ -160,6 +153,8 @@ class SamPredictor:
160
 
161
  masks, iou_predictions, low_res_masks = self.predict_torch(
162
  features,
 
 
163
  coords_torch,
164
  labels_torch,
165
  box_torch,
@@ -178,6 +173,8 @@ class SamPredictor:
178
  def predict_torch(
179
  self,
180
  features: torch.Tensor,
 
 
181
  point_coords: Optional[torch.Tensor],
182
  point_labels: Optional[torch.Tensor],
183
  boxes: Optional[torch.Tensor] = None,
@@ -249,7 +246,7 @@ class SamPredictor:
249
  )
250
 
251
  # Upscale the masks to the original image resolution
252
- masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
253
 
254
  if not return_logits:
255
  masks = masks > self.model.mask_threshold
 
59
  input_image_torch = torch.as_tensor(input_image, device=self.device)
60
  input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
61
 
62
+ input_size = tuple(input_image_torch.shape[-2:])
63
+ original_size = image.shape[:2]
64
+
65
+ return self.set_torch_image(input_image_torch), input_size, original_size
66
 
67
  @torch.no_grad()
68
  def set_torch_image(
69
  self,
70
  transformed_image: torch.Tensor,
 
71
  ) -> torch.Tensor:
72
  """
73
  Calculates the image embeddings for the provided image, allowing
 
77
  Arguments:
78
  transformed_image (torch.Tensor): The input image, with shape
79
  1x3xHxW, which has been transformed with ResizeLongestSide.
 
 
80
  """
81
  assert (
82
  len(transformed_image.shape) == 4
 
85
  ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
86
  self.reset_image()
87
 
 
 
88
  input_image = self.model.preprocess(transformed_image)
89
+ features = self.model.image_encoder(input_image)
 
90
 
91
+ return features
92
 
93
  def predict(
94
  self,
95
+ features: torch.Tensor,
96
+ input_size: Tuple[int, int],
97
+ original_size: Tuple[int, int],
98
  point_coords: Optional[np.ndarray] = None,
99
  point_labels: Optional[np.ndarray] = None,
100
  box: Optional[np.ndarray] = None,
101
  mask_input: Optional[np.ndarray] = None,
102
  num_multimask_outputs: int = 3,
103
  return_logits: bool = False,
104
+ use_stability_score: bool = False,
105
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
106
  """
107
  Predict masks for the given input prompts, using the currently set image.
 
133
  of masks and H=W=256. These low resolution logits can be passed to
134
  a subsequent iteration as mask input.
135
  """
 
 
 
 
 
 
136
  # Transform input prompts
137
  coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
138
  if point_coords is not None:
139
  assert (
140
  point_labels is not None
141
  ), "point_labels must be supplied if point_coords is supplied."
142
+ point_coords = self.transform.apply_coords(point_coords, original_size)
143
  coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
144
  labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
145
  coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
146
  if box is not None:
147
+ box = self.transform.apply_boxes(box, original_size)
148
  box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
149
  box_torch = box_torch[None, :]
150
  if mask_input is not None:
 
153
 
154
  masks, iou_predictions, low_res_masks = self.predict_torch(
155
  features,
156
+ input_size,
157
+ original_size,
158
  coords_torch,
159
  labels_torch,
160
  box_torch,
 
173
  def predict_torch(
174
  self,
175
  features: torch.Tensor,
176
+ input_size: Tuple[int, int],
177
+ original_size: Tuple[int, int],
178
  point_coords: Optional[torch.Tensor],
179
  point_labels: Optional[torch.Tensor],
180
  boxes: Optional[torch.Tensor] = None,
 
246
  )
247
 
248
  # Upscale the masks to the original image resolution
249
+ masks = self.model.postprocess_masks(low_res_masks, input_size, original_size)
250
 
251
  if not return_logits:
252
  masks = masks > self.model.mask_threshold