move cached feature to CUDA
Browse files
sam2/sam2_video_predictor.py
CHANGED
@@ -882,9 +882,9 @@ class SAM2VideoPredictor(SAM2Base):
|
|
882 |
image, backbone_out = inference_state["cached_features"].get(
|
883 |
frame_idx, (None, None)
|
884 |
)
|
|
|
885 |
if backbone_out is None:
|
886 |
# Cache miss -- we will run inference on a single image
|
887 |
-
device = inference_state["device"]
|
888 |
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
889 |
backbone_out = self.forward_image(image)
|
890 |
# Cache the most recent frame's feature (for repeated interactions with
|
@@ -900,10 +900,10 @@ class SAM2VideoPredictor(SAM2Base):
|
|
900 |
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
|
901 |
expanded_backbone_out["backbone_fpn"][i] = feat.expand(
|
902 |
batch_size, -1, -1, -1
|
903 |
-
)
|
904 |
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
|
905 |
pos = pos.expand(batch_size, -1, -1, -1)
|
906 |
-
expanded_backbone_out["vision_pos_enc"][i] = pos
|
907 |
|
908 |
features = self._prepare_backbone_features(expanded_backbone_out)
|
909 |
features = (expanded_image,) + features
|
|
|
882 |
image, backbone_out = inference_state["cached_features"].get(
|
883 |
frame_idx, (None, None)
|
884 |
)
|
885 |
+
device = inference_state["device"]
|
886 |
if backbone_out is None:
|
887 |
# Cache miss -- we will run inference on a single image
|
|
|
888 |
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
889 |
backbone_out = self.forward_image(image)
|
890 |
# Cache the most recent frame's feature (for repeated interactions with
|
|
|
900 |
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
|
901 |
expanded_backbone_out["backbone_fpn"][i] = feat.expand(
|
902 |
batch_size, -1, -1, -1
|
903 |
+
).to(device)
|
904 |
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
|
905 |
pos = pos.expand(batch_size, -1, -1, -1)
|
906 |
+
expanded_backbone_out["vision_pos_enc"][i] = pos.to(device)
|
907 |
|
908 |
features = self._prepare_backbone_features(expanded_backbone_out)
|
909 |
features = (expanded_image,) + features
|