chongzhou commited on
Commit
435ddbc
1 Parent(s): 4c7dc89

save image embeddings in gradio session to avoid repeatedly encoding

Browse files
app.py CHANGED
@@ -107,6 +107,7 @@ def reset(session_state):
107
  session_state['box_list'] = []
108
  session_state['ori_image'] = None
109
  session_state['image_with_prompt'] = None
 
110
  return None, session_state
111
 
112
 
@@ -116,6 +117,7 @@ def reset_all(session_state):
116
  session_state['box_list'] = []
117
  session_state['ori_image'] = None
118
  session_state['image_with_prompt'] = None
 
119
  return None, None, session_state
120
 
121
 
@@ -145,8 +147,8 @@ def on_image_upload(
145
  session_state['ori_image'] = copy.deepcopy(image)
146
  session_state['image_with_prompt'] = copy.deepcopy(image)
147
  print("Image changed")
148
- # nd_image = np.array(global_image)
149
- # predictor.set_image(nd_image)
150
 
151
  return image, session_state
152
 
@@ -188,13 +190,11 @@ def segment_with_points(
188
  )
189
  image = session_state['image_with_prompt']
190
 
191
- nd_image = np.array(session_state['ori_image'])
192
- predictor.set_image(nd_image)
193
-
194
  if ENABLE_ONNX:
195
  coord_np = np.array(session_state['coord_list'])[None]
196
  label_np = np.array(session_state['label_list'])[None]
197
  masks, scores, _ = predictor.predict(
 
198
  point_coords=coord_np,
199
  point_labels=label_np,
200
  )
@@ -204,6 +204,7 @@ def segment_with_points(
204
  coord_np = np.array(session_state['coord_list'])
205
  label_np = np.array(session_state['label_list'])
206
  masks, scores, logits = predictor.predict(
 
207
  point_coords=coord_np,
208
  point_labels=label_np,
209
  num_multimask_outputs=4,
@@ -271,18 +272,18 @@ def segment_with_box(
271
  )
272
 
273
  box_np = np.array(box)
274
- nd_image = np.array(session_state['ori_image'])
275
- predictor.set_image(nd_image)
276
  if ENABLE_ONNX:
277
  point_coords = box_np.reshape(2, 2)[None]
278
  point_labels = np.array([2, 3])[None]
279
  masks, _, _ = predictor.predict(
 
280
  point_coords=point_coords,
281
  point_labels=point_labels,
282
  )
283
  annotations = masks[:, 0, :, :]
284
  else:
285
  masks, scores, _ = predictor.predict(
 
286
  box=box_np,
287
  num_multimask_outputs=1,
288
  )
@@ -312,7 +313,8 @@ with gr.Blocks(css=css, title="EdgeSAM") as demo:
312
  'label_list': [],
313
  'box_list': [],
314
  'ori_image': None,
315
- 'image_with_prompt': None
 
316
  })
317
 
318
  with gr.Row():
 
107
  session_state['box_list'] = []
108
  session_state['ori_image'] = None
109
  session_state['image_with_prompt'] = None
110
+ session_state['feature'] = None
111
  return None, session_state
112
 
113
 
 
117
  session_state['box_list'] = []
118
  session_state['ori_image'] = None
119
  session_state['image_with_prompt'] = None
120
+ session_state['feature'] = None
121
  return None, None, session_state
122
 
123
 
 
147
  session_state['ori_image'] = copy.deepcopy(image)
148
  session_state['image_with_prompt'] = copy.deepcopy(image)
149
  print("Image changed")
150
+ nd_image = np.array(image)
151
+ session_state['feature'] = predictor.set_image(nd_image)
152
 
153
  return image, session_state
154
 
 
190
  )
191
  image = session_state['image_with_prompt']
192
 
 
 
 
193
  if ENABLE_ONNX:
194
  coord_np = np.array(session_state['coord_list'])[None]
195
  label_np = np.array(session_state['label_list'])[None]
196
  masks, scores, _ = predictor.predict(
197
+ features=session_state['feature'],
198
  point_coords=coord_np,
199
  point_labels=label_np,
200
  )
 
204
  coord_np = np.array(session_state['coord_list'])
205
  label_np = np.array(session_state['label_list'])
206
  masks, scores, logits = predictor.predict(
207
+ features=session_state['feature'],
208
  point_coords=coord_np,
209
  point_labels=label_np,
210
  num_multimask_outputs=4,
 
272
  )
273
 
274
  box_np = np.array(box)
 
 
275
  if ENABLE_ONNX:
276
  point_coords = box_np.reshape(2, 2)[None]
277
  point_labels = np.array([2, 3])[None]
278
  masks, _, _ = predictor.predict(
279
+ features=session_state['feature'],
280
  point_coords=point_coords,
281
  point_labels=point_labels,
282
  )
283
  annotations = masks[:, 0, :, :]
284
  else:
285
  masks, scores, _ = predictor.predict(
286
+ features=session_state['feature'],
287
  box=box_np,
288
  num_multimask_outputs=1,
289
  )
 
313
  'label_list': [],
314
  'box_list': [],
315
  'ori_image': None,
316
+ 'image_with_prompt': None,
317
+ 'feature': None
318
  })
319
 
320
  with gr.Row():
segment_anything/onnx/predictor_onnx.py CHANGED
@@ -60,17 +60,22 @@ class SamPredictorONNX:
60
  self.features = outputs[0]
61
  self.is_image_set = True
62
 
 
 
63
  def predict(
64
  self,
 
65
  point_coords: Optional[np.ndarray] = None,
66
  point_labels: Optional[np.ndarray] = None,
67
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
68
- if not self.is_image_set:
69
  raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
 
 
70
 
71
  point_coords = self.transform.apply_coords(point_coords, self.original_size)
72
  outputs = self.decoder.run(None, {
73
- 'image_embeddings': self.features,
74
  'point_coords': point_coords.astype(np.float32),
75
  'point_labels': point_labels.astype(np.float32)
76
  })
 
60
  self.features = outputs[0]
61
  self.is_image_set = True
62
 
63
+ return self.features
64
+
65
  def predict(
66
  self,
67
+ features: np.ndarray = None,
68
  point_coords: Optional[np.ndarray] = None,
69
  point_labels: Optional[np.ndarray] = None,
70
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
71
+ if features is None and not self.is_image_set:
72
  raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
73
+ if features is None:
74
+ features = self.features
75
 
76
  point_coords = self.transform.apply_coords(point_coords, self.original_size)
77
  outputs = self.decoder.run(None, {
78
+ 'image_embeddings': features,
79
  'point_coords': point_coords.astype(np.float32),
80
  'point_labels': point_labels.astype(np.float32)
81
  })
segment_anything/predictor.py CHANGED
@@ -37,7 +37,7 @@ class SamPredictor:
37
  self,
38
  image: np.ndarray,
39
  image_format: str = "RGB",
40
- ) -> None:
41
  """
42
  Calculates the image embeddings for the provided image, allowing
43
  masks to be predicted with the 'predict' method.
@@ -59,14 +59,14 @@ class SamPredictor:
59
  input_image_torch = torch.as_tensor(input_image, device=self.device)
60
  input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
61
 
62
- self.set_torch_image(input_image_torch, image.shape[:2])
63
 
64
  @torch.no_grad()
65
  def set_torch_image(
66
  self,
67
  transformed_image: torch.Tensor,
68
  original_image_size: Tuple[int, ...],
69
- ) -> None:
70
  """
71
  Calculates the image embeddings for the provided image, allowing
72
  masks to be predicted with the 'predict' method. Expects the input
@@ -91,8 +91,11 @@ class SamPredictor:
91
  self.features = self.model.image_encoder(input_image)
92
  self.is_image_set = True
93
 
 
 
94
  def predict(
95
  self,
 
96
  point_coords: Optional[np.ndarray] = None,
97
  point_labels: Optional[np.ndarray] = None,
98
  box: Optional[np.ndarray] = None,
@@ -131,9 +134,12 @@ class SamPredictor:
131
  of masks and H=W=256. These low resolution logits can be passed to
132
  a subsequent iteration as mask input.
133
  """
134
- if not self.is_image_set:
135
  raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
136
 
 
 
 
137
  # Transform input prompts
138
  coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
139
  if point_coords is not None:
@@ -153,6 +159,7 @@ class SamPredictor:
153
  mask_input_torch = mask_input_torch[None, :, :, :]
154
 
155
  masks, iou_predictions, low_res_masks = self.predict_torch(
 
156
  coords_torch,
157
  labels_torch,
158
  box_torch,
@@ -170,6 +177,7 @@ class SamPredictor:
170
  @torch.no_grad()
171
  def predict_torch(
172
  self,
 
173
  point_coords: Optional[torch.Tensor],
174
  point_labels: Optional[torch.Tensor],
175
  boxes: Optional[torch.Tensor] = None,
@@ -211,7 +219,7 @@ class SamPredictor:
211
  of masks and H=W=256. These low res logits can be passed to
212
  a subsequent iteration as mask input.
213
  """
214
- if not self.is_image_set:
215
  raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
216
 
217
  if point_coords is not None:
@@ -228,7 +236,7 @@ class SamPredictor:
228
 
229
  # Predict masks
230
  low_res_masks, iou_predictions = self.model.mask_decoder(
231
- image_embeddings=self.features,
232
  image_pe=self.model.prompt_encoder.get_dense_pe(),
233
  sparse_prompt_embeddings=sparse_embeddings,
234
  dense_prompt_embeddings=dense_embeddings,
 
37
  self,
38
  image: np.ndarray,
39
  image_format: str = "RGB",
40
+ ) -> torch.Tensor:
41
  """
42
  Calculates the image embeddings for the provided image, allowing
43
  masks to be predicted with the 'predict' method.
 
59
  input_image_torch = torch.as_tensor(input_image, device=self.device)
60
  input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
61
 
62
+ return self.set_torch_image(input_image_torch, image.shape[:2])
63
 
64
  @torch.no_grad()
65
  def set_torch_image(
66
  self,
67
  transformed_image: torch.Tensor,
68
  original_image_size: Tuple[int, ...],
69
+ ) -> torch.Tensor:
70
  """
71
  Calculates the image embeddings for the provided image, allowing
72
  masks to be predicted with the 'predict' method. Expects the input
 
91
  self.features = self.model.image_encoder(input_image)
92
  self.is_image_set = True
93
 
94
+ return self.features
95
+
96
  def predict(
97
  self,
98
+ features: torch.Tensor = None,
99
  point_coords: Optional[np.ndarray] = None,
100
  point_labels: Optional[np.ndarray] = None,
101
  box: Optional[np.ndarray] = None,
 
134
  of masks and H=W=256. These low resolution logits can be passed to
135
  a subsequent iteration as mask input.
136
  """
137
+ if features is None and not self.is_image_set:
138
  raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
139
 
140
+ if features is None:
141
+ features = self.features
142
+
143
  # Transform input prompts
144
  coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
145
  if point_coords is not None:
 
159
  mask_input_torch = mask_input_torch[None, :, :, :]
160
 
161
  masks, iou_predictions, low_res_masks = self.predict_torch(
162
+ features,
163
  coords_torch,
164
  labels_torch,
165
  box_torch,
 
177
  @torch.no_grad()
178
  def predict_torch(
179
  self,
180
+ features: torch.Tensor,
181
  point_coords: Optional[torch.Tensor],
182
  point_labels: Optional[torch.Tensor],
183
  boxes: Optional[torch.Tensor] = None,
 
219
  of masks and H=W=256. These low res logits can be passed to
220
  a subsequent iteration as mask input.
221
  """
222
+ if features is None and not self.is_image_set:
223
  raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
224
 
225
  if point_coords is not None:
 
236
 
237
  # Predict masks
238
  low_res_masks, iou_predictions = self.model.mask_decoder(
239
+ image_embeddings=features,
240
  image_pe=self.model.prompt_encoder.get_dense_pe(),
241
  sparse_prompt_embeddings=sparse_embeddings,
242
  dense_prompt_embeddings=dense_embeddings,