Spaces:
Sleeping
Sleeping
File size: 7,739 Bytes
2db7738 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
"""Visualisation utilities for the DRS application.
This module contains functions to generate images and videos that
illustrate the ball's flight and the outcome of the LBW decision.
Using Matplotlib and OpenCV we create a 3D trajectory plot and an
annotated replay video. These assets are returned to the Gradio
interface for display to the user.
"""
from __future__ import annotations
import cv2
import numpy as np
import matplotlib
matplotlib.use("Agg") # Use a non‑interactive backend
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 # needed for 3D plots
from typing import List, Tuple, Callable
def generate_trajectory_plot(
centers: List[Tuple[int, int]],
trajectory: dict,
will_hit_stumps: bool,
output_path: str,
) -> None:
"""Create a 3D plot of the observed and predicted ball trajectory.
The x axis represents the horizontal pixel coordinate, the y axis
represents the vertical coordinate (top at 0), and the z axis
corresponds to the frame index (time). The predicted path is drawn
on the x–y plane at z=0 for clarity.
Parameters
----------
centers: list of tuple(int, int)
Detected ball centre positions.
trajectory: dict
Output of :func:`modules.trajectory.estimate_trajectory`.
will_hit_stumps: bool
Whether the ball is predicted to hit the stumps; controls the
colour of the predicted path.
output_path: str
Where to save the resulting PNG image.
"""
if not centers:
# If no points, draw an empty figure
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111, projection="3d")
ax.set_title("No ball detections")
ax.set_xlabel("X (pixels)")
ax.set_ylabel("Y (pixels)")
ax.set_zlabel("Frame index")
fig.tight_layout()
fig.savefig(output_path)
plt.close(fig)
return
xs = np.array([c[0] for c in centers])
ys = np.array([c[1] for c in centers])
zs = np.arange(len(centers))
# Compute predicted path along the full x range
model: Callable[[float], float] = trajectory["model"]
x_range = np.linspace(xs.min(), xs.max(), 100)
y_pred = model(x_range)
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111, projection="3d")
# Plot observed points
ax.plot(xs, ys, zs, 'o-', label="Detected ball path", color="blue")
# Plot predicted path on z=0 plane
colour = "green" if will_hit_stumps else "red"
ax.plot(x_range, y_pred, np.zeros_like(x_range), '--', label="Predicted path", color=colour)
ax.set_xlabel("X (pixels)")
ax.set_ylabel("Y (pixels)")
ax.set_zlabel("Frame index")
ax.set_title("Ball trajectory (observed vs predicted)")
ax.legend()
ax.invert_yaxis() # Invert y axis to match image coordinates
fig.tight_layout()
fig.savefig(output_path)
plt.close(fig)
def annotate_video_with_tracking(
video_path: str,
centers: List[Tuple[int, int]],
trajectory: dict,
will_hit_stumps: bool,
impact_frame_idx: int,
output_path: str,
) -> None:
"""Create an annotated replay video highlighting key elements.
The function reads the trimmed input video and writes out a new
video with the following overlays:
* The detected ball centre (small filled circle).
* A polyline showing the path of the ball up to the current
frame.
* The predicted trajectory across the frame, drawn as a dashed
curve.
* A rectangle representing the stumps zone at the bottom centre
of the frame; coloured green if the ball is predicted to hit
and red otherwise.
* The text "OUT" or "NOT OUT" displayed after the impact frame.
* Auto zoom effect on the impact frame by drawing a thicker
circle around the ball.
Parameters
----------
video_path: str
Path to the trimmed input video.
centers: list of tuple(int, int)
Detected ball centres for each frame analysed.
trajectory: dict
Output of :func:`modules.trajectory.estimate_trajectory`.
will_hit_stumps: bool
Whether the ball is predicted to hit the stumps.
impact_frame_idx: int
Index of the frame considered as the impact frame.
output_path: str
Where to save the annotated video.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Could not open video {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
model: Callable[[float], float] = trajectory["model"]
# Precompute predicted path points for drawing on each frame
x_vals = np.linspace(0, width - 1, 50)
y_preds = model(x_vals)
# Ensure predicted y values stay within frame
y_preds_clamped = np.clip(y_preds, 0, height - 1).astype(int)
# Define stumps zone coordinates
stumps_width = int(width * 0.1)
stumps_height = int(height * 0.3)
stumps_x = int((width - stumps_width) / 2)
stumps_y = int(height * 0.65)
stumps_color = (0, 255, 0) if will_hit_stumps else (0, 0, 255)
frame_idx = 0
path_points: List[Tuple[int, int]] = []
while True:
ret, frame = cap.read()
if not ret:
break
# Draw stumps region on every frame
cv2.rectangle(
frame,
(stumps_x, stumps_y),
(stumps_x + stumps_width, stumps_y + stumps_height),
stumps_color,
2,
)
# Draw predicted trajectory line (dashed effect by skipping points)
for i in range(len(x_vals) - 1):
if i % 4 != 0:
continue
pt1 = (int(x_vals[i]), int(y_preds_clamped[i]))
pt2 = (int(x_vals[i + 1]), int(y_preds_clamped[i + 1]))
cv2.line(frame, pt1, pt2, stumps_color, 1, lineType=cv2.LINE_AA)
# If we have a centre for this frame, draw it and update the path
if frame_idx < len(centers):
cx, cy = centers[frame_idx]
path_points.append((cx, cy))
# Draw past trajectory as a polyline
if len(path_points) > 1:
cv2.polylines(frame, [np.array(path_points, dtype=np.int32)], False, (255, 0, 0), 2)
# Draw the ball centre (bigger on impact frame)
radius = 5
thickness = -1
colour = (255, 255, 255)
if frame_idx == impact_frame_idx:
# Auto zoom effect: larger circle and thicker outline
radius = 10
thickness = 2
colour = (0, 255, 255)
cv2.circle(frame, (cx, cy), radius, colour, thickness)
else:
# Continue drawing the path beyond detection frames
if len(path_points) > 1:
cv2.polylines(frame, [np.array(path_points, dtype=np.int32)], False, (255, 0, 0), 2)
# After the impact frame, display the decision text
if frame_idx >= impact_frame_idx and impact_frame_idx >= 0:
decision_text = "OUT" if will_hit_stumps else "NOT OUT"
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(
frame,
decision_text,
(50, 50),
font,
1.5,
(0, 255, 0) if will_hit_stumps else (0, 0, 255),
3,
lineType=cv2.LINE_AA,
)
writer.write(frame)
frame_idx += 1
cap.release()
writer.release() |