|
|
|
import cv2 |
|
import numpy as np |
|
from ultralytics import YOLO |
|
import os |
|
import torch |
|
|
|
|
|
MODEL_PATH = 'models/yolov8_model.pt' |
|
|
|
|
|
if not os.path.exists(MODEL_PATH): |
|
raise FileNotFoundError( |
|
f"YOLO model file not found at {MODEL_PATH}. " |
|
"Please place a valid YOLOv8 .pt file (e.g., yolov8n.pt) in the 'models/' directory. " |
|
"You can download it using scripts/download_yolov8_model.py." |
|
) |
|
|
|
|
|
try: |
|
|
|
model = YOLO(MODEL_PATH) |
|
except Exception as e: |
|
|
|
try: |
|
|
|
checkpoint = torch.load(MODEL_PATH, map_location='cpu', weights_only=False) |
|
model = YOLO('models/yolov8n.yaml') |
|
model.load_state_dict(checkpoint['model'].state_dict()) |
|
except Exception as inner_e: |
|
raise RuntimeError( |
|
f"Failed to load YOLO model from {MODEL_PATH}: {str(e)}. " |
|
f"Manual loading also failed: {str(inner_e)}. " |
|
"The model file may be corrupted or not a valid YOLOv8 .pt file. " |
|
"Please replace it with a valid model, e.g., by running scripts/download_yolov8_model.py " |
|
"to download yolov8n.pt, or train a custom model using scripts/train_yolov8_model.py." |
|
) |
|
|
|
def track_ball(video_path: str) -> list: |
|
""" |
|
Track the ball in the video and return its trajectory as a list of (x, y) coordinates. |
|
""" |
|
if not os.path.exists(video_path): |
|
raise FileNotFoundError(f"Video file not found at {video_path}") |
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
raise ValueError(f"Failed to open video file: {video_path}") |
|
|
|
tracker = cv2.TrackerKCF_create() |
|
trajectory = [] |
|
init = False |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
if not init: |
|
|
|
results = model(frame) |
|
for detection in results[0].boxes: |
|
if detection.cls == 0: |
|
x, y, w, h = detection.xywh[0] |
|
bbox = (int(x - w/2), int(y - h/2), int(w), int(h)) |
|
tracker.init(frame, bbox) |
|
trajectory.append((x, y)) |
|
init = True |
|
break |
|
else: |
|
|
|
ok, bbox = tracker.update(frame) |
|
if ok: |
|
x, y, w, h = [int(v) for v in bbox] |
|
trajectory.append((x + w/2, y + h/2)) |
|
|
|
cap.release() |
|
return trajectory |
|
|
|
def generate_replay(video_path: str, trajectory: list, decision: str) -> str: |
|
""" |
|
Generate a slow-motion replay video with ball trajectory and decision overlay. |
|
""" |
|
if not os.path.exists(video_path): |
|
raise FileNotFoundError(f"Video file not found at {video_path}") |
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
raise ValueError(f"Failed to open video file: {video_path}") |
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) / 2 |
|
|
|
replay_path = f"replays/replay_{os.path.basename(video_path)}" |
|
out = cv2.VideoWriter(replay_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) |
|
frame_idx = 0 |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
if frame_idx < len(trajectory): |
|
x, y = trajectory[frame_idx] |
|
cv2.circle(frame, (int(x), int(y)), 5, (0, 0, 255), -1) |
|
for i in range(1, min(frame_idx + 1, len(trajectory))): |
|
cv2.line(frame, (int(trajectory[i-1][0]), int(trajectory[i-1][1])), |
|
(int(trajectory[i][0]), int(trajectory[i][1])), (255, 0, 0), 2) |
|
cv2.putText(frame, f"Decision: {decision}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) |
|
out.write(frame) |
|
frame_idx += 1 |
|
|
|
cap.release() |
|
out.release() |
|
return replay_path |