File size: 7,739 Bytes
2db7738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""Visualisation utilities for the DRS application.

This module contains functions to generate images and videos that
illustrate the ball's flight and the outcome of the LBW decision.
Using Matplotlib and OpenCV we create a 3D trajectory plot and an
annotated replay video.  These assets are returned to the Gradio
interface for display to the user.
"""

from __future__ import annotations

import cv2
import numpy as np
import matplotlib
matplotlib.use("Agg")  # Use a non‑interactive backend
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401  # needed for 3D plots
from typing import List, Tuple, Callable


def generate_trajectory_plot(
    centers: List[Tuple[int, int]],
    trajectory: dict,
    will_hit_stumps: bool,
    output_path: str,
) -> None:
    """Create a 3D plot of the observed and predicted ball trajectory.

    The x axis represents the horizontal pixel coordinate, the y axis
    represents the vertical coordinate (top at 0), and the z axis
    corresponds to the frame index (time).  The predicted path is drawn
    on the x–y plane at z=0 for clarity.

    Parameters
    ----------
    centers: list of tuple(int, int)
        Detected ball centre positions.
    trajectory: dict
        Output of :func:`modules.trajectory.estimate_trajectory`.
    will_hit_stumps: bool
        Whether the ball is predicted to hit the stumps; controls the
        colour of the predicted path.
    output_path: str
        Where to save the resulting PNG image.
    """
    if not centers:
        # If no points, draw an empty figure
        fig = plt.figure(figsize=(6, 4))
        ax = fig.add_subplot(111, projection="3d")
        ax.set_title("No ball detections")
        ax.set_xlabel("X (pixels)")
        ax.set_ylabel("Y (pixels)")
        ax.set_zlabel("Frame index")
        fig.tight_layout()
        fig.savefig(output_path)
        plt.close(fig)
        return

    xs = np.array([c[0] for c in centers])
    ys = np.array([c[1] for c in centers])
    zs = np.arange(len(centers))

    # Compute predicted path along the full x range
    model: Callable[[float], float] = trajectory["model"]
    x_range = np.linspace(xs.min(), xs.max(), 100)
    y_pred = model(x_range)

    fig = plt.figure(figsize=(6, 4))
    ax = fig.add_subplot(111, projection="3d")

    # Plot observed points
    ax.plot(xs, ys, zs, 'o-', label="Detected ball path", color="blue")

    # Plot predicted path on z=0 plane
    colour = "green" if will_hit_stumps else "red"
    ax.plot(x_range, y_pred, np.zeros_like(x_range), '--', label="Predicted path", color=colour)

    ax.set_xlabel("X (pixels)")
    ax.set_ylabel("Y (pixels)")
    ax.set_zlabel("Frame index")
    ax.set_title("Ball trajectory (observed vs predicted)")
    ax.legend()
    ax.invert_yaxis()  # Invert y axis to match image coordinates
    fig.tight_layout()
    fig.savefig(output_path)
    plt.close(fig)


def annotate_video_with_tracking(
    video_path: str,
    centers: List[Tuple[int, int]],
    trajectory: dict,
    will_hit_stumps: bool,
    impact_frame_idx: int,
    output_path: str,
) -> None:
    """Create an annotated replay video highlighting key elements.

    The function reads the trimmed input video and writes out a new
    video with the following overlays:

      * The detected ball centre (small filled circle).
      * A polyline showing the path of the ball up to the current
        frame.
      * The predicted trajectory across the frame, drawn as a dashed
        curve.
      * A rectangle representing the stumps zone at the bottom centre
        of the frame; coloured green if the ball is predicted to hit
        and red otherwise.
      * The text "OUT" or "NOT OUT" displayed after the impact frame.
      * Auto zoom effect on the impact frame by drawing a thicker
        circle around the ball.

    Parameters
    ----------
    video_path: str
        Path to the trimmed input video.
    centers: list of tuple(int, int)
        Detected ball centres for each frame analysed.
    trajectory: dict
        Output of :func:`modules.trajectory.estimate_trajectory`.
    will_hit_stumps: bool
        Whether the ball is predicted to hit the stumps.
    impact_frame_idx: int
        Index of the frame considered as the impact frame.
    output_path: str
        Where to save the annotated video.
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Could not open video {video_path}")

    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    model: Callable[[float], float] = trajectory["model"]
    # Precompute predicted path points for drawing on each frame
    x_vals = np.linspace(0, width - 1, 50)
    y_preds = model(x_vals)
    # Ensure predicted y values stay within frame
    y_preds_clamped = np.clip(y_preds, 0, height - 1).astype(int)

    # Define stumps zone coordinates
    stumps_width = int(width * 0.1)
    stumps_height = int(height * 0.3)
    stumps_x = int((width - stumps_width) / 2)
    stumps_y = int(height * 0.65)
    stumps_color = (0, 255, 0) if will_hit_stumps else (0, 0, 255)

    frame_idx = 0
    path_points: List[Tuple[int, int]] = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Draw stumps region on every frame
        cv2.rectangle(
            frame,
            (stumps_x, stumps_y),
            (stumps_x + stumps_width, stumps_y + stumps_height),
            stumps_color,
            2,
        )

        # Draw predicted trajectory line (dashed effect by skipping points)
        for i in range(len(x_vals) - 1):
            if i % 4 != 0:
                continue
            pt1 = (int(x_vals[i]), int(y_preds_clamped[i]))
            pt2 = (int(x_vals[i + 1]), int(y_preds_clamped[i + 1]))
            cv2.line(frame, pt1, pt2, stumps_color, 1, lineType=cv2.LINE_AA)

        # If we have a centre for this frame, draw it and update the path
        if frame_idx < len(centers):
            cx, cy = centers[frame_idx]
            path_points.append((cx, cy))
            # Draw past trajectory as a polyline
            if len(path_points) > 1:
                cv2.polylines(frame, [np.array(path_points, dtype=np.int32)], False, (255, 0, 0), 2)
            # Draw the ball centre (bigger on impact frame)
            radius = 5
            thickness = -1
            colour = (255, 255, 255)
            if frame_idx == impact_frame_idx:
                # Auto zoom effect: larger circle and thicker outline
                radius = 10
                thickness = 2
                colour = (0, 255, 255)
            cv2.circle(frame, (cx, cy), radius, colour, thickness)
        else:
            # Continue drawing the path beyond detection frames
            if len(path_points) > 1:
                cv2.polylines(frame, [np.array(path_points, dtype=np.int32)], False, (255, 0, 0), 2)

        # After the impact frame, display the decision text
        if frame_idx >= impact_frame_idx and impact_frame_idx >= 0:
            decision_text = "OUT" if will_hit_stumps else "NOT OUT"
            font = cv2.FONT_HERSHEY_SIMPLEX
            cv2.putText(
                frame,
                decision_text,
                (50, 50),
                font,
                1.5,
                (0, 255, 0) if will_hit_stumps else (0, 0, 255),
                3,
                lineType=cv2.LINE_AA,
            )

        writer.write(frame)
        frame_idx += 1

    cap.release()
    writer.release()