File size: 4,284 Bytes
dd08014
 
 
 
 
e4cb859
dd08014
20f24a7
 
 
 
 
1fc0ef6
 
 
 
 
20f24a7
f47a8e5
20f24a7
f47a8e5
 
20f24a7
1fc0ef6
f47a8e5
 
 
1fc0ef6
f47a8e5
 
 
 
 
1fc0ef6
 
 
f47a8e5
dd08014
 
 
 
 
20f24a7
 
 
dd08014
20f24a7
 
 
dd08014
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20f24a7
 
 
dd08014
20f24a7
 
 
dd08014
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f47a8e5
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# backend/utils/video_processing.py
import cv2
import numpy as np
from ultralytics import YOLO
import os
import torch

# Path to the YOLO model
MODEL_PATH = 'models/yolov8_model.pt'

# Check if model file exists
if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(
        f"YOLO model file not found at {MODEL_PATH}. "
        "Please place a valid YOLOv8 .pt file (e.g., yolov8n.pt) in the 'models/' directory. "
        "You can download it using scripts/download_yolov8_model.py."
    )

# Load YOLO model
try:
    # Load the model using Ultralytics YOLO
    model = YOLO(MODEL_PATH)
except Exception as e:
    # If loading fails, try manual loading
    try:
        # Manually load the checkpoint with weights_only=False
        checkpoint = torch.load(MODEL_PATH, map_location='cpu', weights_only=False)
        model = YOLO('models/yolov8n.yaml')  # Load model architecture from YAML
        model.load_state_dict(checkpoint['model'].state_dict())  # Load weights
    except Exception as inner_e:
        raise RuntimeError(
            f"Failed to load YOLO model from {MODEL_PATH}: {str(e)}. "
            f"Manual loading also failed: {str(inner_e)}. "
            "The model file may be corrupted or not a valid YOLOv8 .pt file. "
            "Please replace it with a valid model, e.g., by running scripts/download_yolov8_model.py "
            "to download yolov8n.pt, or train a custom model using scripts/train_yolov8_model.py."
        )

def track_ball(video_path: str) -> list:
    """
    Track the ball in the video and return its trajectory as a list of (x, y) coordinates.
    """
    if not os.path.exists(video_path):
        raise FileNotFoundError(f"Video file not found at {video_path}")

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"Failed to open video file: {video_path}")

    tracker = cv2.TrackerKCF_create()
    trajectory = []
    init = False

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        if not init:
            # Detect ball using YOLO for initial bounding box
            results = model(frame)
            for detection in results[0].boxes:
                if detection.cls == 0:  # Assume class 0 is the ball
                    x, y, w, h = detection.xywh[0]
                    bbox = (int(x - w/2), int(y - h/2), int(w), int(h))
                    tracker.init(frame, bbox)
                    trajectory.append((x, y))
                    init = True
                    break
        else:
            # Update tracker
            ok, bbox = tracker.update(frame)
            if ok:
                x, y, w, h = [int(v) for v in bbox]
                trajectory.append((x + w/2, y + h/2))

    cap.release()
    return trajectory

def generate_replay(video_path: str, trajectory: list, decision: str) -> str:
    """
    Generate a slow-motion replay video with ball trajectory and decision overlay.
    """
    if not os.path.exists(video_path):
        raise FileNotFoundError(f"Video file not found at {video_path}")

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"Failed to open video file: {video_path}")

    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS) / 2  # Slow motion (half speed)
    
    replay_path = f"replays/replay_{os.path.basename(video_path)}"
    out = cv2.VideoWriter(replay_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
    frame_idx = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        if frame_idx < len(trajectory):
            x, y = trajectory[frame_idx]
            cv2.circle(frame, (int(x), int(y)), 5, (0, 0, 255), -1)
            for i in range(1, min(frame_idx + 1, len(trajectory))):
                cv2.line(frame, (int(trajectory[i-1][0]), int(trajectory[i-1][1])),
                         (int(trajectory[i][0]), int(trajectory[i][1])), (255, 0, 0), 2)
        cv2.putText(frame, f"Decision: {decision}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        out.write(frame)
        frame_idx += 1

    cap.release()
    out.release()
    return replay_path