File size: 17,741 Bytes
1f6f52d
 
8a4d72c
1f6f52d
66c096b
8a4d72c
45d3ff2
8a4d72c
1f6f52d
61746ab
2c28e54
61746ab
8a4d72c
 
1f6f52d
61746ab
 
 
 
 
 
 
 
 
 
 
 
1f6f52d
 
8a4d72c
 
1f6f52d
61746ab
3da7a6d
 
61746ab
 
ecfb2f8
 
 
1f6f52d
8a4d72c
 
 
1f6f52d
8a4d72c
1f6f52d
 
 
 
8a4d72c
1f6f52d
61746ab
a568f96
 
 
 
ecfb2f8
61746ab
a568f96
 
 
ecfb2f8
61746ab
9d8c79c
a568f96
 
 
 
 
 
61746ab
a568f96
 
 
 
 
 
 
 
9d8c79c
 
8a4d72c
61746ab
45d3ff2
8a4d72c
3da7a6d
8a4d72c
a568f96
8a4d72c
a568f96
8a4d72c
45d3ff2
8c456c5
8a4d72c
 
 
61746ab
8a4d72c
 
61746ab
8a4d72c
61746ab
8a4d72c
 
 
 
ecfb2f8
8a4d72c
6eec350
8a4d72c
61746ab
8a4d72c
 
 
 
ecfb2f8
61746ab
 
8a4d72c
 
 
61746ab
 
8a4d72c
 
 
 
 
 
 
 
61746ab
 
 
6eec350
61746ab
6eec350
61746ab
 
6eec350
 
45d3ff2
8a4d72c
61746ab
9d8c79c
 
 
6eec350
 
9d8c79c
8a4d72c
 
61746ab
9d8c79c
 
8a4d72c
 
3da7a6d
61746ab
8a4d72c
a568f96
8a4d72c
 
 
239672f
61746ab
 
 
8a4d72c
61746ab
6eec350
9d8c79c
61746ab
 
 
 
6eec350
61746ab
 
 
 
 
 
 
45d3ff2
9d8c79c
8a4d72c
45d3ff2
61746ab
45d3ff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03f512c
45d3ff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a4d72c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3da7a6d
8a4d72c
 
 
1f6f52d
61746ab
 
8a4d72c
61746ab
8a4d72c
 
61746ab
8a4d72c
 
 
 
61746ab
8a4d72c
 
9d8c79c
8a4d72c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45d3ff2
 
 
 
 
 
8a4d72c
 
45d3ff2
 
 
66c096b
8a4d72c
1f6f52d
8a4d72c
 
2c28e54
8a4d72c
 
45d3ff2
 
2c28e54
8a4d72c
45d3ff2
1f6f52d
 
 
61746ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import plotly.graph_objects as go
import uuid
import os
from scipy.ndimage import uniform_filter1d

# Load the trained YOLOv8n model with optimizations
model = YOLO("best.pt")
model.to('cuda' if torch.cuda.is_available() else 'cpu')  # Use GPU if available

# Constants for LBW decision and video processing
STUMPS_WIDTH = 0.2286  # meters (width of stumps)
BALL_DIAMETER = 0.073  # meters (approx. cricket ball diameter)
FRAME_RATE = 20  # Default frame rate, updated dynamically
SLOW_MOTION_FACTOR = 1.5  # Faster replay (e.g., 30 / 1.5 = 20 FPS)
CONF_THRESHOLD = 0.15  # Lowered for better detection
IMPACT_ZONE_Y = 0.9  # Adjusted to 90% of frame height for impact zone
PITCH_LENGTH = 20.12  # meters (standard cricket pitch length)
STUMPS_HEIGHT = 0.71  # meters (stumps height)
CAMERA_HEIGHT = 2.0  # meters (assumed camera height)
CAMERA_DISTANCE = 10.0  # meters (assumed camera distance from pitch)
MAX_POSITION_JUMP = 250  # Increased to include more detections

def process_video(video_path):
    if not os.path.exists(video_path):
        return [], [], [], "Error: Video file not found"
    cap = cv2.VideoCapture(video_path)
    # Get native video resolution and frame rate
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    FRAME_RATE = cap.get(cv2.CAP_PROP_FPS) or 20  # Use actual frame rate or default
    # Adjust image size to be multiple of 32 for YOLO
    stride = 32
    img_width = ((frame_width + stride - 1) // stride) * stride
    img_height = ((frame_height + stride - 1) // stride) * stride
    frames = []
    ball_positions = []
    detection_frames = []
    debug_log = []

    frame_count = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame_count += 1
        frames.append(frame.copy())
        # Enhance frame contrast and sharpness
        frame = cv2.convertScaleAbs(frame, alpha=1.5, beta=20)
        kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
        frame = cv2.filter2D(frame, -1, kernel)
        results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=(img_height, img_width), iou=0.5, max_det=5)
        detections = sum(1 for detection in results[0].boxes if detection.cls == 0)
        if detections >= 1:  # Process frames with at least one ball detection
            max_conf = 0
            best_detection = None
            conf_scores = []
            for detection in results[0].boxes:
                if detection.cls == 0:  # Class 0 is the ball
                    conf = detection.conf.cpu().numpy()[0]
                    conf_scores.append(conf)
                    if conf > max_conf:
                        max_conf = conf
                        best_detection = detection
            if best_detection:
                x1, y1, x2, y2 = best_detection.xyxy[0].cpu().numpy()
                # Scale coordinates back to original frame size
                x1 = x1 * frame_width / img_width
                x2 = x2 * frame_width / img_width
                y1 = y1 * frame_height / img_height
                y2 = y2 * frame_height / img_height
                ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
                detection_frames.append(frame_count - 1)
                cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
                debug_log.append(f"Frame {frame_count}: {detections} ball detections, selected confidence={max_conf:.3f}, all confidences={conf_scores}")
        else:
            debug_log.append(f"Frame {frame_count}: {detections} ball detections")
        frames[-1] = frame
        # Save debug frame
        cv2.imwrite(f"debug_frame_{frame_count}.jpg", frame)
    cap.release()

    if not ball_positions:
        debug_log.append("No frames with ball detection")
    else:
        debug_log.append(f"Total frames with ball detection: {len(ball_positions)}")
        debug_log.append(f"Video resolution: {frame_width}x{frame_height}")
        debug_log.append(f"Video frame rate: {FRAME_RATE}")

    return frames, ball_positions, detection_frames, "\n".join(debug_log)

def pixel_to_3d(x, y, frame_height, frame_width):
    """Convert 2D pixel coordinates to 3D real-world coordinates."""
    x_norm = x / frame_width
    y_norm = y / frame_height
    x_3d = (x_norm - 0.5) * 3.0  # Center x at 0 (middle of pitch)
    y_3d = y_norm * PITCH_LENGTH
    z_3d = (1 - y_norm) * BALL_DIAMETER * 5  # Scale to approximate ball bounce height
    return x_3d, y_3d, z_3d

def estimate_trajectory(ball_positions, frames, detection_frames):
    if len(ball_positions) < 2:
        return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 frames with one ball detection"
    frame_height, frame_width = frames[0].shape[:2]
    debug_log = []

    # Filter out sudden changes in position for continuous trajectory
    filtered_positions = [ball_positions[0]]
    filtered_frames = [detection_frames[0]]
    for i in range(1, len(ball_positions)):
        prev_pos = filtered_positions[-1]
        curr_pos = ball_positions[i]
        distance = np.sqrt((curr_pos[0] - prev_pos[0])**2 + (curr_pos[1] - prev_pos[1])**2)
        if distance <= MAX_POSITION_JUMP:
            filtered_positions.append(curr_pos)
            filtered_frames.append(detection_frames[i])
        else:
            debug_log.append(f"Filtered out detection at frame {detection_frames[i] + 1}: large jump ({distance:.1f} pixels)")
            continue

    if len(filtered_positions) < 2:
        return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 valid ball detections after filtering"

    x_coords = [pos[0] for pos in filtered_positions]
    y_coords = [pos[1] for pos in filtered_positions]
    times = np.array(filtered_frames) / FRAME_RATE

    # Smooth coordinates to avoid sudden jumps
    x_coords = uniform_filter1d(x_coords, size=3)
    y_coords = uniform_filter1d(y_coords, size=3)

    # Convert to 3D for visualization
    detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in zip(x_coords, y_coords)]
    
    # Pitch point: Detection with lowest y-coordinate (near bowler's end)
    pitch_idx = min(range(len(filtered_positions)), key=lambda i: y_coords[i])
    pitch_point = (x_coords[pitch_idx], y_coords[pitch_idx])
    pitch_frame = filtered_frames[pitch_idx]

    # Impact point: Detection with highest y-coordinate after pitch point (near stumps)
    post_pitch_indices = [i for i in range(len(filtered_positions)) if filtered_frames[i] > pitch_frame]
    if not post_pitch_indices:
        return None, None, None, None, None, None, None, None, None, "Error: No detections after pitch point"
    impact_idx = max(post_pitch_indices, key=lambda i: y_coords[i])
    impact_point = (x_coords[impact_idx], y_coords[impact_idx])
    impact_frame = filtered_frames[impact_idx]

    try:
        # Use linear interpolation for stable trajectory
        fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
        fy = interp1d(times, y_coords, kind='linear', fill_value="extrapolate")
    except Exception as e:
        return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"

    # Generate dense points for all frames between first and last detection
    total_frames = max(detection_frames) - min(detection_frames) + 1
    t_full = np.linspace(min(detection_frames) / FRAME_RATE, max(detection_frames) / FRAME_RATE, int(total_frames * SLOW_MOTION_FACTOR))
    x_full = fx(t_full)
    y_full = fy(t_full)
    trajectory_2d = list(zip(x_full, y_full))

    trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
    pitch_point_3d = pixel_to_3d(pitch_point[0], pitch_point[1], frame_height, frame_width)
    impact_point_3d = pixel_to_3d(impact_point[0], impact_point[1], frame_height, frame_width)

    # Debug trajectory and points
    debug_log.extend([
        f"Trajectory estimated successfully",
        f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f}), 3D: {pitch_point_3d}",
        f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f}), 3D: {impact_point_3d}",
        f"Detections in frames: {filtered_frames}",
        f"Total filtered detections: {len(filtered_frames)}"
    ])
    # Save trajectory plot for debugging
    import matplotlib.pyplot as plt
    plt.plot(x_coords, y_coords, 'bo-', label='Filtered Detections')
    plt.plot(pitch_point[0], pitch_point[1], 'ro', label='Pitch Point')
    plt.plot(impact_point[0], impact_point[1], 'yo', label='Impact Point')
    plt.legend()
    plt.savefig("trajectory_debug.png")

    return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "\n".join(debug_log)

def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, plot_type="detections"):
    """Create 3D Plotly visualization for detections or trajectory using single-detection frames."""
    stump_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2, 0]
    stump_y = [PITCH_LENGTH, PITCH_LENGTH, PITCH_LENGTH]
    stump_z = [0, 0, 0]
    stump_top_z = [STUMPS_HEIGHT, STUMPS_HEIGHT, STUMPS_HEIGHT]
    bail_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2]
    bail_y = [PITCH_LENGTH, PITCH_LENGTH]
    bail_z = [STUMPS_HEIGHT, STUMPS_HEIGHT]

    stump_traces = []
    for i in range(3):
        stump_traces.append(go.Scatter3d(
            x=[stump_x[i], stump_x[i]], y=[stump_y[i], stump_y[i]], z=[stump_z[i], stump_top_z[i]],
            mode='lines', line=dict(color='black', width=5), name=f'Stump {i+1}'
        ))
    bail_traces = [
        go.Scatter3d(
            x=bail_x, y=bail_y, z=bail_z,
            mode='lines', line=dict(color='black', width=5), name='Bail'
        )
    ]

    pitch_scatter = go.Scatter3d(
        x=[pitch_point_3d[0]] if pitch_point_3d else [], 
        y=[pitch_point_3d[1]] if pitch_point_3d else [], 
        z=[pitch_point_3d[2]] if pitch_point_3d else [],
        mode='markers', marker=dict(size=8, color='red'), name='Pitch Point'
    )
    impact_scatter = go.Scatter3d(
        x=[impact_point_3d[0]] if impact_point_3d else [], 
        y=[impact_point_3d[1]] if impact_point_3d else [], 
        z=[impact_point_3d[2]] if impact_point_3d else [],
        mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
    )

    if plot_type == "detections":
        x, y, z = zip(*detections_3d) if detections_3d else ([], [], [])
        scatter = go.Scatter3d(
            x=x, y=y, z=z, mode='markers',
            marker=dict(size=5, color='green'), name='Single Ball Detections'
        )
        data = [scatter, pitch_scatter, impact_scatter] + stump_traces + bail_traces
        title = "3D Single Ball Detections"
    else:
        x, y, z = zip(*trajectory_3d) if trajectory_3d else ([], [], [])
        trajectory_line = go.Scatter3d(
            x=x, y=y, z=z, mode='lines',
            line=dict(color='blue', width=4), name='Ball Trajectory (Single Detections)'
        )
        data = [trajectory_line, pitch_scatter, impact_scatter] + stump_traces + bail_traces
        title = "3D Ball Trajectory (Single Detections)"

    layout = go.Layout(
        title=title,
        scene=dict(
            xaxis_title='X (meters)', yaxis_title='Y (meters)', zaxis_title='Z (meters)',
            xaxis=dict(range=[-1.5, 1.5]), yaxis=dict(range=[0, PITCH_LENGTH]),
            zaxis=dict(range=[0, STUMPS_HEIGHT * 2]), aspectmode='manual',
            aspectratio=dict(x=1, y=4, z=0.5)
        ),
        showlegend=True
    )
    fig = go.Figure(data=data, layout=layout)
    return fig

def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
    if not frames:
        return "Error: No frames processed", None, None, None
    if not trajectory or len(ball_positions) < 2:
        return "Not enough data (insufficient ball detections)", None, None, None

    frame_height, frame_width = frames[0].shape[:2]
    stumps_x = frame_width / 2
    stumps_y = frame_height * 0.9
    stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)

    pitch_x, pitch_y = pitch_point
    impact_x, impact_y = impact_point

    if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
        return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
    if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
        return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
    for x, y in trajectory:
        if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
            return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
    return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point

def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
    if not frames:
        return None
    frame_height, frame_width = frames[0].shape[:2]
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))

    if trajectory and detection_frames:
        min_frame = min(detection_frames)
        max_frame = max(detection_frames)
        total_frames = max_frame - min_frame + 1
        trajectory_points = np.array(trajectory, dtype=np.int32).reshape((-1, 1, 2))
        traj_per_frame = len(trajectory) // total_frames
        trajectory_indices = [i * traj_per_frame for i in range(total_frames)]
    else:
        trajectory_points = np.array([], dtype=np.int32)
        trajectory_indices = []

    for i, frame in enumerate(frames):
        frame_idx = i - min_frame if trajectory_indices else -1
        if frame_idx >= 0 and frame_idx < total_frames and trajectory_points.size > 0:
            end_idx = trajectory_indices[frame_idx] + 1
            cv2.polylines(frame, [trajectory_points[:end_idx]], False, (255, 0, 0), 2)  # Blue line in BGR
        if pitch_point and i == pitch_frame:
            x, y = pitch_point
            cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)  # Red circle
            cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
        if impact_point and i == impact_frame:
            x, y = impact_point
            cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)  # Yellow circle
            cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
        for _ in range(int(SLOW_MOTION_FACTOR)):
            out.write(frame)
    out.release()
    return output_path

def drs_review(video):
    frames, ball_positions, detection_frames, debug_log = process_video(video)
    if not frames:
        return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None, None, None

    trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, trajectory_log = estimate_trajectory(ball_positions, frames, detection_frames)
    
    if trajectory_2d is None:
        return (f"Error: {trajectory_log}\nDebug Log:\n{debug_log}", None, None, None)

    decision, trajectory_2d, pitch_point, impact_point = lbw_decision(ball_positions, trajectory_2d, frames, pitch_point, impact_point)

    output_path = f"output_{uuid.uuid4()}.mp4"
    slow_motion_path = generate_slow_motion(frames, trajectory_2d, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path)

    detections_fig = None
    trajectory_fig = None
    if detections_3d:
        detections_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "detections")
        trajectory_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "trajectory")

    debug_output = f"{debug_log}\n{trajectory_log}"
    return (f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", 
            slow_motion_path, 
            detections_fig, 
            trajectory_fig)

# Gradio interface
iface = gr.Interface(
    fn=drs_review,
    inputs=gr.Video(label="Upload Video Clip"),
    outputs=[
        gr.Textbox(label="DRS Decision and Debug Log"),
        gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)"),
        gr.Plot(label="3D Single Ball Detections Plot"),
        gr.Plot(label="3D Ball Trajectory Plot (Single Detections)")
    ],
    title="AI-Powered DRS for LBW in Local Cricket",
    description="Upload a video clip of a cricket delivery to get an LBW decision, a slow-motion replay, and 3D visualizations. The replay shows ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle). The 3D plots show single-detection frames (green markers) and trajectory (blue line) with wicket lines (black), pitch point (red), and impact point (yellow)."
)

if __name__ == "__main__":
    iface.launch()