lbw_drs_app_new / drs_modules /visualization.py
dschandra's picture
Upload 6 files
2db7738 verified
"""Visualisation utilities for the DRS application.
This module contains functions to generate images and videos that
illustrate the ball's flight and the outcome of the LBW decision.
Using Matplotlib and OpenCV we create a 3D trajectory plot and an
annotated replay video. These assets are returned to the Gradio
interface for display to the user.
"""
from __future__ import annotations
import cv2
import numpy as np
import matplotlib
matplotlib.use("Agg") # Use a non‑interactive backend
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 # needed for 3D plots
from typing import List, Tuple, Callable
def generate_trajectory_plot(
centers: List[Tuple[int, int]],
trajectory: dict,
will_hit_stumps: bool,
output_path: str,
) -> None:
"""Create a 3D plot of the observed and predicted ball trajectory.
The x axis represents the horizontal pixel coordinate, the y axis
represents the vertical coordinate (top at 0), and the z axis
corresponds to the frame index (time). The predicted path is drawn
on the x–y plane at z=0 for clarity.
Parameters
----------
centers: list of tuple(int, int)
Detected ball centre positions.
trajectory: dict
Output of :func:`modules.trajectory.estimate_trajectory`.
will_hit_stumps: bool
Whether the ball is predicted to hit the stumps; controls the
colour of the predicted path.
output_path: str
Where to save the resulting PNG image.
"""
if not centers:
# If no points, draw an empty figure
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111, projection="3d")
ax.set_title("No ball detections")
ax.set_xlabel("X (pixels)")
ax.set_ylabel("Y (pixels)")
ax.set_zlabel("Frame index")
fig.tight_layout()
fig.savefig(output_path)
plt.close(fig)
return
xs = np.array([c[0] for c in centers])
ys = np.array([c[1] for c in centers])
zs = np.arange(len(centers))
# Compute predicted path along the full x range
model: Callable[[float], float] = trajectory["model"]
x_range = np.linspace(xs.min(), xs.max(), 100)
y_pred = model(x_range)
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111, projection="3d")
# Plot observed points
ax.plot(xs, ys, zs, 'o-', label="Detected ball path", color="blue")
# Plot predicted path on z=0 plane
colour = "green" if will_hit_stumps else "red"
ax.plot(x_range, y_pred, np.zeros_like(x_range), '--', label="Predicted path", color=colour)
ax.set_xlabel("X (pixels)")
ax.set_ylabel("Y (pixels)")
ax.set_zlabel("Frame index")
ax.set_title("Ball trajectory (observed vs predicted)")
ax.legend()
ax.invert_yaxis() # Invert y axis to match image coordinates
fig.tight_layout()
fig.savefig(output_path)
plt.close(fig)
def annotate_video_with_tracking(
video_path: str,
centers: List[Tuple[int, int]],
trajectory: dict,
will_hit_stumps: bool,
impact_frame_idx: int,
output_path: str,
) -> None:
"""Create an annotated replay video highlighting key elements.
The function reads the trimmed input video and writes out a new
video with the following overlays:
* The detected ball centre (small filled circle).
* A polyline showing the path of the ball up to the current
frame.
* The predicted trajectory across the frame, drawn as a dashed
curve.
* A rectangle representing the stumps zone at the bottom centre
of the frame; coloured green if the ball is predicted to hit
and red otherwise.
* The text "OUT" or "NOT OUT" displayed after the impact frame.
* Auto zoom effect on the impact frame by drawing a thicker
circle around the ball.
Parameters
----------
video_path: str
Path to the trimmed input video.
centers: list of tuple(int, int)
Detected ball centres for each frame analysed.
trajectory: dict
Output of :func:`modules.trajectory.estimate_trajectory`.
will_hit_stumps: bool
Whether the ball is predicted to hit the stumps.
impact_frame_idx: int
Index of the frame considered as the impact frame.
output_path: str
Where to save the annotated video.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Could not open video {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
model: Callable[[float], float] = trajectory["model"]
# Precompute predicted path points for drawing on each frame
x_vals = np.linspace(0, width - 1, 50)
y_preds = model(x_vals)
# Ensure predicted y values stay within frame
y_preds_clamped = np.clip(y_preds, 0, height - 1).astype(int)
# Define stumps zone coordinates
stumps_width = int(width * 0.1)
stumps_height = int(height * 0.3)
stumps_x = int((width - stumps_width) / 2)
stumps_y = int(height * 0.65)
stumps_color = (0, 255, 0) if will_hit_stumps else (0, 0, 255)
frame_idx = 0
path_points: List[Tuple[int, int]] = []
while True:
ret, frame = cap.read()
if not ret:
break
# Draw stumps region on every frame
cv2.rectangle(
frame,
(stumps_x, stumps_y),
(stumps_x + stumps_width, stumps_y + stumps_height),
stumps_color,
2,
)
# Draw predicted trajectory line (dashed effect by skipping points)
for i in range(len(x_vals) - 1):
if i % 4 != 0:
continue
pt1 = (int(x_vals[i]), int(y_preds_clamped[i]))
pt2 = (int(x_vals[i + 1]), int(y_preds_clamped[i + 1]))
cv2.line(frame, pt1, pt2, stumps_color, 1, lineType=cv2.LINE_AA)
# If we have a centre for this frame, draw it and update the path
if frame_idx < len(centers):
cx, cy = centers[frame_idx]
path_points.append((cx, cy))
# Draw past trajectory as a polyline
if len(path_points) > 1:
cv2.polylines(frame, [np.array(path_points, dtype=np.int32)], False, (255, 0, 0), 2)
# Draw the ball centre (bigger on impact frame)
radius = 5
thickness = -1
colour = (255, 255, 255)
if frame_idx == impact_frame_idx:
# Auto zoom effect: larger circle and thicker outline
radius = 10
thickness = 2
colour = (0, 255, 255)
cv2.circle(frame, (cx, cy), radius, colour, thickness)
else:
# Continue drawing the path beyond detection frames
if len(path_points) > 1:
cv2.polylines(frame, [np.array(path_points, dtype=np.int32)], False, (255, 0, 0), 2)
# After the impact frame, display the decision text
if frame_idx >= impact_frame_idx and impact_frame_idx >= 0:
decision_text = "OUT" if will_hit_stumps else "NOT OUT"
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(
frame,
decision_text,
(50, 50),
font,
1.5,
(0, 255, 0) if will_hit_stumps else (0, 0, 255),
3,
lineType=cv2.LINE_AA,
)
writer.write(frame)
frame_idx += 1
cap.release()
writer.release()