yjwnb6 commited on
Commit
3ca680e
·
1 Parent(s): 9375c3b

Update app with new modes and assets

Browse files
.gitattributes CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  sam2/checkpoints/unsamv2_plus_ckpt.pt filter=lfs diff=lfs merge=lfs -text
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  sam2/checkpoints/unsamv2_plus_ckpt.pt filter=lfs diff=lfs merge=lfs -text
37
+ demo/*.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ demo/*.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ _tmp/
__pycache__/app.cpython-310.pyc ADDED
Binary file (27.9 kB). View file
 
app.py CHANGED
@@ -5,10 +5,13 @@ from __future__ import annotations
5
 
6
  import logging
7
  import os
 
8
  import sys
 
9
  import threading
 
10
  from pathlib import Path
11
- from typing import List, Optional, Sequence
12
 
13
  import cv2
14
  import gradio as gr
@@ -25,7 +28,8 @@ SAM2_REPO = REPO_ROOT / "sam2"
25
  if SAM2_REPO.exists():
26
  sys.path.insert(0, str(SAM2_REPO))
27
 
28
- from sam2.build_sam import build_sam2 # noqa: E402
 
29
  from sam2.sam2_image_predictor import SAM2ImagePredictor # noqa: E402
30
 
31
  logging.basicConfig(level=logging.INFO)
@@ -46,30 +50,38 @@ GRANULARITY_MIN = float(os.getenv("UNSAMV2_GRAN_MIN", 0.1))
46
  GRANULARITY_MAX = float(os.getenv("UNSAMV2_GRAN_MAX", 1.0))
47
  ZERO_GPU_ENABLED = os.getenv("UNSAMV2_ENABLE_ZEROGPU", "1").lower() in {"1", "true", "yes"}
48
  ZERO_GPU_DURATION = int(os.getenv("UNSAMV2_ZEROGPU_DURATION", "60"))
 
 
 
49
 
50
  POINT_MODE_TO_LABEL = {"Foreground (+)": 1, "Background (-)": 0}
51
  POINT_COLORS_BGR = {
52
  1: (72, 201, 127), # green-ish for positives
53
  0: (64, 76, 225), # red-ish for negatives
54
  }
55
- MASK_COLOR_BGR = (0, 196, 255)
56
- OUTLINE_COLOR_BGR = (0, 165, 255)
57
 
58
  DEFAULT_IMAGE_PATH = REPO_ROOT / "demo" / "bird.webp"
 
 
59
 
60
 
61
- def _load_default_image() -> Optional[np.ndarray]:
62
- if not DEFAULT_IMAGE_PATH.exists():
63
- LOGGER.warning("Default image missing at %s", DEFAULT_IMAGE_PATH)
64
  return None
65
- img_bgr = cv2.imread(str(DEFAULT_IMAGE_PATH), cv2.IMREAD_COLOR)
66
  if img_bgr is None:
67
- LOGGER.warning("Could not read default image at %s", DEFAULT_IMAGE_PATH)
68
  return None
69
  return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
70
 
71
 
72
- DEFAULT_IMAGE = _load_default_image()
 
 
 
 
73
 
74
 
75
  class ModelManager:
@@ -102,10 +114,56 @@ class ModelManager:
102
  def make_predictor(self, device: torch.device) -> SAM2ImagePredictor:
103
  return SAM2ImagePredictor(self.get_model(device), mask_threshold=-1.0)
104
 
 
 
 
 
 
 
 
105
 
106
  MODEL_MANAGER = ModelManager()
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def ensure_uint8(image: Optional[np.ndarray]) -> Optional[np.ndarray]:
110
  if image is None:
111
  return None
@@ -120,6 +178,176 @@ def ensure_uint8(image: Optional[np.ndarray]) -> Optional[np.ndarray]:
120
  return img
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def choose_device() -> torch.device:
124
  preference = os.getenv("UNSAMV2_DEVICE", "auto").lower()
125
  if preference == "cpu":
@@ -181,7 +409,7 @@ def draw_overlay(
181
  ) -> np.ndarray:
182
  canvas_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
183
  if mask is not None:
184
- mask_bool = mask.astype(bool)
185
  overlay = np.zeros_like(canvas_bgr, dtype=np.uint8)
186
  overlay[mask_bool] = MASK_COLOR_BGR
187
  canvas_bgr = np.where(
@@ -189,12 +417,6 @@ def draw_overlay(
189
  (canvas_bgr * (1.0 - alpha) + overlay * alpha).astype(np.uint8),
190
  canvas_bgr,
191
  )
192
- contours, _ = cv2.findContours(
193
- mask_bool.astype(np.uint8),
194
- mode=cv2.RETR_EXTERNAL,
195
- method=cv2.CHAIN_APPROX_SIMPLE,
196
- )
197
- cv2.drawContours(canvas_bgr, contours, -1, OUTLINE_COLOR_BGR, 2)
198
  for (x, y), lbl in zip(points, labels):
199
  color = POINT_COLORS_BGR.get(lbl, (255, 255, 255))
200
  center = (int(round(x)), int(round(y)))
@@ -325,121 +547,671 @@ def _run_segmentation(
325
  return overlay, status
326
 
327
 
328
- if spaces is not None and ZERO_GPU_ENABLED:
329
- segment_fn = spaces.GPU(duration=ZERO_GPU_DURATION)(_run_segmentation)
330
- else:
331
- segment_fn = _run_segmentation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
 
334
- def build_demo() -> gr.Blocks:
335
- with gr.Blocks(title="UnSAMv2 Interactive Segmentation", theme=gr.themes.Soft()) as demo:
336
- gr.Markdown(
337
- """## UnSAMv2 · Interactive Granularity Control
338
- Upload an image, add positive/negative clicks, tune granularity, and run segmentation."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
- image_state = gr.State(DEFAULT_IMAGE)
342
- points_state = gr.State([])
343
- labels_state = gr.State([])
344
 
345
- image_input = gr.Image(
346
- label="Image · clicks & mask",
347
- type="numpy",
348
- height=480,
349
- value=DEFAULT_IMAGE,
350
- sources=["upload"],
351
- )
 
 
 
352
 
353
- with gr.Row():
354
- point_mode = gr.Radio(
355
- choices=list(POINT_MODE_TO_LABEL.keys()),
356
- value="Foreground (+)",
357
- label="Click type",
358
- )
359
- granularity_slider = gr.Slider(
360
- minimum=GRANULARITY_MIN,
361
- maximum=GRANULARITY_MAX,
362
- value=0.2,
363
- step=0.01,
364
- label="Granularity",
365
- info="Lower = finer details, Higher = coarser regions",
366
- )
367
- segment_button = gr.Button("Segment", variant="primary")
368
-
369
- with gr.Row():
370
- undo_button = gr.Button("Undo last click")
371
- clear_button = gr.Button("Clear clicks")
372
-
373
- status_markdown = gr.Markdown(" Ready.")
374
-
375
- image_input.upload(
376
- handle_image_upload,
377
- inputs=[image_input],
378
- outputs=[
379
- image_input,
380
- image_state,
381
- points_state,
382
- labels_state,
383
- status_markdown,
384
- ],
385
- )
386
 
387
- image_input.clear(
388
- handle_image_upload,
389
- inputs=[image_input],
390
- outputs=[
391
- image_input,
392
- image_state,
393
- points_state,
394
- labels_state,
395
- status_markdown,
396
- ],
397
- )
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
- image_input.select(
400
- handle_click,
401
- inputs=[
402
- point_mode,
403
- points_state,
404
- labels_state,
405
- image_state,
406
- ],
407
- outputs=[
408
- image_input,
409
- points_state,
410
- labels_state,
411
- status_markdown,
412
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
- undo_button.click(
416
- undo_last_click,
417
- inputs=[image_state, points_state, labels_state],
418
- outputs=[
419
- image_input,
420
- points_state,
421
- labels_state,
422
- status_markdown,
423
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
- clear_button.click(
427
- clear_clicks,
428
- inputs=[image_state],
429
- outputs=[
430
- image_input,
431
- points_state,
432
- labels_state,
433
- status_markdown,
434
- ],
 
 
 
 
 
435
  )
436
 
437
- segment_button.click(
438
- segment_fn,
439
- inputs=[image_state, points_state, labels_state, granularity_slider],
440
- outputs=[image_input, status_markdown],
 
 
 
 
 
 
 
 
 
 
 
 
441
  )
442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  demo.queue(max_size=8)
444
  return demo
445
 
 
5
 
6
  import logging
7
  import os
8
+ import shutil
9
  import sys
10
+ import tempfile
11
  import threading
12
+ import uuid
13
  from pathlib import Path
14
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
15
 
16
  import cv2
17
  import gradio as gr
 
28
  if SAM2_REPO.exists():
29
  sys.path.insert(0, str(SAM2_REPO))
30
 
31
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator # noqa: E402
32
+ from sam2.build_sam import build_sam2, build_sam2_video_predictor # noqa: E402
33
  from sam2.sam2_image_predictor import SAM2ImagePredictor # noqa: E402
34
 
35
  logging.basicConfig(level=logging.INFO)
 
50
  GRANULARITY_MAX = float(os.getenv("UNSAMV2_GRAN_MAX", 1.0))
51
  ZERO_GPU_ENABLED = os.getenv("UNSAMV2_ENABLE_ZEROGPU", "1").lower() in {"1", "true", "yes"}
52
  ZERO_GPU_DURATION = int(os.getenv("UNSAMV2_ZEROGPU_DURATION", "60"))
53
+ MAX_VIDEO_FRAMES = int(os.getenv("UNSAMV2_MAX_VIDEO_FRAMES", "360"))
54
+ WHOLE_IMAGE_POINTS_PER_SIDE = int(os.getenv("UNSAMV2_WHOLE_POINTS", "64"))
55
+ WHOLE_IMAGE_MAX_MASKS = 1000
56
 
57
  POINT_MODE_TO_LABEL = {"Foreground (+)": 1, "Background (-)": 0}
58
  POINT_COLORS_BGR = {
59
  1: (72, 201, 127), # green-ish for positives
60
  0: (64, 76, 225), # red-ish for negatives
61
  }
62
+ MASK_COLOR_BGR = (0, 0, 255)
 
63
 
64
  DEFAULT_IMAGE_PATH = REPO_ROOT / "demo" / "bird.webp"
65
+ WHOLE_IMAGE_DEFAULT_PATH = REPO_ROOT / "demo" / "sa_291195.jpg"
66
+ DEFAULT_VIDEO_PATH = REPO_ROOT / "demo" / "bedroom.mp4"
67
 
68
 
69
+ def _load_image_from_path(path: Path) -> Optional[np.ndarray]:
70
+ if not path.exists():
71
+ LOGGER.warning("Default image missing at %s", path)
72
  return None
73
+ img_bgr = cv2.imread(str(path), cv2.IMREAD_COLOR)
74
  if img_bgr is None:
75
+ LOGGER.warning("Could not read default image at %s", path)
76
  return None
77
  return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
78
 
79
 
80
+ DEFAULT_IMAGE = _load_image_from_path(DEFAULT_IMAGE_PATH)
81
+ WHOLE_IMAGE_DEFAULT = _load_image_from_path(WHOLE_IMAGE_DEFAULT_PATH)
82
+
83
+ TMP_ROOT = REPO_ROOT / "_tmp"
84
+ TMP_ROOT.mkdir(exist_ok=True)
85
 
86
 
87
  class ModelManager:
 
114
  def make_predictor(self, device: torch.device) -> SAM2ImagePredictor:
115
  return SAM2ImagePredictor(self.get_model(device), mask_threshold=-1.0)
116
 
117
+ def make_auto_mask_generator(
118
+ self,
119
+ device: torch.device,
120
+ **kwargs,
121
+ ) -> SAM2AutomaticMaskGenerator:
122
+ return SAM2AutomaticMaskGenerator(self.get_model(device), **kwargs)
123
+
124
 
125
  MODEL_MANAGER = ModelManager()
126
 
127
 
128
+ class VideoPredictorManager:
129
+ """Caches heavy video predictors per device."""
130
+
131
+ def __init__(self) -> None:
132
+ self._predictors: dict[str, torch.nn.Module] = {}
133
+ self._lock = threading.Lock()
134
+
135
+ def _build(self, device: torch.device) -> torch.nn.Module:
136
+ LOGGER.info("Loading UnSAMv2 video predictor onto %s", device)
137
+ return build_sam2_video_predictor(
138
+ CONFIG_PATH,
139
+ ckpt_path=str(CKPT_PATH),
140
+ device=device,
141
+ )
142
+
143
+ def get_predictor(self, device: torch.device) -> torch.nn.Module:
144
+ key = (
145
+ f"{device.type}:{device.index}"
146
+ if device.type == "cuda"
147
+ else device.type
148
+ )
149
+ with self._lock:
150
+ if key not in self._predictors:
151
+ self._predictors[key] = self._build(device)
152
+ return self._predictors[key]
153
+
154
+
155
+ VIDEO_PREDICTOR_MANAGER = VideoPredictorManager()
156
+
157
+
158
+ def make_empty_video_state() -> Dict[str, Any]:
159
+ return {
160
+ "frame_dir": None,
161
+ "frame_paths": [],
162
+ "fps": 0.0,
163
+ "frame_size": (0, 0),
164
+ }
165
+
166
+
167
  def ensure_uint8(image: Optional[np.ndarray]) -> Optional[np.ndarray]:
168
  if image is None:
169
  return None
 
178
  return img
179
 
180
 
181
+ def make_temp_subdir(prefix: str) -> Path:
182
+ TMP_ROOT.mkdir(exist_ok=True)
183
+ return Path(tempfile.mkdtemp(prefix=prefix, dir=str(TMP_ROOT)))
184
+
185
+
186
+ def remove_dir_if_exists(path_str: Optional[str]) -> None:
187
+ if not path_str:
188
+ return
189
+ path = Path(path_str)
190
+ if path.exists():
191
+ shutil.rmtree(path, ignore_errors=True)
192
+
193
+
194
+ def load_rgb_image(path: Path) -> np.ndarray:
195
+ bgr = cv2.imread(str(path), cv2.IMREAD_COLOR)
196
+ if bgr is None:
197
+ raise FileNotFoundError(f"Failed to read frame at {path}")
198
+ return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
199
+
200
+
201
+ def resolve_video_path(video_value: Any) -> Optional[str]:
202
+ if video_value is None:
203
+ return None
204
+ if isinstance(video_value, str):
205
+ return video_value
206
+ if isinstance(video_value, dict):
207
+ return video_value.get("name") or video_value.get("path")
208
+ # Gradio may pass a FileData/MediaData object with a .name attribute
209
+ for attr in ("name", "path", "video", "data"):
210
+ candidate = getattr(video_value, attr, None)
211
+ if isinstance(candidate, str):
212
+ return candidate
213
+ return None
214
+
215
+
216
+ def match_mask_to_image(mask: np.ndarray, image: np.ndarray) -> np.ndarray:
217
+ mask_arr = np.asarray(mask)
218
+ if mask_arr.ndim == 3:
219
+ mask_arr = mask_arr.squeeze()
220
+ h, w = image.shape[:2]
221
+ if mask_arr.shape[:2] != (h, w):
222
+ mask_arr = cv2.resize(
223
+ mask_arr.astype(np.float32),
224
+ (w, h),
225
+ interpolation=cv2.INTER_NEAREST,
226
+ )
227
+ return mask_arr.astype(bool)
228
+
229
+
230
+ def colorize_mask_collection(
231
+ image: np.ndarray,
232
+ masks: Sequence[np.ndarray],
233
+ alpha: float = 0.55,
234
+ ) -> np.ndarray:
235
+ if not masks:
236
+ return image
237
+ canvas = image.astype(np.float32)
238
+ rng = np.random.default_rng(1337)
239
+ for mask in masks:
240
+ mask_arr = match_mask_to_image(mask, image)
241
+ if not mask_arr.any():
242
+ continue
243
+ color = rng.integers(20, 235, size=3)
244
+ canvas[mask_arr] = (
245
+ canvas[mask_arr] * (1.0 - alpha) + color * alpha
246
+ )
247
+ return canvas.clip(0, 255).astype(np.uint8)
248
+
249
+
250
+ def render_video_overlay(
251
+ video_state: Dict[str, Any],
252
+ frame_idx: int,
253
+ pts: Sequence[Sequence[float]],
254
+ lbls: Sequence[int],
255
+ ) -> Optional[np.ndarray]:
256
+ frame_paths: List[str] = list(video_state.get("frame_paths", []))
257
+ if not frame_paths:
258
+ return None
259
+ safe_idx = int(np.clip(frame_idx, 0, len(frame_paths) - 1))
260
+ frame = load_rgb_image(Path(frame_paths[safe_idx]))
261
+ return draw_overlay(frame, None, pts, lbls)
262
+
263
+
264
+ def mask_entries_to_arrays(entries: Sequence[Dict[str, Any]]) -> List[np.ndarray]:
265
+ arrays: List[np.ndarray] = []
266
+ for entry in entries:
267
+ seg = entry.get("segmentation", entry)
268
+ if isinstance(seg, np.ndarray):
269
+ mask = seg
270
+ elif isinstance(seg, dict):
271
+ from sam2.utils.amg import rle_to_mask
272
+
273
+ mask = rle_to_mask(seg)
274
+ else:
275
+ mask = np.asarray(seg)
276
+ arrays.append(mask.astype(bool))
277
+ return arrays
278
+
279
+
280
+ def summarize_masks(entries: Sequence[Dict[str, Any]]) -> List[Dict[str, Any]]:
281
+ summary: List[Dict[str, Any]] = []
282
+ for idx, entry in enumerate(entries, start=1):
283
+ summary.append(
284
+ {
285
+ "mask": idx,
286
+ "area": int(entry.get("area", 0)),
287
+ "pred_iou": round(float(entry.get("predicted_iou", 0.0)), 3),
288
+ "stability": round(float(entry.get("stability_score", 0.0)), 3),
289
+ }
290
+ )
291
+ return summary
292
+
293
+
294
+ def extract_video_frames(video_path: str) -> Tuple[List[Path], float, Tuple[int, int], Path]:
295
+ cap = cv2.VideoCapture(video_path)
296
+ if not cap.isOpened():
297
+ raise ValueError("Could not open the uploaded video.")
298
+ fps = cap.get(cv2.CAP_PROP_FPS)
299
+ if not fps or fps <= 1e-3:
300
+ fps = 12.0
301
+ frame_dir = make_temp_subdir("video_frames_")
302
+ frame_paths: List[Path] = []
303
+ height = width = 0
304
+ idx = 0
305
+ while True:
306
+ ok, frame = cap.read()
307
+ if not ok:
308
+ break
309
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
310
+ if idx == 0:
311
+ height, width = rgb.shape[:2]
312
+ out_path = frame_dir / f"{idx:05d}.jpg"
313
+ if not cv2.imwrite(str(out_path), cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)):
314
+ cap.release()
315
+ raise RuntimeError(f"Failed to write frame {idx} to disk")
316
+ frame_paths.append(out_path)
317
+ idx += 1
318
+ if idx >= MAX_VIDEO_FRAMES:
319
+ LOGGER.warning(
320
+ "Stopping frame extraction at %d frames per UNSAMV2_MAX_VIDEO_FRAMES",
321
+ MAX_VIDEO_FRAMES,
322
+ )
323
+ break
324
+ cap.release()
325
+ if not frame_paths:
326
+ remove_dir_if_exists(str(frame_dir))
327
+ raise ValueError("No frames decoded from the provided video.")
328
+ if height == 0 or width == 0:
329
+ sample = load_rgb_image(frame_paths[0])
330
+ height, width = sample.shape[:2]
331
+ return frame_paths, float(fps), (height, width), frame_dir
332
+
333
+
334
+ def write_video_from_frames(frames: Sequence[np.ndarray], fps: float) -> Path:
335
+ if not frames:
336
+ raise ValueError("No frames available to write video output.")
337
+ height, width = frames[0].shape[:2]
338
+ safe_fps = fps if fps and fps > 0 else 12.0
339
+ out_path = TMP_ROOT / f"video_seg_{uuid.uuid4().hex}.mp4"
340
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
341
+ writer = cv2.VideoWriter(str(out_path), fourcc, safe_fps, (width, height))
342
+ if not writer.isOpened():
343
+ raise RuntimeError("Failed to initialize video writer. Check codec support.")
344
+ for frame in frames:
345
+ if frame.shape[:2] != (height, width):
346
+ raise ValueError("All frames must share the same spatial resolution.")
347
+ writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
348
+ writer.release()
349
+ return out_path
350
+
351
  def choose_device() -> torch.device:
352
  preference = os.getenv("UNSAMV2_DEVICE", "auto").lower()
353
  if preference == "cpu":
 
409
  ) -> np.ndarray:
410
  canvas_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
411
  if mask is not None:
412
+ mask_bool = match_mask_to_image(mask, image)
413
  overlay = np.zeros_like(canvas_bgr, dtype=np.uint8)
414
  overlay[mask_bool] = MASK_COLOR_BGR
415
  canvas_bgr = np.where(
 
417
  (canvas_bgr * (1.0 - alpha) + overlay * alpha).astype(np.uint8),
418
  canvas_bgr,
419
  )
 
 
 
 
 
 
420
  for (x, y), lbl in zip(points, labels):
421
  color = POINT_COLORS_BGR.get(lbl, (255, 255, 255))
422
  center = (int(round(x)), int(round(y)))
 
547
  return overlay, status
548
 
549
 
550
+ def run_whole_image_segmentation(
551
+ image: Optional[np.ndarray],
552
+ granularity: float,
553
+ pred_iou_thresh: float,
554
+ stability_thresh: float,
555
+ ):
556
+ img = ensure_uint8(image)
557
+ if img is None:
558
+ return None, [], "Upload an image to run whole-image segmentation."
559
+ device = choose_device()
560
+ mask_generator = MODEL_MANAGER.make_auto_mask_generator(
561
+ device=device,
562
+ points_per_side=WHOLE_IMAGE_POINTS_PER_SIDE,
563
+ points_per_batch=128,
564
+ pred_iou_thresh=float(pred_iou_thresh),
565
+ stability_score_thresh=float(stability_thresh),
566
+ mask_threshold=-1.0,
567
+ box_nms_thresh=0.7,
568
+ crop_n_layers=0,
569
+ min_mask_region_area=0,
570
+ use_m2m=USE_M2M_REFINEMENT,
571
+ output_mode="binary_mask",
572
+ )
573
+ try:
574
+ masks = mask_generator.generate(img, gra=float(granularity))
575
+ except Exception as exc:
576
+ LOGGER.exception("Whole-image segmentation failed")
577
+ return None, [], f"Whole-image segmentation failed: {exc}"
578
+ if not masks:
579
+ return img, [], "Mask generator did not return any regions. Try lowering thresholds."
580
+ trimmed = masks[:WHOLE_IMAGE_MAX_MASKS]
581
+ mask_arrays = mask_entries_to_arrays(trimmed)
582
+ overlay = colorize_mask_collection(img, mask_arrays)
583
+ table = summarize_masks(trimmed)
584
+ status = (
585
+ f"Generated {len(trimmed)} masks | granularity={granularity:.2f}, "
586
+ f"IoU≥{pred_iou_thresh:.2f}, stability≥{stability_thresh:.2f}"
587
+ )
588
+ return overlay, table, status
589
 
590
 
591
+ def handle_video_upload(
592
+ video_file: Any,
593
+ current_state: Optional[Dict[str, Any]] = None,
594
+ ):
595
+ if current_state:
596
+ remove_dir_if_exists(current_state.get("frame_dir"))
597
+ state = make_empty_video_state()
598
+ if isinstance(video_file, (list, tuple)):
599
+ video_file = video_file[0] if video_file else None
600
+ video_path = resolve_video_path(video_file)
601
+ if not video_path:
602
+ return (
603
+ gr.update(value=None, visible=False),
604
+ state,
605
+ gr.update(value=0, minimum=0, maximum=0, interactive=False),
606
+ [],
607
+ [],
608
+ 0,
609
+ "Upload a video to start adding clicks.",
610
  )
611
+ try:
612
+ frame_paths, fps, frame_size, frame_dir = extract_video_frames(video_path)
613
+ except Exception as exc:
614
+ LOGGER.exception("Video decoding failed")
615
+ return (
616
+ gr.update(value=None, visible=False),
617
+ state,
618
+ gr.update(value=0, minimum=0, maximum=0, interactive=False),
619
+ [],
620
+ [],
621
+ 0,
622
+ f"Video decoding failed: {exc}",
623
+ )
624
+ state.update(
625
+ {
626
+ "frame_dir": str(frame_dir),
627
+ "frame_paths": [str(p) for p in frame_paths],
628
+ "fps": fps,
629
+ "frame_size": frame_size,
630
+ }
631
+ )
632
+ first_overlay = render_video_overlay(state, 0, [], [])
633
+ slider_update = gr.update(
634
+ value=0,
635
+ minimum=0,
636
+ maximum=len(frame_paths) - 1,
637
+ step=1,
638
+ interactive=True,
639
+ )
640
+ status = f"Loaded video with {len(frame_paths)} frames at {fps:.1f} FPS."
641
+ return (
642
+ gr.update(value=first_overlay, visible=True),
643
+ state,
644
+ slider_update,
645
+ [],
646
+ [],
647
+ 0,
648
+ status,
649
+ )
650
 
 
 
 
651
 
652
+ def handle_video_frame_change(
653
+ frame_idx: int,
654
+ video_state: Dict[str, Any],
655
+ ):
656
+ overlay = render_video_overlay(video_state, frame_idx, [], [])
657
+ if overlay is None:
658
+ return gr.update(), [], [], 0, "Upload a video first."
659
+ safe_idx = int(np.clip(frame_idx, 0, len(video_state.get("frame_paths", [])) - 1))
660
+ status = f"Annotating frame {safe_idx}."
661
+ return overlay, [], [], safe_idx, status
662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
+ def handle_video_click(
665
+ point_mode: str,
666
+ pts: List[Sequence[float]],
667
+ lbls: List[int],
668
+ video_state: Dict[str, Any],
669
+ frame_idx: int,
670
+ evt: gr.SelectData,
671
+ ):
672
+ overlay = render_video_overlay(video_state, frame_idx, pts, lbls)
673
+ if overlay is None:
674
+ return gr.update(), pts, lbls, "Upload a video first."
675
+ if evt.index is None:
676
+ return overlay, pts, lbls, "Couldn't read click position."
677
+ x, y = evt.index
678
+ label = POINT_MODE_TO_LABEL.get(point_mode, 1)
679
+ pts = pts + [[float(x), float(y)]]
680
+ lbls = lbls + [label]
681
+ overlay = render_video_overlay(video_state, frame_idx, pts, lbls)
682
+ status = (
683
+ f"Added {'positive' if label == 1 else 'negative'} click at "
684
+ f"({int(x)}, {int(y)}) on frame {int(frame_idx)}."
685
+ )
686
+ return overlay, pts, lbls, status
687
 
688
+
689
+ def undo_video_click(
690
+ video_state: Dict[str, Any],
691
+ pts: List[Sequence[float]],
692
+ lbls: List[int],
693
+ frame_idx: int,
694
+ ):
695
+ if not pts:
696
+ return gr.update(), pts, lbls, "No clicks to undo."
697
+ pts = pts[:-1]
698
+ lbls = lbls[:-1]
699
+ overlay = render_video_overlay(video_state, frame_idx, pts, lbls)
700
+ return overlay, pts, lbls, "Removed the last click."
701
+
702
+
703
+ def clear_video_clicks(video_state: Dict[str, Any], frame_idx: int):
704
+ overlay = render_video_overlay(video_state, frame_idx, [], [])
705
+ return overlay, [], [], "Cleared all clicks for the selected frame."
706
+
707
+
708
+ def reset_video_interface(current_state: Dict[str, Any]):
709
+ remove_dir_if_exists(current_state.get("frame_dir"))
710
+ state = make_empty_video_state()
711
+ return (
712
+ gr.update(value=None, visible=False),
713
+ state,
714
+ gr.update(value=0, minimum=0, maximum=0, interactive=False),
715
+ [],
716
+ [],
717
+ 0,
718
+ "Cleared video. Upload a new clip to continue.",
719
+ )
720
+
721
+
722
+ def run_video_segmentation(
723
+ video_state: Dict[str, Any],
724
+ pts: List[Sequence[float]],
725
+ lbls: List[int],
726
+ frame_idx: int,
727
+ granularity: float,
728
+ ):
729
+ frame_paths: List[str] = list(video_state.get("frame_paths", []))
730
+ if not frame_paths:
731
+ return None, "Upload a video to segment."
732
+ if not pts:
733
+ return None, "Add at least one click on the annotation frame."
734
+ frame_dir = video_state.get("frame_dir")
735
+ if not frame_dir:
736
+ return None, "Video frames are unavailable. Please re-upload the video."
737
+ safe_idx = int(np.clip(frame_idx, 0, len(frame_paths) - 1))
738
+ device = choose_device()
739
+ predictor = VIDEO_PREDICTOR_MANAGER.get_predictor(device)
740
+ inference_state = predictor.init_state(video_path=frame_dir)
741
+ predictor.reset_state(inference_state)
742
+ coords = np.asarray(pts, dtype=np.float32)
743
+ labels = np.asarray(lbls, dtype=np.int32)
744
+ try:
745
+ _, obj_ids, mask_logits = predictor.add_new_points_or_box(
746
+ inference_state=inference_state,
747
+ frame_idx=safe_idx,
748
+ obj_id=1,
749
+ points=coords,
750
+ labels=labels,
751
+ gra=float(granularity),
752
  )
753
+ except Exception as exc:
754
+ LOGGER.exception("Video add_new_points_or_box failed")
755
+ return None, f"Video segmentation failed during prompting: {exc}"
756
+ video_masks: Dict[int, Dict[int, np.ndarray]] = {}
757
+ video_masks[safe_idx] = {
758
+ int(obj_id): (mask_logits[i] > -1.0).cpu().numpy()
759
+ for i, obj_id in enumerate(obj_ids)
760
+ }
761
+ try:
762
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
763
+ inference_state,
764
+ gra=float(granularity),
765
+ ):
766
+ video_masks[out_frame_idx] = {
767
+ int(obj_id): (out_mask_logits[i] > -1.0).cpu().numpy()
768
+ for i, obj_id in enumerate(out_obj_ids)
769
+ }
770
+ except Exception as exc:
771
+ LOGGER.exception("Video propagation failed")
772
+ return None, f"Video propagation failed: {exc}"
773
+
774
+ overlays: List[np.ndarray] = []
775
+ for idx, frame_path in enumerate(frame_paths):
776
+ base = load_rgb_image(Path(frame_path))
777
+ mask = video_masks.get(idx, {}).get(1)
778
+ overlays.append(draw_overlay(base, mask, [], []))
779
+ try:
780
+ video_path = write_video_from_frames(overlays, video_state.get("fps", 12.0))
781
+ except Exception as exc:
782
+ LOGGER.exception("Failed to encode output video")
783
+ return None, f"Tracking succeeded but video export failed: {exc}"
784
+
785
+ status = (
786
+ f"Tracked object from frame {safe_idx} across {len(frame_paths)} frames | "
787
+ f"granularity={granularity:.2f}"
788
+ )
789
+ return str(video_path), status
790
 
791
+
792
+ def run_video_frame_segmentation(
793
+ video_state: Dict[str, Any],
794
+ pts: List[Sequence[float]],
795
+ lbls: List[int],
796
+ frame_idx: int,
797
+ granularity: float,
798
+ ):
799
+ frame_paths: List[str] = list(video_state.get("frame_paths", []))
800
+ if not frame_paths:
801
+ return None, "Upload a video to segment."
802
+ if not pts:
803
+ return None, "Add at least one click on the annotation frame."
804
+ frame_dir = video_state.get("frame_dir")
805
+ if not frame_dir:
806
+ return None, "Video frames are unavailable. Please re-upload the video."
807
+ safe_idx = int(np.clip(frame_idx, 0, len(frame_paths) - 1))
808
+ device = choose_device()
809
+ predictor = VIDEO_PREDICTOR_MANAGER.get_predictor(device)
810
+ inference_state = predictor.init_state(video_path=frame_dir)
811
+ predictor.reset_state(inference_state)
812
+ coords = np.asarray(pts, dtype=np.float32)
813
+ labels = np.asarray(lbls, dtype=np.int32)
814
+ try:
815
+ _, obj_ids, mask_logits = predictor.add_new_points_or_box(
816
+ inference_state=inference_state,
817
+ frame_idx=safe_idx,
818
+ obj_id=1,
819
+ points=coords,
820
+ labels=labels,
821
+ gra=float(granularity),
822
  )
823
+ except Exception as exc:
824
+ LOGGER.exception("Video frame segmentation failed")
825
+ return None, f"Frame segmentation failed: {exc}"
826
+ if not obj_ids:
827
+ return None, "Predictor did not return a mask for this frame."
828
+ mask = (mask_logits[0] > -1.0).cpu().numpy()
829
+ base = load_rgb_image(Path(frame_paths[safe_idx]))
830
+ overlay = draw_overlay(base, mask, pts, lbls)
831
+ status = (
832
+ f"Segmented frame {safe_idx} with {len(pts)} clicks | "
833
+ f"granularity={granularity:.2f}"
834
+ )
835
+ return overlay, status
836
+
837
 
838
+ if spaces is not None and ZERO_GPU_ENABLED:
839
+ segment_fn = spaces.GPU(duration=ZERO_GPU_DURATION)(_run_segmentation)
840
+ else:
841
+ segment_fn = _run_segmentation
842
+
843
+
844
+ def build_demo() -> gr.Blocks:
845
+ with gr.Blocks(title="UnSAMv2 Interactive + Whole Image + Video", theme=gr.themes.Soft()) as demo:
846
+ gr.Markdown(
847
+ """
848
+ <div style="text-align:center">
849
+ <h2>UnSAMv2 · Segment Anything at Any Granularity</h2>
850
+ </div>
851
+ """
852
  )
853
 
854
+ gr.HTML(
855
+ """
856
+ <style>
857
+ #mode-tabs button[role="tab"] {
858
+ flex: 0 0 auto;
859
+ min-width: 160px;
860
+ }
861
+ #mode-tabs [role="tablist"],
862
+ #mode-tabs .tab-nav,
863
+ #mode-tabs > div:first-child {
864
+ display: flex !important;
865
+ justify-content: center !important;
866
+ gap: 0.75rem;
867
+ }
868
+ </style>
869
+ """
870
  )
871
 
872
+ with gr.Tabs(elem_id="mode-tabs"):
873
+ # Interactive Image Tab
874
+ with gr.Tab("Interactive Image Segmentation"):
875
+ image_state = gr.State(DEFAULT_IMAGE)
876
+ points_state = gr.State([])
877
+ labels_state = gr.State([])
878
+
879
+ image_input = gr.Image(
880
+ label="Image · clicks & mask",
881
+ type="numpy",
882
+ height=480,
883
+ value=DEFAULT_IMAGE,
884
+ sources=["upload"],
885
+ )
886
+
887
+ with gr.Row(equal_height=True):
888
+ point_mode = gr.Radio(
889
+ choices=list(POINT_MODE_TO_LABEL.keys()),
890
+ value="Foreground (+)",
891
+ label="Click type",
892
+ )
893
+ granularity_slider = gr.Slider(
894
+ minimum=GRANULARITY_MIN,
895
+ maximum=GRANULARITY_MAX,
896
+ value=0.2,
897
+ step=0.01,
898
+ label="Granularity",
899
+ info="Lower = finer details, Higher = coarser regions",
900
+ )
901
+ segment_button = gr.Button("Segment", variant="primary")
902
+
903
+ with gr.Row():
904
+ undo_button = gr.Button("Undo last click")
905
+ clear_button = gr.Button("Clear clicks")
906
+
907
+ status_markdown = gr.Markdown(" Ready for interactive clicks.")
908
+
909
+ image_input.upload(
910
+ handle_image_upload,
911
+ inputs=[image_input],
912
+ outputs=[
913
+ image_input,
914
+ image_state,
915
+ points_state,
916
+ labels_state,
917
+ status_markdown,
918
+ ],
919
+ )
920
+
921
+ image_input.clear(
922
+ handle_image_upload,
923
+ inputs=[image_input],
924
+ outputs=[
925
+ image_input,
926
+ image_state,
927
+ points_state,
928
+ labels_state,
929
+ status_markdown,
930
+ ],
931
+ )
932
+
933
+ image_input.select(
934
+ handle_click,
935
+ inputs=[
936
+ point_mode,
937
+ points_state,
938
+ labels_state,
939
+ image_state,
940
+ ],
941
+ outputs=[
942
+ image_input,
943
+ points_state,
944
+ labels_state,
945
+ status_markdown,
946
+ ],
947
+ )
948
+
949
+ undo_button.click(
950
+ undo_last_click,
951
+ inputs=[image_state, points_state, labels_state],
952
+ outputs=[
953
+ image_input,
954
+ points_state,
955
+ labels_state,
956
+ status_markdown,
957
+ ],
958
+ )
959
+
960
+ clear_button.click(
961
+ clear_clicks,
962
+ inputs=[image_state],
963
+ outputs=[
964
+ image_input,
965
+ points_state,
966
+ labels_state,
967
+ status_markdown,
968
+ ],
969
+ )
970
+
971
+ segment_button.click(
972
+ segment_fn,
973
+ inputs=[image_state, points_state, labels_state, granularity_slider],
974
+ outputs=[image_input, status_markdown],
975
+ )
976
+
977
+ # Whole Image Tab
978
+ with gr.Tab("Whole Image Segmentation"):
979
+ whole_image_input = gr.Image(
980
+ label="Image · automatic masks",
981
+ type="numpy",
982
+ height=480,
983
+ value=WHOLE_IMAGE_DEFAULT if WHOLE_IMAGE_DEFAULT is not None else DEFAULT_IMAGE,
984
+ sources=["upload"],
985
+ )
986
+ whole_granularity = gr.Slider(
987
+ minimum=GRANULARITY_MIN,
988
+ maximum=GRANULARITY_MAX,
989
+ value=0.15,
990
+ step=0.01,
991
+ label="Granularity",
992
+ )
993
+ whole_generate_btn = gr.Button("Generate masks", variant="primary")
994
+ with gr.Accordion("Advanced mask filtering", open=False):
995
+ pred_iou_thresh = gr.Slider(
996
+ minimum=0.1,
997
+ maximum=0.99,
998
+ value=0.77,
999
+ step=0.01,
1000
+ label="Predicted IoU threshold",
1001
+ )
1002
+ stability_thresh = gr.Slider(
1003
+ minimum=0.1,
1004
+ maximum=0.99,
1005
+ value=0.9,
1006
+ step=0.01,
1007
+ label="Stability threshold",
1008
+ )
1009
+
1010
+ whole_overlay = gr.Image(label="Mask overlay", height=480)
1011
+ whole_table = gr.Dataframe(
1012
+ headers=["mask", "area", "pred_iou", "stability"],
1013
+ datatype=["number", "number", "number", "number"],
1014
+ label="Mask stats",
1015
+ wrap=True,
1016
+ visible=False,
1017
+ )
1018
+ whole_status = gr.Markdown(" Ready for whole-image masks.")
1019
+
1020
+ whole_generate_btn.click(
1021
+ run_whole_image_segmentation,
1022
+ inputs=[
1023
+ whole_image_input,
1024
+ whole_granularity,
1025
+ pred_iou_thresh,
1026
+ stability_thresh,
1027
+ ],
1028
+ outputs=[whole_overlay, whole_table, whole_status],
1029
+ )
1030
+
1031
+ # Video Tab
1032
+ with gr.Tab("Video Segmentation"):
1033
+ video_state = gr.State(make_empty_video_state())
1034
+ video_points_state = gr.State([])
1035
+ video_labels_state = gr.State([])
1036
+ annotation_frame_state = gr.State(0)
1037
+
1038
+ with gr.Row(equal_height=True):
1039
+ with gr.Column(scale=1, min_width=360):
1040
+ upload_button = gr.UploadButton(
1041
+ "Upload video",
1042
+ file_types=["video"],
1043
+ file_count="single",
1044
+ )
1045
+ frame_display = gr.Image(
1046
+ label="Video · add clicks",
1047
+ type="numpy",
1048
+ height=420,
1049
+ interactive=True,
1050
+ visible=False,
1051
+ )
1052
+ frame_slider = gr.Slider(
1053
+ minimum=0,
1054
+ maximum=0,
1055
+ value=0,
1056
+ step=1,
1057
+ interactive=False,
1058
+ label="Select frame",
1059
+ )
1060
+ video_point_mode = gr.Radio(
1061
+ choices=list(POINT_MODE_TO_LABEL.keys()),
1062
+ value="Foreground (+)",
1063
+ label="Click type",
1064
+ )
1065
+ with gr.Row():
1066
+ video_undo = gr.Button("Undo click")
1067
+ video_clear = gr.Button("Clear clicks")
1068
+ video_granularity = gr.Slider(
1069
+ minimum=GRANULARITY_MIN,
1070
+ maximum=GRANULARITY_MAX,
1071
+ value=0.33,
1072
+ step=0.01,
1073
+ label="Granularity",
1074
+ )
1075
+ with gr.Row():
1076
+ video_frame_btn = gr.Button("Segment frame", variant="secondary")
1077
+ video_segment_btn = gr.Button("Propagate video", variant="primary")
1078
+
1079
+ with gr.Column(scale=1, min_width=320):
1080
+ video_output = gr.Video(
1081
+ label="Segmented preview",
1082
+ autoplay=False,
1083
+ height=420,
1084
+ )
1085
+
1086
+ video_status = gr.Markdown(" Ready for video segmentation.")
1087
+
1088
+ upload_button.upload(
1089
+ handle_video_upload,
1090
+ inputs=[upload_button, video_state],
1091
+ outputs=[
1092
+ frame_display,
1093
+ video_state,
1094
+ frame_slider,
1095
+ video_points_state,
1096
+ video_labels_state,
1097
+ annotation_frame_state,
1098
+ video_status,
1099
+ ],
1100
+ )
1101
+
1102
+ if DEFAULT_VIDEO_PATH.exists():
1103
+ def _load_default_video(state):
1104
+ return handle_video_upload(str(DEFAULT_VIDEO_PATH), state)
1105
+
1106
+ demo.load(
1107
+ _load_default_video,
1108
+ inputs=[video_state],
1109
+ outputs=[
1110
+ frame_display,
1111
+ video_state,
1112
+ frame_slider,
1113
+ video_points_state,
1114
+ video_labels_state,
1115
+ annotation_frame_state,
1116
+ video_status,
1117
+ ],
1118
+ queue=False,
1119
+ )
1120
+
1121
+ frame_slider.change(
1122
+ handle_video_frame_change,
1123
+ inputs=[frame_slider, video_state],
1124
+ outputs=[
1125
+ frame_display,
1126
+ video_points_state,
1127
+ video_labels_state,
1128
+ annotation_frame_state,
1129
+ video_status,
1130
+ ],
1131
+ )
1132
+
1133
+ frame_display.select(
1134
+ handle_video_click,
1135
+ inputs=[
1136
+ video_point_mode,
1137
+ video_points_state,
1138
+ video_labels_state,
1139
+ video_state,
1140
+ annotation_frame_state,
1141
+ ],
1142
+ outputs=[
1143
+ frame_display,
1144
+ video_points_state,
1145
+ video_labels_state,
1146
+ video_status,
1147
+ ],
1148
+ )
1149
+
1150
+ frame_display.clear(
1151
+ reset_video_interface,
1152
+ inputs=[video_state],
1153
+ outputs=[
1154
+ frame_display,
1155
+ video_state,
1156
+ frame_slider,
1157
+ video_points_state,
1158
+ video_labels_state,
1159
+ annotation_frame_state,
1160
+ video_status,
1161
+ ],
1162
+ )
1163
+
1164
+ video_frame_btn.click(
1165
+ run_video_frame_segmentation,
1166
+ inputs=[
1167
+ video_state,
1168
+ video_points_state,
1169
+ video_labels_state,
1170
+ annotation_frame_state,
1171
+ video_granularity,
1172
+ ],
1173
+ outputs=[frame_display, video_status],
1174
+ )
1175
+
1176
+ video_undo.click(
1177
+ undo_video_click,
1178
+ inputs=[
1179
+ video_state,
1180
+ video_points_state,
1181
+ video_labels_state,
1182
+ annotation_frame_state,
1183
+ ],
1184
+ outputs=[
1185
+ frame_display,
1186
+ video_points_state,
1187
+ video_labels_state,
1188
+ video_status,
1189
+ ],
1190
+ )
1191
+
1192
+ video_clear.click(
1193
+ clear_video_clicks,
1194
+ inputs=[video_state, annotation_frame_state],
1195
+ outputs=[
1196
+ frame_display,
1197
+ video_points_state,
1198
+ video_labels_state,
1199
+ video_status,
1200
+ ],
1201
+ )
1202
+
1203
+ video_segment_btn.click(
1204
+ run_video_segmentation,
1205
+ inputs=[
1206
+ video_state,
1207
+ video_points_state,
1208
+ video_labels_state,
1209
+ annotation_frame_state,
1210
+ video_granularity,
1211
+ ],
1212
+ outputs=[video_output, video_status],
1213
+ )
1214
+
1215
  demo.queue(max_size=8)
1216
  return demo
1217
 
demo/bedroom.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1be76d5d19b066e8ad7c565d88a98e11a8f8d456a707508a7aa35390def70e30
3
+ size 2380401
demo/sa_291195.jpg ADDED

Git LFS Details

  • SHA256: 35ad56b5cd80355dcdb135d4df64439fcf0338bf44133418015f4ea6c214b4ab
  • Pointer size: 131 Bytes
  • Size of remote file: 666 kB
sam2/sam2/__pycache__/automatic_mask_generator.cpython-310.pyc ADDED
Binary file (13.9 kB). View file
 
sam2/sam2/__pycache__/sam2_video_predictor.cpython-310.pyc ADDED
Binary file (25 kB). View file
 
sam2/sam2/utils/__pycache__/amg.cpython-310.pyc ADDED
Binary file (12.1 kB). View file