Spaces:
Sleeping
Sleeping
"""Visualisation utilities for the DRS application. | |
This module contains functions to generate images and videos that | |
illustrate the ball's flight and the outcome of the LBW decision. | |
Using Matplotlib and OpenCV we create a 3D trajectory plot and an | |
annotated replay video. These assets are returned to the Gradio | |
interface for display to the user. | |
""" | |
from __future__ import annotations | |
import cv2 | |
import numpy as np | |
import matplotlib | |
matplotlib.use("Agg") # Use a non‑interactive backend | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 # needed for 3D plots | |
from typing import List, Tuple, Callable | |
def generate_trajectory_plot( | |
centers: List[Tuple[int, int]], | |
trajectory: dict, | |
will_hit_stumps: bool, | |
output_path: str, | |
) -> None: | |
"""Create a 3D plot of the observed and predicted ball trajectory. | |
The x axis represents the horizontal pixel coordinate, the y axis | |
represents the vertical coordinate (top at 0), and the z axis | |
corresponds to the frame index (time). The predicted path is drawn | |
on the x–y plane at z=0 for clarity. | |
Parameters | |
---------- | |
centers: list of tuple(int, int) | |
Detected ball centre positions. | |
trajectory: dict | |
Output of :func:`modules.trajectory.estimate_trajectory`. | |
will_hit_stumps: bool | |
Whether the ball is predicted to hit the stumps; controls the | |
colour of the predicted path. | |
output_path: str | |
Where to save the resulting PNG image. | |
""" | |
if not centers: | |
# If no points, draw an empty figure | |
fig = plt.figure(figsize=(6, 4)) | |
ax = fig.add_subplot(111, projection="3d") | |
ax.set_title("No ball detections") | |
ax.set_xlabel("X (pixels)") | |
ax.set_ylabel("Y (pixels)") | |
ax.set_zlabel("Frame index") | |
fig.tight_layout() | |
fig.savefig(output_path) | |
plt.close(fig) | |
return | |
xs = np.array([c[0] for c in centers]) | |
ys = np.array([c[1] for c in centers]) | |
zs = np.arange(len(centers)) | |
# Compute predicted path along the full x range | |
model: Callable[[float], float] = trajectory["model"] | |
x_range = np.linspace(xs.min(), xs.max(), 100) | |
y_pred = model(x_range) | |
fig = plt.figure(figsize=(6, 4)) | |
ax = fig.add_subplot(111, projection="3d") | |
# Plot observed points | |
ax.plot(xs, ys, zs, 'o-', label="Detected ball path", color="blue") | |
# Plot predicted path on z=0 plane | |
colour = "green" if will_hit_stumps else "red" | |
ax.plot(x_range, y_pred, np.zeros_like(x_range), '--', label="Predicted path", color=colour) | |
ax.set_xlabel("X (pixels)") | |
ax.set_ylabel("Y (pixels)") | |
ax.set_zlabel("Frame index") | |
ax.set_title("Ball trajectory (observed vs predicted)") | |
ax.legend() | |
ax.invert_yaxis() # Invert y axis to match image coordinates | |
fig.tight_layout() | |
fig.savefig(output_path) | |
plt.close(fig) | |
def annotate_video_with_tracking( | |
video_path: str, | |
centers: List[Tuple[int, int]], | |
trajectory: dict, | |
will_hit_stumps: bool, | |
impact_frame_idx: int, | |
output_path: str, | |
) -> None: | |
"""Create an annotated replay video highlighting key elements. | |
The function reads the trimmed input video and writes out a new | |
video with the following overlays: | |
* The detected ball centre (small filled circle). | |
* A polyline showing the path of the ball up to the current | |
frame. | |
* The predicted trajectory across the frame, drawn as a dashed | |
curve. | |
* A rectangle representing the stumps zone at the bottom centre | |
of the frame; coloured green if the ball is predicted to hit | |
and red otherwise. | |
* The text "OUT" or "NOT OUT" displayed after the impact frame. | |
* Auto zoom effect on the impact frame by drawing a thicker | |
circle around the ball. | |
Parameters | |
---------- | |
video_path: str | |
Path to the trimmed input video. | |
centers: list of tuple(int, int) | |
Detected ball centres for each frame analysed. | |
trajectory: dict | |
Output of :func:`modules.trajectory.estimate_trajectory`. | |
will_hit_stumps: bool | |
Whether the ball is predicted to hit the stumps. | |
impact_frame_idx: int | |
Index of the frame considered as the impact frame. | |
output_path: str | |
Where to save the annotated video. | |
""" | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
raise RuntimeError(f"Could not open video {video_path}") | |
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
model: Callable[[float], float] = trajectory["model"] | |
# Precompute predicted path points for drawing on each frame | |
x_vals = np.linspace(0, width - 1, 50) | |
y_preds = model(x_vals) | |
# Ensure predicted y values stay within frame | |
y_preds_clamped = np.clip(y_preds, 0, height - 1).astype(int) | |
# Define stumps zone coordinates | |
stumps_width = int(width * 0.1) | |
stumps_height = int(height * 0.3) | |
stumps_x = int((width - stumps_width) / 2) | |
stumps_y = int(height * 0.65) | |
stumps_color = (0, 255, 0) if will_hit_stumps else (0, 0, 255) | |
frame_idx = 0 | |
path_points: List[Tuple[int, int]] = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Draw stumps region on every frame | |
cv2.rectangle( | |
frame, | |
(stumps_x, stumps_y), | |
(stumps_x + stumps_width, stumps_y + stumps_height), | |
stumps_color, | |
2, | |
) | |
# Draw predicted trajectory line (dashed effect by skipping points) | |
for i in range(len(x_vals) - 1): | |
if i % 4 != 0: | |
continue | |
pt1 = (int(x_vals[i]), int(y_preds_clamped[i])) | |
pt2 = (int(x_vals[i + 1]), int(y_preds_clamped[i + 1])) | |
cv2.line(frame, pt1, pt2, stumps_color, 1, lineType=cv2.LINE_AA) | |
# If we have a centre for this frame, draw it and update the path | |
if frame_idx < len(centers): | |
cx, cy = centers[frame_idx] | |
path_points.append((cx, cy)) | |
# Draw past trajectory as a polyline | |
if len(path_points) > 1: | |
cv2.polylines(frame, [np.array(path_points, dtype=np.int32)], False, (255, 0, 0), 2) | |
# Draw the ball centre (bigger on impact frame) | |
radius = 5 | |
thickness = -1 | |
colour = (255, 255, 255) | |
if frame_idx == impact_frame_idx: | |
# Auto zoom effect: larger circle and thicker outline | |
radius = 10 | |
thickness = 2 | |
colour = (0, 255, 255) | |
cv2.circle(frame, (cx, cy), radius, colour, thickness) | |
else: | |
# Continue drawing the path beyond detection frames | |
if len(path_points) > 1: | |
cv2.polylines(frame, [np.array(path_points, dtype=np.int32)], False, (255, 0, 0), 2) | |
# After the impact frame, display the decision text | |
if frame_idx >= impact_frame_idx and impact_frame_idx >= 0: | |
decision_text = "OUT" if will_hit_stumps else "NOT OUT" | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
cv2.putText( | |
frame, | |
decision_text, | |
(50, 50), | |
font, | |
1.5, | |
(0, 255, 0) if will_hit_stumps else (0, 0, 255), | |
3, | |
lineType=cv2.LINE_AA, | |
) | |
writer.write(frame) | |
frame_idx += 1 | |
cap.release() | |
writer.release() |