jhj0517 commited on
Commit
62faa17
·
1 Parent(s): a503e15

Update sam2

Browse files
segment-anything-2/sam2/sam2_video_predictor.py CHANGED
@@ -4,6 +4,7 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
 
7
  from collections import OrderedDict
8
 
9
  import torch
@@ -44,11 +45,13 @@ class SAM2VideoPredictor(SAM2Base):
44
  async_loading_frames=False,
45
  ):
46
  """Initialize a inference state."""
 
47
  images, video_height, video_width = load_video_frames(
48
  video_path=video_path,
49
  image_size=self.image_size,
50
  offload_video_to_cpu=offload_video_to_cpu,
51
  async_loading_frames=async_loading_frames,
 
52
  )
53
  inference_state = {}
54
  inference_state["images"] = images
@@ -64,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base):
64
  # the original video height and width, used for resizing final output scores
65
  inference_state["video_height"] = video_height
66
  inference_state["video_width"] = video_width
67
- inference_state["device"] = torch.device("cuda")
68
  if offload_state_to_cpu:
69
  inference_state["storage_device"] = torch.device("cpu")
70
  else:
71
- inference_state["storage_device"] = torch.device("cuda")
72
  # inputs on each frame
73
  inference_state["point_inputs_per_obj"] = {}
74
  inference_state["mask_inputs_per_obj"] = {}
@@ -103,6 +106,23 @@ class SAM2VideoPredictor(SAM2Base):
103
  self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
104
  return inference_state
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def _obj_id_to_idx(self, inference_state, obj_id):
107
  """Map client-side object id to model-side object index."""
108
  obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
@@ -146,29 +166,66 @@ class SAM2VideoPredictor(SAM2Base):
146
  return len(inference_state["obj_idx_to_id"])
147
 
148
  @torch.inference_mode()
149
- def add_new_points(
150
  self,
151
  inference_state,
152
  frame_idx,
153
  obj_id,
154
- points,
155
- labels,
156
  clear_old_points=True,
157
  normalize_coords=True,
 
158
  ):
159
  """Add new points to a frame."""
160
  obj_idx = self._obj_id_to_idx(inference_state, obj_id)
161
  point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
162
  mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
163
 
164
- if not isinstance(points, torch.Tensor):
 
 
 
 
 
 
 
165
  points = torch.tensor(points, dtype=torch.float32)
166
- if not isinstance(labels, torch.Tensor):
 
 
167
  labels = torch.tensor(labels, dtype=torch.int32)
168
  if points.dim() == 2:
169
  points = points.unsqueeze(0) # add batch dimension
170
  if labels.dim() == 1:
171
  labels = labels.unsqueeze(0) # add batch dimension
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  if normalize_coords:
173
  video_H = inference_state["video_height"]
174
  video_W = inference_state["video_width"]
@@ -215,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base):
215
  prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
216
 
217
  if prev_out is not None and prev_out["pred_masks"] is not None:
218
- prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
 
219
  # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
220
  prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
221
  current_out, _ = self._run_single_frame_inference(
@@ -251,6 +309,10 @@ class SAM2VideoPredictor(SAM2Base):
251
  )
252
  return frame_idx, obj_ids, video_res_masks
253
 
 
 
 
 
254
  @torch.inference_mode()
255
  def add_new_mask(
256
  self,
@@ -531,7 +593,7 @@ class SAM2VideoPredictor(SAM2Base):
531
  storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
532
  # Find all the frames that contain temporary outputs for any objects
533
  # (these should be the frames that have just received clicks for mask inputs
534
- # via `add_new_points` or `add_new_mask`)
535
  temp_frame_inds = set()
536
  for obj_temp_output_dict in temp_output_dict_per_obj.values():
537
  temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
@@ -734,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base):
734
  )
735
  if backbone_out is None:
736
  # Cache miss -- we will run inference on a single image
737
- image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
 
738
  backbone_out = self.forward_image(image)
739
  # Cache the most recent frame's feature (for repeated interactions with
740
  # a frame; we can use an LRU cache for more frames in the future).
@@ -895,4 +958,4 @@ class SAM2VideoPredictor(SAM2Base):
895
  for t in range(frame_idx_begin, frame_idx_end + 1):
896
  non_cond_frame_outputs.pop(t, None)
897
  for obj_output_dict in inference_state["output_dict_per_obj"].values():
898
- obj_output_dict["non_cond_frame_outputs"].pop(t, None)
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+ import warnings
8
  from collections import OrderedDict
9
 
10
  import torch
 
45
  async_loading_frames=False,
46
  ):
47
  """Initialize a inference state."""
48
+ compute_device = self.device # device of the model
49
  images, video_height, video_width = load_video_frames(
50
  video_path=video_path,
51
  image_size=self.image_size,
52
  offload_video_to_cpu=offload_video_to_cpu,
53
  async_loading_frames=async_loading_frames,
54
+ compute_device=compute_device,
55
  )
56
  inference_state = {}
57
  inference_state["images"] = images
 
67
  # the original video height and width, used for resizing final output scores
68
  inference_state["video_height"] = video_height
69
  inference_state["video_width"] = video_width
70
+ inference_state["device"] = compute_device
71
  if offload_state_to_cpu:
72
  inference_state["storage_device"] = torch.device("cpu")
73
  else:
74
+ inference_state["storage_device"] = compute_device
75
  # inputs on each frame
76
  inference_state["point_inputs_per_obj"] = {}
77
  inference_state["mask_inputs_per_obj"] = {}
 
106
  self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
107
  return inference_state
108
 
109
+ @classmethod
110
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
111
+ """
112
+ Load a pretrained model from the Hugging Face hub.
113
+
114
+ Arguments:
115
+ model_id (str): The Hugging Face repository ID.
116
+ **kwargs: Additional arguments to pass to the model constructor.
117
+
118
+ Returns:
119
+ (SAM2VideoPredictor): The loaded model.
120
+ """
121
+ from sam2.build_sam import build_sam2_video_predictor_hf
122
+
123
+ sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
124
+ return sam_model
125
+
126
  def _obj_id_to_idx(self, inference_state, obj_id):
127
  """Map client-side object id to model-side object index."""
128
  obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
 
166
  return len(inference_state["obj_idx_to_id"])
167
 
168
  @torch.inference_mode()
169
+ def add_new_points_or_box(
170
  self,
171
  inference_state,
172
  frame_idx,
173
  obj_id,
174
+ points=None,
175
+ labels=None,
176
  clear_old_points=True,
177
  normalize_coords=True,
178
+ box=None,
179
  ):
180
  """Add new points to a frame."""
181
  obj_idx = self._obj_id_to_idx(inference_state, obj_id)
182
  point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
183
  mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
184
 
185
+ if (points is not None) != (labels is not None):
186
+ raise ValueError("points and labels must be provided together")
187
+ if points is None and box is None:
188
+ raise ValueError("at least one of points or box must be provided as input")
189
+
190
+ if points is None:
191
+ points = torch.zeros(0, 2, dtype=torch.float32)
192
+ elif not isinstance(points, torch.Tensor):
193
  points = torch.tensor(points, dtype=torch.float32)
194
+ if labels is None:
195
+ labels = torch.zeros(0, dtype=torch.int32)
196
+ elif not isinstance(labels, torch.Tensor):
197
  labels = torch.tensor(labels, dtype=torch.int32)
198
  if points.dim() == 2:
199
  points = points.unsqueeze(0) # add batch dimension
200
  if labels.dim() == 1:
201
  labels = labels.unsqueeze(0) # add batch dimension
202
+
203
+ # If `box` is provided, we add it as the first two points with labels 2 and 3
204
+ # along with the user-provided points (consistent with how SAM 2 is trained).
205
+ if box is not None:
206
+ if not clear_old_points:
207
+ raise ValueError(
208
+ "cannot add box without clearing old points, since "
209
+ "box prompt must be provided before any point prompt "
210
+ "(please use clear_old_points=True instead)"
211
+ )
212
+ if inference_state["tracking_has_started"]:
213
+ warnings.warn(
214
+ "You are adding a box after tracking starts. SAM 2 may not always be "
215
+ "able to incorporate a box prompt for *refinement*. If you intend to "
216
+ "use box prompt as an *initial* input before tracking, please call "
217
+ "'reset_state' on the inference state to restart from scratch.",
218
+ category=UserWarning,
219
+ stacklevel=2,
220
+ )
221
+ if not isinstance(box, torch.Tensor):
222
+ box = torch.tensor(box, dtype=torch.float32, device=points.device)
223
+ box_coords = box.reshape(1, 2, 2)
224
+ box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
225
+ box_labels = box_labels.reshape(1, 2)
226
+ points = torch.cat([box_coords, points], dim=1)
227
+ labels = torch.cat([box_labels, labels], dim=1)
228
+
229
  if normalize_coords:
230
  video_H = inference_state["video_height"]
231
  video_W = inference_state["video_width"]
 
272
  prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
273
 
274
  if prev_out is not None and prev_out["pred_masks"] is not None:
275
+ device = inference_state["device"]
276
+ prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
277
  # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
278
  prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
279
  current_out, _ = self._run_single_frame_inference(
 
309
  )
310
  return frame_idx, obj_ids, video_res_masks
311
 
312
+ def add_new_points(self, *args, **kwargs):
313
+ """Deprecated method. Please use `add_new_points_or_box` instead."""
314
+ return self.add_new_points_or_box(*args, **kwargs)
315
+
316
  @torch.inference_mode()
317
  def add_new_mask(
318
  self,
 
593
  storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
594
  # Find all the frames that contain temporary outputs for any objects
595
  # (these should be the frames that have just received clicks for mask inputs
596
+ # via `add_new_points_or_box` or `add_new_mask`)
597
  temp_frame_inds = set()
598
  for obj_temp_output_dict in temp_output_dict_per_obj.values():
599
  temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
 
796
  )
797
  if backbone_out is None:
798
  # Cache miss -- we will run inference on a single image
799
+ device = inference_state["device"]
800
+ image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
801
  backbone_out = self.forward_image(image)
802
  # Cache the most recent frame's feature (for repeated interactions with
803
  # a frame; we can use an LRU cache for more frames in the future).
 
958
  for t in range(frame_idx_begin, frame_idx_end + 1):
959
  non_cond_frame_outputs.pop(t, None)
960
  for obj_output_dict in inference_state["output_dict_per_obj"].values():
961
+ obj_output_dict["non_cond_frame_outputs"].pop(t, None)
segment-anything-2/sam2/utils/misc.py CHANGED
@@ -106,7 +106,15 @@ class AsyncVideoFrameLoader:
106
  A list of video frames to be load asynchronously without blocking session start.
107
  """
108
 
109
- def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
 
 
 
 
 
 
 
 
110
  self.img_paths = img_paths
111
  self.image_size = image_size
112
  self.offload_video_to_cpu = offload_video_to_cpu
@@ -119,6 +127,7 @@ class AsyncVideoFrameLoader:
119
  # video_height and video_width be filled when loading the first image
120
  self.video_height = None
121
  self.video_width = None
 
122
 
123
  # load the first frame to fill video_height and video_width and also
124
  # to cache it (since it's most likely where the user will click)
@@ -152,7 +161,7 @@ class AsyncVideoFrameLoader:
152
  img -= self.img_mean
153
  img /= self.img_std
154
  if not self.offload_video_to_cpu:
155
- img = img.cuda(non_blocking=True)
156
  self.images[index] = img
157
  return img
158
 
@@ -167,6 +176,7 @@ def load_video_frames(
167
  img_mean=(0.485, 0.456, 0.406),
168
  img_std=(0.229, 0.224, 0.225),
169
  async_loading_frames=False,
 
170
  ):
171
  """
172
  Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
@@ -179,7 +189,15 @@ def load_video_frames(
179
  if isinstance(video_path, str) and os.path.isdir(video_path):
180
  jpg_folder = video_path
181
  else:
182
- raise NotImplementedError("Only JPEG frames are supported at this moment")
 
 
 
 
 
 
 
 
183
 
184
  frame_names = [
185
  p
@@ -196,7 +214,12 @@ def load_video_frames(
196
 
197
  if async_loading_frames:
198
  lazy_images = AsyncVideoFrameLoader(
199
- img_paths, image_size, offload_video_to_cpu, img_mean, img_std
 
 
 
 
 
200
  )
201
  return lazy_images, lazy_images.video_height, lazy_images.video_width
202
 
@@ -204,9 +227,9 @@ def load_video_frames(
204
  for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
205
  images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
206
  if not offload_video_to_cpu:
207
- images = images.cuda()
208
- img_mean = img_mean.cuda()
209
- img_std = img_std.cuda()
210
  # normalize by mean and std
211
  images -= img_mean
212
  images /= img_std
@@ -220,10 +243,25 @@ def fill_holes_in_mask_scores(mask, max_area):
220
  # Holes are those connected components in background with area <= self.max_area
221
  # (background regions are those with mask scores <= 0)
222
  assert max_area > 0, "max_area must be positive"
223
- labels, areas = get_connected_components(mask <= 0)
224
- is_hole = (labels > 0) & (areas <= max_area)
225
- # We fill holes with a small positive mask score (0.1) to change them to foreground.
226
- mask = torch.where(is_hole, 0.1, mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  return mask
228
 
229
 
@@ -235,4 +273,4 @@ def concat_points(old_point_inputs, new_points, new_labels):
235
  points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
236
  labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
237
 
238
- return {"point_coords": points, "point_labels": labels}
 
106
  A list of video frames to be load asynchronously without blocking session start.
107
  """
108
 
109
+ def __init__(
110
+ self,
111
+ img_paths,
112
+ image_size,
113
+ offload_video_to_cpu,
114
+ img_mean,
115
+ img_std,
116
+ compute_device,
117
+ ):
118
  self.img_paths = img_paths
119
  self.image_size = image_size
120
  self.offload_video_to_cpu = offload_video_to_cpu
 
127
  # video_height and video_width be filled when loading the first image
128
  self.video_height = None
129
  self.video_width = None
130
+ self.compute_device = compute_device
131
 
132
  # load the first frame to fill video_height and video_width and also
133
  # to cache it (since it's most likely where the user will click)
 
161
  img -= self.img_mean
162
  img /= self.img_std
163
  if not self.offload_video_to_cpu:
164
+ img = img.to(self.compute_device, non_blocking=True)
165
  self.images[index] = img
166
  return img
167
 
 
176
  img_mean=(0.485, 0.456, 0.406),
177
  img_std=(0.229, 0.224, 0.225),
178
  async_loading_frames=False,
179
+ compute_device=torch.device("cuda"),
180
  ):
181
  """
182
  Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
 
189
  if isinstance(video_path, str) and os.path.isdir(video_path):
190
  jpg_folder = video_path
191
  else:
192
+ raise NotImplementedError(
193
+ "Only JPEG frames are supported at this moment. For video files, you may use "
194
+ "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
195
+ "```\n"
196
+ "ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'\n"
197
+ "```\n"
198
+ "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
199
+ "ffmpeg to start the JPEG file from 00000.jpg."
200
+ )
201
 
202
  frame_names = [
203
  p
 
214
 
215
  if async_loading_frames:
216
  lazy_images = AsyncVideoFrameLoader(
217
+ img_paths,
218
+ image_size,
219
+ offload_video_to_cpu,
220
+ img_mean,
221
+ img_std,
222
+ compute_device,
223
  )
224
  return lazy_images, lazy_images.video_height, lazy_images.video_width
225
 
 
227
  for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
228
  images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
229
  if not offload_video_to_cpu:
230
+ images = images.to(compute_device)
231
+ img_mean = img_mean.to(compute_device)
232
+ img_std = img_std.to(compute_device)
233
  # normalize by mean and std
234
  images -= img_mean
235
  images /= img_std
 
243
  # Holes are those connected components in background with area <= self.max_area
244
  # (background regions are those with mask scores <= 0)
245
  assert max_area > 0, "max_area must be positive"
246
+
247
+ input_mask = mask
248
+ try:
249
+ labels, areas = get_connected_components(mask <= 0)
250
+ is_hole = (labels > 0) & (areas <= max_area)
251
+ # We fill holes with a small positive mask score (0.1) to change them to foreground.
252
+ mask = torch.where(is_hole, 0.1, mask)
253
+ except Exception as e:
254
+ # Skip the post-processing step on removing small holes if the CUDA kernel fails
255
+ warnings.warn(
256
+ f"{e}\n\nSkipping the post-processing step due to the error above. You can "
257
+ "still use SAM 2 and it's OK to ignore the error above, although some post-processing "
258
+ "functionality may be limited (which doesn't affect the results in most cases; see "
259
+ "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
260
+ category=UserWarning,
261
+ stacklevel=2,
262
+ )
263
+ mask = input_mask
264
+
265
  return mask
266
 
267
 
 
273
  points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
274
  labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
275
 
276
+ return {"point_coords": points, "point_labels": labels}