Spaces:
Running
on
Zero
Running
on
Zero
| import colorsys | |
| import gc | |
| from copy import deepcopy | |
| from typing import Optional | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from gradio.themes import Soft | |
| from PIL import Image, ImageDraw | |
| from transformers import AutoModel, Sam2VideoProcessor | |
| def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]: | |
| """Generate a deterministic pastel RGB color for a given object id. | |
| Uses golden ratio to distribute hues; low-medium saturation, high value. | |
| """ | |
| golden_ratio_conjugate = 0.61803398875 | |
| # Map obj_id (1-based) to hue in [0,1) | |
| hue = (obj_id * golden_ratio_conjugate) % 1.0 | |
| saturation = 0.45 | |
| value = 1.0 | |
| r_f, g_f, b_f = colorsys.hsv_to_rgb(hue, saturation, value) | |
| return int(r_f * 255), int(g_f * 255), int(b_f * 255) | |
| def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]: | |
| """Load video frames as PIL Images using transformers.video_utils if available, | |
| otherwise fall back to OpenCV. Returns (frames, info). | |
| """ | |
| cap = cv2.VideoCapture(video_path_or_url) | |
| frames = [] | |
| print("loading video frames") | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| # Gather fps if available | |
| fps_val = cap.get(cv2.CAP_PROP_FPS) | |
| cap.release() | |
| print("loaded video frames") | |
| info = { | |
| "num_frames": len(frames), | |
| "fps": float(fps_val) if fps_val and fps_val > 0 else None, | |
| } | |
| return frames, info | |
| def overlay_masks_on_frame( | |
| frame: Image.Image, | |
| masks_per_object: dict[int, np.ndarray], | |
| color_by_obj: dict[int, tuple[int, int, int]], | |
| alpha: float = 0.5, | |
| ) -> Image.Image: | |
| """Overlay per-object soft masks onto the RGB frame. | |
| masks_per_object: mapping of obj_id -> (H, W) float mask in [0,1] | |
| color_by_obj: mapping of obj_id -> (R, G, B) | |
| """ | |
| base = np.array(frame).astype(np.float32) / 255.0 # H, W, 3 in [0,1] | |
| height, width = base.shape[:2] | |
| overlay = base.copy() | |
| for obj_id, mask in masks_per_object.items(): | |
| if mask is None: | |
| continue | |
| if mask.dtype != np.float32: | |
| mask = mask.astype(np.float32) | |
| # Ensure shape is H x W | |
| if mask.ndim == 3: | |
| mask = mask.squeeze() | |
| mask = np.clip(mask, 0.0, 1.0) | |
| color = np.array(color_by_obj.get(obj_id, (255, 0, 0)), dtype=np.float32) / 255.0 | |
| # Blend: overlay = (1 - a*m)*overlay + (a*m)*color | |
| a = alpha | |
| m = mask[..., None] | |
| overlay = (1.0 - a * m) * overlay + (a * m) * color | |
| out = np.clip(overlay * 255.0, 0, 255).astype(np.uint8) | |
| return Image.fromarray(out) | |
| def get_device_and_dtype() -> tuple[str, torch.dtype]: | |
| device = "cpu" | |
| dtype = torch.bfloat16 | |
| return device, dtype | |
| class AppState: | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.video_frames: list[Image.Image] = [] | |
| self.inference_session = None | |
| self.model: Optional[AutoModel] = None | |
| self.processor: Optional[Sam2VideoProcessor] = None | |
| self.device: str = "cpu" | |
| self.dtype: torch.dtype = torch.bfloat16 | |
| self.video_fps: float | None = None | |
| self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {} | |
| self.color_by_obj: dict[int, tuple[int, int, int]] = {} | |
| self.clicks_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int]]]] = {} | |
| self.boxes_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int, int]]]] = {} | |
| # Cache of composited frames (original + masks + clicks) | |
| self.composited_frames: dict[int, Image.Image] = {} | |
| # UI state for click handler | |
| self.current_frame_idx: int = 0 | |
| self.current_obj_id: int = 1 | |
| self.current_label: str = "positive" | |
| self.current_clear_old: bool = True | |
| self.current_prompt_type: str = "Points" # or "Boxes" | |
| self.pending_box_start: tuple[int, int] | None = None | |
| self.pending_box_start_frame_idx: int | None = None | |
| self.pending_box_start_obj_id: int | None = None | |
| self.is_switching_model: bool = False | |
| # Model selection | |
| self.model_repo_key: str = "tiny" | |
| self.model_repo_id: str | None = None | |
| self.session_repo_id: str | None = None | |
| def __repr__(self): | |
| return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})" | |
| def num_frames(self) -> int: | |
| return len(self.video_frames) | |
| def _model_repo_from_key(key: str) -> str: | |
| mapping = { | |
| "tiny": "facebook/sam2.1-hiera-tiny", | |
| "small": "facebook/sam2.1-hiera-small", | |
| "base_plus": "facebook/sam2.1-hiera-base-plus", | |
| "large": "facebook/sam2.1-hiera-large", | |
| } | |
| return mapping.get(key, mapping["base_plus"]) | |
| def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoProcessor, str, torch.dtype]: | |
| desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key) | |
| if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None: | |
| if GLOBAL_STATE.model_repo_id == desired_repo: | |
| return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype | |
| # Different repo requested: dispose current and reload | |
| GLOBAL_STATE.model = None | |
| GLOBAL_STATE.processor = None | |
| print(f"Loading model from {desired_repo}") | |
| device, dtype = get_device_and_dtype() | |
| # free up the gpu memory | |
| model = AutoModel.from_pretrained(desired_repo) | |
| processor = Sam2VideoProcessor.from_pretrained(desired_repo) | |
| model.to(device, dtype=dtype) | |
| GLOBAL_STATE.model = model | |
| GLOBAL_STATE.processor = processor | |
| GLOBAL_STATE.device = device | |
| GLOBAL_STATE.dtype = dtype | |
| GLOBAL_STATE.model_repo_id = desired_repo | |
| def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None: | |
| """Ensure the model/processor match the selected repo and inference_session exists. | |
| If a video is already loaded, re-initialize the inference session when needed. | |
| """ | |
| load_model_if_needed(GLOBAL_STATE) | |
| desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key) | |
| if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo: | |
| if GLOBAL_STATE.video_frames: | |
| # Clear session-related UI caches when switching model | |
| GLOBAL_STATE.masks_by_frame.clear() | |
| GLOBAL_STATE.clicks_by_frame_obj.clear() | |
| GLOBAL_STATE.boxes_by_frame_obj.clear() | |
| GLOBAL_STATE.composited_frames.clear() | |
| GLOBAL_STATE.inference_session = None | |
| GLOBAL_STATE.inference_session = GLOBAL_STATE.processor.init_video_session( | |
| inference_device=GLOBAL_STATE.device, | |
| video_storage_device="cpu", | |
| dtype=GLOBAL_STATE.dtype, | |
| ) | |
| GLOBAL_STATE.session_repo_id = desired_repo | |
| def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppState, int, int, Image.Image, str]: | |
| """Gradio handler: load video, init session, return state, slider bounds, and first frame.""" | |
| # Reset ONLY video-related fields, keep model loaded | |
| GLOBAL_STATE.video_frames = [] | |
| GLOBAL_STATE.inference_session = None | |
| GLOBAL_STATE.masks_by_frame = {} | |
| GLOBAL_STATE.color_by_obj = {} | |
| load_model_if_needed(GLOBAL_STATE) | |
| # Gradio Video may provide a dict with 'name' or a direct file path | |
| video_path: Optional[str] = None | |
| if isinstance(video, dict): | |
| video_path = video.get("name") or video.get("path") or video.get("data") | |
| elif isinstance(video, str): | |
| video_path = video | |
| else: | |
| video_path = None | |
| if not video_path: | |
| raise gr.Error("Invalid video input.") | |
| frames, info = try_load_video_frames(video_path) | |
| if len(frames) == 0: | |
| raise gr.Error("No frames could be loaded from the video.") | |
| # Enforce max duration of 8 seconds (trim if longer) | |
| MAX_SECONDS = 8.0 | |
| trimmed_note = "" | |
| fps_in = info.get("fps") | |
| max_frames_allowed = int(MAX_SECONDS * fps_in) | |
| if len(frames) > max_frames_allowed: | |
| frames = frames[:max_frames_allowed] | |
| trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)" | |
| if isinstance(info, dict): | |
| info["num_frames"] = len(frames) | |
| GLOBAL_STATE.video_frames = frames | |
| # Try to capture original FPS if provided by loader | |
| GLOBAL_STATE.video_fps = float(fps_in) | |
| # Initialize session | |
| inference_session = GLOBAL_STATE.processor.init_video_session( | |
| inference_device=GLOBAL_STATE.device, | |
| video_storage_device="cpu", | |
| dtype=GLOBAL_STATE.dtype, | |
| ) | |
| GLOBAL_STATE.inference_session = inference_session | |
| first_frame = frames[0] | |
| max_idx = len(frames) - 1 | |
| status = ( | |
| f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. " | |
| f"Device: {GLOBAL_STATE.device}, dtype: bfloat16" | |
| ) | |
| return GLOBAL_STATE, 0, max_idx, first_frame, status | |
| def compose_frame(state: AppState, frame_idx: int) -> Image.Image: | |
| if state is None or state.video_frames is None or len(state.video_frames) == 0: | |
| return None | |
| frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) | |
| frame = state.video_frames[frame_idx] | |
| masks = state.masks_by_frame.get(frame_idx, {}) | |
| out_img = frame | |
| if len(masks) != 0: | |
| out_img = overlay_masks_on_frame(out_img, masks, state.color_by_obj, alpha=0.65) | |
| # Draw crosses for conditioning frames only (frames with recorded clicks) | |
| clicks_map = state.clicks_by_frame_obj.get(frame_idx) | |
| if clicks_map: | |
| draw = ImageDraw.Draw(out_img) | |
| cross_half = 6 | |
| for obj_id, pts in clicks_map.items(): | |
| for x, y, lbl in pts: | |
| color = (0, 255, 0) if int(lbl) == 1 else (255, 0, 0) | |
| # horizontal | |
| draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2) | |
| # vertical | |
| draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2) | |
| # Draw temporary cross for first corner in box mode | |
| if ( | |
| state.pending_box_start is not None | |
| and state.pending_box_start_frame_idx == frame_idx | |
| and state.pending_box_start_obj_id is not None | |
| ): | |
| draw = ImageDraw.Draw(out_img) | |
| x, y = state.pending_box_start | |
| cross_half = 6 | |
| color = state.color_by_obj.get(state.pending_box_start_obj_id, (255, 255, 255)) | |
| draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2) | |
| draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2) | |
| # Draw boxes for conditioning frames | |
| box_map = state.boxes_by_frame_obj.get(frame_idx) | |
| if box_map: | |
| draw = ImageDraw.Draw(out_img) | |
| for obj_id, boxes in box_map.items(): | |
| color = state.color_by_obj.get(obj_id, (255, 255, 255)) | |
| for x1, y1, x2, y2 in boxes: | |
| draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2) | |
| # Save to cache and return | |
| state.composited_frames[frame_idx] = out_img | |
| return out_img | |
| def update_frame_display(state: AppState, frame_idx: int) -> Image.Image: | |
| if state is None or state.video_frames is None or len(state.video_frames) == 0: | |
| return None | |
| frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) | |
| # Serve from cache when available | |
| cached = state.composited_frames.get(frame_idx) | |
| if cached is not None: | |
| return cached | |
| return compose_frame(state, frame_idx) | |
| def _ensure_color_for_obj(state: AppState, obj_id: int): | |
| if obj_id not in state.color_by_obj: | |
| state.color_by_obj[obj_id] = pastel_color_for_object(obj_id) | |
| def on_image_click( | |
| img: Image.Image | np.ndarray, | |
| state: AppState, | |
| frame_idx: int, | |
| obj_id: int, | |
| label: str, | |
| clear_old: bool, | |
| evt: gr.SelectData, | |
| ) -> Image.Image: | |
| if state is None or state.inference_session is None: | |
| return img # no-op preview when not ready | |
| if state.is_switching_model: | |
| # Gracefully ignore input during model switch; return current preview unchanged | |
| return update_frame_display(state, int(frame_idx)) | |
| # Parse click coordinates from event | |
| x = y = None | |
| if evt is not None: | |
| # Try different gradio event data shapes for robustness | |
| try: | |
| if hasattr(evt, "index") and isinstance(evt.index, (list, tuple)) and len(evt.index) == 2: | |
| x, y = int(evt.index[0]), int(evt.index[1]) | |
| elif hasattr(evt, "value") and isinstance(evt.value, dict) and "x" in evt.value and "y" in evt.value: | |
| x, y = int(evt.value["x"]), int(evt.value["y"]) | |
| except Exception: | |
| x = y = None | |
| if x is None or y is None: | |
| raise gr.Error("Could not read click coordinates.") | |
| _ensure_color_for_obj(state, int(obj_id)) | |
| processor = state.processor | |
| model = state.model | |
| inference_session = state.inference_session | |
| original_size = None | |
| pixel_values = None | |
| if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames: | |
| inputs = processor(images=state.video_frames[frame_idx], device=state.device, return_tensors="pt") | |
| original_size = inputs.original_sizes[0] | |
| pixel_values = inputs.pixel_values[0] | |
| if state.current_prompt_type == "Boxes": | |
| # Two-click box input | |
| if state.pending_box_start is None: | |
| # For boxes, always clear old inputs (points) for this object on this frame | |
| frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {}) | |
| frame_clicks[int(obj_id)] = [] | |
| state.composited_frames.pop(int(frame_idx), None) | |
| state.pending_box_start = (int(x), int(y)) | |
| state.pending_box_start_frame_idx = int(frame_idx) | |
| state.pending_box_start_obj_id = int(obj_id) | |
| # Invalidate cache so temporary cross is drawn | |
| state.composited_frames.pop(int(frame_idx), None) | |
| return update_frame_display(state, int(frame_idx)) | |
| else: | |
| x1, y1 = state.pending_box_start | |
| x2, y2 = int(x), int(y) | |
| # Clear temporary state and invalidate cache | |
| state.pending_box_start = None | |
| state.pending_box_start_frame_idx = None | |
| state.pending_box_start_obj_id = None | |
| state.composited_frames.pop(int(frame_idx), None) | |
| x_min, y_min = min(x1, x2), min(y1, y2) | |
| x_max, y_max = max(x1, x2), max(y1, y2) | |
| processor.add_inputs_to_inference_session( | |
| inference_session=inference_session, | |
| frame_idx=int(frame_idx), | |
| obj_ids=int(obj_id), | |
| input_boxes=[[[x_min, y_min, x_max, y_max]]], | |
| clear_old_inputs=True, # For boxes, always clear old inputs | |
| original_size=original_size, | |
| ) | |
| frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {}) | |
| obj_boxes = frame_boxes.setdefault(int(obj_id), []) | |
| # For boxes, always clear old inputs | |
| obj_boxes.clear() | |
| obj_boxes.append((x_min, y_min, x_max, y_max)) | |
| state.composited_frames.pop(int(frame_idx), None) | |
| else: | |
| # Points mode | |
| label_int = 1 if str(label).lower().startswith("pos") else 0 | |
| # If clear_old is enabled, clear prior boxes for this object on this frame | |
| if bool(clear_old): | |
| frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {}) | |
| frame_boxes[int(obj_id)] = [] | |
| state.composited_frames.pop(int(frame_idx), None) | |
| processor.add_inputs_to_inference_session( | |
| inference_session=inference_session, | |
| frame_idx=int(frame_idx), | |
| obj_ids=int(obj_id), | |
| input_points=[[[[int(x), int(y)]]]], | |
| input_labels=[[[int(label_int)]]], | |
| original_size=original_size, | |
| clear_old_inputs=bool(clear_old), | |
| ) | |
| frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {}) | |
| obj_clicks = frame_clicks.setdefault(int(obj_id), []) | |
| if bool(clear_old): | |
| obj_clicks.clear() | |
| obj_clicks.append((int(x), int(y), int(label_int))) | |
| state.composited_frames.pop(int(frame_idx), None) | |
| # Forward on that frame | |
| with torch.inference_mode(): | |
| outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=int(frame_idx)) | |
| H = inference_session.video_height | |
| W = inference_session.video_width | |
| # Detach and move off GPU as early as possible to reduce GPU memory pressure | |
| pred_masks = outputs.pred_masks.detach().cpu() | |
| video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] | |
| # Map returned masks to object ids. For single object forward, it's [1, 1, H, W] | |
| # But to be safe, iterate over session.obj_ids order. | |
| masks_for_frame: dict[int, np.ndarray] = {} | |
| obj_ids_order = list(inference_session.obj_ids) | |
| for i, oid in enumerate(obj_ids_order): | |
| mask_i = video_res_masks[i] | |
| # mask_i shape could be (1, H, W) or (H, W); squeeze to 2D | |
| mask_2d = mask_i.cpu().numpy().squeeze() | |
| masks_for_frame[int(oid)] = mask_2d | |
| state.masks_by_frame[int(frame_idx)] = masks_for_frame | |
| # Invalidate cache for this frame to force recomposition | |
| state.composited_frames.pop(int(frame_idx), None) | |
| # Return updated preview | |
| return update_frame_display(state, int(frame_idx)) | |
| def propagate_masks(GLOBAL_STATE: gr.State): | |
| if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None: | |
| # yield GLOBAL_STATE, "Load a video first.", gr.update() | |
| return GLOBAL_STATE, "Load a video first.", gr.update() | |
| processor = deepcopy(GLOBAL_STATE.processor) | |
| model = deepcopy(GLOBAL_STATE.model) | |
| inference_session = deepcopy(GLOBAL_STATE.inference_session) | |
| # set inference device to cuda to use zero gpu | |
| inference_session.inference_device = "cuda" | |
| inference_session.cache.inference_device = "cuda" | |
| model.to("cuda") | |
| total = max(1, GLOBAL_STATE.num_frames) | |
| processed = 0 | |
| # Initial status; no slider change yet | |
| yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update() | |
| last_frame_idx = 0 | |
| with torch.inference_mode(): | |
| for frame_idx, frame in enumerate(GLOBAL_STATE.video_frames): | |
| pixel_values = None | |
| if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames: | |
| pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0] | |
| sam2_video_output = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx) | |
| H = inference_session.video_height | |
| W = inference_session.video_width | |
| pred_masks = sam2_video_output.pred_masks.detach().cpu() | |
| video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] | |
| last_frame_idx = frame_idx | |
| masks_for_frame: dict[int, np.ndarray] = {} | |
| obj_ids_order = list(inference_session.obj_ids) | |
| for i, oid in enumerate(obj_ids_order): | |
| mask_2d = video_res_masks[i].cpu().numpy().squeeze() | |
| masks_for_frame[int(oid)] = mask_2d | |
| GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame | |
| # Invalidate cache for that frame to force recomposition | |
| GLOBAL_STATE.composited_frames.pop(frame_idx, None) | |
| processed += 1 | |
| # Every 15th frame (or last), move slider to current frame to update preview via slider binding | |
| if processed % 30 == 0 or processed == total: | |
| yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx) | |
| text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects." | |
| # Final status; ensure slider points to last processed frame | |
| yield GLOBAL_STATE, text, gr.update(value=last_frame_idx) | |
| def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]: | |
| # Reset only session-related state, keep uploaded video and model | |
| if not GLOBAL_STATE.video_frames: | |
| # Nothing loaded; keep behavior | |
| return GLOBAL_STATE, None, 0, 0, "Session reset. Load a new video." | |
| # Clear prompts and caches | |
| GLOBAL_STATE.masks_by_frame.clear() | |
| GLOBAL_STATE.clicks_by_frame_obj.clear() | |
| GLOBAL_STATE.boxes_by_frame_obj.clear() | |
| GLOBAL_STATE.composited_frames.clear() | |
| GLOBAL_STATE.pending_box_start = None | |
| GLOBAL_STATE.pending_box_start_frame_idx = None | |
| GLOBAL_STATE.pending_box_start_obj_id = None | |
| # Dispose and re-init inference session for current model with existing frames | |
| try: | |
| if GLOBAL_STATE.inference_session is not None: | |
| GLOBAL_STATE.inference_session.reset_inference_session() | |
| except Exception: | |
| pass | |
| GLOBAL_STATE.inference_session = None | |
| gc.collect() | |
| ensure_session_for_current_model(GLOBAL_STATE) | |
| # Keep current slider index if possible | |
| current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0)) | |
| current_idx = max(0, min(current_idx, GLOBAL_STATE.num_frames - 1)) | |
| preview_img = update_frame_display(GLOBAL_STATE, current_idx) | |
| slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True) | |
| slider_value = gr.update(value=current_idx) | |
| status = "Session reset. Prompts cleared; video preserved." | |
| # clear and reload model and processor | |
| return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status | |
| theme = Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate") | |
| with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", theme=theme) as demo: | |
| GLOBAL_STATE = gr.State(AppState()) | |
| gr.Markdown( | |
| """ | |
| ### SAM2 Video Tracking · powered by Hugging Face 🤗 Transformers | |
| Segment and track objects across a video with SAM2 (Segment Anything 2). This demo runs the official implementation from the Hugging Face Transformers library for interactive, promptable video segmentation. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Quick start** | |
| - **Load a video**: Upload your own or pick an example below. | |
| - **Checkpoint**: Tiny / Small / Base+ / Large (trade speed vs. accuracy). | |
| - **Points mode**: Select an Object ID and point label (positive/negative), then click the frame to add guidance. You can add **multiple points per object** and define **multiple objects** across frames. | |
| - **Boxes mode**: Click two opposite corners to draw a box. Old inputs for that object are cleared automatically. | |
| """ | |
| ) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Working with results** | |
| - **Preview**: Use the slider to navigate frames and see the current masks. | |
| - **Propagate**: Click “Propagate across video” to track all defined objects through the entire video. The preview follows progress periodically to keep things responsive. | |
| - **Export**: Render an MP4 for smooth playback using the original video FPS. | |
| - **Note**: More info on the Hugging Face 🤗 Transformers implementation of SAM2 can be found [here](https://huggingface.co/docs/transformers/en/main/en/model_doc/sam2_video). | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| video_in = gr.Video(label="Upload video", sources=["upload", "webcam"], interactive=True) | |
| ckpt_radio = gr.Radio( | |
| choices=["tiny", "small", "base_plus", "large"], | |
| value="tiny", | |
| label="SAM2.1 checkpoint", | |
| ) | |
| ckpt_progress = gr.Markdown(visible=False) | |
| load_status = gr.Markdown(visible=True) | |
| reset_btn = gr.Button("Reset Session", variant="secondary") | |
| with gr.Column(scale=2): | |
| preview = gr.Image(label="Preview", interactive=True) | |
| with gr.Row(): | |
| frame_slider = gr.Slider(label="Frame", minimum=0, maximum=0, step=1, value=0, interactive=True) | |
| with gr.Column(scale=0): | |
| propagate_btn = gr.Button("Propagate across video", variant="primary") | |
| propagate_status = gr.Markdown(visible=True) | |
| with gr.Row(): | |
| obj_id_inp = gr.Number(value=1, precision=0, label="Object ID", scale=0) | |
| label_radio = gr.Radio(choices=["positive", "negative"], value="positive", label="Point label") | |
| clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object") | |
| prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type") | |
| # Wire events | |
| def _on_video_change(GLOBAL_STATE: gr.State, video): | |
| GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video) | |
| return ( | |
| GLOBAL_STATE, | |
| gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True), | |
| first_frame, | |
| status, | |
| ) | |
| video_in.change( | |
| _on_video_change, | |
| inputs=[GLOBAL_STATE, video_in], | |
| outputs=[GLOBAL_STATE, frame_slider, preview, load_status], | |
| show_progress=True, | |
| ) | |
| # (moved) Examples are defined above the render button | |
| # Each example row must match the number of inputs (GLOBAL_STATE, video_in) | |
| examples_list = [ | |
| [None, "./deers.mp4"], | |
| [None, "./penguins.mp4"], | |
| [None, "./foot.mp4"], | |
| ] | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=examples_list, | |
| inputs=[GLOBAL_STATE, video_in], | |
| fn=_on_video_change, | |
| outputs=[GLOBAL_STATE, frame_slider, preview, load_status], | |
| label="Examples", | |
| cache_examples=False, | |
| examples_per_page=5, | |
| ) | |
| # Examples (place before the render MP4 button) — defined after handler below | |
| with gr.Row(): | |
| render_btn = gr.Button("Render MP4 for smooth playback", variant="primary") | |
| playback_video = gr.Video(label="Rendered Playback", interactive=False) | |
| def _on_ckpt_change(s: AppState, key: str): | |
| if s is not None and key: | |
| key = str(key) | |
| if key != s.model_repo_key: | |
| # Update and drop current model to reload lazily next time | |
| s.is_switching_model = True | |
| s.model_repo_key = key | |
| s.model_repo_id = None | |
| s.model = None | |
| s.processor = None | |
| # Stream progress text while loading (first yield shows text) | |
| yield gr.update(visible=True, value=f"Loading checkpoint: {key}...") | |
| ensure_session_for_current_model(s) | |
| if s is not None: | |
| s.is_switching_model = False | |
| # Final yield hides the text | |
| yield gr.update(visible=False, value="") | |
| ckpt_radio.change(_on_ckpt_change, inputs=[GLOBAL_STATE, ckpt_radio], outputs=[ckpt_progress]) | |
| def _sync_frame_idx(state_in: AppState, idx: int): | |
| if state_in is not None: | |
| state_in.current_frame_idx = int(idx) | |
| return update_frame_display(state_in, int(idx)) | |
| frame_slider.change( | |
| _sync_frame_idx, | |
| inputs=[GLOBAL_STATE, frame_slider], | |
| outputs=preview, | |
| ) | |
| def _sync_obj_id(s: AppState, oid): | |
| if s is not None and oid is not None: | |
| s.current_obj_id = int(oid) | |
| return gr.update() | |
| obj_id_inp.change(_sync_obj_id, inputs=[GLOBAL_STATE, obj_id_inp], outputs=[]) | |
| def _sync_label(s: AppState, lab: str): | |
| if s is not None and lab is not None: | |
| s.current_label = str(lab) | |
| return gr.update() | |
| label_radio.change(_sync_label, inputs=[GLOBAL_STATE, label_radio], outputs=[]) | |
| def _sync_prompt_type(s: AppState, val: str): | |
| if s is not None and val is not None: | |
| s.current_prompt_type = str(val) | |
| s.pending_box_start = None | |
| is_points = str(val).lower() == "points" | |
| # Show labels only for points; hide and disable clear_old when boxes | |
| updates = [ | |
| gr.update(visible=is_points), | |
| gr.update(interactive=is_points) if is_points else gr.update(value=True, interactive=False), | |
| ] | |
| return updates | |
| prompt_type.change( | |
| _sync_prompt_type, | |
| inputs=[GLOBAL_STATE, prompt_type], | |
| outputs=[label_radio, clear_old_chk], | |
| ) | |
| # Image click to add a point and run forward on that frame | |
| preview.select( | |
| on_image_click, [preview, GLOBAL_STATE, frame_slider, obj_id_inp, label_radio, clear_old_chk], preview | |
| ) | |
| # Playback via MP4 rendering only | |
| # Render a smooth MP4 using imageio/pyav (fallbacks to imageio v2 / OpenCV) | |
| def _render_video(s: AppState): | |
| if s is None or s.num_frames == 0: | |
| raise gr.Error("Load a video first.") | |
| fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12 | |
| # Compose all frames (cache will help if already prepared) | |
| frames_np = [] | |
| first = compose_frame(s, 0) | |
| h, w = first.size[1], first.size[0] | |
| for idx in range(s.num_frames): | |
| img = s.composited_frames.get(idx) | |
| if img is None: | |
| img = compose_frame(s, idx) | |
| frames_np.append(np.array(img)[:, :, ::-1]) # BGR for cv2 | |
| # Periodically release CPU mem to reduce pressure | |
| if (idx + 1) % 60 == 0: | |
| gc.collect() | |
| out_path = "/tmp/sam2_playback.mp4" | |
| # Prefer imageio with PyAV/ffmpeg to respect exact fps | |
| try: | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) | |
| for fr_bgr in frames_np: | |
| writer.write(fr_bgr) | |
| writer.release() | |
| return out_path | |
| except Exception as e: | |
| print(f"Failed to render video with cv2: {e}") | |
| raise gr.Error(f"Failed to render video: {e}") | |
| render_btn.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video]) | |
| # While propagating, we stream two outputs: status text and slider value updates | |
| propagate_btn.click( | |
| propagate_masks, | |
| inputs=[GLOBAL_STATE], | |
| outputs=[GLOBAL_STATE, propagate_status, frame_slider], | |
| ) | |
| reset_btn.click( | |
| reset_session, | |
| inputs=GLOBAL_STATE, | |
| outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status], | |
| ) | |
| demo.queue(api_open=False).launch() | |