# backend/utils/video_processing.py import cv2 import numpy as np from ultralytics import YOLO import os import torch # Path to the YOLO model MODEL_PATH = 'models/yolov8_model.pt' # Check if model file exists if not os.path.exists(MODEL_PATH): raise FileNotFoundError( f"YOLO model file not found at {MODEL_PATH}. " "Please place a valid YOLOv8 .pt file (e.g., yolov8n.pt) in the 'models/' directory. " "You can download it using scripts/download_yolov8_model.py." ) # Load YOLO model try: # Load the model using Ultralytics YOLO model = YOLO(MODEL_PATH) except Exception as e: # If loading fails, try manual loading try: # Manually load the checkpoint with weights_only=False checkpoint = torch.load(MODEL_PATH, map_location='cpu', weights_only=False) model = YOLO('models/yolov8n.yaml') # Load model architecture from YAML model.load_state_dict(checkpoint['model'].state_dict()) # Load weights except Exception as inner_e: raise RuntimeError( f"Failed to load YOLO model from {MODEL_PATH}: {str(e)}. " f"Manual loading also failed: {str(inner_e)}. " "The model file may be corrupted or not a valid YOLOv8 .pt file. " "Please replace it with a valid model, e.g., by running scripts/download_yolov8_model.py " "to download yolov8n.pt, or train a custom model using scripts/train_yolov8_model.py." ) def track_ball(video_path: str) -> list: """ Track the ball in the video and return its trajectory as a list of (x, y) coordinates. """ if not os.path.exists(video_path): raise FileNotFoundError(f"Video file not found at {video_path}") cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Failed to open video file: {video_path}") tracker = cv2.TrackerKCF_create() trajectory = [] init = False while cap.isOpened(): ret, frame = cap.read() if not ret: break if not init: # Detect ball using YOLO for initial bounding box results = model(frame) for detection in results[0].boxes: if detection.cls == 0: # Assume class 0 is the ball x, y, w, h = detection.xywh[0] bbox = (int(x - w/2), int(y - h/2), int(w), int(h)) tracker.init(frame, bbox) trajectory.append((x, y)) init = True break else: # Update tracker ok, bbox = tracker.update(frame) if ok: x, y, w, h = [int(v) for v in bbox] trajectory.append((x + w/2, y + h/2)) cap.release() return trajectory def generate_replay(video_path: str, trajectory: list, decision: str) -> str: """ Generate a slow-motion replay video with ball trajectory and decision overlay. """ if not os.path.exists(video_path): raise FileNotFoundError(f"Video file not found at {video_path}") cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Failed to open video file: {video_path}") width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) / 2 # Slow motion (half speed) replay_path = f"replays/replay_{os.path.basename(video_path)}" out = cv2.VideoWriter(replay_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) frame_idx = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break if frame_idx < len(trajectory): x, y = trajectory[frame_idx] cv2.circle(frame, (int(x), int(y)), 5, (0, 0, 255), -1) for i in range(1, min(frame_idx + 1, len(trajectory))): cv2.line(frame, (int(trajectory[i-1][0]), int(trajectory[i-1][1])), (int(trajectory[i][0]), int(trajectory[i][1])), (255, 0, 0), 2) cv2.putText(frame, f"Decision: {decision}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) out.write(frame) frame_idx += 1 cap.release() out.release() return replay_path