Spaces:
Running
Running
| """ | |
| Vehicle trajectory extractor powered by SAM3. | |
| The app takes an aerial video, segments small and large vehicles frame-by-frame | |
| with text prompts (`small-vehicle`, `large-vehicle`), and draws their | |
| trajectories on top of the footage. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import os | |
| import tempfile | |
| import uuid | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Sequence, Tuple | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from transformers import pipeline | |
| # ----------------------------------------------------------------------------- | |
| # Configuration | |
| # ----------------------------------------------------------------------------- | |
| MODEL_ID = "facebook/sam3" | |
| TEXT_PROMPTS = ["small-vehicle", "large-vehicle"] | |
| MIN_MASK_PIXELS = 150 # filter spurious detections | |
| MAX_TRACK_GAP = 3 # frames | |
| DEFAULT_FRAME_STRIDE = 5 | |
| MAX_PROCESSED_FRAMES = 720 | |
| CLASS_COLORS: Dict[str, Tuple[int, int, int]] = { | |
| "small-vehicle": (20, 148, 245), # RGB | |
| "large-vehicle": (255, 120, 30), | |
| } | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ----------------------------------------------------------------------------- | |
| # Model + processor | |
| # ----------------------------------------------------------------------------- | |
| # Use pipeline as shown in Hugging Face guidance | |
| # Then extract model and processor for text-prompt support | |
| mask_pipe = pipeline("mask-generation", model=MODEL_ID, device=0 if DEVICE == "cuda" else -1) | |
| # Extract model and processor from pipeline for direct text prompt usage | |
| model = mask_pipe.model | |
| processor = mask_pipe.feature_extractor if hasattr(mask_pipe, 'feature_extractor') else mask_pipe.image_processor | |
| # ----------------------------------------------------------------------------- | |
| # Tracking utilities | |
| # ----------------------------------------------------------------------------- | |
| class Track: | |
| track_id: int | |
| label: str | |
| points: List[Tuple[int, float, float]] | |
| last_frame: int | |
| score: float | None | |
| def _extract_detections(frame_rgb: np.ndarray) -> List[Dict]: | |
| pil_image = Image.fromarray(frame_rgb) | |
| detections: List[Dict] = [] | |
| for label in TEXT_PROMPTS: | |
| # Use processor and model directly with text prompt | |
| try: | |
| inputs = processor(images=pil_image, text=label, return_tensors="pt") | |
| inputs = { | |
| k: (v.to(DEVICE) if isinstance(v, torch.Tensor) else v) | |
| for k, v in inputs.items() | |
| } | |
| with torch.inference_mode(): | |
| outputs = model(**inputs) | |
| # Extract masks from outputs - SAM3 outputs structure may vary | |
| if hasattr(outputs, "pred_masks"): | |
| masks = outputs.pred_masks | |
| elif hasattr(outputs, "masks"): | |
| masks = outputs.masks | |
| elif isinstance(outputs, dict): | |
| masks = outputs.get("pred_masks") or outputs.get("masks") | |
| else: | |
| masks = outputs | |
| if masks is None: | |
| continue | |
| # Handle different mask formats | |
| if isinstance(masks, torch.Tensor): | |
| if masks.ndim == 4: # [batch, num_masks, H, W] | |
| masks = masks[0] # Remove batch dimension | |
| elif masks.ndim == 3: # [num_masks, H, W] | |
| pass | |
| else: | |
| continue | |
| for mask_tensor in masks: | |
| mask_np = mask_tensor.squeeze().detach().cpu().numpy() | |
| if mask_np.ndim == 3: | |
| mask_np = mask_np[0] | |
| binary_mask = mask_np > 0.5 | |
| area = int(binary_mask.sum()) | |
| if area < MIN_MASK_PIXELS: | |
| continue | |
| ys, xs = np.nonzero(binary_mask) | |
| if len(xs) == 0: | |
| continue | |
| centroid = (float(xs.mean()), float(ys.mean())) | |
| detections.append( | |
| { | |
| "label": label, | |
| "mask": binary_mask, | |
| "score": None, | |
| "centroid": centroid, | |
| "area": area, | |
| } | |
| ) | |
| except Exception as e: | |
| # Fallback to pipeline if direct access fails | |
| try: | |
| results = mask_pipe(pil_image) | |
| if not isinstance(results, list): | |
| results = [results] | |
| for result in results: | |
| if isinstance(result, dict): | |
| mask = result.get("mask") | |
| score = result.get("score") | |
| else: | |
| mask = result | |
| score = None | |
| if isinstance(mask, Image.Image): | |
| mask_np = np.array(mask.convert("L")) | |
| elif isinstance(mask, torch.Tensor): | |
| mask_np = mask.squeeze().detach().cpu().numpy() | |
| elif isinstance(mask, np.ndarray): | |
| mask_np = mask | |
| else: | |
| continue | |
| if mask_np.ndim == 3: | |
| mask_np = mask_np[:, :, 0] if mask_np.shape[2] == 1 else mask_np.max(axis=2) | |
| if mask_np.max() > 1.0: | |
| mask_np = mask_np / 255.0 | |
| binary_mask = mask_np > 0.5 | |
| area = int(binary_mask.sum()) | |
| if area < MIN_MASK_PIXELS: | |
| continue | |
| ys, xs = np.nonzero(binary_mask) | |
| if len(xs) == 0: | |
| continue | |
| centroid = (float(xs.mean()), float(ys.mean())) | |
| detections.append( | |
| { | |
| "label": label, | |
| "mask": binary_mask, | |
| "score": float(score) if score is not None else None, | |
| "centroid": centroid, | |
| "area": area, | |
| } | |
| ) | |
| except Exception as e2: | |
| raise gr.Error(f"Both direct model access and pipeline failed: {e2}") | |
| return detections | |
| def _update_tracks( | |
| tracks: List[Track], | |
| detections: Sequence[Dict], | |
| frame_idx: int, | |
| max_distance: float, | |
| ) -> None: | |
| for detection in detections: | |
| centroid = np.array(detection["centroid"]) | |
| best_track = None | |
| best_distance = math.inf | |
| for track in tracks: | |
| if track.label != detection["label"]: | |
| continue | |
| if frame_idx - track.last_frame > MAX_TRACK_GAP: | |
| continue | |
| prev_point = np.array(track.points[-1][1:]) | |
| dist = np.linalg.norm(centroid - prev_point) | |
| if dist < best_distance and dist <= max_distance: | |
| best_distance = dist | |
| best_track = track | |
| if best_track: | |
| best_track.points.append((frame_idx, *detection["centroid"])) | |
| best_track.last_frame = frame_idx | |
| best_track.score = detection["score"] | |
| else: | |
| new_track = Track( | |
| track_id=len(tracks) + 1, | |
| label=detection["label"], | |
| points=[(frame_idx, *detection["centroid"])], | |
| last_frame=frame_idx, | |
| score=detection["score"], | |
| ) | |
| tracks.append(new_track) | |
| def _blend_mask(frame: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int], alpha: float = 0.45): | |
| overlay = frame.copy() | |
| overlay[mask] = (1 - alpha) * overlay[mask] + alpha * np.array(color, dtype=np.float32) | |
| return overlay | |
| def _draw_annotations( | |
| frame_rgb: np.ndarray, | |
| detections: Sequence[Dict], | |
| tracks: Sequence[Track], | |
| frame_idx: int, | |
| ): | |
| annotated = frame_rgb.astype(np.float32) | |
| for det in detections: | |
| color_rgb = CLASS_COLORS.get(det["label"], (255, 255, 255)) | |
| color_bgr = tuple(int(c) for c in reversed(color_rgb)) | |
| annotated = _blend_mask(annotated, det["mask"], color_rgb) | |
| cx, cy = det["centroid"] | |
| cv2.circle(annotated, (int(cx), int(cy)), 4, color_bgr, -1) | |
| cv2.putText( | |
| annotated, | |
| det["label"], | |
| (int(cx) + 4, int(cy) - 4), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.4, | |
| color_bgr, | |
| 1, | |
| cv2.LINE_AA, | |
| ) | |
| for track in tracks: | |
| if len(track.points) < 2: | |
| continue | |
| if track.points[-1][0] < frame_idx - MAX_TRACK_GAP: | |
| continue | |
| color_rgb = CLASS_COLORS.get(track.label, (255, 255, 255)) | |
| color_bgr = tuple(int(c) for c in reversed(color_rgb)) | |
| pts = [ | |
| (int(x), int(y)) | |
| for (f_idx, x, y) in track.points | |
| if f_idx <= frame_idx | |
| ] | |
| for i in range(1, len(pts)): | |
| cv2.line(annotated, pts[i - 1], pts[i], color_bgr, 2, cv2.LINE_AA) | |
| cv2.circle(annotated, pts[-1], 5, color_bgr, -1) | |
| return np.clip(annotated, 0, 255).astype(np.uint8) | |
| def _summarize_tracks(tracks: Sequence[Track]) -> List[Dict]: | |
| summary = [] | |
| for track in tracks: | |
| if len(track.points) < 2: | |
| continue | |
| distances = [] | |
| for (prev_frame, x1, y1), (curr_frame, x2, y2) in zip(track.points, track.points[1:]): | |
| distances.append(math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)) | |
| summary.append( | |
| { | |
| "track_id": track.track_id, | |
| "label": track.label, | |
| "frames": len(track.points), | |
| "start_frame": track.points[0][0], | |
| "end_frame": track.points[-1][0], | |
| "path_px": round(float(sum(distances)), 2), | |
| } | |
| ) | |
| return summary | |
| # ----------------------------------------------------------------------------- | |
| # Video processing | |
| # ----------------------------------------------------------------------------- | |
| def analyze_video( | |
| video_path: str, | |
| frame_stride: int = DEFAULT_FRAME_STRIDE, | |
| max_frames: int = MAX_PROCESSED_FRAMES, | |
| resize_long_edge: int = 1280, | |
| ) -> Tuple[str, List[Dict]]: | |
| if not video_path: | |
| raise gr.Error("Please upload an aerial video (MP4, MOV, ...).") | |
| capture = cv2.VideoCapture(video_path) | |
| if not capture.isOpened(): | |
| raise gr.Error("Unable to read the uploaded video.") | |
| fps = capture.get(cv2.CAP_PROP_FPS) or 15 | |
| width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| diag = math.sqrt(width**2 + height**2) | |
| max_assign_distance = 0.04 * diag | |
| processed_frames = [] | |
| tracks: List[Track] = [] | |
| frame_index = 0 | |
| processed_count = 0 | |
| while processed_count < max_frames: | |
| ret, frame_bgr = capture.read() | |
| if not ret: | |
| break | |
| if frame_index % frame_stride != 0: | |
| frame_index += 1 | |
| continue | |
| frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| frame_rgb = _resize_long_edge(frame_rgb, resize_long_edge) | |
| detections = _extract_detections(frame_rgb) | |
| _update_tracks(tracks, detections, frame_index, max_assign_distance) | |
| annotated = _draw_annotations(frame_rgb, detections, tracks, frame_index) | |
| processed_frames.append(cv2.cvtColor(annotated, cv2.COLOR_RGB2BGR)) | |
| processed_count += 1 | |
| frame_index += 1 | |
| capture.release() | |
| if not processed_frames: | |
| raise gr.Error("No frames were processed. Try lowering the stride or uploading a different video.") | |
| output_path = _write_video(processed_frames, fps / max(frame_stride, 1)) | |
| summary = _summarize_tracks(tracks) | |
| return output_path, summary | |
| def _resize_long_edge(frame_rgb: np.ndarray, target_long_edge: int) -> np.ndarray: | |
| h, w, _ = frame_rgb.shape | |
| long_edge = max(h, w) | |
| if long_edge <= target_long_edge: | |
| return frame_rgb | |
| scale = target_long_edge / long_edge | |
| new_size = (int(w * scale), int(h * scale)) | |
| resized = cv2.resize(frame_rgb, new_size, interpolation=cv2.INTER_AREA) | |
| return resized | |
| def _write_video(frames: Sequence[np.ndarray], fps: float) -> str: | |
| height, width, _ = frames[0].shape | |
| tmp_path = os.path.join(tempfile.gettempdir(), f"sam3-trajectories-{uuid.uuid4().hex}.mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| writer = cv2.VideoWriter(tmp_path, fourcc, max(fps, 1.0), (width, height)) | |
| for frame in frames: | |
| writer.write(frame) | |
| writer.release() | |
| return tmp_path | |
| # ----------------------------------------------------------------------------- | |
| # Gradio UI | |
| # ----------------------------------------------------------------------------- | |
| with gr.Blocks(title="SAM3 Vehicle Trajectories") as demo: | |
| gr.Markdown( | |
| """ | |
| ### SAM3 for Vehicle Trajectories | |
| 1. Upload an aerial surveillance video. | |
| 2. The app prompts SAM3 with `small-vehicle` and `large-vehicle`. | |
| 3. Segmentations are linked across frames to render motion trails. | |
| """ | |
| ) | |
| with gr.Row(): | |
| video_input = gr.Video(label="Aerial video (MP4/MOV)") | |
| controls = gr.Column() | |
| with controls: | |
| stride_slider = gr.Slider( | |
| label="Frame stride", | |
| minimum=1, | |
| maximum=12, | |
| value=DEFAULT_FRAME_STRIDE, | |
| step=1, | |
| info="Process one frame every N frames", | |
| ) | |
| max_frames_slider = gr.Slider( | |
| label="Max frames to process", | |
| minimum=30, | |
| maximum=1000, | |
| value=MAX_PROCESSED_FRAMES, | |
| step=10, | |
| ) | |
| resize_slider = gr.Slider( | |
| label="Resize longest edge (px)", | |
| minimum=640, | |
| maximum=1920, | |
| value=1280, | |
| step=40, | |
| ) | |
| output_video = gr.Video(label="Overlay with trajectories") | |
| track_table = gr.Dataframe( | |
| headers=["track_id", "label", "frames", "start_frame", "end_frame", "path_px"], | |
| datatype=["number", "str", "number", "number", "number", "number"], | |
| wrap=True, | |
| label="Track summary", | |
| ) | |
| run_button = gr.Button("Extract trajectories", variant="primary") | |
| run_button.click( | |
| fn=analyze_video, | |
| inputs=[video_input, stride_slider, max_frames_slider, resize_slider], | |
| outputs=[output_video, track_table], | |
| api_name="analyze", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |