John Ho commited on
Commit
579e65b
·
1 Parent(s): 0db2411

added new variable for reference_frame_idx

Browse files
Files changed (2) hide show
  1. app.py +14 -2
  2. samv2_handler.py +5 -1
app.py CHANGED
@@ -116,11 +116,17 @@ def process_image(
116
  )
117
 
118
 
119
- @spaces.GPU(duration=300)
 
 
120
  @torch.inference_mode()
121
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
122
  def process_video(
123
- video_path: str, variant: str, masks: Union[list, str], drop_masks: bool = False
 
 
 
 
124
  ):
125
  """
126
  SAM2 Video Segmentation
@@ -148,6 +154,7 @@ def process_video(
148
  do_tidy_up=True,
149
  drop_mask=drop_masks,
150
  async_frame_load=True,
 
151
  )
152
 
153
 
@@ -196,6 +203,11 @@ with gr.Blocks() as demo:
196
  """,
197
  ),
198
  gr.Checkbox(label="remove base64 encoded masks from result JSON"),
 
 
 
 
 
199
  ],
200
  outputs=gr.JSON(label="Output JSON"),
201
  title="SAM2 for Videos",
 
116
  )
117
 
118
 
119
+ @spaces.GPU(
120
+ duration=120
121
+ ) # user must have 2-minute of inference time left at the time of calling
122
  @torch.inference_mode()
123
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
124
  def process_video(
125
+ video_path: str,
126
+ variant: str,
127
+ masks: Union[list, str],
128
+ drop_masks: bool = False,
129
+ ref_frame_idx: int = 0,
130
  ):
131
  """
132
  SAM2 Video Segmentation
 
154
  do_tidy_up=True,
155
  drop_mask=drop_masks,
156
  async_frame_load=True,
157
+ ref_frame_idx=ref_frame_idx,
158
  )
159
 
160
 
 
203
  """,
204
  ),
205
  gr.Checkbox(label="remove base64 encoded masks from result JSON"),
206
+ gr.Number(
207
+ label="frame index for the provided object masks",
208
+ value=0,
209
+ precision=0,
210
+ ),
211
  ],
212
  outputs=gr.JSON(label="Output JSON"),
213
  title="SAM2 for Videos",
samv2_handler.py CHANGED
@@ -161,6 +161,7 @@ def run_sam_video_inference(
161
  do_tidy_up: bool = False,
162
  drop_mask: bool = True,
163
  async_frame_load: bool = False,
 
164
  ):
165
  # put video frames into directory
166
  # TODO:
@@ -183,7 +184,10 @@ def run_sam_video_inference(
183
  )
184
  for i, mask in enumerate(masks):
185
  model.add_new_mask(
186
- inference_state=inference_state, frame_idx=0, obj_id=i, mask=mask
 
 
 
187
  )
188
  masks_generator = model.propagate_in_video(inference_state)
189
 
 
161
  do_tidy_up: bool = False,
162
  drop_mask: bool = True,
163
  async_frame_load: bool = False,
164
+ ref_frame_idx: int = 0,
165
  ):
166
  # put video frames into directory
167
  # TODO:
 
184
  )
185
  for i, mask in enumerate(masks):
186
  model.add_new_mask(
187
+ inference_state=inference_state,
188
+ frame_idx=ref_frame_idx,
189
+ obj_id=i,
190
+ mask=mask,
191
  )
192
  masks_generator = model.propagate_in_video(inference_state)
193