chongzhou commited on
Commit
45c2c68
·
1 Parent(s): 7209747

move cached feature to CUDA

Browse files
Files changed (1) hide show
  1. sam2/sam2_video_predictor.py +3 -3
sam2/sam2_video_predictor.py CHANGED
@@ -882,9 +882,9 @@ class SAM2VideoPredictor(SAM2Base):
882
  image, backbone_out = inference_state["cached_features"].get(
883
  frame_idx, (None, None)
884
  )
 
885
  if backbone_out is None:
886
  # Cache miss -- we will run inference on a single image
887
- device = inference_state["device"]
888
  image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
889
  backbone_out = self.forward_image(image)
890
  # Cache the most recent frame's feature (for repeated interactions with
@@ -900,10 +900,10 @@ class SAM2VideoPredictor(SAM2Base):
900
  for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
901
  expanded_backbone_out["backbone_fpn"][i] = feat.expand(
902
  batch_size, -1, -1, -1
903
- )
904
  for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
905
  pos = pos.expand(batch_size, -1, -1, -1)
906
- expanded_backbone_out["vision_pos_enc"][i] = pos
907
 
908
  features = self._prepare_backbone_features(expanded_backbone_out)
909
  features = (expanded_image,) + features
 
882
  image, backbone_out = inference_state["cached_features"].get(
883
  frame_idx, (None, None)
884
  )
885
+ device = inference_state["device"]
886
  if backbone_out is None:
887
  # Cache miss -- we will run inference on a single image
 
888
  image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
889
  backbone_out = self.forward_image(image)
890
  # Cache the most recent frame's feature (for repeated interactions with
 
900
  for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
901
  expanded_backbone_out["backbone_fpn"][i] = feat.expand(
902
  batch_size, -1, -1, -1
903
+ ).to(device)
904
  for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
905
  pos = pos.expand(batch_size, -1, -1, -1)
906
+ expanded_backbone_out["vision_pos_enc"][i] = pos.to(device)
907
 
908
  features = self._prepare_backbone_features(expanded_backbone_out)
909
  features = (expanded_image,) + features