chongzhou commited on
Commit
3ccde9c
·
1 Parent(s): cf4b18a

rollback sam2_base and sam2_video_predictor

Browse files
sam2/modeling/sam2_base.py CHANGED
@@ -617,7 +617,7 @@ class SAM2Base(torch.nn.Module):
617
  if self.use_signed_tpos_enc_to_obj_ptrs
618
  else abs(frame_idx - t)
619
  ),
620
- out["obj_ptr"].to(device),
621
  )
622
  for t, out in ptr_cond_outputs.items()
623
  ]
@@ -630,7 +630,7 @@ class SAM2Base(torch.nn.Module):
630
  t, unselected_cond_outputs.get(t, None)
631
  )
632
  if out is not None:
633
- pos_and_ptrs.append((t_diff, out["obj_ptr"].to(device)))
634
  # If we have at least one object pointer, add them to the across attention
635
  if len(pos_and_ptrs) > 0:
636
  pos_list, ptrs_list = zip(*pos_and_ptrs)
 
617
  if self.use_signed_tpos_enc_to_obj_ptrs
618
  else abs(frame_idx - t)
619
  ),
620
+ out["obj_ptr"],
621
  )
622
  for t, out in ptr_cond_outputs.items()
623
  ]
 
630
  t, unselected_cond_outputs.get(t, None)
631
  )
632
  if out is not None:
633
+ pos_and_ptrs.append((t_diff, out["obj_ptr"]))
634
  # If we have at least one object pointer, add them to the across attention
635
  if len(pos_and_ptrs) > 0:
636
  pos_list, ptrs_list = zip(*pos_and_ptrs)
sam2/sam2_video_predictor.py CHANGED
@@ -470,7 +470,7 @@ class SAM2VideoPredictor(SAM2Base):
470
  size=(batch_size, self.hidden_dim),
471
  fill_value=NO_OBJ_SCORE,
472
  dtype=torch.float32,
473
- device=inference_state["storage_device"],
474
  ),
475
  "object_score_logits": torch.full(
476
  size=(batch_size, 1),
@@ -478,7 +478,7 @@ class SAM2VideoPredictor(SAM2Base):
478
  # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
479
  fill_value=10.0,
480
  dtype=torch.float32,
481
- device=inference_state["storage_device"],
482
  ),
483
  }
484
  empty_mask_ptr = None
@@ -545,9 +545,7 @@ class SAM2VideoPredictor(SAM2Base):
545
  frame_idx=frame_idx,
546
  batch_size=batch_size,
547
  high_res_masks=high_res_masks,
548
- object_score_logits=consolidated_out["object_score_logits"].to(
549
- device, non_blocking=True
550
- ),
551
  is_mask_from_pts=True, # these frames are what the user interacted with
552
  )
553
  consolidated_out["maskmem_features"] = maskmem_features
@@ -881,10 +879,9 @@ class SAM2VideoPredictor(SAM2Base):
881
  def _get_image_feature(self, inference_state, frame_idx, batch_size):
882
  """Compute the image features on a given frame."""
883
  # Look up in the cache first
884
- # image, backbone_out = inference_state["cached_features"].get(
885
- # frame_idx, (None, None)
886
- # )
887
- image, backbone_out = None, None
888
  if backbone_out is None:
889
  # Cache miss -- we will run inference on a single image
890
  device = inference_state["device"]
@@ -892,7 +889,7 @@ class SAM2VideoPredictor(SAM2Base):
892
  backbone_out = self.forward_image(image)
893
  # Cache the most recent frame's feature (for repeated interactions with
894
  # a frame; we can use an LRU cache for more frames in the future).
895
- # inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
896
 
897
  # expand the features to have the same dimension as the number of objects
898
  expanded_image = image.expand(batch_size, -1, -1, -1)
@@ -967,11 +964,9 @@ class SAM2VideoPredictor(SAM2Base):
967
  pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
968
  # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
969
  maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
970
- # object pointer is a small tensor, so we always keep it on GPU memory for fast access (modified for ZeroGPU)
971
- obj_ptr = current_out["obj_ptr"].to(storage_device, non_blocking=True)
972
- object_score_logits = current_out["object_score_logits"].to(
973
- storage_device, non_blocking=True
974
- )
975
  # make a compact version of this frame's output to reduce the state size
976
  compact_current_out = {
977
  "maskmem_features": maskmem_features,
@@ -1023,7 +1018,6 @@ class SAM2VideoPredictor(SAM2Base):
1023
  `maskmem_pos_enc` is the same across frames and objects, so we cache it as
1024
  a constant in the inference session to reduce session storage size.
1025
  """
1026
- storage_device = inference_state["storage_device"]
1027
  model_constants = inference_state["constants"]
1028
  # "out_maskmem_pos_enc" should be either a list of tensors or None
1029
  out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
@@ -1032,9 +1026,7 @@ class SAM2VideoPredictor(SAM2Base):
1032
  assert isinstance(out_maskmem_pos_enc, list)
1033
  # only take the slice for one object, since it's same across objects
1034
  maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
1035
- model_constants["maskmem_pos_enc"] = maskmem_pos_enc.to(
1036
- storage_device, non_blocking=True
1037
- )
1038
  else:
1039
  maskmem_pos_enc = model_constants["maskmem_pos_enc"]
1040
  # expand the cached maskmem_pos_enc to the actual batch size
 
470
  size=(batch_size, self.hidden_dim),
471
  fill_value=NO_OBJ_SCORE,
472
  dtype=torch.float32,
473
+ device=inference_state["device"],
474
  ),
475
  "object_score_logits": torch.full(
476
  size=(batch_size, 1),
 
478
  # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
479
  fill_value=10.0,
480
  dtype=torch.float32,
481
+ device=inference_state["device"],
482
  ),
483
  }
484
  empty_mask_ptr = None
 
545
  frame_idx=frame_idx,
546
  batch_size=batch_size,
547
  high_res_masks=high_res_masks,
548
+ object_score_logits=consolidated_out["object_score_logits"],
 
 
549
  is_mask_from_pts=True, # these frames are what the user interacted with
550
  )
551
  consolidated_out["maskmem_features"] = maskmem_features
 
879
  def _get_image_feature(self, inference_state, frame_idx, batch_size):
880
  """Compute the image features on a given frame."""
881
  # Look up in the cache first
882
+ image, backbone_out = inference_state["cached_features"].get(
883
+ frame_idx, (None, None)
884
+ )
 
885
  if backbone_out is None:
886
  # Cache miss -- we will run inference on a single image
887
  device = inference_state["device"]
 
889
  backbone_out = self.forward_image(image)
890
  # Cache the most recent frame's feature (for repeated interactions with
891
  # a frame; we can use an LRU cache for more frames in the future).
892
+ inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
893
 
894
  # expand the features to have the same dimension as the number of objects
895
  expanded_image = image.expand(batch_size, -1, -1, -1)
 
964
  pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
965
  # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
966
  maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
967
+ # object pointer is a small tensor, so we always keep it on GPU memory for fast access
968
+ obj_ptr = current_out["obj_ptr"]
969
+ object_score_logits = current_out["object_score_logits"]
 
 
970
  # make a compact version of this frame's output to reduce the state size
971
  compact_current_out = {
972
  "maskmem_features": maskmem_features,
 
1018
  `maskmem_pos_enc` is the same across frames and objects, so we cache it as
1019
  a constant in the inference session to reduce session storage size.
1020
  """
 
1021
  model_constants = inference_state["constants"]
1022
  # "out_maskmem_pos_enc" should be either a list of tensors or None
1023
  out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
 
1026
  assert isinstance(out_maskmem_pos_enc, list)
1027
  # only take the slice for one object, since it's same across objects
1028
  maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
1029
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
 
 
1030
  else:
1031
  maskmem_pos_enc = model_constants["maskmem_pos_enc"]
1032
  # expand the cached maskmem_pos_enc to the actual batch size
trimm_examples.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from moviepy.editor import VideoFileClip
4
+
5
+ # Define the folder and duration
6
+ input_folder = "examples"
7
+ output_folder = "examples/trimmed"
8
+ trim_duration = 3 # seconds
9
+
10
+ # Create output folder if it doesn't exist
11
+ os.makedirs(output_folder, exist_ok=True)
12
+
13
+ # Process each .mp4 file
14
+ for filename in os.listdir(input_folder):
15
+ if filename.lower().endswith(".mp4"):
16
+ input_path = os.path.join(input_folder, filename)
17
+ output_path = os.path.join(output_folder, filename)
18
+
19
+ print(f"Trimming: {input_path} -> {output_path}")
20
+ try:
21
+ clip = VideoFileClip(input_path).subclip(0, trim_duration)
22
+ clip.write_videofile(
23
+ output_path,
24
+ codec="libx264",
25
+ audio_codec="aac",
26
+ verbose=False,
27
+ logger=None,
28
+ )
29
+ except Exception as e:
30
+ print(f"Failed to process {filename}: {e}")