rollback sam2_base and sam2_video_predictor
Browse files- sam2/modeling/sam2_base.py +2 -2
- sam2/sam2_video_predictor.py +11 -19
- trimm_examples.py +30 -0
sam2/modeling/sam2_base.py
CHANGED
@@ -617,7 +617,7 @@ class SAM2Base(torch.nn.Module):
|
|
617 |
if self.use_signed_tpos_enc_to_obj_ptrs
|
618 |
else abs(frame_idx - t)
|
619 |
),
|
620 |
-
out["obj_ptr"]
|
621 |
)
|
622 |
for t, out in ptr_cond_outputs.items()
|
623 |
]
|
@@ -630,7 +630,7 @@ class SAM2Base(torch.nn.Module):
|
|
630 |
t, unselected_cond_outputs.get(t, None)
|
631 |
)
|
632 |
if out is not None:
|
633 |
-
pos_and_ptrs.append((t_diff, out["obj_ptr"]
|
634 |
# If we have at least one object pointer, add them to the across attention
|
635 |
if len(pos_and_ptrs) > 0:
|
636 |
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
|
|
617 |
if self.use_signed_tpos_enc_to_obj_ptrs
|
618 |
else abs(frame_idx - t)
|
619 |
),
|
620 |
+
out["obj_ptr"],
|
621 |
)
|
622 |
for t, out in ptr_cond_outputs.items()
|
623 |
]
|
|
|
630 |
t, unselected_cond_outputs.get(t, None)
|
631 |
)
|
632 |
if out is not None:
|
633 |
+
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
634 |
# If we have at least one object pointer, add them to the across attention
|
635 |
if len(pos_and_ptrs) > 0:
|
636 |
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
sam2/sam2_video_predictor.py
CHANGED
@@ -470,7 +470,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
470 |
size=(batch_size, self.hidden_dim),
|
471 |
fill_value=NO_OBJ_SCORE,
|
472 |
dtype=torch.float32,
|
473 |
-
device=inference_state["
|
474 |
),
|
475 |
"object_score_logits": torch.full(
|
476 |
size=(batch_size, 1),
|
@@ -478,7 +478,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
478 |
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
479 |
fill_value=10.0,
|
480 |
dtype=torch.float32,
|
481 |
-
device=inference_state["
|
482 |
),
|
483 |
}
|
484 |
empty_mask_ptr = None
|
@@ -545,9 +545,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
545 |
frame_idx=frame_idx,
|
546 |
batch_size=batch_size,
|
547 |
high_res_masks=high_res_masks,
|
548 |
-
object_score_logits=consolidated_out["object_score_logits"]
|
549 |
-
device, non_blocking=True
|
550 |
-
),
|
551 |
is_mask_from_pts=True, # these frames are what the user interacted with
|
552 |
)
|
553 |
consolidated_out["maskmem_features"] = maskmem_features
|
@@ -881,10 +879,9 @@ class SAM2VideoPredictor(SAM2Base):
|
|
881 |
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
882 |
"""Compute the image features on a given frame."""
|
883 |
# Look up in the cache first
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
image, backbone_out = None, None
|
888 |
if backbone_out is None:
|
889 |
# Cache miss -- we will run inference on a single image
|
890 |
device = inference_state["device"]
|
@@ -892,7 +889,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
892 |
backbone_out = self.forward_image(image)
|
893 |
# Cache the most recent frame's feature (for repeated interactions with
|
894 |
# a frame; we can use an LRU cache for more frames in the future).
|
895 |
-
|
896 |
|
897 |
# expand the features to have the same dimension as the number of objects
|
898 |
expanded_image = image.expand(batch_size, -1, -1, -1)
|
@@ -967,11 +964,9 @@ class SAM2VideoPredictor(SAM2Base):
|
|
967 |
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
968 |
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
969 |
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
970 |
-
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
971 |
-
obj_ptr = current_out["obj_ptr"]
|
972 |
-
object_score_logits = current_out["object_score_logits"]
|
973 |
-
storage_device, non_blocking=True
|
974 |
-
)
|
975 |
# make a compact version of this frame's output to reduce the state size
|
976 |
compact_current_out = {
|
977 |
"maskmem_features": maskmem_features,
|
@@ -1023,7 +1018,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|
1023 |
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
1024 |
a constant in the inference session to reduce session storage size.
|
1025 |
"""
|
1026 |
-
storage_device = inference_state["storage_device"]
|
1027 |
model_constants = inference_state["constants"]
|
1028 |
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
1029 |
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
@@ -1032,9 +1026,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
1032 |
assert isinstance(out_maskmem_pos_enc, list)
|
1033 |
# only take the slice for one object, since it's same across objects
|
1034 |
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
1035 |
-
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
1036 |
-
storage_device, non_blocking=True
|
1037 |
-
)
|
1038 |
else:
|
1039 |
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
1040 |
# expand the cached maskmem_pos_enc to the actual batch size
|
|
|
470 |
size=(batch_size, self.hidden_dim),
|
471 |
fill_value=NO_OBJ_SCORE,
|
472 |
dtype=torch.float32,
|
473 |
+
device=inference_state["device"],
|
474 |
),
|
475 |
"object_score_logits": torch.full(
|
476 |
size=(batch_size, 1),
|
|
|
478 |
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
479 |
fill_value=10.0,
|
480 |
dtype=torch.float32,
|
481 |
+
device=inference_state["device"],
|
482 |
),
|
483 |
}
|
484 |
empty_mask_ptr = None
|
|
|
545 |
frame_idx=frame_idx,
|
546 |
batch_size=batch_size,
|
547 |
high_res_masks=high_res_masks,
|
548 |
+
object_score_logits=consolidated_out["object_score_logits"],
|
|
|
|
|
549 |
is_mask_from_pts=True, # these frames are what the user interacted with
|
550 |
)
|
551 |
consolidated_out["maskmem_features"] = maskmem_features
|
|
|
879 |
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
880 |
"""Compute the image features on a given frame."""
|
881 |
# Look up in the cache first
|
882 |
+
image, backbone_out = inference_state["cached_features"].get(
|
883 |
+
frame_idx, (None, None)
|
884 |
+
)
|
|
|
885 |
if backbone_out is None:
|
886 |
# Cache miss -- we will run inference on a single image
|
887 |
device = inference_state["device"]
|
|
|
889 |
backbone_out = self.forward_image(image)
|
890 |
# Cache the most recent frame's feature (for repeated interactions with
|
891 |
# a frame; we can use an LRU cache for more frames in the future).
|
892 |
+
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
|
893 |
|
894 |
# expand the features to have the same dimension as the number of objects
|
895 |
expanded_image = image.expand(batch_size, -1, -1, -1)
|
|
|
964 |
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
965 |
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
966 |
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
967 |
+
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
968 |
+
obj_ptr = current_out["obj_ptr"]
|
969 |
+
object_score_logits = current_out["object_score_logits"]
|
|
|
|
|
970 |
# make a compact version of this frame's output to reduce the state size
|
971 |
compact_current_out = {
|
972 |
"maskmem_features": maskmem_features,
|
|
|
1018 |
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
1019 |
a constant in the inference session to reduce session storage size.
|
1020 |
"""
|
|
|
1021 |
model_constants = inference_state["constants"]
|
1022 |
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
1023 |
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
|
|
1026 |
assert isinstance(out_maskmem_pos_enc, list)
|
1027 |
# only take the slice for one object, since it's same across objects
|
1028 |
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
1029 |
+
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
|
|
|
|
1030 |
else:
|
1031 |
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
1032 |
# expand the cached maskmem_pos_enc to the actual batch size
|
trimm_examples.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from moviepy.editor import VideoFileClip
|
4 |
+
|
5 |
+
# Define the folder and duration
|
6 |
+
input_folder = "examples"
|
7 |
+
output_folder = "examples/trimmed"
|
8 |
+
trim_duration = 3 # seconds
|
9 |
+
|
10 |
+
# Create output folder if it doesn't exist
|
11 |
+
os.makedirs(output_folder, exist_ok=True)
|
12 |
+
|
13 |
+
# Process each .mp4 file
|
14 |
+
for filename in os.listdir(input_folder):
|
15 |
+
if filename.lower().endswith(".mp4"):
|
16 |
+
input_path = os.path.join(input_folder, filename)
|
17 |
+
output_path = os.path.join(output_folder, filename)
|
18 |
+
|
19 |
+
print(f"Trimming: {input_path} -> {output_path}")
|
20 |
+
try:
|
21 |
+
clip = VideoFileClip(input_path).subclip(0, trim_duration)
|
22 |
+
clip.write_videofile(
|
23 |
+
output_path,
|
24 |
+
codec="libx264",
|
25 |
+
audio_codec="aac",
|
26 |
+
verbose=False,
|
27 |
+
logger=None,
|
28 |
+
)
|
29 |
+
except Exception as e:
|
30 |
+
print(f"Failed to process {filename}: {e}")
|