Mirko Trasciatti
Fix SAM2 window: use ALL frames when no kick detected
da567fd
from __future__ import annotations
import colorsys
import gc
from copy import deepcopy
import base64
import math
import statistics
from pathlib import Path
import json
import plotly.graph_objects as go
BASE64_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.b64")
EXAMPLE_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.mp4")
def ensure_example_video() -> str:
"""
Ensure the Kickit example video exists locally by decoding the base64 text file.
Returns the path to the decoded MP4.
"""
if EXAMPLE_VIDEO_PATH.exists():
return str(EXAMPLE_VIDEO_PATH)
if not BASE64_VIDEO_PATH.exists():
raise FileNotFoundError("Base64 video asset not found.")
data = BASE64_VIDEO_PATH.read_text()
EXAMPLE_VIDEO_PATH.write_bytes(base64.b64decode(data))
return str(EXAMPLE_VIDEO_PATH)
from types import SimpleNamespace
from typing import Optional, Any
import cv2
import gradio as gr
import numpy as np
try:
import spaces
except ImportError:
class _SpacesFallback:
@staticmethod
def GPU(*args, **kwargs):
def decorator(fn):
return fn
return decorator
spaces = _SpacesFallback()
import torch
from gradio.themes import Soft
from PIL import Image, ImageDraw
from transformers import AutoModel, Sam2VideoProcessor
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
YOLO_MODEL_CACHE: dict[str, YOLO] = {}
YOLO_DEFAULT_MODEL = "yolov13n.pt"
YOLO_REPO_ID = "atalaydenknalbant/Yolov13"
YOLO_TARGET_NAME = "sports ball"
YOLO_CONF_THRESHOLD = 0.0
YOLO_IOU_THRESHOLD = 0.02
PLAYER_TARGET_NAME = "person"
PLAYER_OBJECT_ID = 2
BALL_OBJECT_ID = 1
GOAL_MODE_IDLE = "idle"
GOAL_MODE_PLACING_FIRST = "placing_first"
GOAL_MODE_PLACING_SECOND = "placing_second"
GOAL_MODE_EDITING = "editing"
GOAL_HANDLE_RADIUS_PX = 8
GOAL_HANDLE_HIT_RADIUS_PX = 28
GOAL_LINE_COLOR = (255, 214, 64)
GOAL_HANDLE_FILL = (10, 10, 10)
def get_yolo_model(model_filename: str = YOLO_DEFAULT_MODEL) -> YOLO:
"""
Lazily download and load a YOLOv13 model, caching it for reuse.
"""
if model_filename in YOLO_MODEL_CACHE:
return YOLO_MODEL_CACHE[model_filename]
model_path = hf_hub_download(repo_id=YOLO_REPO_ID, filename=model_filename)
model = YOLO(model_path)
YOLO_MODEL_CACHE[model_filename] = model
return model
def detect_ball_center(
frame: Image.Image,
model_filename: str = YOLO_DEFAULT_MODEL,
conf_threshold: float = YOLO_CONF_THRESHOLD,
iou_threshold: float = YOLO_IOU_THRESHOLD,
) -> Optional[tuple[int, int, int, int, float]]:
"""
Run YOLO on a single frame and return (x_center, y_center, width, height, confidence)
for the highest-confidence sports ball detection.
"""
model = get_yolo_model(model_filename)
class_ids = [
idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME
]
if not class_ids:
return None
results = model.predict(
source=frame,
conf=conf_threshold,
iou=iou_threshold,
max_det=1,
classes=class_ids,
imgsz=640,
device="cpu",
verbose=False,
)
if not results:
return None
boxes = results[0].boxes
if boxes is None or len(boxes) == 0:
return None
box = boxes[0]
# xywh format: x_center, y_center, width, height
xywh = box.xywh[0].cpu().tolist()
conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0
x_center, y_center, width, height = xywh
return (
int(round(x_center)),
int(round(y_center)),
int(round(width)),
int(round(height)),
conf,
)
def detect_all_balls(
frame: Image.Image,
model_filename: str = YOLO_DEFAULT_MODEL,
conf_threshold: float = 0.05, # Minimum 5% confidence to filter noise
iou_threshold: float = YOLO_IOU_THRESHOLD,
max_detections: int = 10, # Get more from YOLO, then filter to top 5
max_candidates: int = 5, # Return only top 5 by confidence
) -> list[dict]:
"""
Detect all ball candidates in a frame.
- Minimum 5% confidence to filter noise
- Returns top 5 candidates by confidence
- No ROI filtering - scoring happens later
Returns list of dicts with keys:
- id: int (candidate index)
- center: (x, y) tuple
- box: (x_min, y_min, x_max, y_max) tuple
- width: float
- height: float
- conf: float (YOLO confidence)
- x_ratio: float (horizontal position as fraction of frame width)
- y_ratio: float (vertical position as fraction of frame height)
"""
model = get_yolo_model(model_filename)
class_ids = [
idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME
]
if not class_ids:
return []
results = model.predict(
source=frame,
conf=conf_threshold,
iou=iou_threshold,
max_det=max_detections,
classes=class_ids,
imgsz=640,
device="cpu",
verbose=False,
)
if not results:
return []
boxes = results[0].boxes
if boxes is None or len(boxes) == 0:
return []
frame_width, frame_height = frame.size
candidates = []
for i, box in enumerate(boxes):
xywh = box.xywh[0].cpu().tolist()
conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0
x_center, y_center, width, height = xywh
# Compute bounding box
x_min = int(round(max(0.0, x_center - width / 2.0)))
y_min = int(round(max(0.0, y_center - height / 2.0)))
x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0)))
y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0)))
if x_max <= x_min or y_max <= y_min:
continue
# Compute position ratios
x_ratio = x_center / frame_width
y_ratio = y_center / frame_height
# NO ROI filtering - accept all balls
# (ROI scoring will happen later in the scoring phase)
candidates.append({
"id": len(candidates),
"center": (float(x_center), float(y_center)),
"box": (x_min, y_min, x_max, y_max),
"width": float(width),
"height": float(height),
"conf": conf,
"x_ratio": x_ratio,
"y_ratio": y_ratio,
})
# Sort by confidence descending
candidates.sort(key=lambda c: c["conf"], reverse=True)
# Keep only top N candidates
candidates = candidates[:max_candidates]
# Re-assign IDs after sorting and filtering
for i, c in enumerate(candidates):
c["id"] = i
# Debug logging
print(f"[detect_all_balls] Found {len(candidates)} ball candidates (top {max_candidates}, conf >= {conf_threshold:.0%}):")
for c in candidates:
print(f" Ball {c['id']}: center={c['center']}, conf={c['conf']:.1%}, box={c['box']}")
return candidates
def detect_person_box(
frame: Image.Image,
model_filename: str = YOLO_DEFAULT_MODEL,
conf_threshold: float = YOLO_CONF_THRESHOLD,
iou_threshold: float = YOLO_IOU_THRESHOLD,
) -> Optional[tuple[int, int, int, int, float]]:
"""
Run YOLO on a single frame and return (x_min, y_min, x_max, y_max, confidence)
for the highest-confidence person detection.
"""
model = get_yolo_model(model_filename)
class_ids = [
idx for idx, name in model.names.items() if name.lower() == PLAYER_TARGET_NAME
]
if not class_ids:
return None
results = model.predict(
source=frame,
conf=conf_threshold,
iou=iou_threshold,
max_det=5,
classes=class_ids,
imgsz=640,
device="cpu",
verbose=False,
)
if not results:
return None
boxes = results[0].boxes
if boxes is None or len(boxes) == 0:
return None
box = boxes[0]
xyxy = box.xyxy[0].cpu().tolist()
conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0
x_min, y_min, x_max, y_max = xyxy
frame_width, frame_height = frame.size
x_min = max(0, min(frame_width - 1, int(round(x_min))))
y_min = max(0, min(frame_height - 1, int(round(y_min))))
x_max = max(0, min(frame_width - 1, int(round(x_max))))
y_max = max(0, min(frame_height - 1, int(round(y_max))))
if x_max <= x_min or y_max <= y_min:
return None
return x_min, y_min, x_max, y_max, conf
def _compute_sam_window_from_kick(state: AppState, kick_frame: int | None) -> tuple[int, int]:
total_frames = state.num_frames
if total_frames == 0:
return 0, 0
# If no kick detected, use ALL frames
if kick_frame is None:
start_idx = 0
end_idx = total_frames
print(f"[_compute_sam_window_from_kick] No kick detected → using ALL {total_frames} frames")
else:
# If kick detected, use 4-second window around kick
fps = state.video_fps if state.video_fps and state.video_fps > 0 else 25.0
target_window_frames = max(1, int(round(fps * 4.0)))
half_window = target_window_frames // 2
start_idx = max(0, int(kick_frame) - half_window)
end_idx = min(total_frames, start_idx + target_window_frames)
if end_idx <= start_idx:
end_idx = min(total_frames, start_idx + 1)
print(f"[_compute_sam_window_from_kick] Kick @ {kick_frame} → window [{start_idx}, {end_idx}] ({end_idx - start_idx} frames)")
state.sam_window = (start_idx, end_idx)
return start_idx, end_idx
def _goal_frame_dims(state: AppState, frame_idx: int | None = None) -> tuple[int, int]:
if state is None or not state.video_frames:
return 1, 1
idx = 0 if frame_idx is None else int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
frame = state.video_frames[idx]
return frame.size
def _goal_norm_from_xy(state: AppState, frame_idx: int, x: int, y: int) -> tuple[float, float]:
width, height = _goal_frame_dims(state, frame_idx)
if width <= 0:
width = 1
if height <= 0:
height = 1
return (
float(np.clip(x / width, 0.0, 1.0)),
float(np.clip(y / height, 0.0, 1.0)),
)
def _goal_xy_from_norm(state: AppState, frame_idx: int, pt: tuple[float, float]) -> tuple[int, int]:
width, height = _goal_frame_dims(state, frame_idx)
return (
int(round(float(pt[0]) * width)),
int(round(float(pt[1]) * height)),
)
def _goal_points_for_drawing(state: AppState) -> list[tuple[float, float]]:
if state is None:
return []
if state.goal_mode in {GOAL_MODE_PLACING_FIRST, GOAL_MODE_PLACING_SECOND, GOAL_MODE_EDITING}:
return list(state.goal_points_norm)
return list(state.goal_overlay_points)
def _goal_clear_preview_cache(state: AppState) -> None:
if state is None:
return
state.composited_frames.clear()
def _goal_has_confirmed(state: AppState) -> bool:
return isinstance(state, AppState) and len(state.goal_confirmed_points_norm) == 2
def _goal_set_status(state: AppState, text: str) -> None:
if state is None:
return
state.goal_status_text = text
def _goal_status_text(state: AppState) -> str:
if state is None:
return "Goal crossbar unavailable."
if state.goal_status_text:
return state.goal_status_text
if _goal_has_confirmed(state):
return "Goal crossbar confirmed. Click Start Mapping to adjust."
return "Goal crossbar inactive."
def _goal_button_updates(state: AppState) -> tuple[Any, Any, Any, Any, Any]:
if state is None:
return (
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(value="Goal crossbar unavailable.", visible=True),
)
start_enabled = state.goal_mode == GOAL_MODE_IDLE
confirm_enabled = len(state.goal_points_norm) == 2 and state.goal_mode in {
GOAL_MODE_PLACING_SECOND,
GOAL_MODE_EDITING,
}
clear_enabled = bool(state.goal_points_norm or state.goal_confirmed_points_norm)
back_enabled = bool(state.goal_prev_confirmed_points_norm)
status_update = gr.update(value=_goal_status_text(state), visible=True)
return (
gr.update(interactive=start_enabled),
gr.update(interactive=confirm_enabled),
gr.update(interactive=clear_enabled),
gr.update(interactive=back_enabled),
status_update,
)
def _goal_handle_hit_index(state: AppState, frame_idx: int, x: int, y: int) -> int | None:
points = state.goal_points_norm
if state is None or len(points) == 0:
return None
width, height = _goal_frame_dims(state, frame_idx)
max_dist = GOAL_HANDLE_HIT_RADIUS_PX
for idx, pt in enumerate(points):
px, py = _goal_xy_from_norm(state, frame_idx, pt)
dist = math.hypot(px - x, py - y)
if dist <= max_dist:
return idx
return None
def _goal_current_frame_idx(state: AppState) -> int:
if state is None or state.num_frames == 0:
return 0
idx = int(getattr(state, "current_frame_idx", 0))
return int(np.clip(idx, 0, state.num_frames - 1))
def _goal_output_tuple(state: AppState, preview_img: Image.Image | None = None) -> tuple[Image.Image, Any, Any, Any, Any, Any]:
if state is None:
return (preview_img, *(gr.update(interactive=False) for _ in range(4)), gr.update(value="Goal crossbar unavailable.", visible=True))
idx = _goal_current_frame_idx(state)
if preview_img is None:
preview_img = update_frame_display(state, idx)
return (preview_img, *_goal_button_updates(state))
def _goal_start_mapping(state: AppState) -> tuple[Image.Image, Any, Any, Any, Any, Any]:
if state is None or not state.video_frames:
raise gr.Error("Load a video first, then map the goal crossbar.")
state.goal_prev_confirmed_points_norm = list(state.goal_confirmed_points_norm)
if state.goal_confirmed_points_norm:
state.goal_points_norm = list(state.goal_confirmed_points_norm)
else:
state.goal_points_norm = []
state.goal_overlay_points = []
state.goal_mode = GOAL_MODE_PLACING_FIRST
state.goal_dragging_idx = None
_goal_set_status(state, "Click the left goalpost to start the crossbar.")
_goal_clear_preview_cache(state)
return _goal_output_tuple(state)
def _goal_confirm_mapping(state: AppState) -> tuple[Image.Image, Any, Any, Any, Any, Any]:
if state is None:
return (None, *_goal_button_updates(state))
if len(state.goal_points_norm) != 2:
_goal_set_status(state, "Select both goal corners before confirming.")
return _goal_output_tuple(state)
state.goal_confirmed_points_norm = list(state.goal_points_norm)
state.goal_overlay_points = list(state.goal_points_norm)
state.goal_points_norm = []
state.goal_mode = GOAL_MODE_IDLE
state.goal_dragging_idx = None
_goal_set_status(state, "Goal crossbar saved. Click Start Mapping to adjust again.")
_goal_clear_preview_cache(state)
return _goal_output_tuple(state)
def _goal_clear_mapping(state: AppState) -> tuple[Image.Image, Any, Any, Any, Any, Any]:
if state is None:
return (None, *_goal_button_updates(state))
state.goal_points_norm.clear()
state.goal_confirmed_points_norm.clear()
state.goal_prev_confirmed_points_norm.clear()
state.goal_overlay_points.clear()
state.goal_points_norm = []
state.goal_confirmed_points_norm = []
state.goal_prev_confirmed_points_norm = []
state.goal_overlay_points = []
state.goal_mode = GOAL_MODE_IDLE
state.goal_dragging_idx = None
_goal_set_status(state, "Goal crossbar cleared.")
_goal_clear_preview_cache(state)
return _goal_output_tuple(state)
def _goal_back_mapping(state: AppState) -> tuple[Image.Image, Any, Any, Any, Any, Any]:
if state is None:
return (None, *_goal_button_updates(state))
if not state.goal_prev_confirmed_points_norm:
_goal_set_status(state, "No previous goal crossbar to restore.")
return _goal_output_tuple(state)
state.goal_confirmed_points_norm = list(state.goal_prev_confirmed_points_norm)
state.goal_overlay_points = list(state.goal_prev_confirmed_points_norm)
state.goal_points_norm = []
state.goal_prev_confirmed_points_norm = []
state.goal_mode = GOAL_MODE_IDLE
state.goal_dragging_idx = None
_goal_set_status(state, "Restored the previous goal crossbar.")
_goal_clear_preview_cache(state)
return _goal_output_tuple(state)
def _goal_process_preview_click(
state: AppState,
frame_idx: int,
evt: gr.SelectData | None,
) -> tuple[Image.Image | None, bool]:
if state is None or state.goal_mode == GOAL_MODE_IDLE:
return None, False
x = y = None
if evt is not None:
try:
if hasattr(evt, "index") and isinstance(evt.index, (list, tuple)) and len(evt.index) == 2:
x, y = int(evt.index[0]), int(evt.index[1])
elif hasattr(evt, "value") and isinstance(evt.value, dict):
data = evt.value
if "x" in data and "y" in data:
x, y = int(data["x"]), int(data["y"])
except Exception:
x = y = None
if x is None or y is None:
_goal_set_status(state, "Could not read click coordinates. Please try again.")
return _goal_output_tuple(state)[0], True
norm_pt = _goal_norm_from_xy(state, frame_idx, x, y)
points = state.goal_points_norm
if state.goal_mode == GOAL_MODE_PLACING_FIRST:
state.goal_points_norm = [norm_pt]
state.goal_mode = GOAL_MODE_PLACING_SECOND
_goal_set_status(state, "Click the right goalpost to finish the crossbar.")
elif state.goal_mode == GOAL_MODE_PLACING_SECOND:
handle_idx = _goal_handle_hit_index(state, frame_idx, x, y)
if handle_idx is not None and handle_idx < len(points):
state.goal_points_norm[handle_idx] = norm_pt
_goal_set_status(state, "Adjusted the first corner. Click the other post.")
else:
if len(points) == 0:
state.goal_points_norm = [norm_pt]
_goal_set_status(state, "Click the next goalpost to finish the crossbar.")
elif len(points) == 1:
state.goal_points_norm.append(norm_pt)
state.goal_mode = GOAL_MODE_EDITING
_goal_set_status(state, "Adjust handles if needed, then Confirm.")
else:
state.goal_points_norm[1] = norm_pt
state.goal_mode = GOAL_MODE_EDITING
_goal_set_status(state, "Adjust handles if needed, then Confirm.")
elif state.goal_mode == GOAL_MODE_EDITING:
handle_idx = _goal_handle_hit_index(state, frame_idx, x, y)
if handle_idx is None and len(points) == 2:
# fall back to whichever endpoint is closest to click
px0, py0 = _goal_xy_from_norm(state, frame_idx, points[0])
px1, py1 = _goal_xy_from_norm(state, frame_idx, points[1])
dist0 = math.hypot(px0 - x, py0 - y)
dist1 = math.hypot(px1 - x, py1 - y)
handle_idx = 0 if dist0 <= dist1 else 1
if handle_idx is not None and handle_idx < len(points):
state.goal_points_norm[handle_idx] = norm_pt
_goal_set_status(state, "Handle moved. Press Confirm to save.")
state.goal_points_norm = state.goal_points_norm[:2]
_goal_clear_preview_cache(state)
preview_img = update_frame_display(state, frame_idx)
return preview_img, True
def _draw_goal_overlay(state: AppState, frame_idx: int, image: Image.Image) -> None:
if state is None or image is None:
return
points = _goal_points_for_drawing(state)
if not points:
return
draw = ImageDraw.Draw(image)
px_points = [_goal_xy_from_norm(state, frame_idx, pt) for pt in points[:2]]
if len(px_points) >= 2:
draw.line(
[px_points[0], px_points[1]],
fill=GOAL_LINE_COLOR,
width=4,
)
handle_radius = max(4, GOAL_HANDLE_RADIUS_PX)
for cx, cy in px_points:
bbox = [
(cx - handle_radius, cy - handle_radius),
(cx + handle_radius, cy + handle_radius),
]
draw.ellipse(bbox, outline=GOAL_LINE_COLOR, fill=GOAL_HANDLE_FILL, width=2)
def _perform_yolo_ball_tracking(state: AppState, progress: gr.Progress | None = None) -> None:
if state is None or state.num_frames == 0:
raise gr.Error("Load a video first, then track with YOLO.")
model = get_yolo_model()
class_ids = [
idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME
]
if not class_ids:
raise gr.Error("YOLO model does not contain the sports ball class.")
frames = state.video_frames
total = len(frames)
centers: dict[int, tuple[float, float]] = {}
boxes: dict[int, tuple[int, int, int, int]] = {}
confs: dict[int, float] = {}
areas: dict[int, float] = {}
first_detection_frame: int | None = None
for idx, frame in enumerate(frames):
if progress is not None:
progress((idx + 1) / total)
results = model.predict(
source=frame,
conf=YOLO_CONF_THRESHOLD,
iou=YOLO_IOU_THRESHOLD,
max_det=1,
classes=class_ids,
imgsz=640,
device="cpu",
verbose=False,
)
if not results:
continue
boxes_result = results[0].boxes
if boxes_result is None or len(boxes_result) == 0:
continue
box = boxes_result[0]
xywh = box.xywh[0].cpu().tolist()
conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0
x_center, y_center, width, height = xywh
x_center = float(x_center)
y_center = float(y_center)
width = max(1.0, float(width))
height = max(1.0, float(height))
frame_width, frame_height = frame.size
x_min = int(round(max(0.0, x_center - width / 2.0)))
y_min = int(round(max(0.0, y_center - height / 2.0)))
x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0)))
y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0)))
if x_max <= x_min or y_max <= y_min:
continue
centers[idx] = (x_center, y_center)
boxes[idx] = (x_min, y_min, x_max, y_max)
confs[idx] = conf
areas[idx] = float((x_max - x_min) * (y_max - y_min))
if first_detection_frame is None:
first_detection_frame = idx
state.yolo_ball_centers = centers
state.yolo_ball_boxes = boxes
state.yolo_ball_conf = confs
state.yolo_mask_area_proxy = [areas.get(k, 0.0) for k in sorted(centers.keys())]
state.yolo_initial_frame = first_detection_frame
if len(centers) < 3:
state.yolo_smoothed_centers = {}
state.yolo_speeds = {}
state.yolo_distance_from_start = {}
state.yolo_threshold = None
state.yolo_baseline_speed = None
state.yolo_speed_std = None
state.yolo_kick_frame = None
state.yolo_status = "❌ YOLO13: insufficient detections to estimate kick. Please retry or annotate manually."
state.sam_window = None
return
items = sorted(centers.items())
dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0
alpha = 0.35
smoothed: dict[int, tuple[float, float]] = {}
speeds: dict[int, float] = {}
prev_frame = None
prev_smooth = None
for frame_idx, (cx, cy) in items:
if prev_smooth is None:
smooth_x, smooth_y = float(cx), float(cy)
else:
smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0])
smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1])
smoothed[frame_idx] = (smooth_x, smooth_y)
if prev_smooth is None or prev_frame is None:
speeds[frame_idx] = 0.0
else:
frame_delta = max(1, frame_idx - prev_frame)
time_delta = frame_delta * dt
dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1])
speed = dist / time_delta if time_delta > 0 else dist
speeds[frame_idx] = speed
prev_smooth = (smooth_x, smooth_y)
prev_frame = frame_idx
frames_ordered = [frame_idx for frame_idx, _ in items]
speed_series = [speeds.get(f, 0.0) for f in frames_ordered]
baseline_window = min(10, len(frames_ordered) // 3 or 1)
baseline_speeds = speed_series[:baseline_window]
baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0
speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0
base_threshold = baseline_speed + 4.0 * speed_std
if base_threshold < baseline_speed * 3.0:
base_threshold = baseline_speed * 3.0
speed_threshold = max(base_threshold, 15.0)
distance_dict: dict[int, float] = {}
if smoothed:
first_frame = frames_ordered[0]
origin = smoothed[first_frame]
for frame_idx, (sx, sy) in smoothed.items():
distance_dict[frame_idx] = math.hypot(sx - origin[0], sy - origin[1])
areas_dict = {idx: areas.get(idx, 0.0) for idx in frames_ordered}
initial_area = areas_dict.get(frames_ordered[0], 1.0) or 1.0
radius_estimate = math.sqrt(initial_area / math.pi)
adaptive_return_distance = max(8.0, min(radius_estimate * 1.5, 40.0))
sustain_frames = 3
holdout_frames = 8
area_window = 4
area_drop_ratio = 0.75
kalman_pos, kalman_speed, _ = _run_kalman_filter(items, dt)
kalman_speed_series = [kalman_speed.get(f, 0.0) for f in frames_ordered]
kick_frame: int | None = None
for idx, frame in enumerate(frames_ordered[baseline_window:], start=baseline_window):
speed = speed_series[idx]
if speed < speed_threshold:
continue
sustain_ok = True
for j in range(1, sustain_frames + 1):
if idx + j >= len(frames_ordered):
break
if speed_series[idx + j] < speed_threshold * 0.7:
sustain_ok = False
break
if not sustain_ok:
continue
area_pass = True
current_area = areas_dict.get(frame)
if current_area:
prev_areas = [
areas_dict.get(f)
for f in frames_ordered[max(0, idx - area_window):idx]
if areas_dict.get(f) is not None
]
if prev_areas:
median_prev = statistics.median(prev_areas)
if median_prev > 0:
ratio = current_area / median_prev
if ratio > area_drop_ratio:
area_pass = False
if not area_pass and speed < speed_threshold * 1.2:
continue
future_slice = frames_ordered[idx: min(len(frames_ordered), idx + holdout_frames)]
max_future_dist = 0.0
for future_frame in future_slice:
dist = distance_dict.get(future_frame, 0.0)
if dist > max_future_dist:
max_future_dist = dist
if max_future_dist < adaptive_return_distance:
continue
kick_frame = frame
break
state.yolo_smoothed_centers = smoothed
state.yolo_speeds = speeds
state.yolo_distance_from_start = distance_dict
state.yolo_threshold = speed_threshold
state.yolo_baseline_speed = baseline_speed
state.yolo_speed_std = speed_std
state.yolo_kick_frames = frames_ordered
state.yolo_kick_speeds = speed_series
state.yolo_kick_distance = [distance_dict.get(f, 0.0) for f in frames_ordered]
state.yolo_mask_area_proxy = [areas_dict.get(f, 0.0) for f in frames_ordered]
state.yolo_kick_frame = kick_frame
coverage = len(centers) / total if total else 0.0
if kick_frame is not None:
state.yolo_status = f"✅ YOLO13 tracked {len(centers)}/{total} frames ({coverage:.0%})."
else:
state.yolo_status = (
f"⚠️ YOLO13 tracked {len(centers)}/{total} frames ({coverage:.0%}) but did not find a definitive kick."
)
state.kalman_centers[BALL_OBJECT_ID] = kalman_pos
state.kalman_speeds[BALL_OBJECT_ID] = kalman_speed
if kick_frame is not None:
state.kick_frame = kick_frame
_compute_sam_window_from_kick(state, kick_frame)
else:
state.sam_window = None
def _track_single_ball_candidate(
state: AppState,
candidate: dict,
progress: gr.Progress | None = None,
) -> dict:
"""
Track a single ball candidate across ALL frames using YOLO.
Uses proximity matching to follow the same ball.
Returns dict with tracking results:
- centers: dict[frame_idx, (x, y)]
- speeds: dict[frame_idx, speed]
- kick_frame: int | None
- max_velocity: float
- has_kick: bool
- coverage: float (fraction of frames with detection)
"""
model = get_yolo_model()
class_ids = [
idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME
]
frames = state.video_frames
total = len(frames)
print(f"[_track_single_ball_candidate] Tracking Ball {candidate['id']} across {total} frames...")
# Initial position from candidate
last_center = candidate["center"]
max_distance_threshold = 100 # Max pixels to consider same ball
centers: dict[int, tuple[float, float]] = {}
boxes: dict[int, tuple[int, int, int, int]] = {}
confs: dict[int, float] = {}
areas: dict[int, float] = {}
for idx, frame in enumerate(frames):
if progress is not None:
progress((idx + 1) / total)
results = model.predict(
source=frame,
conf=0.05, # Lower threshold to catch more
iou=YOLO_IOU_THRESHOLD,
max_det=10, # Allow multiple detections
classes=class_ids,
imgsz=640,
device="cpu",
verbose=False,
)
if not results:
continue
boxes_result = results[0].boxes
if boxes_result is None or len(boxes_result) == 0:
continue
# Find the detection closest to last known position
best_box = None
best_distance = float("inf")
for box in boxes_result:
xywh = box.xywh[0].cpu().tolist()
x_center, y_center = xywh[0], xywh[1]
dist = math.hypot(x_center - last_center[0], y_center - last_center[1])
if dist < best_distance and dist < max_distance_threshold:
best_distance = dist
best_box = box
if best_box is None:
continue
xywh = best_box.xywh[0].cpu().tolist()
conf = float(best_box.conf[0].cpu().item()) if best_box.conf is not None else 0.0
x_center, y_center, width, height = xywh
x_center = float(x_center)
y_center = float(y_center)
width = max(1.0, float(width))
height = max(1.0, float(height))
frame_width, frame_height = frame.size
x_min = int(round(max(0.0, x_center - width / 2.0)))
y_min = int(round(max(0.0, y_center - height / 2.0)))
x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0)))
y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0)))
if x_max <= x_min or y_max <= y_min:
continue
centers[idx] = (x_center, y_center)
boxes[idx] = (x_min, y_min, x_max, y_max)
confs[idx] = conf
areas[idx] = float((x_max - x_min) * (y_max - y_min))
last_center = (x_center, y_center)
# Compute speeds
if len(centers) < 3:
return {
"centers": centers,
"boxes": boxes,
"confs": confs,
"areas": areas,
"speeds": {},
"smoothed_centers": {},
"frames_ordered": [],
"speed_series": [],
"kick_frame": None,
"max_velocity": 0.0,
"has_kick": False,
"coverage": len(centers) / total if total else 0.0,
}
items = sorted(centers.items())
dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0
alpha = 0.35
smoothed: dict[int, tuple[float, float]] = {}
speeds: dict[int, float] = {}
prev_frame = None
prev_smooth = None
for frame_idx, (cx, cy) in items:
if prev_smooth is None:
smooth_x, smooth_y = float(cx), float(cy)
else:
smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0])
smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1])
smoothed[frame_idx] = (smooth_x, smooth_y)
if prev_smooth is None or prev_frame is None:
speeds[frame_idx] = 0.0
else:
frame_delta = max(1, frame_idx - prev_frame)
time_delta = frame_delta * dt
dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1])
speed = dist / time_delta if time_delta > 0 else dist
speeds[frame_idx] = speed
prev_smooth = (smooth_x, smooth_y)
prev_frame = frame_idx
frames_ordered = [frame_idx for frame_idx, _ in items]
speed_series = [speeds.get(f, 0.0) for f in frames_ordered]
# Detect kick (velocity spike)
baseline_window = min(10, len(frames_ordered) // 3 or 1)
baseline_speeds = speed_series[:baseline_window]
baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0
speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0
base_threshold = baseline_speed + 4.0 * speed_std
if base_threshold < baseline_speed * 3.0:
base_threshold = baseline_speed * 3.0
speed_threshold = max(base_threshold, 15.0)
kick_frame: int | None = None
max_velocity = max(speed_series) if speed_series else 0.0
for idx, frame in enumerate(frames_ordered[baseline_window:], start=baseline_window):
speed = speed_series[idx]
if speed < speed_threshold:
continue
# Check sustain
sustain_ok = True
for j in range(1, 4):
if idx + j >= len(frames_ordered):
break
if speed_series[idx + j] < speed_threshold * 0.7:
sustain_ok = False
break
if sustain_ok:
kick_frame = frame
break
result = {
"centers": centers,
"boxes": boxes,
"confs": confs,
"areas": areas,
"speeds": speeds,
"smoothed_centers": smoothed,
"frames_ordered": frames_ordered,
"speed_series": speed_series,
"threshold": speed_threshold,
"baseline": baseline_speed,
"kick_frame": kick_frame,
"max_velocity": max_velocity,
"has_kick": kick_frame is not None,
"coverage": len(centers) / total if total else 0.0,
}
# Summary logging
kick_info = f"Kick @ frame {kick_frame}" if kick_frame else "No kick"
print(f"[_track_single_ball_candidate] Ball {candidate['id']} done: "
f"{len(centers)}/{total} frames ({result['coverage']:.0%}), "
f"max_vel={max_velocity:.1f}px/s, {kick_info}")
return result
def _detect_and_track_all_ball_candidates(
state: AppState,
progress: gr.Progress | None = None,
) -> None:
"""
Detect all ball candidates in first frame, track each with YOLO,
score them, and auto-select the best candidate.
"""
if state is None or state.num_frames == 0:
raise gr.Error("Load a video first.")
first_frame = state.video_frames[0]
frame_width, frame_height = first_frame.size
# Step 1: Detect all balls in first frame
candidates = detect_all_balls(first_frame)
if not candidates:
state.ball_candidates = []
state.multi_ball_status = "❌ No ball candidates detected in first frame."
return
state.multi_ball_status = f"🔍 Found {len(candidates)} ball candidate(s). Tracking..."
# Step 2: Track each candidate
tracking_results: dict[int, dict] = {}
for i, candidate in enumerate(candidates):
if progress is not None:
progress((i + 1) / len(candidates), desc=f"Tracking ball {i+1}/{len(candidates)}")
result = _track_single_ball_candidate(state, candidate, progress=None)
tracking_results[candidate["id"]] = result
# Add tracking summary to candidate
candidate["tracking"] = result
candidate["has_kick"] = result["has_kick"]
candidate["kick_frame"] = result["kick_frame"]
candidate["max_velocity"] = result["max_velocity"]
candidate["coverage"] = result["coverage"]
# Step 3: Score candidates
frame_center_x = frame_width / 2
for candidate in candidates:
score = 0.0
# 1. Has a detected kick (velocity spike) — most important
if candidate["has_kick"]:
score += 50
# 2. Higher max velocity — ball that moves most
score += min(30, candidate["max_velocity"] / 10)
# 3. Centered horizontally
x_offset = abs(candidate["center"][0] - frame_center_x) / frame_center_x
score += 20 * (1 - x_offset)
# 4. YOLO confidence as tiebreaker
score += candidate["conf"] * 10
# 5. Better coverage
score += candidate["coverage"] * 10
candidate["score"] = score
# Sort by score descending
candidates.sort(key=lambda c: c["score"], reverse=True)
# Re-assign IDs after sorting
for i, c in enumerate(candidates):
c["id"] = i
state.ball_candidates = candidates
state.ball_candidates_tracking = tracking_results
state.selected_ball_idx = 0 # Auto-select best candidate
state.ball_selection_confirmed = False
# Build status message
if len(candidates) == 1:
c = candidates[0]
kick_info = f"Kick @ frame {c['kick_frame']}" if c["has_kick"] else "No kick detected"
state.multi_ball_status = f"✅ 1 ball detected. {kick_info}."
else:
kicked_count = sum(1 for c in candidates if c["has_kick"])
state.multi_ball_status = (
f"⚠️ {len(candidates)} balls detected. "
f"{kicked_count} show movement. "
f"Best candidate auto-selected. Please confirm or change selection."
)
def _apply_selected_ball_to_yolo_state(state: AppState) -> None:
"""
Copy the selected ball candidate's tracking data to the main YOLO state.
This allows the rest of the pipeline to work unchanged.
"""
if not state.ball_candidates:
return
idx = state.selected_ball_idx
if idx < 0 or idx >= len(state.ball_candidates):
idx = 0
candidate = state.ball_candidates[idx]
tracking = candidate.get("tracking", {})
# Copy to main YOLO state
state.yolo_ball_centers = tracking.get("centers", {})
state.yolo_ball_boxes = tracking.get("boxes", {})
state.yolo_ball_conf = tracking.get("confs", {})
state.yolo_smoothed_centers = tracking.get("smoothed_centers", {})
state.yolo_speeds = tracking.get("speeds", {})
state.yolo_kick_frames = tracking.get("frames_ordered", [])
state.yolo_kick_speeds = tracking.get("speed_series", [])
state.yolo_threshold = tracking.get("threshold")
state.yolo_baseline_speed = tracking.get("baseline")
state.yolo_kick_frame = tracking.get("kick_frame")
state.yolo_initial_frame = tracking.get("frames_ordered", [None])[0] if tracking.get("frames_ordered") else None
# Compute areas
areas = tracking.get("areas", {})
frames_ordered = tracking.get("frames_ordered", [])
state.yolo_mask_area_proxy = [areas.get(f, 0.0) for f in frames_ordered]
# Compute distance from start
smoothed = tracking.get("smoothed_centers", {})
if smoothed and frames_ordered:
origin = smoothed.get(frames_ordered[0], (0, 0))
distance_dict = {}
for f, (sx, sy) in smoothed.items():
distance_dict[f] = math.hypot(sx - origin[0], sy - origin[1])
state.yolo_distance_from_start = distance_dict
state.yolo_kick_distance = [distance_dict.get(f, 0.0) for f in frames_ordered]
# Update kick frame and SAM window
kick_frame = tracking.get("kick_frame")
state.kick_frame = kick_frame # Can be None
# Always compute SAM window - if no kick, it will use ALL frames
_compute_sam_window_from_kick(state, kick_frame)
# Mark as tracked
state.is_yolo_tracked = True
state.ball_selection_confirmed = True
coverage = tracking.get("coverage", 0.0)
if kick_frame is not None:
state.yolo_status = f"✅ Ball {idx+1} tracked. Kick @ frame {kick_frame}."
else:
state.yolo_status = f"⚠️ Ball {idx+1} tracked ({coverage:.0%} coverage) but no kick detected. SAM2 will analyze ALL frames."
def draw_yolo_detections_on_frame(
frame: Image.Image,
candidates: list[dict],
selected_idx: int = 0,
show_all: bool = True,
) -> Image.Image:
"""
Draw YOLO bounding boxes for all ball candidates on the frame.
- Selected candidate: Green box with thick border
- Other candidates: Yellow/orange boxes with thinner border
- Each box labeled with "Ball N (conf%)"
"""
from PIL import ImageDraw, ImageFont
result = frame.copy()
draw = ImageDraw.Draw(result)
# Try to get a font, fallback to default
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
except:
font = ImageFont.load_default()
small_font = font
for i, candidate in enumerate(candidates):
box = candidate.get("box")
if not box:
continue
x_min, y_min, x_max, y_max = box
conf = candidate.get("conf", 0)
is_selected = (i == selected_idx)
has_kick = candidate.get("has_kick", False)
# Colors and styles
if is_selected:
box_color = (0, 255, 0) # Green for selected
text_color = (0, 255, 0)
width = 4
elif has_kick:
box_color = (255, 165, 0) # Orange for kicked but not selected
text_color = (255, 165, 0)
width = 3
else:
box_color = (255, 255, 0) # Yellow for others
text_color = (255, 255, 0)
width = 2
# Draw bounding box
for offset in range(width):
draw.rectangle(
[x_min - offset, y_min - offset, x_max + offset, y_max + offset],
outline=box_color,
)
# Draw dark outline for visibility
draw.rectangle(
[x_min - width - 1, y_min - width - 1, x_max + width + 1, y_max + width + 1],
outline=(0, 0, 0),
)
# Label
label = f"Ball {i + 1} ({conf:.0%})"
if is_selected:
label = f"✓ {label}"
if has_kick:
label += " ⚽"
# Draw label background
text_bbox = draw.textbbox((x_min, y_min - 22), label, font=font)
padding = 3
bg_box = [
text_bbox[0] - padding,
text_bbox[1] - padding,
text_bbox[2] + padding,
text_bbox[3] + padding,
]
draw.rectangle(bg_box, fill=(0, 0, 0, 200))
# Draw label text
draw.text((x_min, y_min - 22), label, fill=text_color, font=font)
# Draw center crosshair
cx, cy = candidate.get("center", (0, 0))
cx, cy = int(cx), int(cy)
cross_size = 8
draw.line([(cx - cross_size, cy), (cx + cross_size, cy)], fill=box_color, width=2)
draw.line([(cx, cy - cross_size), (cx, cy + cross_size)], fill=box_color, width=2)
return result
def _format_ball_candidates_for_radio(candidates: list[dict]) -> list[str]:
"""Format ball candidates as radio button choices."""
choices = []
for i, c in enumerate(candidates):
kick_info = f"⚽ Kick@{c['kick_frame']}" if c.get('has_kick') else "No kick"
vel_info = f"v={c.get('max_velocity', 0):.0f}px/s"
conf_info = f"conf={c.get('conf', 0):.0%}"
cov_info = f"cov={c.get('coverage', 0):.0%}"
pos_info = f"x={c.get('x_ratio', 0.5):.0%}"
label = f"Ball {i+1}: {kick_info} | {vel_info} | {pos_info} | {conf_info}"
choices.append(label)
return choices
def _format_ball_candidates_markdown(candidates: list[dict], selected_idx: int = 0) -> str:
"""Format ball candidates as markdown summary."""
if not candidates:
return ""
lines = [f"**{len(candidates)} ball candidate(s) detected:**\n"]
for i, c in enumerate(candidates):
marker = "✅" if i == selected_idx else "○"
kick_info = f"⚽ Kick @ frame {c['kick_frame']}" if c.get('has_kick') else "No kick detected"
vel_info = f"Max velocity: {c.get('max_velocity', 0):.0f} px/s"
conf_info = f"YOLO conf: {c.get('conf', 0):.0%}"
pos_x = c.get('x_ratio', 0.5)
pos_info = f"Position: {pos_x:.0%} from left"
lines.append(f"{marker} **Ball {i+1}**: {kick_info}")
lines.append(f" - {vel_info} | {pos_info} | {conf_info}")
return "\n".join(lines)
def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
"""Generate a deterministic pastel RGB color for a given object id.
Uses golden ratio to distribute hues; low-medium saturation, high value.
"""
golden_ratio_conjugate = 0.61803398875
# Map obj_id (1-based) to hue in [0,1)
hue = (obj_id * golden_ratio_conjugate) % 1.0
saturation = 0.45
value = 1.0
r_f, g_f, b_f = colorsys.hsv_to_rgb(hue, saturation, value)
return int(r_f * 255), int(g_f * 255), int(b_f * 255)
def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]:
"""Load video frames as PIL Images using transformers.video_utils if available,
otherwise fall back to OpenCV. Returns (frames, info).
"""
cap = cv2.VideoCapture(video_path_or_url)
frames = []
print("loading video frames")
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
# Gather fps if available
fps_val = cap.get(cv2.CAP_PROP_FPS)
cap.release()
print("loaded video frames")
info = {
"num_frames": len(frames),
"fps": float(fps_val) if fps_val and fps_val > 0 else None,
}
return frames, info
def overlay_masks_on_frame(
frame: Image.Image,
masks_per_object: dict[int, np.ndarray],
color_by_obj: dict[int, tuple[int, int, int]],
alpha: float = 0.5,
) -> Image.Image:
"""Overlay per-object soft masks onto the RGB frame.
masks_per_object: mapping of obj_id -> (H, W) float mask in [0,1]
color_by_obj: mapping of obj_id -> (R, G, B)
"""
base = np.array(frame).astype(np.float32) / 255.0 # H, W, 3 in [0,1]
height, width = base.shape[:2]
overlay = base.copy()
for obj_id, mask in masks_per_object.items():
if mask is None:
continue
if mask.dtype != np.float32:
mask = mask.astype(np.float32)
# Ensure shape is H x W
if mask.ndim == 3:
mask = mask.squeeze()
mask = np.clip(mask, 0.0, 1.0)
color = np.array(color_by_obj.get(obj_id, (255, 0, 0)), dtype=np.float32) / 255.0
# Blend: overlay = (1 - a*m)*overlay + (a*m)*color
a = alpha
m = mask[..., None]
overlay = (1.0 - a * m) * overlay + (a * m) * color
out = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
return Image.fromarray(out)
def get_device_and_dtype() -> tuple[str, torch.dtype]:
device = "cpu"
dtype = torch.bfloat16
return device, dtype
class AppState:
def __init__(self):
self.reset()
def reset(self):
self.video_frames: list[Image.Image] = []
self.inference_session = None
self.model: Optional[AutoModel] = None
self.processor: Optional[Sam2VideoProcessor] = None
self.device: str = "cpu"
self.dtype: torch.dtype = torch.bfloat16
self.video_fps: float | None = None
self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
self.color_by_obj: dict[int, tuple[int, int, int]] = {}
self.clicks_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int]]]] = {}
self.boxes_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int, int]]]] = {}
# Cache of composited frames (original + masks + clicks)
self.composited_frames: dict[int, Image.Image] = {}
# UI state for click handler
self.current_frame_idx: int = 0
self.current_obj_id: int = 1
self.current_label: str = "positive"
self.current_clear_old: bool = True
self.current_prompt_type: str = "Points" # or "Boxes"
self.pending_box_start: tuple[int, int] | None = None
self.pending_box_start_frame_idx: int | None = None
self.pending_box_start_obj_id: int | None = None
self.is_switching_model: bool = False
self.ball_centers: dict[int, dict[int, tuple[int, int]]] = {}
self.mask_areas: dict[int, dict[int, float]] = {}
self.smoothed_centers: dict[int, dict[int, tuple[float, float]]] = {}
self.ball_speeds: dict[int, dict[int, float]] = {}
self.distance_from_start: dict[int, dict[int, float]] = {}
self.direction_change: dict[int, dict[int, float]] = {}
self.kick_frame: int | None = None
self.kick_debug_frames: list[int] = []
self.kick_debug_speeds: list[float] = []
self.kick_debug_threshold: float | None = None
self.kick_debug_baseline: float | None = None
self.kick_debug_speed_std: float | None = None
self.kick_debug_area: list[float] = []
self.kick_debug_kick_frame: int | None = None
self.kick_debug_distance: list[float] = []
self.kick_debug_kalman_speeds: list[float] = []
self.kalman_centers: dict[int, dict[int, tuple[float, float]]] = {}
self.kalman_speeds: dict[int, dict[int, float]] = {}
self.kalman_residuals: dict[int, dict[int, float]] = {}
self.min_impact_speed_kmh: float = 20.0
self.goal_distance_m: float = 18.0
self.impact_frame: int | None = None
self.impact_debug_frames: list[int] = []
self.impact_debug_innovation: list[float] = []
self.impact_debug_innovation_threshold: float | None = None
self.impact_debug_direction: list[float] = []
self.impact_debug_direction_threshold: float | None = None
self.impact_debug_speed_kmh: list[float] = []
self.impact_debug_speed_threshold_px: float | None = None
self.impact_meters_per_px: float | None = None
# Model selection
self.model_repo_key: str = "tiny"
self.model_repo_id: str | None = None
self.session_repo_id: str | None = None
self.player_obj_id: int | None = None
self.player_detection_frame: int | None = None
self.player_detection_conf: float | None = None
# YOLO tracking caches
self.yolo_ball_centers: dict[int, tuple[float, float]] = {}
self.yolo_ball_boxes: dict[int, tuple[int, int, int, int]] = {}
self.yolo_ball_conf: dict[int, float] = {}
self.yolo_smoothed_centers: dict[int, tuple[float, float]] = {}
self.yolo_speeds: dict[int, float] = {}
self.yolo_distance_from_start: dict[int, float] = {}
self.yolo_threshold: float | None = None
self.yolo_baseline_speed: float | None = None
self.yolo_speed_std: float | None = None
self.yolo_kick_frame: int | None = None
self.yolo_status: str = ""
self.yolo_kick_frames: list[int] = []
self.yolo_kick_speeds: list[float] = []
self.yolo_kick_distance: list[float] = []
self.yolo_mask_area_proxy: list[float] = []
self.yolo_initial_frame: int | None = None
# SAM window (start_idx inclusive, end_idx exclusive)
self.sam_window: tuple[int, int] | None = None
# Cutout / compositing effects
self.fx_soft_matte_enabled: bool = True
self.fx_soft_matte_feather: float = 4.0
self.fx_soft_matte_erode: float = 0.5
self.fx_blur_enabled: bool = True
self.fx_blur_sigma: float = 0.0
self.fx_bg_darkening: float = 0.75
self.fx_light_wrap_enabled: bool = False
self.fx_light_wrap_strength: float = 0.6
self.fx_light_wrap_width: float = 15.0
self.fx_glow_enabled: bool = False
self.fx_glow_strength: float = 0.4
self.fx_glow_radius: float = 10.0
self.fx_ghost_trail_enabled: bool = False
self.fx_ball_ring_enabled: bool = True
self.show_click_marks: bool = False
# Ring FX parameters (initialized with defaults, but can be overridden by UI)
self.fx_ring_thickness: float = BALL_RING_THICKNESS_PX
self.fx_ring_alpha: float = BALL_RING_ALPHA
self.fx_ring_feather: float = BALL_RING_FEATHER_SIGMA
self.fx_ring_gamma: float = BALL_RING_INTENSITY_GAMMA
self.fx_ring_duration: int = 30 # Default duration in frames
self.fx_ring_scale_pct: float = RING_SIZE_SCALE_DEFAULT
self.manual_kick_frame: int | None = None
self.manual_impact_frame: int | None = None
self.is_ball_detected: bool = False
self.is_yolo_tracked: bool = False
self.is_sam_tracked: bool = False
self.is_player_detected: bool = False
self.is_player_propagated: bool = False
# Multi-ball candidate tracking
self.ball_candidates: list[dict] = [] # All detected ball candidates
self.ball_candidates_tracking: dict[int, dict] = {} # Per-candidate tracking data
self.selected_ball_idx: int = 0 # Currently selected candidate index
self.ball_selection_confirmed: bool = False # True after user confirms selection
self.multi_ball_status: str = "" # Status message for multi-ball detection
self.goal_mode: str = GOAL_MODE_IDLE
self.goal_points_norm: list[tuple[float, float]] = []
self.goal_confirmed_points_norm: list[tuple[float, float]] = []
self.goal_prev_confirmed_points_norm: list[tuple[float, float]] = []
self.goal_overlay_points: list[tuple[float, float]] = []
self.goal_status_text: str = "Goal crossbar inactive."
self.goal_dragging_idx: int | None = None
def __repr__(self):
return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})"
@property
def num_frames(self) -> int:
return len(self.video_frames)
def _model_repo_from_key(key: str) -> str:
mapping = {
"tiny": "facebook/sam2.1-hiera-tiny",
"small": "facebook/sam2.1-hiera-small",
"base_plus": "facebook/sam2.1-hiera-base-plus",
"large": "facebook/sam2.1-hiera-large",
}
return mapping.get(key, mapping["base_plus"])
def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoProcessor, str, torch.dtype]:
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
if GLOBAL_STATE.model_repo_id == desired_repo:
return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
# Different repo requested: dispose current and reload
GLOBAL_STATE.model = None
GLOBAL_STATE.processor = None
print(f"Loading model from {desired_repo}")
device, dtype = get_device_and_dtype()
# free up the gpu memory
model = AutoModel.from_pretrained(desired_repo)
processor = Sam2VideoProcessor.from_pretrained(desired_repo)
model.to(device, dtype=dtype)
GLOBAL_STATE.model = model
GLOBAL_STATE.processor = processor
GLOBAL_STATE.device = device
GLOBAL_STATE.dtype = dtype
GLOBAL_STATE.model_repo_id = desired_repo
def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
"""Ensure the model/processor match the selected repo and inference_session exists.
If a video is already loaded, re-initialize the inference session when needed.
"""
load_model_if_needed(GLOBAL_STATE)
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo:
if GLOBAL_STATE.video_frames:
# Clear session-related UI caches when switching model
GLOBAL_STATE.masks_by_frame.clear()
GLOBAL_STATE.clicks_by_frame_obj.clear()
GLOBAL_STATE.boxes_by_frame_obj.clear()
GLOBAL_STATE.composited_frames.clear()
GLOBAL_STATE.inference_session = None
GLOBAL_STATE.inference_session = GLOBAL_STATE.processor.init_video_session(
inference_device=GLOBAL_STATE.device,
video_storage_device="cpu",
dtype=GLOBAL_STATE.dtype,
)
GLOBAL_STATE.session_repo_id = desired_repo
def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppState, int, int, Image.Image, str]:
"""Gradio handler: load video, init session, return state, slider bounds, and first frame."""
# Reset ONLY video-related fields, keep model loaded
GLOBAL_STATE.video_frames = []
GLOBAL_STATE.inference_session = None
GLOBAL_STATE.masks_by_frame = {}
GLOBAL_STATE.color_by_obj = {}
GLOBAL_STATE.ball_centers = {}
GLOBAL_STATE.mask_areas = {}
GLOBAL_STATE.smoothed_centers = {}
GLOBAL_STATE.ball_speeds = {}
GLOBAL_STATE.distance_from_start = {}
GLOBAL_STATE.direction_change = {}
GLOBAL_STATE.kick_frame = None
GLOBAL_STATE.kalman_centers = {}
GLOBAL_STATE.kalman_speeds = {}
GLOBAL_STATE.kalman_residuals = {}
GLOBAL_STATE.kick_debug_kalman_speeds = []
GLOBAL_STATE.kick_debug_frames = []
GLOBAL_STATE.kick_debug_speeds = []
GLOBAL_STATE.kick_debug_threshold = None
GLOBAL_STATE.kick_debug_baseline = None
GLOBAL_STATE.kick_debug_speed_std = None
GLOBAL_STATE.kick_debug_area = []
GLOBAL_STATE.kick_debug_kick_frame = None
GLOBAL_STATE.kick_debug_distance = []
GLOBAL_STATE.impact_frame = None
GLOBAL_STATE.impact_debug_frames = []
GLOBAL_STATE.impact_debug_innovation = []
GLOBAL_STATE.impact_debug_innovation_threshold = None
GLOBAL_STATE.impact_debug_direction = []
GLOBAL_STATE.impact_debug_direction_threshold = None
GLOBAL_STATE.impact_debug_speed_kmh = []
GLOBAL_STATE.impact_debug_speed_threshold_px = None
GLOBAL_STATE.impact_meters_per_px = None
GLOBAL_STATE.goal_mode = GOAL_MODE_IDLE
GLOBAL_STATE.goal_points_norm = []
GLOBAL_STATE.goal_confirmed_points_norm = []
GLOBAL_STATE.goal_prev_confirmed_points_norm = []
GLOBAL_STATE.goal_overlay_points = []
GLOBAL_STATE.goal_status_text = "Goal crossbar inactive."
GLOBAL_STATE.goal_dragging_idx = None
GLOBAL_STATE.goal_mode = GOAL_MODE_IDLE
GLOBAL_STATE.goal_points_norm = []
GLOBAL_STATE.goal_confirmed_points_norm = []
GLOBAL_STATE.goal_prev_confirmed_points_norm = []
GLOBAL_STATE.goal_overlay_points = []
GLOBAL_STATE.goal_status_text = "Goal crossbar inactive."
GLOBAL_STATE.goal_dragging_idx = None
GLOBAL_STATE.yolo_ball_centers = {}
GLOBAL_STATE.yolo_ball_boxes = {}
GLOBAL_STATE.yolo_ball_conf = {}
GLOBAL_STATE.yolo_smoothed_centers = {}
GLOBAL_STATE.yolo_speeds = {}
GLOBAL_STATE.yolo_distance_from_start = {}
GLOBAL_STATE.yolo_threshold = None
GLOBAL_STATE.yolo_baseline_speed = None
GLOBAL_STATE.yolo_speed_std = None
GLOBAL_STATE.yolo_kick_frame = None
GLOBAL_STATE.yolo_status = ""
GLOBAL_STATE.yolo_kick_frames = []
GLOBAL_STATE.yolo_kick_speeds = []
GLOBAL_STATE.yolo_kick_distance = []
GLOBAL_STATE.yolo_mask_area_proxy = []
GLOBAL_STATE.yolo_initial_frame = None
GLOBAL_STATE.sam_window = None
GLOBAL_STATE.player_obj_id = None
GLOBAL_STATE.player_detection_frame = None
GLOBAL_STATE.player_detection_conf = None
GLOBAL_STATE.yolo_ball_centers = {}
GLOBAL_STATE.yolo_ball_boxes = {}
GLOBAL_STATE.yolo_ball_conf = {}
GLOBAL_STATE.yolo_smoothed_centers = {}
GLOBAL_STATE.yolo_speeds = {}
GLOBAL_STATE.yolo_distance_from_start = {}
GLOBAL_STATE.yolo_threshold = None
GLOBAL_STATE.yolo_baseline_speed = None
GLOBAL_STATE.yolo_speed_std = None
GLOBAL_STATE.yolo_kick_frame = None
GLOBAL_STATE.yolo_status = ""
GLOBAL_STATE.yolo_kick_frames = []
GLOBAL_STATE.yolo_kick_speeds = []
GLOBAL_STATE.yolo_kick_distance = []
GLOBAL_STATE.yolo_mask_area_proxy = []
GLOBAL_STATE.yolo_initial_frame = None
GLOBAL_STATE.sam_window = None
GLOBAL_STATE.is_ball_detected = False
GLOBAL_STATE.is_yolo_tracked = False
GLOBAL_STATE.is_sam_tracked = False
GLOBAL_STATE.is_player_detected = False
GLOBAL_STATE.is_player_propagated = False
load_model_if_needed(GLOBAL_STATE)
# Gradio Video may provide a dict with 'name' or a direct file path
video_path: Optional[str] = None
if isinstance(video, dict):
video_path = video.get("name") or video.get("path") or video.get("data")
elif isinstance(video, str):
video_path = video
else:
video_path = None
if not video_path:
raise gr.Error("Invalid video input.")
frames, info = try_load_video_frames(video_path)
if len(frames) == 0:
raise gr.Error("No frames could be loaded from the video.")
# Enforce max duration of 8 seconds (trim if longer)
MAX_SECONDS = 8.0
trimmed_note = ""
fps_in = info.get("fps")
max_frames_allowed = int(MAX_SECONDS * fps_in)
if len(frames) > max_frames_allowed:
frames = frames[:max_frames_allowed]
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
if isinstance(info, dict):
info["num_frames"] = len(frames)
GLOBAL_STATE.video_frames = frames
# Try to capture original FPS if provided by loader
GLOBAL_STATE.video_fps = float(fps_in)
# Initialize session
inference_session = GLOBAL_STATE.processor.init_video_session(
inference_device=GLOBAL_STATE.device,
video_storage_device="cpu",
dtype=GLOBAL_STATE.dtype,
)
GLOBAL_STATE.inference_session = inference_session
first_frame = frames[0]
max_idx = len(frames) - 1
status = (
f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
f"Device: {GLOBAL_STATE.device}, dtype: bfloat16"
)
return GLOBAL_STATE, 0, max_idx, first_frame, status
RING_RADIUS_CLAMP_RATIO = 0.2 # ±20%
def _speed_to_color(ratio: float) -> tuple[int, int, int]:
ratio = float(np.clip(ratio, 0.0, 1.0))
gradient = [
(255, 0, 0), # red
(255, 165, 0), # orange
(255, 255, 0), # yellow
(0, 255, 0), # green
]
segment = ratio * (len(gradient) - 1)
idx = int(segment)
frac = segment - idx
if idx >= len(gradient) - 1:
return gradient[-1]
c1 = np.array(gradient[idx], dtype=float)
c2 = np.array(gradient[idx + 1], dtype=float)
blended = (1 - frac) * c1 + frac * c2
return tuple(int(v) for v in blended)
def _speed_to_ring_color(speed_kmh: float) -> tuple[float, float, float]:
"""Map a speed value (km/h) to the discrete palette used across the app."""
for threshold, color in SPEED_COLOR_STOPS:
if speed_kmh < threshold:
return color
return SPEED_COLOR_ABOVE_MAX
def _get_prioritized_kick_frame(state: AppState) -> int | None:
if state is None:
return None
for attr in ("kick_frame", "kick_debug_kick_frame", "yolo_kick_frame"):
frame = getattr(state, attr, None)
if frame is not None:
return int(frame)
return None
def _median_smooth_radii(radii: list[float]) -> list[float]:
if not radii:
return []
if len(radii) < 3:
return radii[:]
smoothed: list[float] = []
n = len(radii)
for i in range(n):
window = radii[max(0, i - 1):min(n, i + 2)]
smoothed.append(float(statistics.median(window)))
return smoothed
def _clamp_radii(radii: list[float], clamp_ratio: float = RING_RADIUS_CLAMP_RATIO) -> list[float]:
if not radii:
return []
clamped: list[float] = []
for i, value in enumerate(radii):
val = max(0.0, float(value))
if i == 0:
clamped.append(val)
continue
prev = clamped[-1]
min_allowed = prev * (1.0 - clamp_ratio)
max_allowed = prev * (1.0 + clamp_ratio)
if prev <= FX_EPS:
min_allowed = 0.0
max_allowed = val
val = min(max(val, min_allowed), max_allowed)
clamped.append(val)
return clamped
def _angle_between(v1: tuple[float, float], v2: tuple[float, float]) -> float:
x1, y1 = v1
x2, y2 = v2
mag1 = math.hypot(x1, y1)
mag2 = math.hypot(x2, y2)
if mag1 < 1e-6 or mag2 < 1e-6:
return 0.0
cos_val = (x1 * x2 + y1 * y2) / (mag1 * mag2)
cos_val = max(-1.0, min(1.0, cos_val))
return math.degrees(math.acos(cos_val))
DISPLAY_MIN_WIDTH = 640
DISPLAY_MAX_WIDTH = 1280
FX_GLOW_COLOR = np.array([1.0, 0.1, 0.6], dtype=np.float32)
FX_EPS = 1e-6
GHOST_TRAIL_COLOR = np.array([1.0, 0.0, 1.0], dtype=np.float32)
GHOST_TRAIL_ALPHA = 0.55
BALL_RING_ALPHA = 3.0 # Increased brightness
BALL_RING_THICKNESS_PX = 1.0 # Thinner rings
BALL_RING_FEATHER_SIGMA = 0.1 # Softer default blur
BALL_RING_INTENSITY_GAMMA = 2.0 # Contrast shaping
# Speed range palette (mirrors iOS app)
SPEED_COLOR_STOPS = [
(30.0, (0 / 255.0, 191 / 255.0, 255 / 255.0)), # Electric Blue
(50.0, (0 / 255.0, 191 / 255.0, 255 / 255.0)), # Electric Blue (same band)
(70.0, (92 / 255.0, 124 / 255.0, 250 / 255.0)), # Blue Violet
(90.0, (154 / 255.0, 77 / 255.0, 255 / 255.0)), # Intense Violet
(110.0, (214 / 255.0, 51 / 255.0, 132 / 255.0)), # Fuchsia
(130.0, (255 / 255.0, 77 / 255.0, 109 / 255.0)), # Strong Pink
]
SPEED_COLOR_ABOVE_MAX = (255 / 255.0, 162 / 255.0, 0 / 255.0) # Neon Orange
RING_RADIUS_CLAMP_RATIO = 0.2 # ±20%
RING_SIZE_SCALE_DEFAULT = 125.0 # percent
def _maybe_upscale_for_display(image: Image.Image) -> Image.Image:
if image is None:
return image
original_width, original_height = image.size
if original_width <= 0 or original_height <= 0:
return image
target_width = original_width
if original_width < DISPLAY_MIN_WIDTH:
target_width = DISPLAY_MIN_WIDTH
elif original_width > DISPLAY_MAX_WIDTH:
target_width = DISPLAY_MAX_WIDTH
if target_width == original_width:
return image
scale = target_width / float(original_width)
target_height = int(round(original_height * scale))
return image.resize((target_width, target_height), Image.BILINEAR)
def _annotate_frame_index(image: Image.Image, frame_idx: int) -> Image.Image:
if image is None:
return image
annotated = image.copy()
draw = ImageDraw.Draw(annotated)
text = f"Frame {frame_idx}"
padding = 6
try:
bbox = draw.textbbox((0, 0), text)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
except AttributeError:
text_w, text_h = draw.textsize(text)
x0, y0 = padding, padding
x1, y1 = x0 + text_w + padding, y0 + text_h + padding
draw.rectangle([(x0 - padding // 2, y0 - padding // 2), (x1, y1)], fill=(0, 0, 0))
draw.text((x0, y0), text, fill=(255, 255, 255))
return annotated
def _apply_cutout_fx(state: "AppState", frame_np: np.ndarray, combined_mask: np.ndarray) -> np.ndarray:
mask = np.clip(combined_mask.astype(np.float32), 0.0, 1.0)
if mask.max() <= FX_EPS:
# No foreground detected; fall back to darkened background choice
bg = frame_np.copy()
if state.fx_blur_enabled and state.fx_blur_sigma > FX_EPS:
bg = cv2.GaussianBlur(bg, (0, 0), sigmaX=state.fx_blur_sigma, sigmaY=state.fx_blur_sigma)
bg = bg * (1.0 - np.clip(state.fx_bg_darkening, 0.0, 1.0))
return np.clip(bg * 255.0, 0, 255).astype(np.uint8)
mask_soft = mask.copy()
if state.fx_soft_matte_enabled:
erode_px = max(0.0, float(state.fx_soft_matte_erode))
if erode_px > FX_EPS:
kernel_size = int(round(erode_px * 2 + 1))
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
mask_soft = cv2.erode(mask_soft, kernel)
feather = max(0.0, float(state.fx_soft_matte_feather))
if feather > FX_EPS:
mask_soft = cv2.GaussianBlur(mask_soft, (0, 0), sigmaX=feather, sigmaY=feather)
mask_soft = np.clip(mask_soft * 1.05, 0.0, 1.0)
bg_source = frame_np.copy()
if state.fx_blur_enabled and state.fx_blur_sigma > FX_EPS:
bg_source = cv2.GaussianBlur(bg_source, (0, 0), sigmaX=state.fx_blur_sigma, sigmaY=state.fx_blur_sigma)
darkening = np.clip(state.fx_bg_darkening, 0.0, 1.0)
bg = bg_source * (1.0 - darkening)
alpha = mask_soft[..., None]
out = frame_np * alpha + bg * (1.0 - alpha)
light_wrap_strength = float(state.fx_light_wrap_strength)
light_wrap_width = max(0.0, float(state.fx_light_wrap_width))
if state.fx_light_wrap_enabled and light_wrap_strength > FX_EPS and light_wrap_width > FX_EPS:
inner_blur = cv2.GaussianBlur(mask_soft, (0, 0), sigmaX=light_wrap_width, sigmaY=light_wrap_width)
inner_edge = np.clip(mask_soft - inner_blur, 0.0, 1.0)
if inner_edge.max() > FX_EPS:
inner_edge /= (inner_edge.max() + FX_EPS)
bg_wrap = cv2.GaussianBlur(bg_source, (0, 0), sigmaX=light_wrap_width * 1.5, sigmaY=light_wrap_width * 1.5)
out = np.clip(out + inner_edge[..., None] * bg_wrap * light_wrap_strength, 0.0, 1.0)
glow_strength = float(state.fx_glow_strength)
glow_radius = max(0.0, float(state.fx_glow_radius))
if state.fx_glow_enabled and glow_strength > FX_EPS and glow_radius > FX_EPS:
outer_blur = cv2.GaussianBlur(mask_soft, (0, 0), sigmaX=glow_radius, sigmaY=glow_radius)
glow_band = np.clip(outer_blur - mask_soft, 0.0, 1.0)
if glow_band.max() > FX_EPS:
glow_band /= (glow_band.max() + FX_EPS)
glow_color = FX_GLOW_COLOR[None, None, :]
out = np.clip(out + glow_band[..., None] * glow_color * glow_strength, 0.0, 1.0)
return np.clip(out * 255.0, 0, 255).astype(np.uint8)
def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> Image.Image:
if state is None or state.video_frames is None or len(state.video_frames) == 0:
return None
frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
goal_overlay_active = bool(getattr(state, "goal_points_norm", [])) or bool(getattr(state, "goal_overlay_points", []))
frame = state.video_frames[frame_idx]
masks = state.masks_by_frame.get(frame_idx, {})
out_img: Image.Image | None = state.composited_frames.get(frame_idx)
if out_img is None:
out_img = frame
current_union_mask: np.ndarray | None = None
focus_mask: np.ndarray | None = None
ball_mask_main: np.ndarray | None = None
for obj_id, mask in masks.items():
if mask is None:
continue
mask_np = mask.astype(np.float32)
if mask_np.ndim == 3:
mask_np = mask_np.squeeze()
mask_np = np.clip(mask_np, 0.0, 1.0)
if current_union_mask is None:
current_union_mask = np.zeros_like(mask_np, dtype=np.float32)
current_union_mask = np.maximum(current_union_mask, mask_np)
if obj_id in (BALL_OBJECT_ID, PLAYER_OBJECT_ID):
if focus_mask is None:
focus_mask = np.zeros_like(mask_np, dtype=np.float32)
focus_mask = np.maximum(focus_mask, mask_np)
if obj_id == BALL_OBJECT_ID:
ball_mask_main = mask
ghost_mask = _build_ball_trail_mask(state, frame_idx)
ring_result = _build_ball_ring_mask(state, frame_idx)
if len(masks) != 0:
if remove_bg:
# Remove background - show only tracked objects
frame_np = np.array(frame).astype(np.float32) / 255.0
combined_mask = current_union_mask
if combined_mask is None:
combined_mask = np.zeros((frame_np.shape[0], frame_np.shape[1]), dtype=np.float32)
# Apply falloff to ball component when rendering foreground
if BALL_OBJECT_ID in masks:
ball_mask = masks[BALL_OBJECT_ID]
if ball_mask is not None:
combined_mask = np.maximum(
combined_mask,
_apply_radial_falloff(np.clip(ball_mask.astype(np.float32), 0.0, 1.0), strength=1.0, solid_ratio=0.8),
)
result_np = _apply_cutout_fx(state, frame_np, combined_mask)
out_img = Image.fromarray(result_np)
else:
if masks:
out_img = overlay_masks_on_frame(out_img, masks, state.color_by_obj, alpha=0.65)
# Overlay feathered ball on top
if BALL_OBJECT_ID in masks:
ball_mask = masks[BALL_OBJECT_ID]
if ball_mask is not None:
ball_alpha = _apply_radial_falloff(ball_mask, strength=1.0, solid_ratio=0.8)
if ball_alpha is not None and ball_alpha.max() > FX_EPS:
base_np = np.array(out_img).astype(np.float32) / 255.0
color = np.array(state.color_by_obj.get(BALL_OBJECT_ID, (255, 255, 0)), dtype=np.float32) / 255.0
alpha = np.clip(ball_alpha[..., None], 0.0, 1.0)
base_np = (1.0 - alpha) * base_np + alpha * color
out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
if ghost_mask is not None:
ghost_np = np.clip(ghost_mask.astype(np.float32), 0.0, 1.0)
if current_union_mask is not None:
ghost_np = ghost_np * np.clip(1.0 - current_union_mask, 0.0, 1.0)
if ghost_np.max() > FX_EPS:
base_np = np.array(out_img).astype(np.float32) / 255.0
ghost_alpha = ghost_np[..., None]
base_np = (1.0 - GHOST_TRAIL_ALPHA * ghost_alpha) * base_np + (
GHOST_TRAIL_ALPHA * ghost_alpha
) * GHOST_TRAIL_COLOR
if focus_mask is not None:
focus_alpha = np.clip(focus_mask, 0.0, 1.0)[..., None]
orig_np = np.array(frame).astype(np.float32) / 255.0
base_np = focus_alpha * orig_np + (1.0 - focus_alpha) * base_np
out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
if ring_result is not None:
ring_presence, ring_color_map = ring_result
ring_presence = np.clip(ring_presence.astype(np.float32), 0.0, 1.0)
ring_color_map = np.clip(ring_color_map.astype(np.float32), 0.0, 1.0)
if current_union_mask is not None:
if ball_mask_main is not None:
ball_np = np.clip(ball_mask_main.astype(np.float32), 0.0, 1.0)
mask_block = np.maximum(current_union_mask - ball_np, 0.0)
else:
mask_block = current_union_mask
mask_keep = np.clip(1.0 - mask_block, 0.0, 1.0)
ring_presence = ring_presence * mask_keep
ring_color_map = ring_color_map * mask_keep[..., None]
if ring_presence.max() > FX_EPS and ring_color_map.max() > FX_EPS:
base_np = np.array(out_img).astype(np.float32) / 255.0
alpha_val = getattr(state, "fx_ring_alpha", BALL_RING_ALPHA)
added_light = np.clip(ring_color_map * alpha_val, 0.0, 1.0)
base_np = np.clip(base_np + added_light, 0.0, 1.0)
if focus_mask is not None:
focus_alpha = np.clip(focus_mask, 0.0, 1.0)[..., None]
orig_np = np.array(frame).astype(np.float32) / 255.0
base_np = focus_alpha * orig_np + (1.0 - focus_alpha) * base_np
out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
_draw_goal_overlay(state, frame_idx, out_img)
# Draw crosses for conditioning frames only (frames with recorded clicks)
clicks_map = state.clicks_by_frame_obj.get(frame_idx)
if state.show_click_marks and clicks_map:
draw = ImageDraw.Draw(out_img)
cross_half = 6
for obj_id, pts in clicks_map.items():
for x, y, lbl in pts:
color = (0, 255, 0) if int(lbl) == 1 else (255, 0, 0)
# horizontal
draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2)
# vertical
draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2)
# Draw temporary cross for first corner in box mode
if (
state.show_click_marks
and state.pending_box_start is not None
and state.pending_box_start_frame_idx == frame_idx
and state.pending_box_start_obj_id is not None
):
draw = ImageDraw.Draw(out_img)
x, y = state.pending_box_start
cross_half = 6
color = state.color_by_obj.get(state.pending_box_start_obj_id, (255, 255, 255))
draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2)
draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2)
# Draw boxes for conditioning frames
box_map = state.boxes_by_frame_obj.get(frame_idx)
if state.show_click_marks and box_map:
draw = ImageDraw.Draw(out_img)
for obj_id, boxes in box_map.items():
color = state.color_by_obj.get(obj_id, (255, 255, 255))
for x1, y1, x2, y2 in boxes:
draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2)
# Draw trajectory centers (all frames)
if state.show_click_marks and state.ball_centers:
draw = ImageDraw.Draw(out_img)
cross_half = 4
for obj_id, centers in state.ball_centers.items():
if not centers:
continue
raw_items = sorted(centers.items())
for _, (rx, ry) in raw_items:
draw.line([(rx - cross_half, ry), (rx + cross_half, ry)], fill=(160, 160, 160), width=1)
draw.line([(rx, ry - cross_half), (rx, ry + cross_half)], fill=(160, 160, 160), width=1)
smooth_dict = state.smoothed_centers.get(obj_id, {})
if not smooth_dict:
continue
smooth_items = sorted(smooth_dict.items())
distances: list[float] = []
prev_center = None
for _, (sx, sy) in smooth_items:
if prev_center is None:
distances.append(0.0)
else:
dx = sx - prev_center[0]
dy = sy - prev_center[1]
distances.append(float(np.hypot(dx, dy)))
prev_center = (sx, sy)
max_dist = max(distances[1:], default=0.0)
color_by_frame: dict[int, tuple[int, int, int]] = {}
for (f_idx, _), dist in zip(smooth_items, distances):
ratio = dist / max_dist if max_dist > 0 else 0.0
color_by_frame[f_idx] = _speed_to_color(ratio)
for f_idx, (sx, sy) in reversed(smooth_items):
highlight = (f_idx == frame_idx)
color = (255, 0, 0) if highlight else color_by_frame.get(f_idx, (255, 255, 0))
line_width = 1 if not highlight else 2
draw.line([(sx - cross_half, sy), (sx + cross_half, sy)], fill=color, width=line_width)
draw.line([(sx, sy - cross_half), (sx, sy + cross_half)], fill=color, width=line_width)
# Save to cache and return
if not remove_bg and not goal_overlay_active:
state.composited_frames[frame_idx] = out_img
return out_img
def update_frame_display(state: AppState, frame_idx: int) -> Image.Image:
if state is None or state.video_frames is None or len(state.video_frames) == 0:
return None
frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
# Serve from cache when available
cached = state.composited_frames.get(frame_idx)
if cached is not None:
return _maybe_upscale_for_display(cached)
composed = compose_frame(state, frame_idx)
return _maybe_upscale_for_display(composed)
def _update_fx_controls(
state: AppState,
soft_enabled: bool,
soft_feather: float,
soft_erode: float,
blur_enabled: bool,
blur_sigma: float,
bg_darkening: float,
wrap_enabled: bool,
wrap_strength: float,
wrap_width: float,
glow_enabled: bool,
glow_strength: float,
glow_radius: float,
# New parameters
ring_thickness: float,
ring_alpha: float,
ring_feather: float,
ring_gamma: float,
ring_scale_pct: float,
ring_duration: float,
) -> Image.Image:
if state is None:
return None
state.fx_soft_matte_enabled = bool(soft_enabled)
state.fx_soft_matte_feather = max(0.0, float(soft_feather))
state.fx_soft_matte_erode = max(0.0, float(soft_erode))
state.fx_blur_enabled = bool(blur_enabled)
state.fx_blur_sigma = max(0.0, float(blur_sigma))
state.fx_bg_darkening = float(np.clip(bg_darkening, 0.0, 1.0))
state.fx_light_wrap_enabled = bool(wrap_enabled)
state.fx_light_wrap_strength = max(0.0, float(wrap_strength))
state.fx_light_wrap_width = max(0.0, float(wrap_width))
state.fx_glow_enabled = bool(glow_enabled)
state.fx_glow_strength = max(0.0, float(glow_strength))
state.fx_glow_radius = max(0.0, float(glow_radius))
# Update Ring FX
state.fx_ring_thickness = max(0.1, float(ring_thickness))
state.fx_ring_alpha = max(0.0, float(ring_alpha))
state.fx_ring_feather = max(0.0, float(ring_feather))
state.fx_ring_gamma = max(0.1, float(ring_gamma))
state.fx_ring_duration = int(max(0, float(ring_duration)))
state.fx_ring_scale_pct = float(np.clip(ring_scale_pct, 10.0, 200.0))
state.composited_frames.clear()
idx = int(getattr(state, "current_frame_idx", 0))
return update_frame_display(state, idx)
def _toggle_ghost_trail(state: AppState, enabled: bool) -> Image.Image:
if state is None:
return None
state.fx_ghost_trail_enabled = bool(enabled)
state.composited_frames.clear()
idx = int(getattr(state, "current_frame_idx", 0))
return update_frame_display(state, idx)
def _toggle_ball_ring(state: AppState, enabled: bool) -> Image.Image:
if state is None:
return None
state.fx_ball_ring_enabled = bool(enabled)
state.composited_frames.clear()
idx = int(getattr(state, "current_frame_idx", 0))
return update_frame_display(state, idx)
def _toggle_click_marks(state: AppState, enabled: bool) -> Image.Image:
if state is None:
return None
state.show_click_marks = bool(enabled)
state.composited_frames.clear()
idx = int(getattr(state, "current_frame_idx", 0))
return update_frame_display(state, idx)
def _build_ball_trail_mask(state: AppState, frame_idx: int) -> np.ndarray | None:
if (
state is None
or not state.fx_ghost_trail_enabled
or state.masks_by_frame is None
):
return None
if state.fx_ball_ring_enabled:
# When ring rendering is active we skip building the filled ghost mask.
return None
kick_candidate = _get_prioritized_kick_frame(state)
if kick_candidate is None:
return None
if int(frame_idx) <= int(kick_candidate):
start_idx = int(kick_candidate) + 1
else:
start_idx = max(int(kick_candidate) + 1, int(frame_idx))
end_idx = state.num_frames
if start_idx >= end_idx:
return None
trail_mask: np.ndarray | None = None
for idx in range(start_idx, end_idx):
frame_masks = state.masks_by_frame.get(idx)
if not frame_masks:
continue
mask = frame_masks.get(BALL_OBJECT_ID)
if mask is None:
continue
mask_np = mask.astype(np.float32)
if mask_np.ndim == 3:
mask_np = mask_np.squeeze()
mask_np = np.clip(mask_np, 0.0, 1.0)
mask_np = _apply_radial_falloff(mask_np, strength=1.0, solid_ratio=0.8)
if trail_mask is None:
trail_mask = np.zeros_like(mask_np, dtype=np.float32)
if trail_mask.shape != mask_np.shape:
continue
trail_mask = np.maximum(trail_mask, mask_np)
return trail_mask
def _build_ball_ring_mask(
state: AppState, frame_idx: int
) -> tuple[np.ndarray, np.ndarray] | None:
if (
state is None
or not state.fx_ball_ring_enabled
or state.masks_by_frame is None
):
return None
kick_candidate = _get_prioritized_kick_frame(state)
if kick_candidate is None:
return None
if int(frame_idx) <= int(kick_candidate):
start_idx = int(kick_candidate) + 1
else:
start_idx = max(int(kick_candidate) + 1, int(frame_idx))
# Determine end frame based on duration limit
duration = getattr(state, "fx_ring_duration", 16)
limit_idx = int(kick_candidate) + 1 + duration
end_idx = min(state.num_frames, limit_idx)
if start_idx >= end_idx:
return None
ring_entries: list[tuple[int, tuple[int, int], float, np.ndarray, float]] = []
canvas_shape: tuple[int, int] | None = None
ring_presence: np.ndarray | None = None
ring_color_map: np.ndarray | None = None
fps = state.video_fps if state.video_fps and state.video_fps > 0 else 25.0
distance_m = state.goal_distance_m if state.goal_distance_m and state.goal_distance_m > 0 else 16.5
# Iterate in REVERSE order so that later frames (further in time/distance) are drawn first,
# and earlier frames (closer in time/distance) are drawn on top.
# This ensures the "nearest" rings (temporally) obscure the "further" rings.
for idx in range(end_idx - 1, start_idx - 1, -1):
frame_masks = state.masks_by_frame.get(idx)
if not frame_masks:
continue
mask = frame_masks.get(BALL_OBJECT_ID)
if mask is None:
continue
mask_np = mask.astype(np.float32)
if mask_np.ndim == 3:
mask_np = mask_np.squeeze()
if mask_np.size == 0:
continue
mask_np = np.clip(mask_np, 0.0, 1.0)
if mask_np.max() <= FX_EPS:
continue
if canvas_shape is None:
canvas_shape = mask_np.shape
if canvas_shape != mask_np.shape:
continue
centroid = _compute_mask_centroid(mask_np)
if centroid is None:
continue
cx, cy = centroid
ys, xs = np.nonzero(mask_np > 0.05)
if xs.size == 0 or ys.size == 0:
continue
min_x, max_x = xs.min(), xs.max()
min_y, max_y = ys.min(), ys.max()
radius_x = (max_x - min_x + 1) / 2.0
radius_y = (max_y - min_y + 1) / 2.0
radius = float(max(radius_x, radius_y))
if radius <= 1.5:
continue
# Use dynamic parameters from state if available, else defaults
thick_val = getattr(state, "fx_ring_thickness", BALL_RING_THICKNESS_PX)
center = (int(round(cx)), int(round(cy)))
radius_int = max(1, int(round(radius)))
delta_frames = max(1, idx - int(kick_candidate))
time_s = max(delta_frames / fps, 1.0 / fps)
speed_kmh = max(0.0, (distance_m / time_s) * 3.6)
color_vec = np.array(_speed_to_ring_color(speed_kmh), dtype=np.float32)
ring_entries.append((idx, center, radius, color_vec, thick_val))
if not ring_entries or canvas_shape is None:
return None
raw_radii = [entry[2] for entry in ring_entries]
smoothed = _median_smooth_radii(raw_radii)
smoothed = _clamp_radii(smoothed)
base_radius = smoothed[0] if smoothed else 1.0
if base_radius <= FX_EPS:
base_radius = 1.0
h, w = canvas_shape
ring_presence = np.zeros((h, w), dtype=np.float32)
ring_color_map = np.zeros((h, w, 3), dtype=np.float32)
base_feather = getattr(state, "fx_ring_feather", BALL_RING_FEATHER_SIGMA)
base_gamma = getattr(state, "fx_ring_gamma", BALL_RING_INTENSITY_GAMMA)
scale_factor = float(getattr(state, "fx_ring_scale_pct", RING_SIZE_SCALE_DEFAULT)) / 100.0
scale_factor = np.clip(scale_factor, 0.1, 2.0)
for (entry, smooth_radius) in zip(ring_entries, smoothed):
_, center, _, color_vec, thick_val = entry
radius_ratio = smooth_radius / base_radius if base_radius > FX_EPS else 1.0
radius_ratio = float(np.clip(radius_ratio, 0.05, 1.0))
radius_val = max(1.0, smooth_radius * scale_factor)
radius_int = max(1, int(round(radius_val)))
ring_local = np.zeros((h, w), dtype=np.float32)
thickness_scale = max(0.1, radius_ratio)
t_glow = max(1, int(round(thick_val * 4.0 * thickness_scale)))
cv2.circle(ring_local, center, radius_int, 0.3, thickness=t_glow)
t_mid = max(1, int(round(thick_val * 2.0 * thickness_scale)))
cv2.circle(ring_local, center, radius_int, 0.6, thickness=t_mid)
t_core = max(1, int(round(thick_val * thickness_scale)))
cv2.circle(ring_local, center, radius_int, 1.0, thickness=t_core)
effective_feather = max(0.0, base_feather * radius_ratio)
ring_local = cv2.GaussianBlur(ring_local, (0, 0), sigmaX=effective_feather, sigmaY=effective_feather)
if ring_local.max() <= FX_EPS:
continue
effective_gamma = max(0.1, base_gamma * radius_ratio)
if abs(effective_gamma - 1.0) > 1e-6:
ring_local = np.power(np.clip(ring_local, 0.0, 1.0), effective_gamma)
ring_local = np.clip(ring_local * radius_ratio, 0.0, 1.0)
ring_presence = np.maximum(ring_presence, ring_local)
ring_color_map += ring_local[..., None] * color_vec
if ring_presence.max() <= FX_EPS or ring_color_map.max() <= FX_EPS:
return None
return np.clip(ring_presence, 0.0, 1.0), np.clip(ring_color_map, 0.0, 1.0)
def _ensure_color_for_obj(state: AppState, obj_id: int):
if obj_id not in state.color_by_obj:
state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
def _compute_mask_centroid(mask: np.ndarray) -> tuple[int, int] | None:
if mask is None:
return None
mask_np = np.array(mask)
if mask_np.ndim == 3:
mask_np = mask_np.squeeze()
if mask_np.size == 0:
return None
mask_float = np.clip(mask_np, 0.0, 1.0).astype(np.float32)
moments = cv2.moments(mask_float)
if moments["m00"] == 0:
return None
cx = int(moments["m10"] / moments["m00"])
cy = int(moments["m01"] / moments["m00"])
return cx, cy
def _apply_radial_falloff(mask: np.ndarray, strength: float = 1.0, solid_ratio: float = 0.8) -> np.ndarray:
if mask is None:
return None
mask_np = np.clip(mask.astype(np.float32), 0.0, 1.0)
if mask_np.ndim == 3:
mask_np = mask_np.squeeze()
if mask_np.max() <= FX_EPS:
return mask_np
centroid = _compute_mask_centroid(mask_np)
if centroid is None:
return mask_np
cx, cy = centroid
h, w = mask_np.shape
yy, xx = np.ogrid[:h, :w]
dist = np.sqrt((xx - cx) ** 2 + (yy - cy) ** 2)
max_dist = dist[mask_np > FX_EPS].max() if np.any(mask_np > FX_EPS) else 0.0
if max_dist <= FX_EPS:
return mask_np
if solid_ratio >= 1.0:
return mask_np
clipped_dist = np.clip((dist / max_dist - solid_ratio) / (1.0 - solid_ratio), 0.0, 1.0)
falloff = 1.0 - np.power(clipped_dist, strength)
return np.clip(mask_np * falloff, 0.0, 1.0)
def _update_centroids_for_frame(state: AppState, frame_idx: int):
if state is None:
return
masks = state.masks_by_frame.get(int(frame_idx), {})
seen_obj_ids: set[int] = set()
for obj_id, mask in masks.items():
centroid = _compute_mask_centroid(mask)
centers = state.ball_centers.setdefault(int(obj_id), {})
if centroid is not None:
centers[int(frame_idx)] = centroid
else:
centers.pop(int(frame_idx), None)
seen_obj_ids.add(int(obj_id))
_ensure_color_for_obj(state, int(obj_id))
mask_np = np.array(mask)
if mask_np.ndim == 3:
mask_np = mask_np.squeeze()
mask_np = np.clip(mask_np, 0.0, 1.0)
area = float(np.count_nonzero(mask_np > 0.3))
areas = state.mask_areas.setdefault(int(obj_id), {})
areas[int(frame_idx)] = area
# Remove frames for objects without masks at this frame
for obj_id, centers in state.ball_centers.items():
if obj_id not in seen_obj_ids:
centers.pop(int(frame_idx), None)
for obj_id, areas in state.mask_areas.items():
if obj_id not in seen_obj_ids:
areas.pop(int(frame_idx), None)
_recompute_motion_metrics(state)
def _run_kalman_filter(
ordered_items: list[tuple[int, tuple[float, float]]],
base_dt: float,
) -> tuple[dict[int, tuple[float, float]], dict[int, float], dict[int, float]]:
if not ordered_items:
return {}, {}, {}
H = np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=float)
R = np.eye(2, dtype=float) * 25.0
state_vec = np.array(
[ordered_items[0][1][0], ordered_items[0][1][1], 0.0, 0.0], dtype=float
)
P = np.eye(4, dtype=float) * 50.0
positions: dict[int, tuple[float, float]] = {}
speeds: dict[int, float] = {}
residuals: dict[int, float] = {}
prev_frame = ordered_items[0][0]
for frame_idx, (cx, cy) in ordered_items:
frame_delta = max(1, frame_idx - prev_frame) if frame_idx != prev_frame else 1
dt = frame_delta * base_dt
F = np.array(
[
[1, 0, dt, 0],
[0, 1, 0, dt],
[0, 0, 1, 0],
[0, 0, 0, 1],
],
dtype=float,
)
q = 0.5 * dt**2
Q = np.array(
[
[q, 0, dt, 0],
[0, q, 0, dt],
[dt, 0, 1, 0],
[0, dt, 0, 1],
],
dtype=float,
) * 0.05
state_vec = F @ state_vec
P = F @ P @ F.T + Q
z = np.array([cx, cy], dtype=float)
innovation = z - H @ state_vec
S = H @ P @ H.T + R
K = P @ H.T @ np.linalg.inv(S)
state_vec = state_vec + K @ innovation
P = (np.eye(4) - K @ H) @ P
positions[frame_idx] = (state_vec[0], state_vec[1])
speeds[frame_idx] = float(math.hypot(state_vec[2], state_vec[3]))
residuals[frame_idx] = float(math.hypot(innovation[0], innovation[1]))
prev_frame = frame_idx
return positions, speeds, residuals
def _build_kick_plot(state: AppState):
fig = go.Figure()
if state is None or not state.kick_debug_frames or not state.kick_debug_speeds:
fig.update_layout(
title="Kick & impact diagnostics",
xaxis_title="Frame",
yaxis_title="Speed (px/s)",
)
return fig
frames = state.kick_debug_frames
speeds = state.kick_debug_speeds
areas = state.kick_debug_area if state.kick_debug_area else [0.0] * len(frames)
threshold = state.kick_debug_threshold or 0.0
baseline = state.kick_debug_baseline or 0.0
kick_frame = state.kick_debug_kick_frame
distance = state.kick_debug_distance if state.kick_debug_distance else [0.0] * len(frames)
impact_frames = state.impact_debug_frames if state.impact_debug_frames else frames
fig.add_trace(
go.Scatter(
x=frames,
y=speeds,
mode="lines+markers",
name="Speed (px/s)",
line=dict(color="#1f77b4"),
)
)
fig.add_trace(
go.Scatter(
x=[frames[0], frames[-1]],
y=[threshold, threshold],
mode="lines",
name="Adaptive threshold",
line=dict(color="#d62728", dash="dash"),
)
)
fig.add_trace(
go.Scatter(
x=[frames[0], frames[-1]],
y=[baseline, baseline],
mode="lines",
name="Baseline speed",
line=dict(color="#ff7f0e", dash="dot"),
)
)
fig.add_trace(
go.Scatter(
x=frames,
y=areas,
mode="lines",
name="Mask area",
line=dict(color="#2ca02c"),
yaxis="y2",
)
)
max_primary = max(
max(speeds) if speeds else 0.0,
threshold,
baseline,
max(state.kick_debug_kalman_speeds) if state.kick_debug_kalman_speeds else 0.0,
state.impact_debug_innovation_threshold or 0.0,
state.impact_debug_direction_threshold or 0.0,
state.impact_debug_speed_threshold_px or 0.0,
1.0,
)
max_distance = max(distance) if distance else 0.0
if max_distance > 0 and max_primary > 0:
distance_scaled = [d * (max_primary / max_distance) for d in distance]
else:
distance_scaled = distance
fig.add_trace(
go.Scatter(
x=frames,
y=distance_scaled,
mode="lines",
name="Distance from start (scaled)",
line=dict(color="#9467bd"),
)
)
if state.kick_debug_kalman_speeds:
fig.add_trace(
go.Scatter(
x=frames,
y=state.kick_debug_kalman_speeds,
mode="lines",
name="Kalman speed",
line=dict(color="#8c564b"),
)
)
if state.impact_debug_innovation:
fig.add_trace(
go.Scatter(
x=impact_frames,
y=state.impact_debug_innovation,
mode="lines",
name="Kalman innovation",
line=dict(color="#17becf"),
)
)
max_primary = max(max_primary, max(state.impact_debug_innovation))
if (
state.impact_debug_innovation_threshold is not None
and impact_frames
and len(impact_frames) >= 2
):
fig.add_trace(
go.Scatter(
x=[impact_frames[0], impact_frames[-1]],
y=[
state.impact_debug_innovation_threshold,
state.impact_debug_innovation_threshold,
],
mode="lines",
name="Innovation threshold",
line=dict(color="#17becf", dash="dash"),
)
)
max_primary = max(max_primary, state.impact_debug_innovation_threshold or 0.0)
if state.impact_debug_direction:
fig.add_trace(
go.Scatter(
x=impact_frames,
y=state.impact_debug_direction,
mode="lines",
name="Direction change (deg)",
line=dict(color="#bcbd22"),
)
)
max_primary = max(max_primary, max(state.impact_debug_direction))
if (
state.impact_debug_direction_threshold is not None
and impact_frames
and len(impact_frames) >= 2
):
fig.add_trace(
go.Scatter(
x=[impact_frames[0], impact_frames[-1]],
y=[
state.impact_debug_direction_threshold,
state.impact_debug_direction_threshold,
],
mode="lines",
name="Direction threshold (deg)",
line=dict(color="#bcbd22", dash="dot"),
)
)
max_primary = max(max_primary, state.impact_debug_direction_threshold or 0.0)
if state.impact_debug_speed_threshold_px:
fig.add_trace(
go.Scatter(
x=[frames[0], frames[-1]],
y=[state.impact_debug_speed_threshold_px] * 2,
mode="lines",
name="Min impact speed (px/s)",
line=dict(color="#b82e2e", dash="dot"),
)
)
max_primary = max(max_primary, state.impact_debug_speed_threshold_px or 0.0)
if kick_frame is not None:
fig.add_trace(
go.Scatter(
x=[kick_frame, kick_frame],
y=[0, max_primary * 1.05],
mode="lines",
name="Detected kick",
line=dict(color="#ff00ff", dash="solid", width=3),
)
)
impact_frame = state.impact_frame
if impact_frame is not None:
fig.add_trace(
go.Scatter(
x=[impact_frame, impact_frame],
y=[0, max_primary * 1.05],
mode="lines",
name="Detected impact",
line=dict(color="#ff1493", width=3),
)
)
fig.update_layout(
title="Kick & impact diagnostics",
xaxis_title="Frame",
yaxis_title="Speed (px/s)",
yaxis=dict(side="left"),
yaxis2=dict(
title="Mask area / Distance / Direction",
overlaying="y",
side="right",
showgrid=False,
),
legend=dict(orientation="h"),
margin=dict(t=40, l=40, r=40, b=40),
)
return fig
def _ensure_ball_prompt_from_yolo(state: AppState):
if (
state is None
or state.inference_session is None
or not state.yolo_ball_centers
):
return
# Check if we already have clicks for the ball
for frame_clicks in state.clicks_by_frame_obj.values():
if frame_clicks.get(BALL_OBJECT_ID):
return
anchor_frame = state.yolo_initial_frame
if anchor_frame is None and state.yolo_ball_centers:
anchor_frame = min(state.yolo_ball_centers.keys())
if anchor_frame is None or anchor_frame >= state.num_frames:
return
center = state.yolo_ball_centers.get(anchor_frame)
if center is None:
return
x_center, y_center = center
frame_width, frame_height = state.video_frames[anchor_frame].size
x_center = int(np.clip(round(x_center), 0, frame_width - 1))
y_center = int(np.clip(round(y_center), 0, frame_height - 1))
event = SimpleNamespace(
index=(x_center, y_center),
value={"x": x_center, "y": y_center},
)
state.current_obj_id = BALL_OBJECT_ID
state.current_label = "positive"
state.current_frame_idx = anchor_frame
on_image_click(
update_frame_display(state, anchor_frame),
state,
anchor_frame,
BALL_OBJECT_ID,
"positive",
False,
event,
)
def _build_yolo_plot(state: AppState):
fig = go.Figure()
if state is None or not state.yolo_kick_frames or not state.yolo_kick_speeds:
fig.update_layout(
title="YOLO kick diagnostics",
xaxis_title="Frame",
yaxis_title="Speed (px/s)",
)
return fig
frames = state.yolo_kick_frames
speeds = state.yolo_kick_speeds
distance = state.yolo_kick_distance if state.yolo_kick_distance else [0.0] * len(frames)
areas = state.yolo_mask_area_proxy if state.yolo_mask_area_proxy else [0.0] * len(frames)
threshold = state.yolo_threshold or 0.0
baseline = state.yolo_baseline_speed or 0.0
kick_frame = state.yolo_kick_frame
fig.add_trace(
go.Scatter(
x=frames,
y=speeds,
mode="lines+markers",
name="YOLO speed",
line=dict(color="#4caf50"),
)
)
fig.add_trace(
go.Scatter(
x=frames,
y=[threshold] * len(frames),
mode="lines",
name="Adaptive threshold",
line=dict(color="#ff9800", dash="dash"),
)
)
fig.add_trace(
go.Scatter(
x=frames,
y=[baseline] * len(frames),
mode="lines",
name="Baseline speed",
line=dict(color="#9e9e9e", dash="dot"),
)
)
fig.add_trace(
go.Scatter(
x=frames,
y=distance,
mode="lines",
name="Distance from start",
line=dict(color="#03a9f4"),
yaxis="y2",
)
)
fig.add_trace(
go.Scatter(
x=frames,
y=areas,
mode="lines",
name="Box area proxy",
line=dict(color="#ab47bc", dash="dot"),
yaxis="y2",
)
)
if kick_frame is not None:
fig.add_vline(
x=kick_frame,
line=dict(color="#e91e63", width=2),
annotation_text=f"Kick {kick_frame}",
annotation_position="top right",
)
fig.update_layout(
title="YOLO kick diagnostics",
xaxis=dict(title="Frame"),
yaxis=dict(title="Speed (px/s)"),
yaxis2=dict(
title="Distance / Area",
overlaying="y",
side="right",
showgrid=False,
),
legend=dict(orientation="h"),
margin=dict(t=40, l=40, r=40, b=40),
)
return fig
def _build_multi_ball_chart(state: AppState):
"""
Build a combined speed chart showing all ball candidates.
The selected/kicked ball is highlighted in green, others in gray.
"""
fig = go.Figure()
if state is None or not state.ball_candidates:
fig.update_layout(
title="Ball Candidates Speed Comparison",
xaxis_title="Frame",
yaxis_title="Speed (px/s)",
)
return fig
# Color palette for candidates
colors = [
"#4caf50", # Green (selected/kicked)
"#9e9e9e", # Gray
"#bdbdbd", # Light gray
"#757575", # Dark gray
"#e0e0e0", # Very light gray
]
selected_idx = state.selected_ball_idx
max_speed = 0.0
kick_frames_to_mark = []
for i, candidate in enumerate(state.ball_candidates):
tracking = candidate.get("tracking", {})
frames = tracking.get("frames_ordered", [])
speeds = tracking.get("speed_series", [])
if not frames or not speeds:
continue
max_speed = max(max_speed, max(speeds) if speeds else 0)
is_selected = (i == selected_idx)
is_kicked = candidate.get("has_kick", False)
# Determine color and style
if is_selected:
color = "#4caf50" # Green
width = 3
opacity = 1.0
elif is_kicked:
color = "#ff9800" # Orange for other kicked balls
width = 2
opacity = 0.7
else:
color = "#9e9e9e" # Gray
width = 1
opacity = 0.5
# Build label
label_parts = [f"Ball {i+1}"]
if is_kicked:
label_parts.append("⚽")
if is_selected:
label_parts.append("✓")
label = " ".join(label_parts)
fig.add_trace(
go.Scatter(
x=frames,
y=speeds,
mode="lines",
name=label,
line=dict(color=color, width=width),
opacity=opacity,
)
)
# Mark kick frame
kick_frame = candidate.get("kick_frame")
if kick_frame is not None:
kick_frames_to_mark.append((kick_frame, i, is_selected))
# Add vertical lines for kick frames
for kick_frame, ball_idx, is_selected in kick_frames_to_mark:
color = "#e91e63" if is_selected else "#ffcdd2"
width = 3 if is_selected else 1
fig.add_vline(
x=kick_frame,
line=dict(color=color, width=width, dash="solid" if is_selected else "dot"),
annotation_text=f"Ball {ball_idx+1} kick" if is_selected else "",
annotation_position="top right" if is_selected else None,
)
fig.update_layout(
title="Ball Candidates Speed Comparison",
xaxis=dict(title="Frame"),
yaxis=dict(title="Speed (px/s)", range=[0, max_speed * 1.1] if max_speed > 0 else None),
legend=dict(orientation="h", yanchor="bottom", y=1.02),
margin=dict(t=60, l=40, r=40, b=40),
hovermode="x unified",
)
return fig
def _jump_to_frame(state: AppState, target: int | None):
if state is None or state.num_frames == 0 or target is None:
return gr.update(), gr.update()
idx = int(np.clip(int(target), 0, state.num_frames - 1))
state.current_frame_idx = idx
return (
update_frame_display(state, idx),
gr.update(value=idx),
)
def _jump_to_yolo_kick(state: AppState):
return _jump_to_frame(state, getattr(state, "yolo_kick_frame", None))
def _jump_to_sam_kick(state: AppState):
return _jump_to_frame(state, _get_prioritized_kick_frame(state))
def _jump_to_sam_impact(state: AppState):
impact = getattr(state, "impact_frame", None)
if impact is None:
frames = getattr(state, "impact_debug_frames", [])
if frames:
impact = frames[-1]
return _jump_to_frame(state, impact)
def _jump_to_manual_kick(state: AppState):
return _jump_to_frame(state, getattr(state, "manual_kick_frame", None))
def _jump_to_manual_impact(state: AppState):
return _jump_to_frame(state, getattr(state, "manual_impact_frame", None))
def _format_impact_status(state: AppState) -> str:
def fmt(value: int | None) -> str:
return str(int(value)) if value is not None else "N/A"
def impact_value(st: AppState | None) -> int | None:
if st is None:
return None
if st.impact_frame is not None:
return st.impact_frame
debug_frames = getattr(st, "impact_debug_frames", [])
if debug_frames:
return debug_frames[-1]
return None
yolo_kick = fmt(getattr(state, "yolo_kick_frame", None) if state else None)
sam_kick = fmt(_get_prioritized_kick_frame(state))
sam_impact = fmt(impact_value(state))
lines = [
f"YOLO13 · Kick ⚽ {yolo_kick} · Impact 🚩 N/A",
f"SAM2 · Kick ⚽ {sam_kick} · Impact 🚩 {sam_impact}",
]
return "\n".join(lines)
def _format_kick_text(state: AppState) -> str:
if state is None:
return "Kick: n/a"
parts = []
if getattr(state, "yolo_kick_frame", None) is not None:
parts.append(f"YOLO ⚽ {state.yolo_kick_frame}")
sam_kick = _get_prioritized_kick_frame(state)
if sam_kick is not None:
parts.append(f"SAM ⚽ {sam_kick}")
if parts:
return " | ".join(parts)
return "Kick: n/a"
def _format_impact_text(state: AppState) -> str:
if state is None:
return "Impact: n/a"
impact = getattr(state, "impact_frame", None)
if impact is None:
frames = getattr(state, "impact_debug_frames", [])
if frames:
impact = frames[-1]
return f"SAM 🚩 {impact}" if impact is not None else "Impact: n/a"
def _format_kick_status(state: AppState) -> str:
if state is None or not isinstance(state, AppState):
return "Kick frame: not computed"
frame = state.kick_frame
if frame is None:
frame = getattr(state, "kick_debug_kick_frame", None)
if frame is None:
if state.kick_debug_frames:
return "Kick frame: not detected"
return "Kick frame: not computed"
if state.kick_frame is None and frame is not None:
state.kick_frame = frame
time_part = ""
if state.video_fps and state.video_fps > 1e-6:
time_part = f" (~{frame / state.video_fps:.2f}s)"
return f"Kick frame ≈ {frame}{time_part}"
def _mark_kick_frame(state: AppState, frame_value: float):
if state is None or state.num_frames == 0:
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state)
status_updates = _ui_status_updates(state)
return (
gr.update(),
gr.update(value="Load a video first.", visible=True),
gr.update(),
_build_kick_plot(state),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
idx = int(np.clip(int(frame_value), 0, state.num_frames - 1))
state.kick_frame = idx
state.manual_kick_frame = idx
_compute_sam_window_from_kick(state, idx)
state.current_frame_idx = idx
msg = f"⚽ Kick frame manually set to {idx}"
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state)
status_updates = _ui_status_updates(state)
return (
update_frame_display(state, idx),
gr.update(value=msg, visible=True),
gr.update(value=idx),
_build_kick_plot(state),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
def _mark_impact_frame(state: AppState, frame_value: float):
if state is None or state.num_frames == 0:
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state)
status_updates = _ui_status_updates(state)
return (
gr.update(),
gr.update(value="Load a video first.", visible=True),
gr.update(),
_build_kick_plot(state),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
idx = int(np.clip(int(frame_value), 0, state.num_frames - 1))
state.impact_frame = idx
state.manual_impact_frame = idx
msg = f"🚩 Impact frame manually set to {idx}"
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state)
status_updates = _ui_status_updates(state)
return (
update_frame_display(state, idx),
gr.update(value=msg, visible=True),
gr.update(value=idx),
_build_kick_plot(state),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
def _kick_button_updates(state: AppState) -> tuple[Any, ...]:
def fmt(symbol: str, value: int | None, clickable: bool = True):
text = f"{symbol}: {value if value is not None else 'N/A'}"
return gr.update(value=text, interactive=clickable and value is not None)
yolo_kick = getattr(state, "yolo_kick_frame", None)
sam_kick = _get_prioritized_kick_frame(state)
sam_impact = getattr(state, "impact_frame", None)
if sam_impact is None:
frames = getattr(state, "impact_debug_frames", [])
if frames:
sam_impact = frames[-1]
manual_kick = getattr(state, "manual_kick_frame", None)
manual_impact = getattr(state, "manual_impact_frame", None)
return (
fmt("⚽", yolo_kick),
fmt("🚩", None, clickable=False),
fmt("⚽", sam_kick),
fmt("🚩", sam_impact),
fmt("⚽", manual_kick),
fmt("🚩", manual_impact),
)
def _impact_status_update(state: AppState):
return gr.update(value=_format_impact_status(state), visible=False)
def _ball_has_masks(state: AppState, target_obj_id: int = BALL_OBJECT_ID) -> bool:
if state is None:
return False
for masks in state.masks_by_frame.values():
if target_obj_id in masks:
return True
return False
def _player_has_masks(state: AppState) -> bool:
if state is None or state.player_obj_id is None:
return False
player_id = state.player_obj_id
for masks in state.masks_by_frame.values():
if player_id in masks:
return True
return False
def _button_updates(state: AppState) -> tuple[Any, Any, Any]:
yolo_ready = isinstance(state, AppState) and state.yolo_kick_frame is not None
propagate_main_enabled = _ball_has_masks(state) or yolo_ready
detect_player_enabled = yolo_ready
propagate_player_enabled = _player_has_masks(state)
sam_tracked = isinstance(state, AppState) and getattr(state, "is_sam_tracked", False)
player_detected = isinstance(state, AppState) and getattr(state, "is_player_detected", False)
player_propagated = isinstance(state, AppState) and getattr(state, "is_player_propagated", False)
sam_variant = "secondary" if sam_tracked else "stop"
detect_variant = "secondary" if player_detected else "stop"
propagate_variant = "secondary" if player_propagated else "stop"
return (
gr.update(interactive=propagate_main_enabled, variant=sam_variant),
gr.update(interactive=detect_player_enabled, variant=detect_variant),
gr.update(interactive=propagate_player_enabled, variant=propagate_variant),
)
def _ball_button_updates(state: AppState) -> tuple[Any, Any]:
def variant(flag: bool) -> str:
return "secondary" if flag else "stop"
ball_detected = isinstance(state, AppState) and getattr(state, "is_ball_detected", False)
yolo_tracked = isinstance(state, AppState) and getattr(state, "is_yolo_tracked", False)
return (
gr.update(variant=variant(ball_detected)),
gr.update(variant=variant(yolo_tracked)),
)
def _ui_status_updates(state: AppState) -> tuple[Any, ...]:
return _kick_button_updates(state) + _ball_button_updates(state) + _goal_button_updates(state)
def _recompute_motion_metrics(state: AppState, target_obj_id: int = 1):
centers = state.ball_centers.get(target_obj_id)
if not centers or len(centers) < 3:
state.smoothed_centers[target_obj_id] = {}
state.ball_speeds[target_obj_id] = {}
state.kick_frame = None
state.kick_debug_frames = []
state.kick_debug_speeds = []
state.kick_debug_threshold = None
state.kick_debug_baseline = None
state.kick_debug_speed_std = None
state.kick_debug_area = []
state.kick_debug_kick_frame = None
state.kick_debug_distance = []
state.kalman_centers[target_obj_id] = {}
state.kalman_speeds[target_obj_id] = {}
state.kalman_residuals[target_obj_id] = {}
state.kick_debug_kalman_speeds = []
state.distance_from_start[target_obj_id] = {}
state.direction_change[target_obj_id] = {}
state.impact_frame = None
state.impact_debug_frames = []
state.impact_debug_innovation = []
state.impact_debug_innovation_threshold = None
state.impact_debug_direction = []
state.impact_debug_direction_threshold = None
state.impact_debug_speed_kmh = []
state.impact_debug_speed_threshold_px = None
state.impact_meters_per_px = None
return
items = sorted(centers.items())
dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0
alpha = 0.35
smoothed: dict[int, tuple[float, float]] = {}
speeds: dict[int, float] = {}
prev_frame = None
prev_smooth = None
for frame_idx, (cx, cy) in items:
if prev_smooth is None:
smooth_x, smooth_y = float(cx), float(cy)
else:
smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0])
smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1])
smoothed[frame_idx] = (smooth_x, smooth_y)
if prev_smooth is None or prev_frame is None:
speeds[frame_idx] = 0.0
else:
frame_delta = max(1, frame_idx - prev_frame)
time_delta = frame_delta * dt
dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1])
speed = dist / time_delta if time_delta > 0 else dist
speeds[frame_idx] = speed
prev_smooth = (smooth_x, smooth_y)
prev_frame = frame_idx
state.smoothed_centers[target_obj_id] = smoothed
state.ball_speeds[target_obj_id] = speeds
if smoothed:
first_frame = min(smoothed.keys())
origin = smoothed[first_frame]
distance_dict: dict[int, float] = {}
for frame_idx, (sx, sy) in smoothed.items():
distance_dict[frame_idx] = math.hypot(sx - origin[0], sy - origin[1])
state.distance_from_start[target_obj_id] = distance_dict
state.kick_debug_distance = [distance_dict.get(f, 0.0) for f in sorted(smoothed.keys())]
kalman_pos, kalman_speed, kalman_res = _run_kalman_filter(items, dt)
state.kalman_centers[target_obj_id] = kalman_pos
state.kalman_speeds[target_obj_id] = kalman_speed
state.kalman_residuals[target_obj_id] = kalman_res
state.kick_frame = _detect_kick_frame(state, target_obj_id)
state.impact_frame = _detect_impact_frame(state, target_obj_id)
def _detect_kick_frame(state: AppState, target_obj_id: int) -> int | None:
smoothed = state.smoothed_centers.get(target_obj_id, {})
speeds = state.ball_speeds.get(target_obj_id, {})
if len(smoothed) < 5:
return None
frames = sorted(smoothed.keys())
speed_series = [speeds.get(f, 0.0) for f in frames]
baseline_window = min(10, len(frames) // 3 or 1)
baseline_speeds = speed_series[:baseline_window]
baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0
speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0
base_threshold = baseline_speed + 4.0 * speed_std
if base_threshold < baseline_speed * 3.0:
base_threshold = baseline_speed * 3.0
speed_threshold = max(base_threshold, 15.0)
sustain_frames = 3
holdout_frames = 8
area_window = 4
area_drop_ratio = 0.75
areas_dict = state.mask_areas.get(target_obj_id, {})
initial_center = smoothed[frames[0]]
initial_area = areas_dict.get(frames[0], 1.0) or 1.0
radius_estimate = math.sqrt(initial_area / math.pi)
adaptive_return_distance = max(8.0, min(radius_estimate * 1.5, 40.0))
state.kick_debug_frames = frames
state.kick_debug_speeds = speed_series
state.kick_debug_threshold = speed_threshold
state.kick_debug_baseline = baseline_speed
state.kick_debug_speed_std = speed_std
state.kick_debug_area = [areas_dict.get(f, 0.0) for f in frames]
state.kick_debug_distance = [
math.hypot(smoothed[f][0] - initial_center[0], smoothed[f][1] - initial_center[1])
for f in frames
]
kalman_speed_dict = state.kalman_speeds.get(target_obj_id, {})
state.kick_debug_kalman_speeds = [kalman_speed_dict.get(f, 0.0) for f in frames]
state.kick_debug_kick_frame = None
for idx in range(baseline_window, len(frames)):
frame = frames[idx]
speed = speed_series[idx]
if speed < speed_threshold:
continue
sustain_ok = True
for j in range(1, sustain_frames + 1):
if idx + j >= len(frames):
break
if speed_series[idx + j] < speed_threshold * 0.7:
sustain_ok = False
break
if not sustain_ok:
continue
current_area = areas_dict.get(frame)
area_pass = True
if current_area:
prev_areas = [
areas_dict.get(f)
for f in frames[max(0, idx - area_window):idx]
if areas_dict.get(f) is not None
]
if prev_areas:
median_prev = statistics.median(prev_areas)
if median_prev > 0:
ratio = current_area / median_prev
if ratio > area_drop_ratio:
area_pass = False
if not area_pass and speed < speed_threshold * 1.2:
continue
future_frames = frames[idx:min(len(frames), idx + holdout_frames)]
max_future_dist = 0.0
for future_frame in future_frames:
cx, cy = smoothed[future_frame]
dist = math.hypot(cx - initial_center[0], cy - initial_center[1])
if dist > max_future_dist:
max_future_dist = dist
if max_future_dist < adaptive_return_distance:
continue
state.kick_debug_kick_frame = frame
return frame
state.kick_debug_kick_frame = None
return None
def _detect_impact_frame(state: AppState, target_obj_id: int) -> int | None:
residuals = state.kalman_residuals.get(target_obj_id, {})
frames = sorted(residuals.keys())
state.impact_debug_frames = frames
state.impact_debug_innovation = [residuals.get(f, 0.0) for f in frames]
state.impact_debug_innovation_threshold = None
state.impact_debug_direction = []
state.impact_debug_direction_threshold = None
state.impact_debug_speed_kmh = []
state.impact_debug_speed_threshold_px = None
state.impact_meters_per_px = None
if not frames or state.kick_frame is None:
state.impact_frame = None
return None
kalman_positions = state.kalman_centers.get(target_obj_id, {})
direction_dict: dict[int, float] = {}
prev_pos: tuple[float, float] | None = None
prev_vec: tuple[float, float] | None = None
for frame in frames:
pos = kalman_positions.get(frame)
if pos is None:
direction_dict[frame] = 0.0
continue
if prev_pos is None:
direction_dict[frame] = 0.0
prev_vec = (0.0, 0.0)
else:
vec = (pos[0] - prev_pos[0], pos[1] - prev_pos[1])
if prev_vec is None:
direction_dict[frame] = 0.0
else:
direction_dict[frame] = _angle_between(prev_vec, vec)
prev_vec = vec
prev_pos = pos
state.direction_change[target_obj_id] = direction_dict
state.impact_debug_direction = [direction_dict.get(f, 0.0) for f in frames]
distance_dict = state.distance_from_start.get(target_obj_id, {})
max_distance_px = max(distance_dict.values()) if distance_dict else 0.0
goal_distance_m = max(state.goal_distance_m, 0.0)
meters_per_px = goal_distance_m / max_distance_px if goal_distance_m > 0 and max_distance_px > 1e-6 else None
state.impact_meters_per_px = meters_per_px
kalman_speed_dict = state.kalman_speeds.get(target_obj_id, {})
if meters_per_px:
state.impact_debug_speed_kmh = [
kalman_speed_dict.get(f, 0.0) * meters_per_px * 3.6 for f in frames
]
if state.min_impact_speed_kmh > 0:
state.impact_debug_speed_threshold_px = (state.min_impact_speed_kmh / 3.6) / meters_per_px
else:
state.impact_debug_speed_kmh = [0.0 for _ in frames]
state.impact_debug_speed_threshold_px = None
baseline_frames = [f for f in frames if f <= state.kick_frame]
if not baseline_frames:
baseline_frames = frames[: max(1, min(len(frames), 10))]
baseline_vals = [residuals.get(f, 0.0) for f in baseline_frames]
baseline_median = statistics.median(baseline_vals) if baseline_vals else 0.0
baseline_std = statistics.pstdev(baseline_vals) if len(baseline_vals) > 1 else 0.0
innovation_threshold = baseline_median + 4.0 * baseline_std
innovation_threshold = max(innovation_threshold, baseline_median * 3.0, 5.0)
state.impact_debug_innovation_threshold = innovation_threshold
direction_threshold = 25.0
state.impact_debug_direction_threshold = direction_threshold
post_kick_buffer = 3
candidates: list[tuple[float, float, int]] = []
meters_limit = goal_distance_m * 1.1 if goal_distance_m > 0 else None
frame_list_len = len(frames)
for idx, frame in enumerate(frames):
if frame <= state.kick_frame + post_kick_buffer:
continue
innovation = residuals.get(frame, 0.0)
if innovation < innovation_threshold:
continue
direction_delta = direction_dict.get(frame, 0.0)
if direction_delta < direction_threshold:
continue
speed_px = kalman_speed_dict.get(frame, 0.0)
if state.impact_debug_speed_threshold_px and speed_px < state.impact_debug_speed_threshold_px:
continue
if meters_per_px and meters_limit is not None:
distance_m = distance_dict.get(frame, 0.0) * meters_per_px
if distance_m > meters_limit:
continue
# approximate local peak filter
prev_innovation = residuals.get(frames[idx - 1], innovation) if idx > 0 else innovation
next_innovation = residuals.get(frames[idx + 1], innovation) if idx + 1 < frame_list_len else innovation
if innovation < prev_innovation and innovation < next_innovation:
continue
candidates.append((innovation, -frame, frame))
if not candidates:
state.impact_frame = None
return None
candidates.sort(reverse=True)
impact_frame = candidates[0][2]
state.impact_frame = impact_frame
return impact_frame
def on_image_click(
img: Image.Image | np.ndarray,
state: AppState,
frame_idx: int,
obj_id: int,
label: str,
clear_old: bool,
evt: gr.SelectData,
) -> Image.Image:
if state is None or state.inference_session is None:
return img # no-op preview when not ready
if state.is_switching_model:
# Gracefully ignore input during model switch; return current preview unchanged
return update_frame_display(state, int(frame_idx))
# Parse click coordinates from event
x = y = None
if evt is not None:
# Try different gradio event data shapes for robustness
try:
if hasattr(evt, "index") and isinstance(evt.index, (list, tuple)) and len(evt.index) == 2:
x, y = int(evt.index[0]), int(evt.index[1])
elif hasattr(evt, "value") and isinstance(evt.value, dict) and "x" in evt.value and "y" in evt.value:
x, y = int(evt.value["x"]), int(evt.value["y"])
except Exception:
x = y = None
if x is None or y is None:
raise gr.Error("Could not read click coordinates.")
_ensure_color_for_obj(state, int(obj_id))
processor = state.processor
model = state.model
inference_session = state.inference_session
original_size = None
pixel_values = None
if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
inputs = processor(images=state.video_frames[frame_idx], device=state.device, return_tensors="pt")
original_size = inputs.original_sizes[0]
pixel_values = inputs.pixel_values[0]
if state.current_prompt_type == "Boxes":
# Two-click box input
if state.pending_box_start is None:
# For boxes, always clear old inputs (points) for this object on this frame
frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {})
frame_clicks[int(obj_id)] = []
state.composited_frames.pop(int(frame_idx), None)
state.pending_box_start = (int(x), int(y))
state.pending_box_start_frame_idx = int(frame_idx)
state.pending_box_start_obj_id = int(obj_id)
# Invalidate cache so temporary cross is drawn
state.composited_frames.pop(int(frame_idx), None)
return update_frame_display(state, int(frame_idx))
else:
x1, y1 = state.pending_box_start
x2, y2 = int(x), int(y)
# Clear temporary state and invalidate cache
state.pending_box_start = None
state.pending_box_start_frame_idx = None
state.pending_box_start_obj_id = None
state.composited_frames.pop(int(frame_idx), None)
x_min, y_min = min(x1, x2), min(y1, y2)
x_max, y_max = max(x1, x2), max(y1, y2)
processor.add_inputs_to_inference_session(
inference_session=inference_session,
frame_idx=int(frame_idx),
obj_ids=int(obj_id),
input_boxes=[[[x_min, y_min, x_max, y_max]]],
clear_old_inputs=True, # For boxes, always clear old inputs
original_size=original_size,
)
frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
obj_boxes = frame_boxes.setdefault(int(obj_id), [])
# For boxes, always clear old inputs
obj_boxes.clear()
obj_boxes.append((x_min, y_min, x_max, y_max))
state.composited_frames.pop(int(frame_idx), None)
else:
# Points mode
label_int = 1 if str(label).lower().startswith("pos") else 0
# If clear_old is enabled, clear prior boxes for this object on this frame
if bool(clear_old):
frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
frame_boxes[int(obj_id)] = []
state.composited_frames.pop(int(frame_idx), None)
processor.add_inputs_to_inference_session(
inference_session=inference_session,
frame_idx=int(frame_idx),
obj_ids=int(obj_id),
input_points=[[[[int(x), int(y)]]]],
input_labels=[[[int(label_int)]]],
original_size=original_size,
clear_old_inputs=bool(clear_old),
)
frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {})
obj_clicks = frame_clicks.setdefault(int(obj_id), [])
if bool(clear_old):
obj_clicks.clear()
obj_clicks.append((int(x), int(y), int(label_int)))
state.composited_frames.pop(int(frame_idx), None)
# Forward on that frame
with torch.inference_mode():
outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=int(frame_idx))
H = inference_session.video_height
W = inference_session.video_width
# Detach and move off GPU as early as possible to reduce GPU memory pressure
pred_masks = outputs.pred_masks.detach().cpu()
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
# Map returned masks to object ids. For single object forward, it's [1, 1, H, W]
# But to be safe, iterate over session.obj_ids order.
masks_for_frame: dict[int, np.ndarray] = {}
obj_ids_order = list(inference_session.obj_ids)
for i, oid in enumerate(obj_ids_order):
mask_i = video_res_masks[i]
# mask_i shape could be (1, H, W) or (H, W); squeeze to 2D
mask_2d = mask_i.cpu().numpy().squeeze()
masks_for_frame[int(oid)] = mask_2d
state.masks_by_frame[int(frame_idx)] = masks_for_frame
_update_centroids_for_frame(state, int(frame_idx))
# Invalidate cache for this frame to force recomposition
state.composited_frames.pop(int(frame_idx), None)
# Return updated preview
return update_frame_display(state, int(frame_idx))
def _on_image_click_with_updates(
img: Image.Image | np.ndarray,
state: AppState,
frame_idx: int,
obj_id: int,
label: str,
clear_old: bool,
evt: gr.SelectData,
):
frame_idx = int(frame_idx)
handled_preview = None
handled = False
if state is not None and state.goal_mode != GOAL_MODE_IDLE:
handled_preview, handled = _goal_process_preview_click(state, frame_idx, evt)
if handled and handled_preview is not None:
preview_img = handled_preview
else:
preview_img = on_image_click(img, state, frame_idx, obj_id, label, clear_old, evt)
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state)
status_updates = _ui_status_updates(state)
return (
preview_img,
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
@spaces.GPU()
def propagate_masks(GLOBAL_STATE: gr.State):
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
# yield GLOBAL_STATE, "Load a video first.", gr.update()
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
status_updates = _ui_status_updates(GLOBAL_STATE)
return (
GLOBAL_STATE,
"Load a video first.",
gr.update(),
_build_kick_plot(GLOBAL_STATE),
_build_yolo_plot(GLOBAL_STATE),
_impact_status_update(GLOBAL_STATE),
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
_ensure_ball_prompt_from_yolo(GLOBAL_STATE)
processor = deepcopy(GLOBAL_STATE.processor)
model = deepcopy(GLOBAL_STATE.model)
inference_session = deepcopy(GLOBAL_STATE.inference_session)
# set inference device to cuda to use zero gpu
inference_session.inference_device = "cuda"
inference_session.cache.inference_device = "cuda"
model.to("cuda")
if not GLOBAL_STATE.sam_window:
_compute_sam_window_from_kick(
GLOBAL_STATE,
_get_prioritized_kick_frame(GLOBAL_STATE),
)
start_idx, end_idx = GLOBAL_STATE.sam_window or (0, GLOBAL_STATE.num_frames)
start_idx = max(0, int(start_idx))
end_idx = min(GLOBAL_STATE.num_frames, max(start_idx + 1, int(end_idx)))
total = max(1, end_idx - start_idx)
processed = 0
_ensure_ball_prompt_from_yolo(GLOBAL_STATE)
# Initial status; no slider change yet
GLOBAL_STATE.is_sam_tracked = False
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
status_updates = _ui_status_updates(GLOBAL_STATE)
yield (
GLOBAL_STATE,
f"Propagating masks: {processed}/{total}",
gr.update(),
_build_kick_plot(GLOBAL_STATE),
_build_yolo_plot(GLOBAL_STATE),
_impact_status_update(GLOBAL_STATE),
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
last_frame_idx = start_idx
with torch.inference_mode():
for frame_idx in range(start_idx, end_idx):
frame = GLOBAL_STATE.video_frames[frame_idx]
pixel_values = None
if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0]
sam2_video_output = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx)
H = inference_session.video_height
W = inference_session.video_width
pred_masks = sam2_video_output.pred_masks.detach().cpu()
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
last_frame_idx = frame_idx
masks_for_frame: dict[int, np.ndarray] = {}
obj_ids_order = list(inference_session.obj_ids)
for i, oid in enumerate(obj_ids_order):
mask_2d = video_res_masks[i].cpu().numpy().squeeze()
masks_for_frame[int(oid)] = mask_2d
GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame
_update_centroids_for_frame(GLOBAL_STATE, frame_idx)
# Invalidate cache for that frame to force recomposition
GLOBAL_STATE.composited_frames.pop(frame_idx, None)
processed += 1
# Every 15th frame (or last), move slider to current frame to update preview via slider binding
if processed % 30 == 0 or processed == total:
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
status_updates = _ui_status_updates(GLOBAL_STATE)
yield (
GLOBAL_STATE,
f"Propagating masks: {processed}/{total}",
gr.update(value=frame_idx),
_build_kick_plot(GLOBAL_STATE),
_build_yolo_plot(GLOBAL_STATE),
_impact_status_update(GLOBAL_STATE),
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
# Focus UI on kick frame if available; otherwise stick to last processed frame
target_frame = GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None)
if target_frame is None:
target_frame = last_frame_idx
target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1)))
GLOBAL_STATE.current_frame_idx = target_frame
# Final status; ensure slider points to the target frame (kick frame when detected)
GLOBAL_STATE.is_sam_tracked = True
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
status_updates = _ui_status_updates(GLOBAL_STATE)
yield (
GLOBAL_STATE,
text,
gr.update(value=target_frame),
_build_kick_plot(GLOBAL_STATE),
_build_yolo_plot(GLOBAL_STATE),
_impact_status_update(GLOBAL_STATE),
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str, any, go.Figure, Any, Any, Any]:
# Reset only session-related state, keep uploaded video and model
if not GLOBAL_STATE.video_frames:
# Nothing loaded; keep behavior
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
status_updates = _ui_status_updates(GLOBAL_STATE)
GLOBAL_STATE.is_ball_detected = False
GLOBAL_STATE.is_yolo_tracked = False
GLOBAL_STATE.is_sam_tracked = False
GLOBAL_STATE.is_player_detected = False
GLOBAL_STATE.is_player_propagated = False
return (
GLOBAL_STATE,
None,
0,
0,
"Session reset. Load a new video.",
gr.update(visible=False, value=""),
_build_kick_plot(GLOBAL_STATE),
_impact_status_update(GLOBAL_STATE),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
# Clear prompts and caches
GLOBAL_STATE.masks_by_frame.clear()
GLOBAL_STATE.clicks_by_frame_obj.clear()
GLOBAL_STATE.boxes_by_frame_obj.clear()
GLOBAL_STATE.composited_frames.clear()
GLOBAL_STATE.pending_box_start = None
GLOBAL_STATE.pending_box_start_frame_idx = None
GLOBAL_STATE.pending_box_start_obj_id = None
GLOBAL_STATE.ball_centers.clear()
GLOBAL_STATE.mask_areas.clear()
GLOBAL_STATE.smoothed_centers.clear()
GLOBAL_STATE.ball_speeds.clear()
GLOBAL_STATE.distance_from_start.clear()
GLOBAL_STATE.direction_change.clear()
GLOBAL_STATE.kick_frame = None
GLOBAL_STATE.ball_centers.clear()
GLOBAL_STATE.kalman_centers.clear()
GLOBAL_STATE.kalman_speeds.clear()
GLOBAL_STATE.kalman_residuals.clear()
GLOBAL_STATE.kick_debug_frames = []
GLOBAL_STATE.kick_debug_speeds = []
GLOBAL_STATE.kick_debug_threshold = None
GLOBAL_STATE.kick_debug_baseline = None
GLOBAL_STATE.kick_debug_speed_std = None
GLOBAL_STATE.kick_debug_area = []
GLOBAL_STATE.kick_debug_kick_frame = None
GLOBAL_STATE.kick_debug_distance = []
GLOBAL_STATE.kick_debug_kalman_speeds = []
GLOBAL_STATE.is_ball_detected = False
GLOBAL_STATE.is_yolo_tracked = False
GLOBAL_STATE.is_sam_tracked = False
GLOBAL_STATE.is_player_detected = False
GLOBAL_STATE.is_player_propagated = False
GLOBAL_STATE.impact_frame = None
GLOBAL_STATE.impact_debug_frames = []
GLOBAL_STATE.impact_debug_innovation = []
GLOBAL_STATE.impact_debug_innovation_threshold = None
GLOBAL_STATE.impact_debug_direction = []
GLOBAL_STATE.impact_debug_direction_threshold = None
GLOBAL_STATE.impact_debug_speed_kmh = []
GLOBAL_STATE.impact_debug_speed_threshold_px = None
GLOBAL_STATE.impact_meters_per_px = None
# Dispose and re-init inference session for current model with existing frames
try:
if GLOBAL_STATE.inference_session is not None:
GLOBAL_STATE.inference_session.reset_inference_session()
except Exception:
pass
GLOBAL_STATE.inference_session = None
gc.collect()
ensure_session_for_current_model(GLOBAL_STATE)
# Keep current slider index if possible
current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0))
current_idx = max(0, min(current_idx, GLOBAL_STATE.num_frames - 1))
preview_img = update_frame_display(GLOBAL_STATE, current_idx)
slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True)
slider_value = gr.update(value=current_idx)
status = "Session reset. Prompts cleared; video preserved."
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
status_updates = _ui_status_updates(GLOBAL_STATE)
# clear and reload model and processor
return (
GLOBAL_STATE,
preview_img,
slider_minmax,
slider_value,
status,
gr.update(visible=False, value=""),
_build_kick_plot(GLOBAL_STATE),
_build_yolo_plot(GLOBAL_STATE),
_impact_status_update(GLOBAL_STATE),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
def create_annotation_preview(video_file, annotations):
"""
Create a preview image showing annotation points on video frames.
Args:
video_file: Path to video file
annotations: List of annotation dicts
Returns:
PIL Image with annotations visualized
"""
import tempfile
from pathlib import Path
# Get video frames for the annotated frame indices
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
return None
# Group annotations by frame
frames_to_show = {}
for ann in annotations:
frame_idx = ann.get("frame", 0)
if frame_idx not in frames_to_show:
frames_to_show[frame_idx] = []
frames_to_show[frame_idx].append(ann)
# Read and annotate frames
annotated_frames = []
for frame_idx in sorted(frames_to_show.keys())[:3]: # Show max 3 frames
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if not ret:
continue
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(frame_rgb)
draw = ImageDraw.Draw(pil_img)
# Draw annotations
for ann in frames_to_show[frame_idx]:
x, y = ann.get("x", 0), ann.get("y", 0)
obj_id = ann.get("object_id", 1)
label = ann.get("label", "positive")
# Color based on object ID
color = pastel_color_for_object(obj_id)
# Draw crosshair
size = 20
draw.line([(x-size, y), (x+size, y)], fill=color, width=3)
draw.line([(x, y-size), (x, y+size)], fill=color, width=3)
draw.ellipse([(x-10, y-10), (x+10, y+10)], outline=color, width=3)
# Draw label
text = f"Obj{obj_id} F{frame_idx}"
draw.text((x+15, y-15), text, fill=color)
# Add frame number label
draw.text((10, 10), f"Frame {frame_idx}", fill=(255, 255, 255))
annotated_frames.append(pil_img)
cap.release()
# Combine frames horizontally
if not annotated_frames:
return None
total_width = sum(img.width for img in annotated_frames)
max_height = max(img.height for img in annotated_frames)
combined = Image.new('RGB', (total_width, max_height))
x_offset = 0
for img in annotated_frames:
combined.paste(img, (x_offset, 0))
x_offset += img.width
return combined
@spaces.GPU(duration=120) # Allocate GPU for up to 2 minutes
def process_video_api(
video_file,
annotations_json_str: str,
checkpoint: str = "base_plus",
remove_background: bool = True,
):
"""
Single-endpoint API for programmatic video processing.
Args:
video_file: Uploaded video file
annotations_json_str: Optional JSON string containing helper annotations
checkpoint: SAM2 model checkpoint (tiny, small, base_plus, large)
remove_background: Whether to remove the background in the render
Returns:
Tuple of (preview_image, processed_video_path, progress_log)
"""
import json
try:
log_entries: list[str] = []
def log_msg(message: str):
text = f"[API] {message}"
print(text)
log_entries.append(text)
# Parse annotations (optional)
annotations_payload = annotations_json_str or ""
annotations_data = json.loads(annotations_payload) if annotations_payload.strip() else {}
annotations = annotations_data.get("annotations", [])
client_fps = annotations_data.get("fps", None)
log_msg(f"Received {len(annotations)} annotations")
log_msg(f"Checkpoint: {checkpoint} | Remove background: {remove_background}")
preview_img = create_annotation_preview(video_file, annotations) if annotations else None
# Create a temporary state for this API call
api_state = AppState()
api_state.model_repo_key = checkpoint
# Step 1: Initialize session with video
log_msg("Loading video...")
api_state, min_idx, max_idx, first_frame, status = init_video_session(api_state, video_file)
space_fps = api_state.video_fps
log_msg(status)
log_msg(f"Client FPS={client_fps} | Space FPS={space_fps}")
# If FPS mismatch, warn about potential frame offset
if client_fps and space_fps and abs(client_fps - space_fps) > 0.5:
offset_estimate = abs(int((client_fps - space_fps) * (api_state.num_frames / client_fps)))
log_msg(f"⚠️ FPS mismatch detected. Frame indices may be off by ~{offset_estimate} frames.")
log_msg("ℹ️ Recommendation: Use timestamps instead of frame indices for accuracy.")
# Step 2: Apply each annotation
if annotations:
for i, ann in enumerate(annotations):
object_id = ann.get("object_id", 1)
timestamp_ms = ann.get("timestamp_ms", None)
frame_idx = ann.get("frame", None)
x = ann.get("x", 0)
y = ann.get("y", 0)
label = ann.get("label", "positive")
# Calculate frame from timestamp using Space's FPS (more accurate)
if timestamp_ms is not None and space_fps and space_fps > 0:
calculated_frame = int((timestamp_ms / 1000.0) * space_fps)
if frame_idx is not None and calculated_frame != frame_idx:
log_msg(f"Annotation {i+1}: using timestamp {timestamp_ms}ms → Frame {calculated_frame} (client sent {frame_idx})")
else:
log_msg(f"Annotation {i+1}: timestamp {timestamp_ms}ms → Frame {calculated_frame}")
frame_idx = calculated_frame
elif frame_idx is None:
log_msg(f"Annotation {i+1}: ⚠️ No timestamp/frame provided, defaulting to frame 0")
frame_idx = 0
log_msg(f"Adding annotation {i+1}/{len(annotations)} | Obj {object_id} | Frame {frame_idx}")
# Sync state
api_state.current_frame_idx = int(frame_idx)
api_state.current_obj_id = int(object_id)
api_state.current_label = str(label)
# Create a mock event with coordinates
class MockEvent:
def __init__(self, x, y):
self.index = (x, y)
mock_evt = MockEvent(x, y)
# Add the point annotation
preview_img = on_image_click(
first_frame,
api_state,
frame_idx,
object_id,
label,
clear_old=False,
evt=mock_evt
)
if preview_img is None:
preview_img = first_frame
# Helper to consume generator-based steps and capture log messages
def _run_generator(gen, label: str):
final_state = None
for outputs in gen:
if not outputs:
continue
final_state = outputs[0]
status_msg = outputs[1] if len(outputs) > 1 else ""
if status_msg:
log_msg(f"{label}: {status_msg}")
if final_state is not None:
return final_state
raise gr.Error(f"{label} did not produce any output.")
# Step 3: YOLO13 detect ball
api_state.current_obj_id = BALL_OBJECT_ID
api_state.current_label = "positive"
log_msg("YOLO13 · Detect ball (single-frame search)")
_auto_detect_ball(api_state, BALL_OBJECT_ID, "positive", False)
if not api_state.is_ball_detected:
raise gr.Error("YOLO13 could not detect the ball automatically.")
# Step 4: YOLO13 track ball
log_msg("YOLO13 · Track ball across clip")
_track_ball_yolo(api_state)
if not api_state.is_yolo_tracked:
raise gr.Error("YOLO13 tracking failed.")
# Step 5: SAM2 track ball around kick window
log_msg("SAM2 · Track ball around kick window")
api_state = _run_generator(propagate_masks(api_state), "SAM2 · Ball")
sam_kick = _get_prioritized_kick_frame(api_state)
yolo_kick = api_state.yolo_kick_frame
if sam_kick is not None:
log_msg(f"SAM2 kick frame ≈ {sam_kick}")
if yolo_kick is not None:
log_msg(f"YOLO kick frame ≈ {yolo_kick}")
# Fallback: re-run SAM2 on entire video if kicks disagree
if (
yolo_kick is not None
and sam_kick is not None
and int(yolo_kick) != int(sam_kick)
):
log_msg("Kick disagreement detected → re-running SAM2 across entire video.")
api_state.sam_window = (0, api_state.num_frames)
api_state = _run_generator(propagate_masks(api_state), "SAM2 · Full sweep")
sam_kick = _get_prioritized_kick_frame(api_state)
log_msg(f"SAM2 full sweep kick frame ≈ {sam_kick}")
else:
log_msg("Kick frames aligned. No full sweep required.")
# Step 6: YOLO detect player on SAM2 kick frame
log_msg("YOLO13 · Detect player on SAM2 kick frame")
_auto_detect_player(api_state)
if api_state.is_player_detected:
log_msg("YOLO13 · Player detected successfully.")
else:
log_msg("YOLO13 · Player detection failed; continuing without player propagation.")
# Step 7: SAM2 track player if detection succeeded
if api_state.is_player_detected:
log_msg("SAM2 · Track player around kick window")
try:
api_state = _run_generator(propagate_player_masks(api_state), "SAM2 · Player")
except gr.Error as player_error:
log_msg(f"SAM2 player propagation warning: {player_error}")
# Step 8: Render the final video
log_msg(f"Rendering video (remove_background={remove_background})")
result_video_path = _render_video(api_state, remove_background, log_fn=log_msg)
log_msg("Processing complete 🎉")
return preview_img, result_video_path, "\n".join(log_entries)
except Exception as e:
print(f"[API] ❌ Error: {str(e)}")
import traceback
traceback.print_exc()
raise gr.Error(f"Processing failed: {str(e)}")
theme = Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")
CUSTOM_CSS = """
.gr-button-stop {
background-color: #f97316 !important;
border-color: #ea580c !important;
color: #fff !important;
}
.gr-button-stop:hover {
background-color: #ea580c !important;
border-color: #c2410c !important;
}
.gr-button-stop:disabled {
opacity: 0.7 !important;
color: #fff !important;
}
.model-row {
display: flex;
align-items: center;
gap: 0.4rem;
flex-wrap: nowrap !important;
}
.model-label {
min-width: 68px;
font-weight: 600;
}
.model-label p {
margin: 0 !important;
}
.model-section {
background: rgba(255, 255, 255, 0.02);
border-radius: 0.4rem;
padding: 0.45rem 0.65rem;
margin-bottom: 0.45rem;
display: flex;
flex-direction: column;
gap: 0.3rem;
}
.model-actions {
flex: 1 1 auto;
display: flex;
flex-wrap: nowrap;
gap: 0.35rem;
}
.model-actions .gr-button {
flex: 0 0 auto;
min-width: unset;
width: fit-content;
padding: 0.32rem 0.7rem;
}
.model-status {
flex: 0 0 auto;
display: flex;
gap: 0.25rem;
margin-left: auto;
}
.model-status .gr-button {
min-width: unset;
width: fit-content;
padding: 0.25rem 0.55rem;
}
"""
BUTTON_TOOLTIPS = {
"btn-reset-session": (
"Clears the entire workspace: YOLO detections, SAM2 masks, manual kick/impact overrides, and FX settings all "
"return to defaults so you can load a new clip without leftover state."
),
"btn-mark-kick": (
"Stores the current frame as the definitive kick moment. We override YOLO or SAM guesses immediately so SAM2 "
"propagation, ring rendering, impact-speed math, and player workflows all pivot around this human-confirmed "
"timestamp until it is cleared."
),
"btn-mark-impact": (
"Declares the current frame as the impact (goal crossing or contact). Automatic impact detection is still in "
"progress, so this manual anchor feeds the diagnostics plot and tells the renderer when to fade rings or ghost trails."
),
"btn-detect-ball": (
"Runs YOLO13 over the entire video to find the stationary ball before it moves. We keep only the single best "
"candidate per frame so the rest of the pipeline has one anchor. This yields the first kick guess, an initial radius, "
"and enables the tracking and player steps."
),
"btn-track-ball-yolo": (
"Sweeps YOLO13 tracking across every frame while forcing exactly one plausible ball trajectory. We smooth detections, "
"locate the velocity spike that marks the kick, and cache the radius there. This fast scout tells SAM2 where to focus "
"later and populates the future ring / ghost trail data."
),
"btn-detect-player": (
"Samples the prioritized kick frame (manual > SAM > YOLO) and runs player detection there. Aligning the player mask "
"with the kick ensures SAM2 can later propagate the athlete through the same window, unlocking the Track Player step."
),
"btn-track-ball-sam": (
"Runs the SAM2 transformer on a tight window centered on the prioritized kick frame. We seed it with YOLO’s latest "
"ball mask so SAM2 delivers high-fidelity segmentation only where it matters, refreshing ring radii without scanning "
"the full clip."
),
"btn-track-player-sam": (
"After a player mask exists, SAM2 propagates it within the same kick-centric window. This keeps athlete and ball masks "
"time-synced, enabling combined overlays, exports, and analytics comparing foot position to ball contact."
),
"btn-goal-start": (
"Enters goal mapping mode so you can click the two crossbar corners. After the first click a handle appears; the second "
"click closes the bar and exposes draggable anchors before you confirm."
),
"btn-goal-confirm": (
"Locks the currently placed crossbar across the entire video. The line and handles stay visible on every frame and can "
"be re-edited later by tapping Map Goal again."
),
"btn-goal-clear": (
"Removes the current crossbar (and any in-progress points) so you can restart the goal alignment workflow from scratch."
),
"btn-goal-back": (
"Restores the previously confirmed crossbar if the latest edits missed the mark. Useful when you want to compare two "
"placements without re-clicking both corners."
),
}
def _build_tooltip_script() -> str:
data = json.dumps(BUTTON_TOOLTIPS)
return f"""
<script>
const KT_TOOLTIPS = {data};
function applyKTTitles() {{
Object.entries(KT_TOOLTIPS).forEach(([id, text]) => {{
const el = document.getElementById(id);
if (el && !el.dataset.ktTooltip) {{
el.dataset.ktTooltip = "1";
el.setAttribute("title", text);
}}
}});
}}
const observer = new MutationObserver(() => applyKTTitles());
observer.observe(document.body, {{ childList: true, subtree: true }});
document.addEventListener("DOMContentLoaded", applyKTTitles);
applyKTTitles();
</script>
"""
TOOLTIP_SCRIPT = _build_tooltip_script()
with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", theme=theme, css=CUSTOM_CSS) as demo:
GLOBAL_STATE = gr.State(AppState())
gr.Markdown(
"""
### KickTrimmer Lab · Ball-Speed Video Twin
This Space acts as a desktop twin of the KickTrimmer mobile app: load a football clip, detect the kick, and estimate the ball speed frame-by-frame. It previews future ball rings color-coded by hypothetical impact velocity as the ball travels toward the goal, letting you experiment with FX settings before shipping them to the phone build.
⚠️ **Work in progress:** we are still closing the gap with the mobile feature set (automatic horizon & goal finding, diagonal speed correction, etc.), so the numbers you see here are prototypes—not final certified speeds.
"""
)
with gr.Row():
with gr.Column():
gr.Markdown(
"""
**Quick start**
- **Load a video**: Upload your own or pick an example below.
- **Checkpoint**: Tiny / Small / Base+ / Large (trade speed vs. accuracy).
- **Points mode**: Select an Object ID and point label (positive/negative), then click the frame to add guidance. You can add **multiple points per object** and define **multiple objects** across frames.
- **Boxes mode**: Click two opposite corners to draw a box. Old inputs for that object are cleared automatically.
"""
)
with gr.Column():
gr.Markdown(
"""
**Working with results**
- **Preview**: Use the slider to navigate frames and see the current masks.
- **Track**: Click “Track ball (SAM2)” to track all defined objects across the selected window. The preview follows progress periodically to keep things responsive.
- **Export**: Render an MP4 for smooth playback using the original video FPS.
- **Note**: More info on the Hugging Face 🤗 Transformers implementation of SAM2 can be found [here](https://huggingface.co/docs/transformers/en/main/en/model_doc/sam2_video).
"""
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
video_in = gr.Video(
label="Upload video",
sources=["upload", "webcam"],
interactive=True,
elem_id="video-pane",
)
ckpt_radio = gr.Radio(
choices=["tiny", "small", "base_plus", "large"],
value="tiny",
label="SAM2.1 checkpoint",
)
ckpt_progress = gr.Markdown(visible=False)
load_status = gr.Markdown(visible=True)
reset_btn = gr.Button("Reset Session", variant="secondary", elem_id="btn-reset-session")
with gr.Column(scale=1):
gr.Markdown("**Preview**")
preview = gr.Image(
interactive=True,
elem_id="preview-pane",
container=False,
show_label=False,
)
frame_slider = gr.Slider(
label="Frame",
minimum=0,
maximum=0,
step=1,
value=0,
interactive=True,
elem_id="frame-slider",
)
with gr.Column():
with gr.Column(elem_classes=["model-section"]):
with gr.Row(elem_classes=["model-row"]):
gr.Markdown("Manual", elem_classes=["model-label"])
with gr.Row(elem_classes=["model-actions"]):
mark_kick_btn = gr.Button("⚽ Mark Kick", variant="primary", elem_id="btn-mark-kick")
mark_impact_btn = gr.Button("🚩 Mark Impact", variant="primary", elem_id="btn-mark-impact")
with gr.Row(elem_classes=["model-status"]):
manual_kick_btn = gr.Button("⚽: N/A", interactive=False)
manual_impact_btn = gr.Button("🚩: N/A", interactive=False)
with gr.Row(elem_classes=["model-actions"]):
goal_start_btn = gr.Button(
"Map Goal",
variant="secondary",
elem_id="btn-goal-start",
)
goal_confirm_btn = gr.Button(
"Confirm",
variant="primary",
interactive=False,
elem_id="btn-goal-confirm",
)
goal_clear_btn = gr.Button(
"Clear",
variant="secondary",
interactive=False,
elem_id="btn-goal-clear",
)
goal_back_btn = gr.Button(
"Back",
variant="secondary",
interactive=False,
elem_id="btn-goal-back",
)
goal_status = gr.Markdown("Goal crossbar inactive.", elem_id="goal-status-text")
with gr.Column(elem_classes=["model-section"]):
with gr.Row(elem_classes=["model-row"]):
gr.Markdown("YOLO13", elem_classes=["model-label"])
with gr.Row(elem_classes=["model-actions"]):
detect_ball_btn = gr.Button("Detect Ball", variant="stop", elem_id="btn-detect-ball")
track_ball_yolo_btn = gr.Button("Track Ball", variant="stop", elem_id="btn-track-ball-yolo")
detect_player_btn = gr.Button(
"Detect Player",
variant="stop",
interactive=False,
elem_id="btn-detect-player",
)
with gr.Row(elem_classes=["model-status"]):
yolo_kick_btn = gr.Button("⚽: N/A", interactive=False)
yolo_impact_btn = gr.Button("🚩: N/A", interactive=False)
# Multi-ball candidate selection UI
with gr.Column(visible=False) as multi_ball_selection_col:
multi_ball_status_md = gr.Markdown("", visible=True)
ball_candidate_radio = gr.Radio(
choices=[],
value=None,
label="Select Ball Candidate",
interactive=True,
)
with gr.Row():
confirm_ball_btn = gr.Button("Confirm Selection", variant="primary")
multi_ball_chart = gr.Plot(label="Ball Candidates Speed Comparison", show_label=True)
yolo_plot = gr.Plot(label="YOLO kick diagnostics", show_label=True)
with gr.Column(elem_classes=["model-section"]):
with gr.Row(elem_classes=["model-row"]):
gr.Markdown("SAM2", elem_classes=["model-label"])
with gr.Row(elem_classes=["model-actions"]):
propagate_btn = gr.Button(
"Track Ball", variant="stop", interactive=False, elem_id="btn-track-ball-sam"
)
propagate_player_btn = gr.Button(
"Track Player",
variant="stop",
interactive=False,
elem_id="btn-track-player-sam",
)
with gr.Row(elem_classes=["model-status"]):
sam_kick_btn = gr.Button("⚽: N/A", interactive=False)
sam_impact_btn = gr.Button("🚩: N/A", interactive=False)
kick_plot = gr.Plot(label="Kick & impact diagnostics", show_label=True)
gr.HTML(value=TOOLTIP_SCRIPT, visible=False)
with gr.Row():
min_impact_speed_slider = gr.Slider(
label="Min impact speed (km/h)",
minimum=0,
maximum=120,
step=1,
value=20,
interactive=True,
)
goal_distance_slider = gr.Slider(
label="Distance to goal (m)",
minimum=1,
maximum=60,
step=0.5,
value=18,
interactive=True,
)
ball_status = gr.Markdown(visible=False)
propagate_status = gr.Markdown(visible=True)
impact_status = gr.Markdown("Impact frame: not computed", visible=False)
with gr.Row():
obj_id_inp = gr.Number(value=1, precision=0, label="Object ID", scale=0)
label_radio = gr.Radio(choices=["positive", "negative"], value="positive", label="Point label")
clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object")
prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type")
# Wire events
def _on_video_change(GLOBAL_STATE: gr.State, video):
GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video)
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
status_updates = _ui_status_updates(GLOBAL_STATE)
return (
GLOBAL_STATE,
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
first_frame,
status,
gr.update(visible=False, value=""),
_build_kick_plot(GLOBAL_STATE),
_build_yolo_plot(GLOBAL_STATE),
*status_updates,
propagate_main_update,
detect_btn_update,
propagate_player_update,
)
video_in.change(
_on_video_change,
inputs=[GLOBAL_STATE, video_in],
outputs=[
GLOBAL_STATE,
frame_slider,
preview,
load_status,
ball_status,
kick_plot,
yolo_plot,
yolo_kick_btn,
yolo_impact_btn,
sam_kick_btn,
sam_impact_btn,
manual_kick_btn,
manual_impact_btn,
detect_ball_btn,
track_ball_yolo_btn,
goal_start_btn,
goal_confirm_btn,
goal_clear_btn,
goal_back_btn,
goal_status,
propagate_btn,
detect_player_btn,
propagate_player_btn,
],
show_progress=True,
)
example_video_path = ensure_example_video()
examples_list = [
[None, example_video_path],
]
with gr.Row():
gr.Examples(
examples=examples_list,
inputs=[GLOBAL_STATE, video_in],
fn=_on_video_change,
outputs=[
GLOBAL_STATE,
frame_slider,
preview,
load_status,
ball_status,
kick_plot,
yolo_plot,
yolo_kick_btn,
yolo_impact_btn,
sam_kick_btn,
sam_impact_btn,
manual_kick_btn,
manual_impact_btn,
detect_ball_btn,
track_ball_yolo_btn,
goal_start_btn,
goal_confirm_btn,
goal_clear_btn,
goal_back_btn,
goal_status,
propagate_btn,
detect_player_btn,
propagate_player_btn,
],
label="Examples",
cache_examples=False,
examples_per_page=5,
)
with gr.Row():
with gr.Column(scale=1):
remove_bg_checkbox = gr.Checkbox(
label="Remove Background",
value=True,
info="If checked, shows only tracked objects on black background. If unchecked, overlays colored masks on original video.",
)
with gr.Column(scale=1):
ghost_trail_chk = gr.Checkbox(
label="Ghost trail (ball)",
value=False,
info="Overlay post-kick SAM2 ball masks in magenta to visualize trajectory.",
)
with gr.Column(scale=1):
ball_ring_chk = gr.Checkbox(
label="Ball rings (future)",
value=True,
info="Replace the ghost trail fill with magenta rings at future ball positions.",
)
with gr.Column(scale=1):
click_marks_chk = gr.Checkbox(
label="Show annotation '+'",
value=False,
info="If unchecked, hides the '+' markers from clicks in preview and renders.",
)
with gr.Accordion("Cutout FX", open=False):
gr.Markdown("These options apply when rendering with background removal.")
with gr.Row():
with gr.Column(scale=1):
soft_matte_chk = gr.Checkbox(label="Soft matte", value=True)
with gr.Column(scale=2):
soft_matte_feather = gr.Slider(
label="Feather radius (px)",
minimum=0.0,
maximum=12.0,
step=0.5,
value=4.0,
)
with gr.Column(scale=2):
soft_matte_erode = gr.Slider(
label="Edge shrink (px)",
minimum=0.0,
maximum=5.0,
step=0.5,
value=0.5,
)
with gr.Row():
with gr.Column(scale=1):
blur_bg_chk = gr.Checkbox(label="Blur background", value=True)
with gr.Column(scale=2):
blur_radius = gr.Slider(
label="Background blur (px)",
minimum=0.0,
maximum=45.0,
step=1.0,
value=0.0,
)
with gr.Column(scale=2):
bg_darkening = gr.Slider(
label="Darken background",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.75,
info="0 keeps original brightness, 1 turns the background black.",
)
with gr.Row():
with gr.Column(scale=1):
light_wrap_chk = gr.Checkbox(label="Light wrap", value=False)
with gr.Column(scale=2):
light_wrap_strength = gr.Slider(
label="Wrap strength",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.6,
)
with gr.Column(scale=2):
light_wrap_width = gr.Slider(
label="Wrap width (px)",
minimum=0.0,
maximum=25.0,
step=0.5,
value=15.0,
)
with gr.Row():
with gr.Column(scale=1):
glow_chk = gr.Checkbox(label="Neon glow", value=False)
with gr.Column(scale=2):
glow_strength = gr.Slider(
label="Glow strength",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.4,
)
with gr.Column(scale=2):
glow_radius = gr.Slider(
label="Glow radius (px)",
minimum=0.0,
maximum=35.0,
step=0.5,
value=10.0,
)
# New Ring FX Controls
gr.Markdown("### Ring FX Settings")
with gr.Row():
with gr.Column(scale=1):
ring_thickness = gr.Slider(
label="Ring Thickness",
minimum=0.1,
maximum=2.0,
step=0.1,
value=1.0,
)
with gr.Column(scale=1):
ring_alpha = gr.Slider(
label="Ring Intensity (Alpha)",
minimum=0.1,
maximum=3.0,
step=0.1,
value=3.0,
)
with gr.Column(scale=1):
ring_feather = gr.Slider(
label="Ring Softness (Blur)",
minimum=0.0,
maximum=5.0,
step=0.1,
value=0.1,
)
with gr.Column(scale=1):
ring_gamma = gr.Slider(
label="Ring Gamma (Contrast)",
minimum=0.1,
maximum=2.0,
step=0.05,
value=2.0,
info="Lower values = higher contrast/sharper falloff"
)
with gr.Column(scale=1):
ring_duration = gr.Slider(
label="Rings Duration (frames)",
minimum=0,
maximum=120,
step=1,
value=30,
info="Limit how many frames after the kick to show rings (approx 0-4s)"
)
with gr.Column(scale=1):
ring_scale_pct = gr.Slider(
label="Ring Size Scale (%)",
minimum=10,
maximum=200,
step=5,
value=125,
info="Adjust overall ring size relative to detected ball radius."
)
with gr.Row():
render_btn = gr.Button("Render MP4 for smooth playback", variant="primary")
playback_video = gr.Video(label="Rendered Playback", interactive=False)
fx_inputs = [
soft_matte_chk,
soft_matte_feather,
soft_matte_erode,
blur_bg_chk,
blur_radius,
bg_darkening,
light_wrap_chk,
light_wrap_strength,
light_wrap_width,
glow_chk,
glow_strength,
glow_radius,
# New inputs
ring_thickness,
ring_alpha,
ring_feather,
ring_gamma,
ring_scale_pct,
ring_duration,
]
for comp in fx_inputs:
comp.change(
_update_fx_controls,
inputs=[GLOBAL_STATE] + fx_inputs,
outputs=preview,
)
ghost_trail_chk.change(
_toggle_ghost_trail,
inputs=[GLOBAL_STATE, ghost_trail_chk],
outputs=preview,
)
ball_ring_chk.change(
_toggle_ball_ring,
inputs=[GLOBAL_STATE, ball_ring_chk],
outputs=preview,
)
click_marks_chk.change(
_toggle_click_marks,
inputs=[GLOBAL_STATE, click_marks_chk],
outputs=preview,
)
def _on_ckpt_change(s: AppState, key: str):
if s is not None and key:
key = str(key)
if key != s.model_repo_key:
# Update and drop current model to reload lazily next time
s.is_switching_model = True
s.model_repo_key = key
s.model_repo_id = None
s.model = None
s.processor = None
# Stream progress text while loading (first yield shows text)
yield gr.update(visible=True, value=f"Loading checkpoint: {key}...")
ensure_session_for_current_model(s)
if s is not None:
s.is_switching_model = False
# Final yield hides the text
yield gr.update(visible=False, value="")
ckpt_radio.change(_on_ckpt_change, inputs=[GLOBAL_STATE, ckpt_radio], outputs=[ckpt_progress])
def _sync_frame_idx(state_in: AppState, idx: int):
if state_in is not None:
state_in.current_frame_idx = int(idx)
return update_frame_display(state_in, int(idx))
frame_slider.change(
_sync_frame_idx,
inputs=[GLOBAL_STATE, frame_slider],
outputs=preview,
)
yolo_kick_btn.click(
_jump_to_yolo_kick,
inputs=[GLOBAL_STATE],
outputs=[preview, frame_slider],
)
sam_kick_btn.click(
_jump_to_sam_kick,
inputs=[GLOBAL_STATE],
outputs=[preview, frame_slider],
)
sam_impact_btn.click(
_jump_to_sam_impact,
inputs=[GLOBAL_STATE],
outputs=[preview, frame_slider],
)
manual_kick_btn.click(
_jump_to_manual_kick,
inputs=[GLOBAL_STATE],
outputs=[preview, frame_slider],
)
manual_impact_btn.click(
_jump_to_manual_impact,
inputs=[GLOBAL_STATE],
outputs=[preview, frame_slider],
)
mark_kick_btn.click(
_mark_kick_frame,
inputs=[GLOBAL_STATE, frame_slider],
outputs=[
preview,
ball_status,
frame_slider,
kick_plot,
propagate_btn,
detect_player_btn,
propagate_player_btn,
yolo_kick_btn,
yolo_impact_btn,
sam_kick_btn,
sam_impact_btn,
manual_kick_btn,
manual_impact_btn,
detect_ball_btn,
track_ball_yolo_btn,
goal_start_btn,
goal_confirm_btn,
goal_clear_btn,
goal_back_btn,
goal_status,
],
)
mark_impact_btn.click(
_mark_impact_frame,
inputs=[GLOBAL_STATE, frame_slider],
outputs=[
preview,
ball_status,
frame_slider,
kick_plot,
propagate_btn,
detect_player_btn,
propagate_player_btn,
yolo_kick_btn,
yolo_impact_btn,
sam_kick_btn,
sam_impact_btn,
manual_kick_btn,
manual_impact_btn,
detect_ball_btn,
track_ball_yolo_btn,
goal_start_btn,
goal_confirm_btn,
goal_clear_btn,
goal_back_btn,
goal_status,
],
)
def _sync_obj_id(s: AppState, oid):
if s is not None and oid is not None:
s.current_obj_id = int(oid)
return gr.update()
obj_id_inp.change(_sync_obj_id, inputs=[GLOBAL_STATE, obj_id_inp], outputs=[])
def _sync_label(s: AppState, lab: str):
if s is not None and lab is not None:
s.current_label = str(lab)
return gr.update()
label_radio.change(_sync_label, inputs=[GLOBAL_STATE, label_radio], outputs=[])
def _sync_prompt_type(s: AppState, val: str):
if s is not None and val is not None:
s.current_prompt_type = str(val)
s.pending_box_start = None
is_points = str(val).lower() == "points"
# Show labels only for points; hide and disable clear_old when boxes
updates = [
gr.update(visible=is_points),
gr.update(interactive=is_points) if is_points else gr.update(value=True, interactive=False),
]
return updates
prompt_type.change(
_sync_prompt_type,
inputs=[GLOBAL_STATE, prompt_type],
outputs=[label_radio, clear_old_chk],
)
def _update_min_impact_speed(s: AppState, val: float):
if s is not None and val is not None:
s.min_impact_speed_kmh = float(val)
_recompute_motion_metrics(s)
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(s)
return (
_build_kick_plot(s),
_impact_status_update(s),
gr.update(value=_format_kick_status(s), visible=True),
propagate_main_update,
detect_btn_update,
propagate_player_update,
)
def _update_goal_distance(s: AppState, val: float):
if s is not None and val is not None:
s.goal_distance_m = float(val)
_recompute_motion_metrics(s)
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(s)
return (
_build_kick_plot(s),
_impact_status_update(s),
gr.update(value=_format_kick_status(s), visible=True),
propagate_main_update,
detect_btn_update,
propagate_player_update,
)
min_impact_speed_slider.change(
_update_min_impact_speed,
inputs=[GLOBAL_STATE, min_impact_speed_slider],
outputs=[kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
)
goal_distance_slider.change(
_update_goal_distance,
inputs=[GLOBAL_STATE, goal_distance_slider],
outputs=[kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
)
def _auto_detect_ball(
state_in: AppState,
obj_id,
label_value: str,
clear_old_value: bool,
):
if state_in is None or state_in.num_frames == 0:
raise gr.Error("Load a video first, then try auto-detect.")
state_in.is_ball_detected = False
frame_idx = 0
frame = state_in.video_frames[frame_idx]
print(f"[_auto_detect_ball] Frame size: {frame.size}")
# First, try multi-ball detection
candidates = detect_all_balls(frame)
print(f"[_auto_detect_ball] detect_all_balls returned {len(candidates)} candidates")
# Default multi-ball UI updates (hidden)
multi_ball_col_update = gr.update(visible=False)
multi_ball_status_update = gr.update(value="")
multi_ball_radio_update = gr.update(choices=[], value=None, visible=False)
multi_ball_chart_update = gr.update(value=None)
if not candidates:
# Fallback to single-ball detection
print("[_auto_detect_ball] No candidates from detect_all_balls, trying detect_ball_center...")
detection = detect_ball_center(frame)
print(f"[_auto_detect_ball] detect_ball_center returned: {detection}")
if detection is None:
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
status_updates = _ui_status_updates(state_in)
return (
update_frame_display(state_in, frame_idx),
gr.update(
value="❌ Unable to auto-detect the ball. Please add a point manually.",
visible=True,
),
gr.update(value=frame_idx),
_build_kick_plot(state_in),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
multi_ball_col_update,
multi_ball_status_update,
multi_ball_radio_update,
multi_ball_chart_update,
)
x_center, y_center, _, _, conf = detection
state_in.ball_candidates = []
else:
# Store all candidates
state_in.ball_candidates = candidates
state_in.selected_ball_idx = 0
# Use the best candidate (first one after sorting by confidence)
best = candidates[0]
x_center, y_center = best["center"]
conf = best["conf"]
if len(candidates) > 1:
state_in.multi_ball_status = f"⚠️ {len(candidates)} balls detected in frame. Click 'Track Ball' to analyze which one is kicked."
# Show multi-ball UI with candidate list
multi_ball_col_update = gr.update(visible=True)
multi_ball_status_update = gr.update(
value=f"**{len(candidates)} balls detected!** YOLO found multiple balls in the first frame.\n\n"
f"The best candidate (highest confidence) is auto-selected.\n"
f"Click **Track Ball** to analyze all candidates and find the one being kicked."
)
# Don't show radio yet - will show after tracking
multi_ball_radio_update = gr.update(choices=[], value=None, visible=False)
else:
state_in.multi_ball_status = ""
frame_width, frame_height = frame.size
x_center = max(0, min(frame_width - 1, int(x_center)))
y_center = max(0, min(frame_height - 1, int(y_center)))
obj_id_int = int(obj_id) if obj_id is not None else state_in.current_obj_id
label_str = label_value if label_value else state_in.current_label
clear_old_flag = bool(clear_old_value)
# Build a synthetic click event to reuse existing handler
synthetic_evt = SimpleNamespace(
index=(x_center, y_center),
value={"x": x_center, "y": y_center},
)
state_in.current_frame_idx = frame_idx
preview_img = on_image_click(
update_frame_display(state_in, frame_idx),
state_in,
frame_idx,
obj_id_int,
label_str,
clear_old_flag,
synthetic_evt,
)
state_in.is_ball_detected = True
num_candidates = len(getattr(state_in, 'ball_candidates', []))
# Draw YOLO bounding boxes on preview if we have candidates
if num_candidates > 0 and isinstance(preview_img, Image.Image):
preview_img = draw_yolo_detections_on_frame(
preview_img,
state_in.ball_candidates,
selected_idx=0,
)
if num_candidates > 1:
status_text = f"⚠️ {num_candidates} balls found! Best at ({x_center}, {y_center}) (conf={conf:.2f}). Click 'Track Ball' to analyze all."
else:
status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})"
status_text += f" | {_format_kick_status(state_in)}"
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
status_updates = _ui_status_updates(state_in)
return (
preview_img,
gr.update(value=status_text, visible=True),
gr.update(value=frame_idx),
_build_kick_plot(state_in),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
multi_ball_col_update,
multi_ball_status_update,
multi_ball_radio_update,
multi_ball_chart_update,
)
detect_ball_btn.click(
_auto_detect_ball,
inputs=[GLOBAL_STATE, obj_id_inp, label_radio, clear_old_chk],
outputs=[
preview,
ball_status,
frame_slider,
kick_plot,
propagate_btn,
detect_player_btn,
propagate_player_btn,
yolo_kick_btn,
yolo_impact_btn,
sam_kick_btn,
sam_impact_btn,
manual_kick_btn,
manual_impact_btn,
detect_ball_btn,
track_ball_yolo_btn,
goal_start_btn,
goal_confirm_btn,
goal_clear_btn,
goal_back_btn,
goal_status,
# Multi-ball UI outputs
multi_ball_selection_col,
multi_ball_status_md,
ball_candidate_radio,
multi_ball_chart,
],
)
def _track_ball_yolo(state_in: AppState):
if state_in is None or state_in.num_frames == 0:
raise gr.Error("Load a video first, then track the ball with YOLO.")
progress = gr.Progress(track_tqdm=False)
state_in.is_yolo_tracked = False
# Check if we have multiple ball candidates
num_candidates = len(getattr(state_in, 'ball_candidates', []))
# Default multi-ball UI updates
multi_ball_col_update = gr.update(visible=False)
multi_ball_status_update = gr.update(value="")
multi_ball_radio_update = gr.update(choices=[], value=None, visible=False)
multi_ball_chart_update = gr.update(value=None)
if num_candidates > 1:
# Multi-ball mode: track all candidates and show comparison
_detect_and_track_all_ball_candidates(state_in, progress=progress)
# Apply the best candidate to YOLO state
_apply_selected_ball_to_yolo_state(state_in)
base_msg = state_in.multi_ball_status or state_in.yolo_status or ""
# Build the multi-ball UI
candidates = state_in.ball_candidates
if len(candidates) > 1:
radio_choices = _format_ball_candidates_for_radio(candidates)
selected_value = radio_choices[state_in.selected_ball_idx] if radio_choices else None
multi_ball_col_update = gr.update(visible=True)
multi_ball_status_update = gr.update(
value=_format_ball_candidates_markdown(candidates, state_in.selected_ball_idx)
)
multi_ball_radio_update = gr.update(
choices=radio_choices,
value=selected_value,
visible=True,
)
multi_ball_chart_update = gr.update(value=_build_multi_ball_chart(state_in))
else:
# Single ball mode: use original tracking
_perform_yolo_ball_tracking(state_in, progress=progress)
base_msg = state_in.yolo_status or ""
target_frame = (
state_in.yolo_kick_frame
if state_in.yolo_kick_frame is not None
else state_in.yolo_initial_frame
if state_in.yolo_initial_frame is not None
else 0
)
if state_in.num_frames:
target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1))
state_in.current_frame_idx = target_frame
preview_img = update_frame_display(state_in, target_frame)
# Draw YOLO bounding boxes on preview if we have candidates (after tracking, with kick info)
candidates = getattr(state_in, 'ball_candidates', [])
if len(candidates) > 0 and isinstance(preview_img, Image.Image):
preview_img = draw_yolo_detections_on_frame(
preview_img,
candidates,
selected_idx=state_in.selected_ball_idx,
)
kick_msg = _format_kick_status(state_in)
status_text = f"{base_msg} | {kick_msg}" if base_msg else kick_msg
state_in.is_yolo_tracked = True
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
status_updates = _ui_status_updates(state_in)
return (
preview_img,
gr.update(value=status_text, visible=True),
gr.update(value=target_frame),
_build_kick_plot(state_in),
_build_yolo_plot(state_in),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
multi_ball_col_update,
multi_ball_status_update,
multi_ball_radio_update,
multi_ball_chart_update,
)
track_ball_yolo_btn.click(
_track_ball_yolo,
inputs=[GLOBAL_STATE],
outputs=[
preview,
ball_status,
frame_slider,
kick_plot,
yolo_plot,
propagate_btn,
detect_player_btn,
propagate_player_btn,
yolo_kick_btn,
yolo_impact_btn,
sam_kick_btn,
sam_impact_btn,
manual_kick_btn,
manual_impact_btn,
detect_ball_btn,
track_ball_yolo_btn,
goal_start_btn,
goal_confirm_btn,
goal_clear_btn,
goal_back_btn,
goal_status,
# Multi-ball UI outputs
multi_ball_selection_col,
multi_ball_status_md,
ball_candidate_radio,
multi_ball_chart,
],
)
# Multi-ball selection handlers
def _on_ball_candidate_change(state_in: AppState, selected_label: str):
"""Handle radio button selection change."""
if state_in is None or not state_in.ball_candidates:
return gr.update(), gr.update(), gr.update()
# Find the selected index from the label
radio_choices = _format_ball_candidates_for_radio(state_in.ball_candidates)
try:
new_idx = radio_choices.index(selected_label)
except ValueError:
new_idx = 0
state_in.selected_ball_idx = new_idx
# Update the preview to show the new selection highlighted
frame_idx = state_in.current_frame_idx
preview_img = update_frame_display(state_in, frame_idx)
if isinstance(preview_img, Image.Image):
preview_img = draw_yolo_detections_on_frame(
preview_img,
state_in.ball_candidates,
selected_idx=new_idx,
)
# Update the chart to highlight the new selection
chart_update = gr.update(value=_build_multi_ball_chart(state_in))
status_update = gr.update(
value=_format_ball_candidates_markdown(state_in.ball_candidates, new_idx)
)
return preview_img, chart_update, status_update
ball_candidate_radio.change(
_on_ball_candidate_change,
inputs=[GLOBAL_STATE, ball_candidate_radio],
outputs=[preview, multi_ball_chart, multi_ball_status_md],
)
def _on_confirm_ball_selection(state_in: AppState):
"""Confirm the selected ball and apply it to the main tracking state."""
if state_in is None or not state_in.ball_candidates:
raise gr.Error("No ball candidates to confirm.")
# Apply the selected candidate to YOLO state
_apply_selected_ball_to_yolo_state(state_in)
# Get the selected candidate info
idx = state_in.selected_ball_idx
candidate = state_in.ball_candidates[idx]
# Jump to kick frame if available
target_frame = candidate.get('kick_frame') or 0
if state_in.num_frames:
target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1))
state_in.current_frame_idx = target_frame
# Clear the candidates list to indicate selection is done
state_in.ball_selection_confirmed = True
preview_img = update_frame_display(state_in, target_frame)
kick_info = f"Kick @ frame {candidate.get('kick_frame')}" if candidate.get('has_kick') else "No kick detected"
status_text = f"✅ Ball {idx + 1} confirmed. {kick_info}"
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
status_updates = _ui_status_updates(state_in)
return (
preview_img,
gr.update(value=status_text, visible=True),
gr.update(value=target_frame),
_build_kick_plot(state_in),
_build_yolo_plot(state_in),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
# Hide the multi-ball selection UI after confirmation
gr.update(visible=False),
)
confirm_ball_btn.click(
_on_confirm_ball_selection,
inputs=[GLOBAL_STATE],
outputs=[
preview,
ball_status,
frame_slider,
kick_plot,
yolo_plot,
propagate_btn,
detect_player_btn,
propagate_player_btn,
yolo_kick_btn,
yolo_impact_btn,
sam_kick_btn,
sam_impact_btn,
manual_kick_btn,
manual_impact_btn,
detect_ball_btn,
track_ball_yolo_btn,
goal_start_btn,
goal_confirm_btn,
goal_clear_btn,
goal_back_btn,
goal_status,
multi_ball_selection_col,
],
)
def _auto_detect_player(state_in: AppState):
if state_in is None or state_in.num_frames == 0:
raise gr.Error("Load a video first, then try auto-detect.")
if state_in.inference_session is None or state_in.processor is None or state_in.model is None:
raise gr.Error("Model session is not ready. Load a video and propagate masks first.")
state_in.is_player_detected = False
priority_frames: list[int] = []
sam_frame = state_in.kick_frame or getattr(state_in, "kick_debug_kick_frame", None)
if sam_frame is not None:
priority_frames.append(int(sam_frame))
yolo_frame = getattr(state_in, "yolo_kick_frame", None)
if yolo_frame is not None:
yolo_int = int(yolo_frame)
if yolo_int not in priority_frames:
priority_frames.append(yolo_int)
if not priority_frames:
raise gr.Error("Detect the kick frame first by propagating the ball masks.")
detection = None
used_frame_idx = None
for candidate in priority_frames:
frame_idx = int(np.clip(candidate, 0, state_in.num_frames - 1))
frame = state_in.video_frames[frame_idx]
detection = detect_person_box(frame)
if detection is not None:
used_frame_idx = frame_idx
break
frame_idx = used_frame_idx if detection is not None else priority_frames[0]
frame_idx = int(np.clip(frame_idx, 0, state_in.num_frames - 1))
state_in.current_frame_idx = frame_idx
def _result(preview_img, status_text):
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
status_updates = _ui_status_updates(state_in)
return (
preview_img,
gr.update(value=status_text, visible=True),
gr.update(value=frame_idx),
_build_kick_plot(state_in),
propagate_main_update,
detect_btn_update,
propagate_player_update,
gr.update(),
_impact_status_update(state_in),
*status_updates,
)
if detection is None:
state_in.is_player_detected = False
status_text = (
f"{_format_kick_status(state_in)} | ⚠️ Unable to auto-detect the player on frame {frame_idx}. "
"Please add a box manually."
)
return _result(update_frame_display(state_in, frame_idx), status_text)
x_min, y_min, x_max, y_max, conf = detection
state_in.player_obj_id = PLAYER_OBJECT_ID
state_in.player_detection_frame = frame_idx
state_in.player_detection_conf = conf
state_in.current_obj_id = PLAYER_OBJECT_ID
state_in.is_player_detected = True
# Clear previous player-specific prompts/masks
for frame_boxes in state_in.boxes_by_frame_obj.values():
frame_boxes.pop(PLAYER_OBJECT_ID, None)
for frame_clicks in state_in.clicks_by_frame_obj.values():
frame_clicks.pop(PLAYER_OBJECT_ID, None)
for frame_masks in state_in.masks_by_frame.values():
frame_masks.pop(PLAYER_OBJECT_ID, None)
_ensure_color_for_obj(state_in, PLAYER_OBJECT_ID)
processor = state_in.processor
model = state_in.model
inference_session = state_in.inference_session
inputs = processor(images=frame, device=state_in.device, return_tensors="pt")
original_size = inputs.original_sizes[0]
pixel_values = inputs.pixel_values[0]
processor.add_inputs_to_inference_session(
inference_session=inference_session,
frame_idx=frame_idx,
obj_ids=PLAYER_OBJECT_ID,
input_boxes=[[[x_min, y_min, x_max, y_max]]],
clear_old_inputs=True,
original_size=original_size,
)
frame_boxes = state_in.boxes_by_frame_obj.setdefault(frame_idx, {})
frame_boxes[PLAYER_OBJECT_ID] = [(x_min, y_min, x_max, y_max)]
state_in.composited_frames.pop(frame_idx, None)
with torch.inference_mode():
outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx)
H = inference_session.video_height
W = inference_session.video_width
pred_masks = outputs.pred_masks.detach().cpu()
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
masks_for_frame = state_in.masks_by_frame.get(frame_idx, {}).copy()
obj_ids_order = list(inference_session.obj_ids)
for i, oid in enumerate(obj_ids_order):
mask_i = video_res_masks[i].cpu().numpy().squeeze()
masks_for_frame[int(oid)] = mask_i
state_in.masks_by_frame[frame_idx] = masks_for_frame
_update_centroids_for_frame(state_in, frame_idx)
state_in.composited_frames.pop(frame_idx, None)
state_in.current_frame_idx = frame_idx
status_text = (
f"{_format_kick_status(state_in)} | ✅ Player auto-detected on frame {frame_idx} (conf={conf:.2f})"
)
return _result(update_frame_display(state_in, frame_idx), status_text)
detect_player_btn.click(
_auto_detect_player,
inputs=[GLOBAL_STATE],
outputs=[
preview,
ball_status,
frame_slider,
kick_plot,
propagate_btn,
detect_player_btn,
propagate_player_btn,
obj_id_inp,
impact_status,
yolo_kick_btn,
yolo_impact_btn,
sam_kick_btn,
sam_impact_btn,
manual_kick_btn,
manual_impact_btn,
detect_ball_btn,
track_ball_yolo_btn,
goal_start_btn,
goal_confirm_btn,
goal_clear_btn,
goal_back_btn,
goal_status,
],
)
@spaces.GPU()
def propagate_player_masks(GLOBAL_STATE: gr.State):
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
status_updates = _ui_status_updates(GLOBAL_STATE)
return (
GLOBAL_STATE,
"Load a video first.",
gr.update(),
_build_kick_plot(GLOBAL_STATE),
_build_yolo_plot(GLOBAL_STATE),
_impact_status_update(GLOBAL_STATE),
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
if GLOBAL_STATE.player_obj_id is None or not _player_has_masks(GLOBAL_STATE):
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
status_updates = _ui_status_updates(GLOBAL_STATE)
return (
GLOBAL_STATE,
"Detect the player before propagating.",
gr.update(),
_build_kick_plot(GLOBAL_STATE),
_build_yolo_plot(GLOBAL_STATE),
_impact_status_update(GLOBAL_STATE),
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
processor = deepcopy(GLOBAL_STATE.processor)
model = deepcopy(GLOBAL_STATE.model)
inference_session = deepcopy(GLOBAL_STATE.inference_session)
inference_session.inference_device = "cuda"
inference_session.cache.inference_device = "cuda"
model.to("cuda")
if not GLOBAL_STATE.sam_window:
_compute_sam_window_from_kick(
GLOBAL_STATE,
_get_prioritized_kick_frame(GLOBAL_STATE),
)
start_idx, end_idx = GLOBAL_STATE.sam_window or (0, GLOBAL_STATE.num_frames)
start_idx = max(0, int(start_idx))
end_idx = min(GLOBAL_STATE.num_frames, max(start_idx + 1, int(end_idx)))
total = max(1, end_idx - start_idx)
processed = 0
last_frame_idx = start_idx
GLOBAL_STATE.is_player_propagated = False
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
status_updates = _ui_status_updates(GLOBAL_STATE)
yield (
GLOBAL_STATE,
f"Propagating player: {processed}/{total}",
gr.update(),
_build_kick_plot(GLOBAL_STATE),
_build_yolo_plot(GLOBAL_STATE),
_impact_status_update(GLOBAL_STATE),
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
player_id = GLOBAL_STATE.player_obj_id or PLAYER_OBJECT_ID
with torch.inference_mode():
for frame_idx in range(start_idx, end_idx):
frame = GLOBAL_STATE.video_frames[frame_idx]
pixel_values = None
if (
inference_session.processed_frames is None
or frame_idx not in inference_session.processed_frames
):
pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0]
sam2_video_output = model(
inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx
)
H = inference_session.video_height
W = inference_session.video_width
pred_masks = sam2_video_output.pred_masks.detach().cpu()
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
masks_for_frame = GLOBAL_STATE.masks_by_frame.get(frame_idx, {}).copy()
obj_ids_order = list(inference_session.obj_ids)
for i, oid in enumerate(obj_ids_order):
mask_2d = video_res_masks[i].cpu().numpy().squeeze()
if int(oid) == int(player_id):
masks_for_frame[int(player_id)] = mask_2d
GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame
_update_centroids_for_frame(GLOBAL_STATE, frame_idx)
GLOBAL_STATE.composited_frames.pop(frame_idx, None)
processed += 1
last_frame_idx = frame_idx
if processed % 30 == 0 or processed == total:
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
status_updates = _ui_status_updates(GLOBAL_STATE)
yield (
GLOBAL_STATE,
f"Propagating player: {processed}/{total}",
gr.update(value=frame_idx),
_build_kick_plot(GLOBAL_STATE),
_build_yolo_plot(GLOBAL_STATE),
_impact_status_update(GLOBAL_STATE),
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
text = f"Propagated player across {processed} frames."
target_frame = GLOBAL_STATE.player_detection_frame
if target_frame is None:
target_frame = _get_prioritized_kick_frame(GLOBAL_STATE)
if target_frame is None:
target_frame = last_frame_idx
target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1)))
GLOBAL_STATE.current_frame_idx = target_frame
GLOBAL_STATE.is_player_propagated = True
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
status_updates = _ui_status_updates(GLOBAL_STATE)
yield (
GLOBAL_STATE,
text,
gr.update(value=target_frame),
_build_kick_plot(GLOBAL_STATE),
_build_yolo_plot(GLOBAL_STATE),
_impact_status_update(GLOBAL_STATE),
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
propagate_main_update,
detect_btn_update,
propagate_player_update,
*status_updates,
)
propagate_player_btn.click(
propagate_player_masks,
inputs=[GLOBAL_STATE],
outputs=[
GLOBAL_STATE,
propagate_status,
frame_slider,
kick_plot,
yolo_plot,
impact_status,
ball_status,
propagate_btn,
detect_player_btn,
propagate_player_btn,
yolo_kick_btn,
yolo_impact_btn,
sam_kick_btn,
sam_impact_btn,
manual_kick_btn,
manual_impact_btn,
detect_ball_btn,
track_ball_yolo_btn,
goal_start_btn,
goal_confirm_btn,
goal_clear_btn,
goal_back_btn,
goal_status,
],
)
# Image click to add a point and run forward on that frame
preview.select(
_on_image_click_with_updates,
[preview, GLOBAL_STATE, frame_slider, obj_id_inp, label_radio, clear_old_chk],
[
preview,
propagate_btn,
detect_player_btn,
propagate_player_btn,
yolo_kick_btn,
yolo_impact_btn,
sam_kick_btn,
sam_impact_btn,
manual_kick_btn,
manual_impact_btn,
detect_ball_btn,
track_ball_yolo_btn,
goal_start_btn,
goal_confirm_btn,
goal_clear_btn,
goal_back_btn,
goal_status,
],
)
goal_start_btn.click(
_goal_start_mapping,
inputs=[GLOBAL_STATE],
outputs=[preview, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status],
)
goal_confirm_btn.click(
_goal_confirm_mapping,
inputs=[GLOBAL_STATE],
outputs=[preview, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status],
)
goal_clear_btn.click(
_goal_clear_mapping,
inputs=[GLOBAL_STATE],
outputs=[preview, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status],
)
goal_back_btn.click(
_goal_back_mapping,
inputs=[GLOBAL_STATE],
outputs=[preview, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status],
)
# Playback via MP4 rendering only
# Render a smooth MP4 using imageio/pyav (fallbacks to imageio v2 / OpenCV)
def _render_video(s: AppState, remove_bg: bool = False, log_fn=None):
if s is None or s.num_frames == 0:
raise gr.Error("Load a video first.")
fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12
trim_duration_sec = 4.0
target_window_frames = max(1, int(round(fps * trim_duration_sec)))
half_window = target_window_frames // 2
kick_frame = s.kick_frame or getattr(s, "kick_debug_kick_frame", None)
start_idx = 0
end_idx = min(s.num_frames, target_window_frames)
if kick_frame is not None:
start_idx = max(0, int(kick_frame) - half_window)
end_idx = start_idx + target_window_frames
if end_idx > s.num_frames:
end_idx = s.num_frames
start_idx = max(0, end_idx - target_window_frames)
else:
end_idx = min(s.num_frames, start_idx + target_window_frames)
if end_idx <= start_idx:
end_idx = min(s.num_frames, start_idx + 1)
# Compose all frames in trimmed window
frames_np = []
first = compose_frame(s, start_idx, remove_bg=remove_bg)
h, w = first.size[1], first.size[0]
total_frames = max(1, end_idx - start_idx)
for idx in range(start_idx, end_idx):
# Don't use cache when remove_bg changes behavior
if remove_bg:
img = compose_frame(s, idx, remove_bg=True)
else:
img = s.composited_frames.get(idx)
if img is None:
img = compose_frame(s, idx, remove_bg=False)
img_with_idx = _annotate_frame_index(img, idx)
frames_np.append(np.array(img_with_idx)[:, :, ::-1]) # BGR for cv2
# Periodically release CPU mem to reduce pressure
if (idx + 1) % 60 == 0:
gc.collect()
processed = idx - start_idx + 1
if log_fn and (processed % 20 == 0 or processed == total_frames):
log_fn(f"Rendering frames {processed}/{total_frames}")
import tempfile
out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
out_path = out_file.name
out_file.close()
def _write_with_opencv():
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
if not writer.isOpened():
writer.release()
raise RuntimeError("OpenCV VideoWriter failed to open (missing codec?).")
for fr_bgr in frames_np:
writer.write(fr_bgr)
writer.release()
def _write_with_imageio():
import imageio
with imageio.get_writer(out_path, fps=fps, codec="libx264", mode="I", quality=8) as writer:
for fr_bgr in frames_np:
writer.append_data(fr_bgr[:, :, ::-1]) # convert back to RGB
try:
_write_with_opencv()
except Exception as cv_err:
print(f"OpenCV writer failed: {cv_err}")
try:
if log_fn:
log_fn("OpenCV writer unavailable, falling back to imageio/pyav.")
_write_with_imageio()
except Exception as io_err:
print(f"Failed to render video: {io_err}")
raise gr.Error(f"Failed to render video: {io_err}")
return out_path
render_btn.click(_render_video, inputs=[GLOBAL_STATE, remove_bg_checkbox], outputs=[playback_video])
# While propagating, we stream two outputs: status text and slider value updates
propagate_btn.click(
propagate_masks,
inputs=[GLOBAL_STATE],
outputs=[
GLOBAL_STATE,
propagate_status,
frame_slider,
kick_plot,
yolo_plot,
impact_status,
ball_status,
propagate_btn,
detect_player_btn,
propagate_player_btn,
yolo_kick_btn,
yolo_impact_btn,
sam_kick_btn,
sam_impact_btn,
manual_kick_btn,
manual_impact_btn,
detect_ball_btn,
track_ball_yolo_btn,
goal_start_btn,
goal_confirm_btn,
goal_clear_btn,
goal_back_btn,
goal_status,
],
)
reset_btn.click(
reset_session,
inputs=GLOBAL_STATE,
outputs=[
GLOBAL_STATE,
preview,
frame_slider,
frame_slider,
load_status,
ball_status,
kick_plot,
yolo_plot,
impact_status,
propagate_btn,
detect_player_btn,
propagate_player_btn,
yolo_kick_btn,
yolo_impact_btn,
sam_kick_btn,
sam_impact_btn,
manual_kick_btn,
manual_impact_btn,
detect_ball_btn,
track_ball_yolo_btn,
goal_start_btn,
goal_confirm_btn,
goal_clear_btn,
goal_back_btn,
goal_status,
],
)
# ============================================================================
# COMBINED INTERFACE WITH EXPLICIT API ENDPOINT
# ============================================================================
# Create API interface with explicit endpoint
api_interface = gr.Interface(
fn=process_video_api,
inputs=[
gr.Video(label="Video File"),
gr.Textbox(
label="Annotations JSON (optional)",
placeholder='{"annotations": [{"object_id": 1, "frame": 139, "x": 369, "y": 652, "label": "positive"}]}',
lines=5
),
gr.Radio(
choices=["tiny", "small", "base_plus", "large"],
value="base_plus",
label="SAM2 Checkpoint"
),
gr.Checkbox(label="Remove Background", value=True)
],
outputs=[
gr.Image(label="Annotation Preview / First Frame"),
gr.Video(label="Processed Video"),
gr.Textbox(label="Processing Log", lines=12)
],
title="SAM2 API",
description="""
## Programmatic KickTrimmer Pipeline
Submitting a video here runs the same automated workflow as the Interactive UI:
1. **Upload** the raw MP4.
2. `YOLO13` **detects** and **tracks** the ball to get the first kick estimate.
3. `SAM2` **tracks the ball** around that kick window. If SAM2's kick disagrees with YOLO's, it automatically re-tracks **the entire clip** for better accuracy.
4. `YOLO13` **detects the player** on the SAM2 kick frame, then `SAM2` propagates the player masks around that window.
5. The Space **renders a default cutout video** and returns it together with the processing log below.
### Optional annotations
You can still send helper points via JSON:
```json
{
"annotations": [
{"object_id": 1, "frame": 0, "x": 363, "y": 631, "label": "positive"},
{"object_id": 1, "frame": 187, "x": 296, "y": 485, "label": "positive"},
{"object_id": 2, "frame": 187, "x": 296, "y": 412, "label": "positive"}
]
}
```
- **Object 1** = ball, **Object 2** = player. Use timestamps when possible; the API will reconcile timestamps and frame indices for you.
"""
)
# Use gr.Blocks to combine both with proper API exposure
with gr.Blocks(title="SAM2 Video Tracking") as combined_demo:
gr.Markdown("# SAM2 Video Tracking")
with gr.Tabs():
with gr.TabItem("Interactive UI"):
demo.render()
with gr.TabItem("API"):
api_interface.render()
# Explicitly expose the API function at root level for external API calls
# This creates the /api/predict endpoint
api_video_input_hidden = gr.Video(visible=False)
api_annotations_input_hidden = gr.Textbox(visible=False)
api_checkpoint_input_hidden = gr.Radio(choices=["tiny", "small", "base_plus", "large"], visible=False)
api_remove_bg_input_hidden = gr.Checkbox(visible=False)
api_preview_output_hidden = gr.Image(visible=False)
api_video_output_hidden = gr.Video(visible=False)
api_logs_output_hidden = gr.Textbox(visible=False)
# This dummy component creates the external API endpoint
api_dummy_btn = gr.Button("API", visible=False)
api_dummy_btn.click(
fn=process_video_api,
inputs=[api_video_input_hidden, api_annotations_input_hidden, api_checkpoint_input_hidden, api_remove_bg_input_hidden],
outputs=[api_preview_output_hidden, api_video_output_hidden, api_logs_output_hidden],
api_name="predict" # This creates /api/predict for external calls
)
# Launch with API enabled
if __name__ == "__main__":
combined_demo.queue(api_open=True).launch()