Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """Gradio demo for UnSAMv2 interactive image segmentation with Hugging Face ZeroGPU support.""" | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| import shutil | |
| import sys | |
| import tempfile | |
| import threading | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| try: | |
| import spaces # type: ignore | |
| except ImportError: # pragma: no cover - optional dependency on Spaces runtime | |
| spaces = None | |
| REPO_ROOT = Path(__file__).resolve().parent | |
| SAM2_REPO = REPO_ROOT / "sam2" | |
| if SAM2_REPO.exists(): | |
| sys.path.insert(0, str(SAM2_REPO)) | |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator # noqa: E402 | |
| from sam2.build_sam import build_sam2, build_sam2_video_predictor # noqa: E402 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor # noqa: E402 | |
| logging.basicConfig(level=logging.INFO) | |
| LOGGER = logging.getLogger("unsamv2-gradio") | |
| USE_M2M_REFINEMENT = True | |
| CONFIG_PATH = os.getenv("UNSAMV2_CONFIG", "configs/unsamv2_small.yaml") | |
| CKPT_PATH = Path( | |
| os.getenv("UNSAMV2_CKPT", SAM2_REPO / "checkpoints" / "unsamv2_plus_ckpt.pt") | |
| ).resolve() | |
| if not CKPT_PATH.exists(): | |
| raise FileNotFoundError( | |
| f"Checkpoint not found at {CKPT_PATH}. Set UNSAMV2_CKPT to a valid .pt file." | |
| ) | |
| GRANULARITY_MIN = float(os.getenv("UNSAMV2_GRAN_MIN", 0.1)) | |
| GRANULARITY_MAX = float(os.getenv("UNSAMV2_GRAN_MAX", 1.0)) | |
| ZERO_GPU_ENABLED = os.getenv("UNSAMV2_ENABLE_ZEROGPU", "1").lower() in {"1", "true", "yes"} | |
| ZERO_GPU_DURATION = int(os.getenv("UNSAMV2_ZEROGPU_DURATION", "60")) | |
| ZERO_GPU_WHOLE_DURATION = int( | |
| os.getenv("UNSAMV2_ZEROGPU_WHOLE_DURATION", str(ZERO_GPU_DURATION)) | |
| ) | |
| ZERO_GPU_VIDEO_DURATION = int( | |
| os.getenv("UNSAMV2_ZEROGPU_VIDEO_DURATION", str(max(120, ZERO_GPU_DURATION))) | |
| ) | |
| MAX_VIDEO_FRAMES = int(os.getenv("UNSAMV2_MAX_VIDEO_FRAMES", "360")) | |
| WHOLE_IMAGE_POINTS_PER_SIDE = int(os.getenv("UNSAMV2_WHOLE_POINTS", "64")) | |
| WHOLE_IMAGE_MAX_MASKS = 1000 | |
| POINT_MODE_TO_LABEL = {"Foreground (+)": 1, "Background (-)": 0} | |
| POINT_COLORS_BGR = { | |
| 1: (72, 201, 127), # green-ish for positives | |
| 0: (64, 76, 225), # red-ish for negatives | |
| } | |
| MASK_COLOR_BGR = (0, 0, 255) | |
| DEFAULT_IMAGE_PATH = REPO_ROOT / "demo" / "bird.webp" | |
| WHOLE_IMAGE_DEFAULT_PATH = REPO_ROOT / "demo" / "sa_291195.jpg" | |
| DEFAULT_VIDEO_PATH = REPO_ROOT / "demo" / "bedroom.mp4" | |
| def _load_image_from_path(path: Path) -> Optional[np.ndarray]: | |
| if not path.exists(): | |
| LOGGER.warning("Default image missing at %s", path) | |
| return None | |
| img_bgr = cv2.imread(str(path), cv2.IMREAD_COLOR) | |
| if img_bgr is None: | |
| LOGGER.warning("Could not read default image at %s", path) | |
| return None | |
| return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) | |
| DEFAULT_IMAGE = _load_image_from_path(DEFAULT_IMAGE_PATH) | |
| WHOLE_IMAGE_DEFAULT = _load_image_from_path(WHOLE_IMAGE_DEFAULT_PATH) | |
| TMP_ROOT = REPO_ROOT / "_tmp" | |
| TMP_ROOT.mkdir(exist_ok=True) | |
| class ModelManager: | |
| """Keeps SAM2 models on each device and spawns lightweight predictors.""" | |
| def __init__(self) -> None: | |
| self._models: dict[str, torch.nn.Module] = {} | |
| self._lock = threading.Lock() | |
| def _build(self, device: torch.device) -> torch.nn.Module: | |
| LOGGER.info("Loading UnSAMv2 weights onto %s", device) | |
| return build_sam2( | |
| CONFIG_PATH, | |
| ckpt_path=str(CKPT_PATH), | |
| device=device, | |
| mode="eval", | |
| ) | |
| def get_model(self, device: torch.device) -> torch.nn.Module: | |
| key = ( | |
| f"{device.type}:{device.index}" | |
| if device.type == "cuda" | |
| else device.type | |
| ) | |
| with self._lock: | |
| if key not in self._models: | |
| self._models[key] = self._build(device) | |
| return self._models[key] | |
| def make_predictor(self, device: torch.device) -> SAM2ImagePredictor: | |
| return SAM2ImagePredictor(self.get_model(device), mask_threshold=-1.0) | |
| def make_auto_mask_generator( | |
| self, | |
| device: torch.device, | |
| **kwargs, | |
| ) -> SAM2AutomaticMaskGenerator: | |
| return SAM2AutomaticMaskGenerator(self.get_model(device), **kwargs) | |
| MODEL_MANAGER = ModelManager() | |
| class VideoPredictorManager: | |
| """Caches heavy video predictors per device.""" | |
| def __init__(self) -> None: | |
| self._predictors: dict[str, torch.nn.Module] = {} | |
| self._lock = threading.Lock() | |
| def _build(self, device: torch.device) -> torch.nn.Module: | |
| LOGGER.info("Loading UnSAMv2 video predictor onto %s", device) | |
| return build_sam2_video_predictor( | |
| CONFIG_PATH, | |
| ckpt_path=str(CKPT_PATH), | |
| device=device, | |
| ) | |
| def get_predictor(self, device: torch.device) -> torch.nn.Module: | |
| key = ( | |
| f"{device.type}:{device.index}" | |
| if device.type == "cuda" | |
| else device.type | |
| ) | |
| with self._lock: | |
| if key not in self._predictors: | |
| self._predictors[key] = self._build(device) | |
| return self._predictors[key] | |
| VIDEO_PREDICTOR_MANAGER = VideoPredictorManager() | |
| def make_empty_video_state() -> Dict[str, Any]: | |
| return { | |
| "frame_dir": None, | |
| "frame_paths": [], | |
| "fps": 0.0, | |
| "frame_size": (0, 0), | |
| } | |
| def ensure_uint8(image: Optional[np.ndarray]) -> Optional[np.ndarray]: | |
| if image is None: | |
| return None | |
| img = image[..., :3] # drop alpha if present | |
| if img.dtype == np.float32 or img.dtype == np.float64: | |
| if img.max() <= 1.0: | |
| img = (img * 255).clip(0, 255).astype(np.uint8) | |
| else: | |
| img = img.clip(0, 255).astype(np.uint8) | |
| elif img.dtype != np.uint8: | |
| img = img.clip(0, 255).astype(np.uint8) | |
| return img | |
| def make_temp_subdir(prefix: str) -> Path: | |
| TMP_ROOT.mkdir(exist_ok=True) | |
| return Path(tempfile.mkdtemp(prefix=prefix, dir=str(TMP_ROOT))) | |
| def remove_dir_if_exists(path_str: Optional[str]) -> None: | |
| if not path_str: | |
| return | |
| path = Path(path_str) | |
| if path.exists(): | |
| shutil.rmtree(path, ignore_errors=True) | |
| def load_rgb_image(path: Path) -> np.ndarray: | |
| bgr = cv2.imread(str(path), cv2.IMREAD_COLOR) | |
| if bgr is None: | |
| raise FileNotFoundError(f"Failed to read frame at {path}") | |
| return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) | |
| def resolve_video_path(video_value: Any) -> Optional[str]: | |
| if video_value is None: | |
| return None | |
| if isinstance(video_value, str): | |
| return video_value | |
| if isinstance(video_value, dict): | |
| return video_value.get("name") or video_value.get("path") | |
| # Gradio may pass a FileData/MediaData object with a .name attribute | |
| for attr in ("name", "path", "video", "data"): | |
| candidate = getattr(video_value, attr, None) | |
| if isinstance(candidate, str): | |
| return candidate | |
| return None | |
| def match_mask_to_image(mask: np.ndarray, image: np.ndarray) -> np.ndarray: | |
| mask_arr = np.asarray(mask) | |
| if mask_arr.ndim == 3: | |
| mask_arr = mask_arr.squeeze() | |
| h, w = image.shape[:2] | |
| if mask_arr.shape[:2] != (h, w): | |
| mask_arr = cv2.resize( | |
| mask_arr.astype(np.float32), | |
| (w, h), | |
| interpolation=cv2.INTER_NEAREST, | |
| ) | |
| return mask_arr.astype(bool) | |
| def colorize_mask_collection( | |
| image: np.ndarray, | |
| masks: Sequence[np.ndarray], | |
| alpha: float = 0.55, | |
| ) -> np.ndarray: | |
| if not masks: | |
| return image | |
| canvas = image.astype(np.float32) | |
| rng = np.random.default_rng(1337) | |
| for mask in masks: | |
| mask_arr = match_mask_to_image(mask, image) | |
| if not mask_arr.any(): | |
| continue | |
| color = rng.integers(20, 235, size=3) | |
| canvas[mask_arr] = ( | |
| canvas[mask_arr] * (1.0 - alpha) + color * alpha | |
| ) | |
| return canvas.clip(0, 255).astype(np.uint8) | |
| def render_video_overlay( | |
| video_state: Dict[str, Any], | |
| frame_idx: int, | |
| pts: Sequence[Sequence[float]], | |
| lbls: Sequence[int], | |
| ) -> Optional[np.ndarray]: | |
| frame_paths: List[str] = list(video_state.get("frame_paths", [])) | |
| if not frame_paths: | |
| return None | |
| safe_idx = int(np.clip(frame_idx, 0, len(frame_paths) - 1)) | |
| frame = load_rgb_image(Path(frame_paths[safe_idx])) | |
| return draw_overlay(frame, None, pts, lbls) | |
| def mask_entries_to_arrays(entries: Sequence[Dict[str, Any]]) -> List[np.ndarray]: | |
| arrays: List[np.ndarray] = [] | |
| for entry in entries: | |
| seg = entry.get("segmentation", entry) | |
| if isinstance(seg, np.ndarray): | |
| mask = seg | |
| elif isinstance(seg, dict): | |
| from sam2.utils.amg import rle_to_mask | |
| mask = rle_to_mask(seg) | |
| else: | |
| mask = np.asarray(seg) | |
| arrays.append(mask.astype(bool)) | |
| return arrays | |
| def summarize_masks(entries: Sequence[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| summary: List[Dict[str, Any]] = [] | |
| for idx, entry in enumerate(entries, start=1): | |
| summary.append( | |
| { | |
| "mask": idx, | |
| "area": int(entry.get("area", 0)), | |
| "pred_iou": round(float(entry.get("predicted_iou", 0.0)), 3), | |
| "stability": round(float(entry.get("stability_score", 0.0)), 3), | |
| } | |
| ) | |
| return summary | |
| def extract_video_frames(video_path: str) -> Tuple[List[Path], float, Tuple[int, int], Path]: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError("Could not open the uploaded video.") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if not fps or fps <= 1e-3: | |
| fps = 12.0 | |
| frame_dir = make_temp_subdir("video_frames_") | |
| frame_paths: List[Path] = [] | |
| height = width = 0 | |
| idx = 0 | |
| while True: | |
| ok, frame = cap.read() | |
| if not ok: | |
| break | |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| if idx == 0: | |
| height, width = rgb.shape[:2] | |
| out_path = frame_dir / f"{idx:05d}.jpg" | |
| if not cv2.imwrite(str(out_path), cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)): | |
| cap.release() | |
| raise RuntimeError(f"Failed to write frame {idx} to disk") | |
| frame_paths.append(out_path) | |
| idx += 1 | |
| if idx >= MAX_VIDEO_FRAMES: | |
| LOGGER.warning( | |
| "Stopping frame extraction at %d frames per UNSAMV2_MAX_VIDEO_FRAMES", | |
| MAX_VIDEO_FRAMES, | |
| ) | |
| break | |
| cap.release() | |
| if not frame_paths: | |
| remove_dir_if_exists(str(frame_dir)) | |
| raise ValueError("No frames decoded from the provided video.") | |
| if height == 0 or width == 0: | |
| sample = load_rgb_image(frame_paths[0]) | |
| height, width = sample.shape[:2] | |
| return frame_paths, float(fps), (height, width), frame_dir | |
| def write_video_from_frames(frames: Sequence[np.ndarray], fps: float) -> Path: | |
| if not frames: | |
| raise ValueError("No frames available to write video output.") | |
| height, width = frames[0].shape[:2] | |
| safe_fps = fps if fps and fps > 0 else 12.0 | |
| out_path = TMP_ROOT / f"video_seg_{uuid.uuid4().hex}.mp4" | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| writer = cv2.VideoWriter(str(out_path), fourcc, safe_fps, (width, height)) | |
| if not writer.isOpened(): | |
| raise RuntimeError("Failed to initialize video writer. Check codec support.") | |
| for frame in frames: | |
| if frame.shape[:2] != (height, width): | |
| raise ValueError("All frames must share the same spatial resolution.") | |
| writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
| writer.release() | |
| return out_path | |
| def choose_device() -> torch.device: | |
| preference = os.getenv("UNSAMV2_DEVICE", "auto").lower() | |
| if preference == "cpu": | |
| return torch.device("cpu") | |
| if preference.startswith("cuda") or preference == "gpu": | |
| if torch.cuda.is_available(): | |
| return torch.device(preference if preference.startswith("cuda") else "cuda") | |
| LOGGER.warning("CUDA requested but not available; defaulting to CPU") | |
| return torch.device("cpu") | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def wrap_with_zero_gpu( | |
| fn: Callable[..., Any], | |
| duration: int, | |
| ) -> Callable[..., Any]: | |
| if spaces is None or not ZERO_GPU_ENABLED: | |
| return fn | |
| try: | |
| LOGGER.info("Enabling ZeroGPU (duration=%ss) for %s", duration, fn.__name__) | |
| return spaces.GPU(duration=duration)(fn) # type: ignore[misc] | |
| except Exception: # pragma: no cover - defensive logging | |
| LOGGER.exception("Failed to wrap %s with ZeroGPU; running on CPU", fn.__name__) | |
| return fn | |
| def build_granularity_tensor(value: float, device: torch.device) -> torch.Tensor: | |
| tensor = torch.tensor([[[[value]]]], dtype=torch.float32, device=device) | |
| return tensor | |
| def apply_m2m_refinement( | |
| predictor, | |
| point_coords, | |
| point_labels, | |
| granularity, | |
| logits, | |
| best_mask_idx, | |
| use_m2m: bool = True, | |
| ): | |
| """Optionally run a second M2M pass using the best mask's logits.""" | |
| if not use_m2m: | |
| return None | |
| logging.info("Applying M2M refinement...") | |
| try: | |
| if logits is None: | |
| raise ValueError("logits must be provided for M2M refinement.") | |
| low_res_logits = logits[best_mask_idx : best_mask_idx + 1] | |
| refined_masks, refined_scores, _ = predictor.predict( | |
| point_coords=point_coords, | |
| point_labels=point_labels, | |
| multimask_output=False, | |
| gra=granularity, | |
| mask_input=low_res_logits, | |
| ) | |
| refined_mask = refined_masks[0] | |
| refined_score = float(refined_scores[0]) | |
| logging.info("M2M refinement completed with score: %.3f", refined_score) | |
| return refined_mask, refined_score | |
| except Exception as exc: # pragma: no cover - logging only | |
| logging.error("M2M refinement failed: %s, using original mask", exc) | |
| return None | |
| def draw_overlay( | |
| image: np.ndarray, | |
| mask: Optional[np.ndarray], | |
| points: Sequence[Sequence[float]], | |
| labels: Sequence[int], | |
| alpha: float = 0.55, | |
| ) -> np.ndarray: | |
| canvas_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| if mask is not None: | |
| mask_bool = match_mask_to_image(mask, image) | |
| overlay = np.zeros_like(canvas_bgr, dtype=np.uint8) | |
| overlay[mask_bool] = MASK_COLOR_BGR | |
| canvas_bgr = np.where( | |
| mask_bool[..., None], | |
| (canvas_bgr * (1.0 - alpha) + overlay * alpha).astype(np.uint8), | |
| canvas_bgr, | |
| ) | |
| for (x, y), lbl in zip(points, labels): | |
| color = POINT_COLORS_BGR.get(lbl, (255, 255, 255)) | |
| center = (int(round(x)), int(round(y))) | |
| cv2.circle(canvas_bgr, center, 7, color, thickness=-1, lineType=cv2.LINE_AA) | |
| cv2.circle(canvas_bgr, center, 9, (255, 255, 255), thickness=2, lineType=cv2.LINE_AA) | |
| return cv2.cvtColor(canvas_bgr, cv2.COLOR_BGR2RGB) | |
| def handle_image_upload(image: Optional[np.ndarray]): | |
| img = ensure_uint8(image) | |
| if img is None: | |
| return ( | |
| None, | |
| None, | |
| [], | |
| [], | |
| "Upload an image to start adding clicks.", | |
| ) | |
| return ( | |
| img, | |
| img, | |
| [], | |
| [], | |
| "Image loaded. Choose click type, then tap on the image.", | |
| ) | |
| def handle_click( | |
| point_mode: str, | |
| pts: List[Sequence[float]], | |
| lbls: List[int], | |
| image: Optional[np.ndarray], | |
| evt: gr.SelectData, | |
| ): | |
| if image is None: | |
| return ( | |
| gr.update(), | |
| pts, | |
| lbls, | |
| "Upload an image first.", | |
| ) | |
| coord = evt.index # (x, y) | |
| if coord is None: | |
| return ( | |
| gr.update(), | |
| pts, | |
| lbls, | |
| "Couldn't read click position.", | |
| ) | |
| x, y = coord | |
| label = POINT_MODE_TO_LABEL.get(point_mode, 1) | |
| pts = pts + [[float(x), float(y)]] | |
| lbls = lbls + [label] | |
| overlay = draw_overlay(image, None, pts, lbls) | |
| status = f"Added {'positive' if label == 1 else 'negative'} click at ({int(x)}, {int(y)})." | |
| return overlay, pts, lbls, status | |
| def undo_last_click(image: Optional[np.ndarray], pts: List[Sequence[float]], lbls: List[int]): | |
| if not pts: | |
| return ( | |
| gr.update(), | |
| pts, | |
| lbls, | |
| "No clicks to undo.", | |
| ) | |
| pts = pts[:-1] | |
| lbls = lbls[:-1] | |
| overlay = draw_overlay(image, None, pts, lbls) if image is not None else None | |
| status = "Removed the last click." | |
| return overlay, pts, lbls, status | |
| def clear_clicks(image: Optional[np.ndarray]): | |
| overlay = image if image is not None else None | |
| return overlay, [], [], "Cleared all clicks." | |
| def _run_segmentation( | |
| image: Optional[np.ndarray], | |
| pts: List[Sequence[float]], | |
| lbls: List[int], | |
| granularity: float, | |
| ): | |
| img = ensure_uint8(image) | |
| if img is None: | |
| return None, "Upload an image to segment." | |
| if not pts: | |
| return draw_overlay(img, None, [], []), "Add at least one click before running segmentation." | |
| device = choose_device() | |
| predictor = MODEL_MANAGER.make_predictor(device) | |
| predictor.set_image(img) | |
| coords = np.asarray(pts, dtype=np.float32) | |
| labels = np.asarray(lbls, dtype=np.int32) | |
| gran_tensor = build_granularity_tensor(granularity, predictor.device) | |
| masks, scores, logits = predictor.predict( | |
| point_coords=coords, | |
| point_labels=labels, | |
| multimask_output=True, | |
| gra=float(granularity), | |
| granularity=gran_tensor, | |
| ) | |
| best_idx = int(np.argmax(scores)) | |
| best_mask = masks[best_idx].astype(bool) | |
| status = ( | |
| f"Best mask #{best_idx + 1} IoU score: {float(scores[best_idx]):.3f} | " | |
| f"granularity={granularity:.2f}" | |
| ) | |
| refinement = apply_m2m_refinement( | |
| predictor=predictor, | |
| point_coords=coords, | |
| point_labels=labels, | |
| granularity=float(granularity), | |
| logits=logits, | |
| best_mask_idx=best_idx, | |
| use_m2m=USE_M2M_REFINEMENT, | |
| ) | |
| if refinement is not None: | |
| refined_mask, refined_score = refinement | |
| best_mask = refined_mask.astype(bool) | |
| status += f" | M2M IoU: {refined_score:.3f}" | |
| overlay = draw_overlay(img, best_mask, pts, lbls) | |
| return overlay, status | |
| def run_whole_image_segmentation( | |
| image: Optional[np.ndarray], | |
| granularity: float, | |
| pred_iou_thresh: float, | |
| stability_thresh: float, | |
| ): | |
| img = ensure_uint8(image) | |
| if img is None: | |
| return None, [], "Upload an image to run whole-image segmentation." | |
| device = choose_device() | |
| mask_generator = MODEL_MANAGER.make_auto_mask_generator( | |
| device=device, | |
| points_per_side=WHOLE_IMAGE_POINTS_PER_SIDE, | |
| points_per_batch=128, | |
| pred_iou_thresh=float(pred_iou_thresh), | |
| stability_score_thresh=float(stability_thresh), | |
| mask_threshold=-1.0, | |
| box_nms_thresh=0.7, | |
| crop_n_layers=0, | |
| min_mask_region_area=0, | |
| use_m2m=USE_M2M_REFINEMENT, | |
| output_mode="binary_mask", | |
| ) | |
| try: | |
| masks = mask_generator.generate(img, gra=float(granularity)) | |
| except Exception as exc: | |
| LOGGER.exception("Whole-image segmentation failed") | |
| return None, [], f"Whole-image segmentation failed: {exc}" | |
| if not masks: | |
| return img, [], "Mask generator did not return any regions. Try lowering thresholds." | |
| trimmed = masks[:WHOLE_IMAGE_MAX_MASKS] | |
| mask_arrays = mask_entries_to_arrays(trimmed) | |
| overlay = colorize_mask_collection(img, mask_arrays) | |
| table = summarize_masks(trimmed) | |
| status = ( | |
| f"Generated {len(trimmed)} masks | granularity={granularity:.2f}, " | |
| f"IoU≥{pred_iou_thresh:.2f}, stability≥{stability_thresh:.2f}" | |
| ) | |
| return overlay, table, status | |
| def handle_video_upload( | |
| video_file: Any, | |
| current_state: Optional[Dict[str, Any]] = None, | |
| ): | |
| if current_state: | |
| remove_dir_if_exists(current_state.get("frame_dir")) | |
| state = make_empty_video_state() | |
| if isinstance(video_file, (list, tuple)): | |
| video_file = video_file[0] if video_file else None | |
| video_path = resolve_video_path(video_file) | |
| if not video_path: | |
| return ( | |
| gr.update(value=None, visible=False), | |
| state, | |
| gr.update(value=0, minimum=0, maximum=0, interactive=False), | |
| [], | |
| [], | |
| 0, | |
| "Upload a video to start adding clicks.", | |
| ) | |
| try: | |
| frame_paths, fps, frame_size, frame_dir = extract_video_frames(video_path) | |
| except Exception as exc: | |
| LOGGER.exception("Video decoding failed") | |
| return ( | |
| gr.update(value=None, visible=False), | |
| state, | |
| gr.update(value=0, minimum=0, maximum=0, interactive=False), | |
| [], | |
| [], | |
| 0, | |
| f"Video decoding failed: {exc}", | |
| ) | |
| state.update( | |
| { | |
| "frame_dir": str(frame_dir), | |
| "frame_paths": [str(p) for p in frame_paths], | |
| "fps": fps, | |
| "frame_size": frame_size, | |
| } | |
| ) | |
| first_overlay = render_video_overlay(state, 0, [], []) | |
| slider_update = gr.update( | |
| value=0, | |
| minimum=0, | |
| maximum=len(frame_paths) - 1, | |
| step=1, | |
| interactive=True, | |
| ) | |
| status = f"Loaded video with {len(frame_paths)} frames at {fps:.1f} FPS." | |
| return ( | |
| gr.update(value=first_overlay, visible=True), | |
| state, | |
| slider_update, | |
| [], | |
| [], | |
| 0, | |
| status, | |
| ) | |
| def handle_video_frame_change( | |
| frame_idx: int, | |
| video_state: Dict[str, Any], | |
| ): | |
| overlay = render_video_overlay(video_state, frame_idx, [], []) | |
| if overlay is None: | |
| return gr.update(), [], [], 0, "Upload a video first." | |
| safe_idx = int(np.clip(frame_idx, 0, len(video_state.get("frame_paths", [])) - 1)) | |
| status = f"Annotating frame {safe_idx}." | |
| return overlay, [], [], safe_idx, status | |
| def handle_video_click( | |
| point_mode: str, | |
| pts: List[Sequence[float]], | |
| lbls: List[int], | |
| video_state: Dict[str, Any], | |
| frame_idx: int, | |
| evt: gr.SelectData, | |
| ): | |
| overlay = render_video_overlay(video_state, frame_idx, pts, lbls) | |
| if overlay is None: | |
| return gr.update(), pts, lbls, "Upload a video first." | |
| if evt.index is None: | |
| return overlay, pts, lbls, "Couldn't read click position." | |
| x, y = evt.index | |
| label = POINT_MODE_TO_LABEL.get(point_mode, 1) | |
| pts = pts + [[float(x), float(y)]] | |
| lbls = lbls + [label] | |
| overlay = render_video_overlay(video_state, frame_idx, pts, lbls) | |
| status = ( | |
| f"Added {'positive' if label == 1 else 'negative'} click at " | |
| f"({int(x)}, {int(y)}) on frame {int(frame_idx)}." | |
| ) | |
| return overlay, pts, lbls, status | |
| def undo_video_click( | |
| video_state: Dict[str, Any], | |
| pts: List[Sequence[float]], | |
| lbls: List[int], | |
| frame_idx: int, | |
| ): | |
| if not pts: | |
| return gr.update(), pts, lbls, "No clicks to undo." | |
| pts = pts[:-1] | |
| lbls = lbls[:-1] | |
| overlay = render_video_overlay(video_state, frame_idx, pts, lbls) | |
| return overlay, pts, lbls, "Removed the last click." | |
| def clear_video_clicks(video_state: Dict[str, Any], frame_idx: int): | |
| overlay = render_video_overlay(video_state, frame_idx, [], []) | |
| return overlay, [], [], "Cleared all clicks for the selected frame." | |
| def reset_video_interface(current_state: Dict[str, Any]): | |
| remove_dir_if_exists(current_state.get("frame_dir")) | |
| state = make_empty_video_state() | |
| return ( | |
| gr.update(value=None, visible=False), | |
| state, | |
| gr.update(value=0, minimum=0, maximum=0, interactive=False), | |
| [], | |
| [], | |
| 0, | |
| "Cleared video. Upload a new clip to continue.", | |
| ) | |
| def run_video_segmentation( | |
| video_state: Dict[str, Any], | |
| pts: List[Sequence[float]], | |
| lbls: List[int], | |
| frame_idx: int, | |
| granularity: float, | |
| ): | |
| frame_paths: List[str] = list(video_state.get("frame_paths", [])) | |
| if not frame_paths: | |
| return None, "Upload a video to segment." | |
| if not pts: | |
| return None, "Add at least one click on the annotation frame." | |
| frame_dir = video_state.get("frame_dir") | |
| if not frame_dir: | |
| return None, "Video frames are unavailable. Please re-upload the video." | |
| safe_idx = int(np.clip(frame_idx, 0, len(frame_paths) - 1)) | |
| device = choose_device() | |
| predictor = VIDEO_PREDICTOR_MANAGER.get_predictor(device) | |
| inference_state = predictor.init_state(video_path=frame_dir) | |
| predictor.reset_state(inference_state) | |
| coords = np.asarray(pts, dtype=np.float32) | |
| labels = np.asarray(lbls, dtype=np.int32) | |
| try: | |
| _, obj_ids, mask_logits = predictor.add_new_points_or_box( | |
| inference_state=inference_state, | |
| frame_idx=safe_idx, | |
| obj_id=1, | |
| points=coords, | |
| labels=labels, | |
| gra=float(granularity), | |
| ) | |
| except Exception as exc: | |
| LOGGER.exception("Video add_new_points_or_box failed") | |
| return None, f"Video segmentation failed during prompting: {exc}" | |
| video_masks: Dict[int, Dict[int, np.ndarray]] = {} | |
| video_masks[safe_idx] = { | |
| int(obj_id): (mask_logits[i] > -1.0).cpu().numpy() | |
| for i, obj_id in enumerate(obj_ids) | |
| } | |
| try: | |
| for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( | |
| inference_state, | |
| gra=float(granularity), | |
| ): | |
| video_masks[out_frame_idx] = { | |
| int(obj_id): (out_mask_logits[i] > -1.0).cpu().numpy() | |
| for i, obj_id in enumerate(out_obj_ids) | |
| } | |
| except Exception as exc: | |
| LOGGER.exception("Video propagation failed") | |
| return None, f"Video propagation failed: {exc}" | |
| overlays: List[np.ndarray] = [] | |
| for idx, frame_path in enumerate(frame_paths): | |
| base = load_rgb_image(Path(frame_path)) | |
| mask = video_masks.get(idx, {}).get(1) | |
| overlays.append(draw_overlay(base, mask, [], [])) | |
| try: | |
| video_path = write_video_from_frames(overlays, video_state.get("fps", 12.0)) | |
| except Exception as exc: | |
| LOGGER.exception("Failed to encode output video") | |
| return None, f"Tracking succeeded but video export failed: {exc}" | |
| status = ( | |
| f"Tracked object from frame {safe_idx} across {len(frame_paths)} frames | " | |
| f"granularity={granularity:.2f}" | |
| ) | |
| return str(video_path), status | |
| def run_video_frame_segmentation( | |
| video_state: Dict[str, Any], | |
| pts: List[Sequence[float]], | |
| lbls: List[int], | |
| frame_idx: int, | |
| granularity: float, | |
| ): | |
| frame_paths: List[str] = list(video_state.get("frame_paths", [])) | |
| if not frame_paths: | |
| return None, "Upload a video to segment." | |
| if not pts: | |
| return None, "Add at least one click on the annotation frame." | |
| frame_dir = video_state.get("frame_dir") | |
| if not frame_dir: | |
| return None, "Video frames are unavailable. Please re-upload the video." | |
| safe_idx = int(np.clip(frame_idx, 0, len(frame_paths) - 1)) | |
| device = choose_device() | |
| predictor = VIDEO_PREDICTOR_MANAGER.get_predictor(device) | |
| inference_state = predictor.init_state(video_path=frame_dir) | |
| predictor.reset_state(inference_state) | |
| coords = np.asarray(pts, dtype=np.float32) | |
| labels = np.asarray(lbls, dtype=np.int32) | |
| try: | |
| _, obj_ids, mask_logits = predictor.add_new_points_or_box( | |
| inference_state=inference_state, | |
| frame_idx=safe_idx, | |
| obj_id=1, | |
| points=coords, | |
| labels=labels, | |
| gra=float(granularity), | |
| ) | |
| except Exception as exc: | |
| LOGGER.exception("Video frame segmentation failed") | |
| return None, f"Frame segmentation failed: {exc}" | |
| if not obj_ids: | |
| return None, "Predictor did not return a mask for this frame." | |
| mask = (mask_logits[0] > -1.0).cpu().numpy() | |
| base = load_rgb_image(Path(frame_paths[safe_idx])) | |
| overlay = draw_overlay(base, mask, pts, lbls) | |
| status = ( | |
| f"Segmented frame {safe_idx} with {len(pts)} clicks | " | |
| f"granularity={granularity:.2f}" | |
| ) | |
| return overlay, status | |
| segment_fn = wrap_with_zero_gpu(_run_segmentation, ZERO_GPU_DURATION) | |
| whole_image_fn = wrap_with_zero_gpu( | |
| run_whole_image_segmentation, | |
| ZERO_GPU_WHOLE_DURATION, | |
| ) | |
| video_frame_fn = wrap_with_zero_gpu( | |
| run_video_frame_segmentation, | |
| ZERO_GPU_VIDEO_DURATION, | |
| ) | |
| video_segmentation_fn = wrap_with_zero_gpu( | |
| run_video_segmentation, | |
| ZERO_GPU_VIDEO_DURATION, | |
| ) | |
| def build_demo() -> gr.Blocks: | |
| with gr.Blocks(title="UnSAMv2 Interactive + Whole Image + Video", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| <div style="text-align:center"> | |
| <h2>UnSAMv2 · Segment Anything at Any Granularity</h2> | |
| </div> | |
| """ | |
| ) | |
| gr.HTML( | |
| """ | |
| <style> | |
| #mode-tabs button[role="tab"] { | |
| flex: 0 0 auto; | |
| min-width: 160px; | |
| } | |
| #mode-tabs [role="tablist"], | |
| #mode-tabs .tab-nav, | |
| #mode-tabs > div:first-child { | |
| display: flex !important; | |
| justify-content: center !important; | |
| gap: 0.75rem; | |
| } | |
| </style> | |
| """ | |
| ) | |
| with gr.Tabs(elem_id="mode-tabs"): | |
| # Interactive Image Tab | |
| with gr.Tab("Interactive Image Segmentation"): | |
| image_state = gr.State(DEFAULT_IMAGE) | |
| points_state = gr.State([]) | |
| labels_state = gr.State([]) | |
| image_input = gr.Image( | |
| label="Image · clicks & mask", | |
| type="numpy", | |
| height=480, | |
| value=DEFAULT_IMAGE, | |
| sources=["upload"], | |
| ) | |
| with gr.Row(equal_height=True): | |
| point_mode = gr.Radio( | |
| choices=list(POINT_MODE_TO_LABEL.keys()), | |
| value="Foreground (+)", | |
| label="Click type", | |
| ) | |
| granularity_slider = gr.Slider( | |
| minimum=GRANULARITY_MIN, | |
| maximum=GRANULARITY_MAX, | |
| value=0.2, | |
| step=0.01, | |
| label="Granularity", | |
| info="Lower = finer details, Higher = coarser regions", | |
| ) | |
| segment_button = gr.Button("Segment", variant="primary") | |
| with gr.Row(): | |
| undo_button = gr.Button("Undo last click") | |
| clear_button = gr.Button("Clear clicks") | |
| status_markdown = gr.Markdown(" Ready for interactive clicks.") | |
| image_input.upload( | |
| handle_image_upload, | |
| inputs=[image_input], | |
| outputs=[ | |
| image_input, | |
| image_state, | |
| points_state, | |
| labels_state, | |
| status_markdown, | |
| ], | |
| ) | |
| image_input.clear( | |
| handle_image_upload, | |
| inputs=[image_input], | |
| outputs=[ | |
| image_input, | |
| image_state, | |
| points_state, | |
| labels_state, | |
| status_markdown, | |
| ], | |
| ) | |
| image_input.select( | |
| handle_click, | |
| inputs=[ | |
| point_mode, | |
| points_state, | |
| labels_state, | |
| image_state, | |
| ], | |
| outputs=[ | |
| image_input, | |
| points_state, | |
| labels_state, | |
| status_markdown, | |
| ], | |
| ) | |
| undo_button.click( | |
| undo_last_click, | |
| inputs=[image_state, points_state, labels_state], | |
| outputs=[ | |
| image_input, | |
| points_state, | |
| labels_state, | |
| status_markdown, | |
| ], | |
| ) | |
| clear_button.click( | |
| clear_clicks, | |
| inputs=[image_state], | |
| outputs=[ | |
| image_input, | |
| points_state, | |
| labels_state, | |
| status_markdown, | |
| ], | |
| ) | |
| segment_button.click( | |
| segment_fn, | |
| inputs=[image_state, points_state, labels_state, granularity_slider], | |
| outputs=[image_input, status_markdown], | |
| ) | |
| # Whole Image Tab | |
| with gr.Tab("Whole Image Segmentation"): | |
| whole_image_input = gr.Image( | |
| label="Image · automatic masks", | |
| type="numpy", | |
| height=480, | |
| value=WHOLE_IMAGE_DEFAULT if WHOLE_IMAGE_DEFAULT is not None else DEFAULT_IMAGE, | |
| sources=["upload"], | |
| ) | |
| whole_granularity = gr.Slider( | |
| minimum=GRANULARITY_MIN, | |
| maximum=GRANULARITY_MAX, | |
| value=0.15, | |
| step=0.01, | |
| label="Granularity", | |
| ) | |
| whole_generate_btn = gr.Button("Generate masks", variant="primary") | |
| with gr.Accordion("Advanced mask filtering", open=False): | |
| pred_iou_thresh = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.99, | |
| value=0.77, | |
| step=0.01, | |
| label="Predicted IoU threshold", | |
| ) | |
| stability_thresh = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.99, | |
| value=0.9, | |
| step=0.01, | |
| label="Stability threshold", | |
| ) | |
| whole_overlay = gr.Image(label="Mask overlay", height=480) | |
| whole_table = gr.Dataframe( | |
| headers=["mask", "area", "pred_iou", "stability"], | |
| datatype=["number", "number", "number", "number"], | |
| label="Mask stats", | |
| wrap=True, | |
| visible=False, | |
| ) | |
| whole_status = gr.Markdown(" Ready for whole-image masks.") | |
| whole_generate_btn.click( | |
| whole_image_fn, | |
| inputs=[ | |
| whole_image_input, | |
| whole_granularity, | |
| pred_iou_thresh, | |
| stability_thresh, | |
| ], | |
| outputs=[whole_overlay, whole_table, whole_status], | |
| ) | |
| # Video Tab | |
| with gr.Tab("Video Segmentation"): | |
| video_state = gr.State(make_empty_video_state()) | |
| video_points_state = gr.State([]) | |
| video_labels_state = gr.State([]) | |
| annotation_frame_state = gr.State(0) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, min_width=360): | |
| upload_button = gr.UploadButton( | |
| "Upload video", | |
| file_types=["video"], | |
| file_count="single", | |
| ) | |
| frame_display = gr.Image( | |
| label="Video · add clicks", | |
| type="numpy", | |
| height=420, | |
| interactive=True, | |
| visible=False, | |
| ) | |
| frame_slider = gr.Slider( | |
| minimum=0, | |
| maximum=0, | |
| value=0, | |
| step=1, | |
| interactive=False, | |
| label="Select frame", | |
| ) | |
| video_point_mode = gr.Radio( | |
| choices=list(POINT_MODE_TO_LABEL.keys()), | |
| value="Foreground (+)", | |
| label="Click type", | |
| ) | |
| with gr.Row(): | |
| video_undo = gr.Button("Undo click") | |
| video_clear = gr.Button("Clear clicks") | |
| video_granularity = gr.Slider( | |
| minimum=GRANULARITY_MIN, | |
| maximum=GRANULARITY_MAX, | |
| value=0.33, | |
| step=0.01, | |
| label="Granularity", | |
| ) | |
| with gr.Row(): | |
| video_frame_btn = gr.Button("Segment frame", variant="secondary") | |
| video_segment_btn = gr.Button("Propagate video", variant="primary") | |
| with gr.Column(scale=1, min_width=320): | |
| video_output = gr.Video( | |
| label="Segmented preview", | |
| autoplay=False, | |
| height=420, | |
| ) | |
| video_status = gr.Markdown(" Ready for video segmentation.") | |
| upload_button.upload( | |
| handle_video_upload, | |
| inputs=[upload_button, video_state], | |
| outputs=[ | |
| frame_display, | |
| video_state, | |
| frame_slider, | |
| video_points_state, | |
| video_labels_state, | |
| annotation_frame_state, | |
| video_status, | |
| ], | |
| ) | |
| if DEFAULT_VIDEO_PATH.exists(): | |
| def _load_default_video(state): | |
| return handle_video_upload(str(DEFAULT_VIDEO_PATH), state) | |
| demo.load( | |
| _load_default_video, | |
| inputs=[video_state], | |
| outputs=[ | |
| frame_display, | |
| video_state, | |
| frame_slider, | |
| video_points_state, | |
| video_labels_state, | |
| annotation_frame_state, | |
| video_status, | |
| ], | |
| queue=False, | |
| ) | |
| frame_slider.change( | |
| handle_video_frame_change, | |
| inputs=[frame_slider, video_state], | |
| outputs=[ | |
| frame_display, | |
| video_points_state, | |
| video_labels_state, | |
| annotation_frame_state, | |
| video_status, | |
| ], | |
| ) | |
| frame_display.select( | |
| handle_video_click, | |
| inputs=[ | |
| video_point_mode, | |
| video_points_state, | |
| video_labels_state, | |
| video_state, | |
| annotation_frame_state, | |
| ], | |
| outputs=[ | |
| frame_display, | |
| video_points_state, | |
| video_labels_state, | |
| video_status, | |
| ], | |
| ) | |
| frame_display.clear( | |
| reset_video_interface, | |
| inputs=[video_state], | |
| outputs=[ | |
| frame_display, | |
| video_state, | |
| frame_slider, | |
| video_points_state, | |
| video_labels_state, | |
| annotation_frame_state, | |
| video_status, | |
| ], | |
| ) | |
| video_frame_btn.click( | |
| video_frame_fn, | |
| inputs=[ | |
| video_state, | |
| video_points_state, | |
| video_labels_state, | |
| annotation_frame_state, | |
| video_granularity, | |
| ], | |
| outputs=[frame_display, video_status], | |
| ) | |
| video_undo.click( | |
| undo_video_click, | |
| inputs=[ | |
| video_state, | |
| video_points_state, | |
| video_labels_state, | |
| annotation_frame_state, | |
| ], | |
| outputs=[ | |
| frame_display, | |
| video_points_state, | |
| video_labels_state, | |
| video_status, | |
| ], | |
| ) | |
| video_clear.click( | |
| clear_video_clicks, | |
| inputs=[video_state, annotation_frame_state], | |
| outputs=[ | |
| frame_display, | |
| video_points_state, | |
| video_labels_state, | |
| video_status, | |
| ], | |
| ) | |
| video_segment_btn.click( | |
| video_segmentation_fn, | |
| inputs=[ | |
| video_state, | |
| video_points_state, | |
| video_labels_state, | |
| annotation_frame_state, | |
| video_granularity, | |
| ], | |
| outputs=[video_output, video_status], | |
| ) | |
| demo.queue(max_size=8) | |
| return demo | |
| demo = build_demo() | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), share=True) | |