UnSAMv2 / app.py
yjwnb6
0
6c5665b
#!/usr/bin/env python3
"""Gradio demo for UnSAMv2 interactive image segmentation with Hugging Face ZeroGPU support."""
from __future__ import annotations
import logging
import os
import shutil
import sys
import tempfile
import threading
import uuid
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
import cv2
import gradio as gr
import numpy as np
import torch
try:
import spaces # type: ignore
except ImportError: # pragma: no cover - optional dependency on Spaces runtime
spaces = None
REPO_ROOT = Path(__file__).resolve().parent
SAM2_REPO = REPO_ROOT / "sam2"
if SAM2_REPO.exists():
sys.path.insert(0, str(SAM2_REPO))
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator # noqa: E402
from sam2.build_sam import build_sam2, build_sam2_video_predictor # noqa: E402
from sam2.sam2_image_predictor import SAM2ImagePredictor # noqa: E402
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger("unsamv2-gradio")
USE_M2M_REFINEMENT = True
CONFIG_PATH = os.getenv("UNSAMV2_CONFIG", "configs/unsamv2_small.yaml")
CKPT_PATH = Path(
os.getenv("UNSAMV2_CKPT", SAM2_REPO / "checkpoints" / "unsamv2_plus_ckpt.pt")
).resolve()
if not CKPT_PATH.exists():
raise FileNotFoundError(
f"Checkpoint not found at {CKPT_PATH}. Set UNSAMV2_CKPT to a valid .pt file."
)
GRANULARITY_MIN = float(os.getenv("UNSAMV2_GRAN_MIN", 0.1))
GRANULARITY_MAX = float(os.getenv("UNSAMV2_GRAN_MAX", 1.0))
ZERO_GPU_ENABLED = os.getenv("UNSAMV2_ENABLE_ZEROGPU", "1").lower() in {"1", "true", "yes"}
ZERO_GPU_DURATION = int(os.getenv("UNSAMV2_ZEROGPU_DURATION", "60"))
ZERO_GPU_WHOLE_DURATION = int(
os.getenv("UNSAMV2_ZEROGPU_WHOLE_DURATION", str(ZERO_GPU_DURATION))
)
ZERO_GPU_VIDEO_DURATION = int(
os.getenv("UNSAMV2_ZEROGPU_VIDEO_DURATION", str(max(120, ZERO_GPU_DURATION)))
)
MAX_VIDEO_FRAMES = int(os.getenv("UNSAMV2_MAX_VIDEO_FRAMES", "360"))
WHOLE_IMAGE_POINTS_PER_SIDE = int(os.getenv("UNSAMV2_WHOLE_POINTS", "64"))
WHOLE_IMAGE_MAX_MASKS = 1000
POINT_MODE_TO_LABEL = {"Foreground (+)": 1, "Background (-)": 0}
POINT_COLORS_BGR = {
1: (72, 201, 127), # green-ish for positives
0: (64, 76, 225), # red-ish for negatives
}
MASK_COLOR_BGR = (0, 0, 255)
DEFAULT_IMAGE_PATH = REPO_ROOT / "demo" / "bird.webp"
WHOLE_IMAGE_DEFAULT_PATH = REPO_ROOT / "demo" / "sa_291195.jpg"
DEFAULT_VIDEO_PATH = REPO_ROOT / "demo" / "bedroom.mp4"
def _load_image_from_path(path: Path) -> Optional[np.ndarray]:
if not path.exists():
LOGGER.warning("Default image missing at %s", path)
return None
img_bgr = cv2.imread(str(path), cv2.IMREAD_COLOR)
if img_bgr is None:
LOGGER.warning("Could not read default image at %s", path)
return None
return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
DEFAULT_IMAGE = _load_image_from_path(DEFAULT_IMAGE_PATH)
WHOLE_IMAGE_DEFAULT = _load_image_from_path(WHOLE_IMAGE_DEFAULT_PATH)
TMP_ROOT = REPO_ROOT / "_tmp"
TMP_ROOT.mkdir(exist_ok=True)
class ModelManager:
"""Keeps SAM2 models on each device and spawns lightweight predictors."""
def __init__(self) -> None:
self._models: dict[str, torch.nn.Module] = {}
self._lock = threading.Lock()
def _build(self, device: torch.device) -> torch.nn.Module:
LOGGER.info("Loading UnSAMv2 weights onto %s", device)
return build_sam2(
CONFIG_PATH,
ckpt_path=str(CKPT_PATH),
device=device,
mode="eval",
)
def get_model(self, device: torch.device) -> torch.nn.Module:
key = (
f"{device.type}:{device.index}"
if device.type == "cuda"
else device.type
)
with self._lock:
if key not in self._models:
self._models[key] = self._build(device)
return self._models[key]
def make_predictor(self, device: torch.device) -> SAM2ImagePredictor:
return SAM2ImagePredictor(self.get_model(device), mask_threshold=-1.0)
def make_auto_mask_generator(
self,
device: torch.device,
**kwargs,
) -> SAM2AutomaticMaskGenerator:
return SAM2AutomaticMaskGenerator(self.get_model(device), **kwargs)
MODEL_MANAGER = ModelManager()
class VideoPredictorManager:
"""Caches heavy video predictors per device."""
def __init__(self) -> None:
self._predictors: dict[str, torch.nn.Module] = {}
self._lock = threading.Lock()
def _build(self, device: torch.device) -> torch.nn.Module:
LOGGER.info("Loading UnSAMv2 video predictor onto %s", device)
return build_sam2_video_predictor(
CONFIG_PATH,
ckpt_path=str(CKPT_PATH),
device=device,
)
def get_predictor(self, device: torch.device) -> torch.nn.Module:
key = (
f"{device.type}:{device.index}"
if device.type == "cuda"
else device.type
)
with self._lock:
if key not in self._predictors:
self._predictors[key] = self._build(device)
return self._predictors[key]
VIDEO_PREDICTOR_MANAGER = VideoPredictorManager()
def make_empty_video_state() -> Dict[str, Any]:
return {
"frame_dir": None,
"frame_paths": [],
"fps": 0.0,
"frame_size": (0, 0),
}
def ensure_uint8(image: Optional[np.ndarray]) -> Optional[np.ndarray]:
if image is None:
return None
img = image[..., :3] # drop alpha if present
if img.dtype == np.float32 or img.dtype == np.float64:
if img.max() <= 1.0:
img = (img * 255).clip(0, 255).astype(np.uint8)
else:
img = img.clip(0, 255).astype(np.uint8)
elif img.dtype != np.uint8:
img = img.clip(0, 255).astype(np.uint8)
return img
def make_temp_subdir(prefix: str) -> Path:
TMP_ROOT.mkdir(exist_ok=True)
return Path(tempfile.mkdtemp(prefix=prefix, dir=str(TMP_ROOT)))
def remove_dir_if_exists(path_str: Optional[str]) -> None:
if not path_str:
return
path = Path(path_str)
if path.exists():
shutil.rmtree(path, ignore_errors=True)
def load_rgb_image(path: Path) -> np.ndarray:
bgr = cv2.imread(str(path), cv2.IMREAD_COLOR)
if bgr is None:
raise FileNotFoundError(f"Failed to read frame at {path}")
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
def resolve_video_path(video_value: Any) -> Optional[str]:
if video_value is None:
return None
if isinstance(video_value, str):
return video_value
if isinstance(video_value, dict):
return video_value.get("name") or video_value.get("path")
# Gradio may pass a FileData/MediaData object with a .name attribute
for attr in ("name", "path", "video", "data"):
candidate = getattr(video_value, attr, None)
if isinstance(candidate, str):
return candidate
return None
def match_mask_to_image(mask: np.ndarray, image: np.ndarray) -> np.ndarray:
mask_arr = np.asarray(mask)
if mask_arr.ndim == 3:
mask_arr = mask_arr.squeeze()
h, w = image.shape[:2]
if mask_arr.shape[:2] != (h, w):
mask_arr = cv2.resize(
mask_arr.astype(np.float32),
(w, h),
interpolation=cv2.INTER_NEAREST,
)
return mask_arr.astype(bool)
def colorize_mask_collection(
image: np.ndarray,
masks: Sequence[np.ndarray],
alpha: float = 0.55,
) -> np.ndarray:
if not masks:
return image
canvas = image.astype(np.float32)
rng = np.random.default_rng(1337)
for mask in masks:
mask_arr = match_mask_to_image(mask, image)
if not mask_arr.any():
continue
color = rng.integers(20, 235, size=3)
canvas[mask_arr] = (
canvas[mask_arr] * (1.0 - alpha) + color * alpha
)
return canvas.clip(0, 255).astype(np.uint8)
def render_video_overlay(
video_state: Dict[str, Any],
frame_idx: int,
pts: Sequence[Sequence[float]],
lbls: Sequence[int],
) -> Optional[np.ndarray]:
frame_paths: List[str] = list(video_state.get("frame_paths", []))
if not frame_paths:
return None
safe_idx = int(np.clip(frame_idx, 0, len(frame_paths) - 1))
frame = load_rgb_image(Path(frame_paths[safe_idx]))
return draw_overlay(frame, None, pts, lbls)
def mask_entries_to_arrays(entries: Sequence[Dict[str, Any]]) -> List[np.ndarray]:
arrays: List[np.ndarray] = []
for entry in entries:
seg = entry.get("segmentation", entry)
if isinstance(seg, np.ndarray):
mask = seg
elif isinstance(seg, dict):
from sam2.utils.amg import rle_to_mask
mask = rle_to_mask(seg)
else:
mask = np.asarray(seg)
arrays.append(mask.astype(bool))
return arrays
def summarize_masks(entries: Sequence[Dict[str, Any]]) -> List[Dict[str, Any]]:
summary: List[Dict[str, Any]] = []
for idx, entry in enumerate(entries, start=1):
summary.append(
{
"mask": idx,
"area": int(entry.get("area", 0)),
"pred_iou": round(float(entry.get("predicted_iou", 0.0)), 3),
"stability": round(float(entry.get("stability_score", 0.0)), 3),
}
)
return summary
def extract_video_frames(video_path: str) -> Tuple[List[Path], float, Tuple[int, int], Path]:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Could not open the uploaded video.")
fps = cap.get(cv2.CAP_PROP_FPS)
if not fps or fps <= 1e-3:
fps = 12.0
frame_dir = make_temp_subdir("video_frames_")
frame_paths: List[Path] = []
height = width = 0
idx = 0
while True:
ok, frame = cap.read()
if not ok:
break
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if idx == 0:
height, width = rgb.shape[:2]
out_path = frame_dir / f"{idx:05d}.jpg"
if not cv2.imwrite(str(out_path), cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)):
cap.release()
raise RuntimeError(f"Failed to write frame {idx} to disk")
frame_paths.append(out_path)
idx += 1
if idx >= MAX_VIDEO_FRAMES:
LOGGER.warning(
"Stopping frame extraction at %d frames per UNSAMV2_MAX_VIDEO_FRAMES",
MAX_VIDEO_FRAMES,
)
break
cap.release()
if not frame_paths:
remove_dir_if_exists(str(frame_dir))
raise ValueError("No frames decoded from the provided video.")
if height == 0 or width == 0:
sample = load_rgb_image(frame_paths[0])
height, width = sample.shape[:2]
return frame_paths, float(fps), (height, width), frame_dir
def write_video_from_frames(frames: Sequence[np.ndarray], fps: float) -> Path:
if not frames:
raise ValueError("No frames available to write video output.")
height, width = frames[0].shape[:2]
safe_fps = fps if fps and fps > 0 else 12.0
out_path = TMP_ROOT / f"video_seg_{uuid.uuid4().hex}.mp4"
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(str(out_path), fourcc, safe_fps, (width, height))
if not writer.isOpened():
raise RuntimeError("Failed to initialize video writer. Check codec support.")
for frame in frames:
if frame.shape[:2] != (height, width):
raise ValueError("All frames must share the same spatial resolution.")
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
writer.release()
return out_path
def choose_device() -> torch.device:
preference = os.getenv("UNSAMV2_DEVICE", "auto").lower()
if preference == "cpu":
return torch.device("cpu")
if preference.startswith("cuda") or preference == "gpu":
if torch.cuda.is_available():
return torch.device(preference if preference.startswith("cuda") else "cuda")
LOGGER.warning("CUDA requested but not available; defaulting to CPU")
return torch.device("cpu")
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def wrap_with_zero_gpu(
fn: Callable[..., Any],
duration: int,
) -> Callable[..., Any]:
if spaces is None or not ZERO_GPU_ENABLED:
return fn
try:
LOGGER.info("Enabling ZeroGPU (duration=%ss) for %s", duration, fn.__name__)
return spaces.GPU(duration=duration)(fn) # type: ignore[misc]
except Exception: # pragma: no cover - defensive logging
LOGGER.exception("Failed to wrap %s with ZeroGPU; running on CPU", fn.__name__)
return fn
def build_granularity_tensor(value: float, device: torch.device) -> torch.Tensor:
tensor = torch.tensor([[[[value]]]], dtype=torch.float32, device=device)
return tensor
def apply_m2m_refinement(
predictor,
point_coords,
point_labels,
granularity,
logits,
best_mask_idx,
use_m2m: bool = True,
):
"""Optionally run a second M2M pass using the best mask's logits."""
if not use_m2m:
return None
logging.info("Applying M2M refinement...")
try:
if logits is None:
raise ValueError("logits must be provided for M2M refinement.")
low_res_logits = logits[best_mask_idx : best_mask_idx + 1]
refined_masks, refined_scores, _ = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=False,
gra=granularity,
mask_input=low_res_logits,
)
refined_mask = refined_masks[0]
refined_score = float(refined_scores[0])
logging.info("M2M refinement completed with score: %.3f", refined_score)
return refined_mask, refined_score
except Exception as exc: # pragma: no cover - logging only
logging.error("M2M refinement failed: %s, using original mask", exc)
return None
def draw_overlay(
image: np.ndarray,
mask: Optional[np.ndarray],
points: Sequence[Sequence[float]],
labels: Sequence[int],
alpha: float = 0.55,
) -> np.ndarray:
canvas_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
if mask is not None:
mask_bool = match_mask_to_image(mask, image)
overlay = np.zeros_like(canvas_bgr, dtype=np.uint8)
overlay[mask_bool] = MASK_COLOR_BGR
canvas_bgr = np.where(
mask_bool[..., None],
(canvas_bgr * (1.0 - alpha) + overlay * alpha).astype(np.uint8),
canvas_bgr,
)
for (x, y), lbl in zip(points, labels):
color = POINT_COLORS_BGR.get(lbl, (255, 255, 255))
center = (int(round(x)), int(round(y)))
cv2.circle(canvas_bgr, center, 7, color, thickness=-1, lineType=cv2.LINE_AA)
cv2.circle(canvas_bgr, center, 9, (255, 255, 255), thickness=2, lineType=cv2.LINE_AA)
return cv2.cvtColor(canvas_bgr, cv2.COLOR_BGR2RGB)
def handle_image_upload(image: Optional[np.ndarray]):
img = ensure_uint8(image)
if img is None:
return (
None,
None,
[],
[],
"Upload an image to start adding clicks.",
)
return (
img,
img,
[],
[],
"Image loaded. Choose click type, then tap on the image.",
)
def handle_click(
point_mode: str,
pts: List[Sequence[float]],
lbls: List[int],
image: Optional[np.ndarray],
evt: gr.SelectData,
):
if image is None:
return (
gr.update(),
pts,
lbls,
"Upload an image first.",
)
coord = evt.index # (x, y)
if coord is None:
return (
gr.update(),
pts,
lbls,
"Couldn't read click position.",
)
x, y = coord
label = POINT_MODE_TO_LABEL.get(point_mode, 1)
pts = pts + [[float(x), float(y)]]
lbls = lbls + [label]
overlay = draw_overlay(image, None, pts, lbls)
status = f"Added {'positive' if label == 1 else 'negative'} click at ({int(x)}, {int(y)})."
return overlay, pts, lbls, status
def undo_last_click(image: Optional[np.ndarray], pts: List[Sequence[float]], lbls: List[int]):
if not pts:
return (
gr.update(),
pts,
lbls,
"No clicks to undo.",
)
pts = pts[:-1]
lbls = lbls[:-1]
overlay = draw_overlay(image, None, pts, lbls) if image is not None else None
status = "Removed the last click."
return overlay, pts, lbls, status
def clear_clicks(image: Optional[np.ndarray]):
overlay = image if image is not None else None
return overlay, [], [], "Cleared all clicks."
def _run_segmentation(
image: Optional[np.ndarray],
pts: List[Sequence[float]],
lbls: List[int],
granularity: float,
):
img = ensure_uint8(image)
if img is None:
return None, "Upload an image to segment."
if not pts:
return draw_overlay(img, None, [], []), "Add at least one click before running segmentation."
device = choose_device()
predictor = MODEL_MANAGER.make_predictor(device)
predictor.set_image(img)
coords = np.asarray(pts, dtype=np.float32)
labels = np.asarray(lbls, dtype=np.int32)
gran_tensor = build_granularity_tensor(granularity, predictor.device)
masks, scores, logits = predictor.predict(
point_coords=coords,
point_labels=labels,
multimask_output=True,
gra=float(granularity),
granularity=gran_tensor,
)
best_idx = int(np.argmax(scores))
best_mask = masks[best_idx].astype(bool)
status = (
f"Best mask #{best_idx + 1} IoU score: {float(scores[best_idx]):.3f} | "
f"granularity={granularity:.2f}"
)
refinement = apply_m2m_refinement(
predictor=predictor,
point_coords=coords,
point_labels=labels,
granularity=float(granularity),
logits=logits,
best_mask_idx=best_idx,
use_m2m=USE_M2M_REFINEMENT,
)
if refinement is not None:
refined_mask, refined_score = refinement
best_mask = refined_mask.astype(bool)
status += f" | M2M IoU: {refined_score:.3f}"
overlay = draw_overlay(img, best_mask, pts, lbls)
return overlay, status
def run_whole_image_segmentation(
image: Optional[np.ndarray],
granularity: float,
pred_iou_thresh: float,
stability_thresh: float,
):
img = ensure_uint8(image)
if img is None:
return None, [], "Upload an image to run whole-image segmentation."
device = choose_device()
mask_generator = MODEL_MANAGER.make_auto_mask_generator(
device=device,
points_per_side=WHOLE_IMAGE_POINTS_PER_SIDE,
points_per_batch=128,
pred_iou_thresh=float(pred_iou_thresh),
stability_score_thresh=float(stability_thresh),
mask_threshold=-1.0,
box_nms_thresh=0.7,
crop_n_layers=0,
min_mask_region_area=0,
use_m2m=USE_M2M_REFINEMENT,
output_mode="binary_mask",
)
try:
masks = mask_generator.generate(img, gra=float(granularity))
except Exception as exc:
LOGGER.exception("Whole-image segmentation failed")
return None, [], f"Whole-image segmentation failed: {exc}"
if not masks:
return img, [], "Mask generator did not return any regions. Try lowering thresholds."
trimmed = masks[:WHOLE_IMAGE_MAX_MASKS]
mask_arrays = mask_entries_to_arrays(trimmed)
overlay = colorize_mask_collection(img, mask_arrays)
table = summarize_masks(trimmed)
status = (
f"Generated {len(trimmed)} masks | granularity={granularity:.2f}, "
f"IoU≥{pred_iou_thresh:.2f}, stability≥{stability_thresh:.2f}"
)
return overlay, table, status
def handle_video_upload(
video_file: Any,
current_state: Optional[Dict[str, Any]] = None,
):
if current_state:
remove_dir_if_exists(current_state.get("frame_dir"))
state = make_empty_video_state()
if isinstance(video_file, (list, tuple)):
video_file = video_file[0] if video_file else None
video_path = resolve_video_path(video_file)
if not video_path:
return (
gr.update(value=None, visible=False),
state,
gr.update(value=0, minimum=0, maximum=0, interactive=False),
[],
[],
0,
"Upload a video to start adding clicks.",
)
try:
frame_paths, fps, frame_size, frame_dir = extract_video_frames(video_path)
except Exception as exc:
LOGGER.exception("Video decoding failed")
return (
gr.update(value=None, visible=False),
state,
gr.update(value=0, minimum=0, maximum=0, interactive=False),
[],
[],
0,
f"Video decoding failed: {exc}",
)
state.update(
{
"frame_dir": str(frame_dir),
"frame_paths": [str(p) for p in frame_paths],
"fps": fps,
"frame_size": frame_size,
}
)
first_overlay = render_video_overlay(state, 0, [], [])
slider_update = gr.update(
value=0,
minimum=0,
maximum=len(frame_paths) - 1,
step=1,
interactive=True,
)
status = f"Loaded video with {len(frame_paths)} frames at {fps:.1f} FPS."
return (
gr.update(value=first_overlay, visible=True),
state,
slider_update,
[],
[],
0,
status,
)
def handle_video_frame_change(
frame_idx: int,
video_state: Dict[str, Any],
):
overlay = render_video_overlay(video_state, frame_idx, [], [])
if overlay is None:
return gr.update(), [], [], 0, "Upload a video first."
safe_idx = int(np.clip(frame_idx, 0, len(video_state.get("frame_paths", [])) - 1))
status = f"Annotating frame {safe_idx}."
return overlay, [], [], safe_idx, status
def handle_video_click(
point_mode: str,
pts: List[Sequence[float]],
lbls: List[int],
video_state: Dict[str, Any],
frame_idx: int,
evt: gr.SelectData,
):
overlay = render_video_overlay(video_state, frame_idx, pts, lbls)
if overlay is None:
return gr.update(), pts, lbls, "Upload a video first."
if evt.index is None:
return overlay, pts, lbls, "Couldn't read click position."
x, y = evt.index
label = POINT_MODE_TO_LABEL.get(point_mode, 1)
pts = pts + [[float(x), float(y)]]
lbls = lbls + [label]
overlay = render_video_overlay(video_state, frame_idx, pts, lbls)
status = (
f"Added {'positive' if label == 1 else 'negative'} click at "
f"({int(x)}, {int(y)}) on frame {int(frame_idx)}."
)
return overlay, pts, lbls, status
def undo_video_click(
video_state: Dict[str, Any],
pts: List[Sequence[float]],
lbls: List[int],
frame_idx: int,
):
if not pts:
return gr.update(), pts, lbls, "No clicks to undo."
pts = pts[:-1]
lbls = lbls[:-1]
overlay = render_video_overlay(video_state, frame_idx, pts, lbls)
return overlay, pts, lbls, "Removed the last click."
def clear_video_clicks(video_state: Dict[str, Any], frame_idx: int):
overlay = render_video_overlay(video_state, frame_idx, [], [])
return overlay, [], [], "Cleared all clicks for the selected frame."
def reset_video_interface(current_state: Dict[str, Any]):
remove_dir_if_exists(current_state.get("frame_dir"))
state = make_empty_video_state()
return (
gr.update(value=None, visible=False),
state,
gr.update(value=0, minimum=0, maximum=0, interactive=False),
[],
[],
0,
"Cleared video. Upload a new clip to continue.",
)
def run_video_segmentation(
video_state: Dict[str, Any],
pts: List[Sequence[float]],
lbls: List[int],
frame_idx: int,
granularity: float,
):
frame_paths: List[str] = list(video_state.get("frame_paths", []))
if not frame_paths:
return None, "Upload a video to segment."
if not pts:
return None, "Add at least one click on the annotation frame."
frame_dir = video_state.get("frame_dir")
if not frame_dir:
return None, "Video frames are unavailable. Please re-upload the video."
safe_idx = int(np.clip(frame_idx, 0, len(frame_paths) - 1))
device = choose_device()
predictor = VIDEO_PREDICTOR_MANAGER.get_predictor(device)
inference_state = predictor.init_state(video_path=frame_dir)
predictor.reset_state(inference_state)
coords = np.asarray(pts, dtype=np.float32)
labels = np.asarray(lbls, dtype=np.int32)
try:
_, obj_ids, mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=safe_idx,
obj_id=1,
points=coords,
labels=labels,
gra=float(granularity),
)
except Exception as exc:
LOGGER.exception("Video add_new_points_or_box failed")
return None, f"Video segmentation failed during prompting: {exc}"
video_masks: Dict[int, Dict[int, np.ndarray]] = {}
video_masks[safe_idx] = {
int(obj_id): (mask_logits[i] > -1.0).cpu().numpy()
for i, obj_id in enumerate(obj_ids)
}
try:
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
inference_state,
gra=float(granularity),
):
video_masks[out_frame_idx] = {
int(obj_id): (out_mask_logits[i] > -1.0).cpu().numpy()
for i, obj_id in enumerate(out_obj_ids)
}
except Exception as exc:
LOGGER.exception("Video propagation failed")
return None, f"Video propagation failed: {exc}"
overlays: List[np.ndarray] = []
for idx, frame_path in enumerate(frame_paths):
base = load_rgb_image(Path(frame_path))
mask = video_masks.get(idx, {}).get(1)
overlays.append(draw_overlay(base, mask, [], []))
try:
video_path = write_video_from_frames(overlays, video_state.get("fps", 12.0))
except Exception as exc:
LOGGER.exception("Failed to encode output video")
return None, f"Tracking succeeded but video export failed: {exc}"
status = (
f"Tracked object from frame {safe_idx} across {len(frame_paths)} frames | "
f"granularity={granularity:.2f}"
)
return str(video_path), status
def run_video_frame_segmentation(
video_state: Dict[str, Any],
pts: List[Sequence[float]],
lbls: List[int],
frame_idx: int,
granularity: float,
):
frame_paths: List[str] = list(video_state.get("frame_paths", []))
if not frame_paths:
return None, "Upload a video to segment."
if not pts:
return None, "Add at least one click on the annotation frame."
frame_dir = video_state.get("frame_dir")
if not frame_dir:
return None, "Video frames are unavailable. Please re-upload the video."
safe_idx = int(np.clip(frame_idx, 0, len(frame_paths) - 1))
device = choose_device()
predictor = VIDEO_PREDICTOR_MANAGER.get_predictor(device)
inference_state = predictor.init_state(video_path=frame_dir)
predictor.reset_state(inference_state)
coords = np.asarray(pts, dtype=np.float32)
labels = np.asarray(lbls, dtype=np.int32)
try:
_, obj_ids, mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=safe_idx,
obj_id=1,
points=coords,
labels=labels,
gra=float(granularity),
)
except Exception as exc:
LOGGER.exception("Video frame segmentation failed")
return None, f"Frame segmentation failed: {exc}"
if not obj_ids:
return None, "Predictor did not return a mask for this frame."
mask = (mask_logits[0] > -1.0).cpu().numpy()
base = load_rgb_image(Path(frame_paths[safe_idx]))
overlay = draw_overlay(base, mask, pts, lbls)
status = (
f"Segmented frame {safe_idx} with {len(pts)} clicks | "
f"granularity={granularity:.2f}"
)
return overlay, status
segment_fn = wrap_with_zero_gpu(_run_segmentation, ZERO_GPU_DURATION)
whole_image_fn = wrap_with_zero_gpu(
run_whole_image_segmentation,
ZERO_GPU_WHOLE_DURATION,
)
video_frame_fn = wrap_with_zero_gpu(
run_video_frame_segmentation,
ZERO_GPU_VIDEO_DURATION,
)
video_segmentation_fn = wrap_with_zero_gpu(
run_video_segmentation,
ZERO_GPU_VIDEO_DURATION,
)
def build_demo() -> gr.Blocks:
with gr.Blocks(title="UnSAMv2 Interactive + Whole Image + Video", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
<div style="text-align:center">
<h2>UnSAMv2 · Segment Anything at Any Granularity</h2>
</div>
"""
)
gr.HTML(
"""
<style>
#mode-tabs button[role="tab"] {
flex: 0 0 auto;
min-width: 160px;
}
#mode-tabs [role="tablist"],
#mode-tabs .tab-nav,
#mode-tabs > div:first-child {
display: flex !important;
justify-content: center !important;
gap: 0.75rem;
}
</style>
"""
)
with gr.Tabs(elem_id="mode-tabs"):
# Interactive Image Tab
with gr.Tab("Interactive Image Segmentation"):
image_state = gr.State(DEFAULT_IMAGE)
points_state = gr.State([])
labels_state = gr.State([])
image_input = gr.Image(
label="Image · clicks & mask",
type="numpy",
height=480,
value=DEFAULT_IMAGE,
sources=["upload"],
)
with gr.Row(equal_height=True):
point_mode = gr.Radio(
choices=list(POINT_MODE_TO_LABEL.keys()),
value="Foreground (+)",
label="Click type",
)
granularity_slider = gr.Slider(
minimum=GRANULARITY_MIN,
maximum=GRANULARITY_MAX,
value=0.2,
step=0.01,
label="Granularity",
info="Lower = finer details, Higher = coarser regions",
)
segment_button = gr.Button("Segment", variant="primary")
with gr.Row():
undo_button = gr.Button("Undo last click")
clear_button = gr.Button("Clear clicks")
status_markdown = gr.Markdown(" Ready for interactive clicks.")
image_input.upload(
handle_image_upload,
inputs=[image_input],
outputs=[
image_input,
image_state,
points_state,
labels_state,
status_markdown,
],
)
image_input.clear(
handle_image_upload,
inputs=[image_input],
outputs=[
image_input,
image_state,
points_state,
labels_state,
status_markdown,
],
)
image_input.select(
handle_click,
inputs=[
point_mode,
points_state,
labels_state,
image_state,
],
outputs=[
image_input,
points_state,
labels_state,
status_markdown,
],
)
undo_button.click(
undo_last_click,
inputs=[image_state, points_state, labels_state],
outputs=[
image_input,
points_state,
labels_state,
status_markdown,
],
)
clear_button.click(
clear_clicks,
inputs=[image_state],
outputs=[
image_input,
points_state,
labels_state,
status_markdown,
],
)
segment_button.click(
segment_fn,
inputs=[image_state, points_state, labels_state, granularity_slider],
outputs=[image_input, status_markdown],
)
# Whole Image Tab
with gr.Tab("Whole Image Segmentation"):
whole_image_input = gr.Image(
label="Image · automatic masks",
type="numpy",
height=480,
value=WHOLE_IMAGE_DEFAULT if WHOLE_IMAGE_DEFAULT is not None else DEFAULT_IMAGE,
sources=["upload"],
)
whole_granularity = gr.Slider(
minimum=GRANULARITY_MIN,
maximum=GRANULARITY_MAX,
value=0.15,
step=0.01,
label="Granularity",
)
whole_generate_btn = gr.Button("Generate masks", variant="primary")
with gr.Accordion("Advanced mask filtering", open=False):
pred_iou_thresh = gr.Slider(
minimum=0.1,
maximum=0.99,
value=0.77,
step=0.01,
label="Predicted IoU threshold",
)
stability_thresh = gr.Slider(
minimum=0.1,
maximum=0.99,
value=0.9,
step=0.01,
label="Stability threshold",
)
whole_overlay = gr.Image(label="Mask overlay", height=480)
whole_table = gr.Dataframe(
headers=["mask", "area", "pred_iou", "stability"],
datatype=["number", "number", "number", "number"],
label="Mask stats",
wrap=True,
visible=False,
)
whole_status = gr.Markdown(" Ready for whole-image masks.")
whole_generate_btn.click(
whole_image_fn,
inputs=[
whole_image_input,
whole_granularity,
pred_iou_thresh,
stability_thresh,
],
outputs=[whole_overlay, whole_table, whole_status],
)
# Video Tab
with gr.Tab("Video Segmentation"):
video_state = gr.State(make_empty_video_state())
video_points_state = gr.State([])
video_labels_state = gr.State([])
annotation_frame_state = gr.State(0)
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=360):
upload_button = gr.UploadButton(
"Upload video",
file_types=["video"],
file_count="single",
)
frame_display = gr.Image(
label="Video · add clicks",
type="numpy",
height=420,
interactive=True,
visible=False,
)
frame_slider = gr.Slider(
minimum=0,
maximum=0,
value=0,
step=1,
interactive=False,
label="Select frame",
)
video_point_mode = gr.Radio(
choices=list(POINT_MODE_TO_LABEL.keys()),
value="Foreground (+)",
label="Click type",
)
with gr.Row():
video_undo = gr.Button("Undo click")
video_clear = gr.Button("Clear clicks")
video_granularity = gr.Slider(
minimum=GRANULARITY_MIN,
maximum=GRANULARITY_MAX,
value=0.33,
step=0.01,
label="Granularity",
)
with gr.Row():
video_frame_btn = gr.Button("Segment frame", variant="secondary")
video_segment_btn = gr.Button("Propagate video", variant="primary")
with gr.Column(scale=1, min_width=320):
video_output = gr.Video(
label="Segmented preview",
autoplay=False,
height=420,
)
video_status = gr.Markdown(" Ready for video segmentation.")
upload_button.upload(
handle_video_upload,
inputs=[upload_button, video_state],
outputs=[
frame_display,
video_state,
frame_slider,
video_points_state,
video_labels_state,
annotation_frame_state,
video_status,
],
)
if DEFAULT_VIDEO_PATH.exists():
def _load_default_video(state):
return handle_video_upload(str(DEFAULT_VIDEO_PATH), state)
demo.load(
_load_default_video,
inputs=[video_state],
outputs=[
frame_display,
video_state,
frame_slider,
video_points_state,
video_labels_state,
annotation_frame_state,
video_status,
],
queue=False,
)
frame_slider.change(
handle_video_frame_change,
inputs=[frame_slider, video_state],
outputs=[
frame_display,
video_points_state,
video_labels_state,
annotation_frame_state,
video_status,
],
)
frame_display.select(
handle_video_click,
inputs=[
video_point_mode,
video_points_state,
video_labels_state,
video_state,
annotation_frame_state,
],
outputs=[
frame_display,
video_points_state,
video_labels_state,
video_status,
],
)
frame_display.clear(
reset_video_interface,
inputs=[video_state],
outputs=[
frame_display,
video_state,
frame_slider,
video_points_state,
video_labels_state,
annotation_frame_state,
video_status,
],
)
video_frame_btn.click(
video_frame_fn,
inputs=[
video_state,
video_points_state,
video_labels_state,
annotation_frame_state,
video_granularity,
],
outputs=[frame_display, video_status],
)
video_undo.click(
undo_video_click,
inputs=[
video_state,
video_points_state,
video_labels_state,
annotation_frame_state,
],
outputs=[
frame_display,
video_points_state,
video_labels_state,
video_status,
],
)
video_clear.click(
clear_video_clicks,
inputs=[video_state, annotation_frame_state],
outputs=[
frame_display,
video_points_state,
video_labels_state,
video_status,
],
)
video_segment_btn.click(
video_segmentation_fn,
inputs=[
video_state,
video_points_state,
video_labels_state,
annotation_frame_state,
video_granularity,
],
outputs=[video_output, video_status],
)
demo.queue(max_size=8)
return demo
demo = build_demo()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), share=True)