"""Visualisation utilities for the DRS application. This module contains functions to generate images and videos that illustrate the ball's flight and the outcome of the LBW decision. Using Matplotlib and OpenCV we create a 3D trajectory plot and an annotated replay video. These assets are returned to the Gradio interface for display to the user. """ from __future__ import annotations import cv2 import numpy as np import matplotlib matplotlib.use("Agg") # Use a non‑interactive backend import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa: F401 # needed for 3D plots from typing import List, Tuple, Callable def generate_trajectory_plot( centers: List[Tuple[int, int]], trajectory: dict, will_hit_stumps: bool, output_path: str, ) -> None: """Create a 3D plot of the observed and predicted ball trajectory. The x axis represents the horizontal pixel coordinate, the y axis represents the vertical coordinate (top at 0), and the z axis corresponds to the frame index (time). The predicted path is drawn on the x–y plane at z=0 for clarity. Parameters ---------- centers: list of tuple(int, int) Detected ball centre positions. trajectory: dict Output of :func:`modules.trajectory.estimate_trajectory`. will_hit_stumps: bool Whether the ball is predicted to hit the stumps; controls the colour of the predicted path. output_path: str Where to save the resulting PNG image. """ if not centers: # If no points, draw an empty figure fig = plt.figure(figsize=(6, 4)) ax = fig.add_subplot(111, projection="3d") ax.set_title("No ball detections") ax.set_xlabel("X (pixels)") ax.set_ylabel("Y (pixels)") ax.set_zlabel("Frame index") fig.tight_layout() fig.savefig(output_path) plt.close(fig) return xs = np.array([c[0] for c in centers]) ys = np.array([c[1] for c in centers]) zs = np.arange(len(centers)) # Compute predicted path along the full x range model: Callable[[float], float] = trajectory["model"] x_range = np.linspace(xs.min(), xs.max(), 100) y_pred = model(x_range) fig = plt.figure(figsize=(6, 4)) ax = fig.add_subplot(111, projection="3d") # Plot observed points ax.plot(xs, ys, zs, 'o-', label="Detected ball path", color="blue") # Plot predicted path on z=0 plane colour = "green" if will_hit_stumps else "red" ax.plot(x_range, y_pred, np.zeros_like(x_range), '--', label="Predicted path", color=colour) ax.set_xlabel("X (pixels)") ax.set_ylabel("Y (pixels)") ax.set_zlabel("Frame index") ax.set_title("Ball trajectory (observed vs predicted)") ax.legend() ax.invert_yaxis() # Invert y axis to match image coordinates fig.tight_layout() fig.savefig(output_path) plt.close(fig) def annotate_video_with_tracking( video_path: str, centers: List[Tuple[int, int]], trajectory: dict, will_hit_stumps: bool, impact_frame_idx: int, output_path: str, ) -> None: """Create an annotated replay video highlighting key elements. The function reads the trimmed input video and writes out a new video with the following overlays: * The detected ball centre (small filled circle). * A polyline showing the path of the ball up to the current frame. * The predicted trajectory across the frame, drawn as a dashed curve. * A rectangle representing the stumps zone at the bottom centre of the frame; coloured green if the ball is predicted to hit and red otherwise. * The text "OUT" or "NOT OUT" displayed after the impact frame. * Auto zoom effect on the impact frame by drawing a thicker circle around the ball. Parameters ---------- video_path: str Path to the trimmed input video. centers: list of tuple(int, int) Detected ball centres for each frame analysed. trajectory: dict Output of :func:`modules.trajectory.estimate_trajectory`. will_hit_stumps: bool Whether the ball is predicted to hit the stumps. impact_frame_idx: int Index of the frame considered as the impact frame. output_path: str Where to save the annotated video. """ cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise RuntimeError(f"Could not open video {video_path}") fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fourcc = cv2.VideoWriter_fourcc(*"mp4v") writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) model: Callable[[float], float] = trajectory["model"] # Precompute predicted path points for drawing on each frame x_vals = np.linspace(0, width - 1, 50) y_preds = model(x_vals) # Ensure predicted y values stay within frame y_preds_clamped = np.clip(y_preds, 0, height - 1).astype(int) # Define stumps zone coordinates stumps_width = int(width * 0.1) stumps_height = int(height * 0.3) stumps_x = int((width - stumps_width) / 2) stumps_y = int(height * 0.65) stumps_color = (0, 255, 0) if will_hit_stumps else (0, 0, 255) frame_idx = 0 path_points: List[Tuple[int, int]] = [] while True: ret, frame = cap.read() if not ret: break # Draw stumps region on every frame cv2.rectangle( frame, (stumps_x, stumps_y), (stumps_x + stumps_width, stumps_y + stumps_height), stumps_color, 2, ) # Draw predicted trajectory line (dashed effect by skipping points) for i in range(len(x_vals) - 1): if i % 4 != 0: continue pt1 = (int(x_vals[i]), int(y_preds_clamped[i])) pt2 = (int(x_vals[i + 1]), int(y_preds_clamped[i + 1])) cv2.line(frame, pt1, pt2, stumps_color, 1, lineType=cv2.LINE_AA) # If we have a centre for this frame, draw it and update the path if frame_idx < len(centers): cx, cy = centers[frame_idx] path_points.append((cx, cy)) # Draw past trajectory as a polyline if len(path_points) > 1: cv2.polylines(frame, [np.array(path_points, dtype=np.int32)], False, (255, 0, 0), 2) # Draw the ball centre (bigger on impact frame) radius = 5 thickness = -1 colour = (255, 255, 255) if frame_idx == impact_frame_idx: # Auto zoom effect: larger circle and thicker outline radius = 10 thickness = 2 colour = (0, 255, 255) cv2.circle(frame, (cx, cy), radius, colour, thickness) else: # Continue drawing the path beyond detection frames if len(path_points) > 1: cv2.polylines(frame, [np.array(path_points, dtype=np.int32)], False, (255, 0, 0), 2) # After the impact frame, display the decision text if frame_idx >= impact_frame_idx and impact_frame_idx >= 0: decision_text = "OUT" if will_hit_stumps else "NOT OUT" font = cv2.FONT_HERSHEY_SIMPLEX cv2.putText( frame, decision_text, (50, 50), font, 1.5, (0, 255, 0) if will_hit_stumps else (0, 0, 255), 3, lineType=cv2.LINE_AA, ) writer.write(frame) frame_idx += 1 cap.release() writer.release()