backend / utils /video_processing.py
dschandra's picture
Update utils/video_processing.py
1fc0ef6 verified
raw
history blame
4.28 kB
# backend/utils/video_processing.py
import cv2
import numpy as np
from ultralytics import YOLO
import os
import torch
# Path to the YOLO model
MODEL_PATH = 'models/yolov8_model.pt'
# Check if model file exists
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(
f"YOLO model file not found at {MODEL_PATH}. "
"Please place a valid YOLOv8 .pt file (e.g., yolov8n.pt) in the 'models/' directory. "
"You can download it using scripts/download_yolov8_model.py."
)
# Load YOLO model
try:
# Load the model using Ultralytics YOLO
model = YOLO(MODEL_PATH)
except Exception as e:
# If loading fails, try manual loading
try:
# Manually load the checkpoint with weights_only=False
checkpoint = torch.load(MODEL_PATH, map_location='cpu', weights_only=False)
model = YOLO('models/yolov8n.yaml') # Load model architecture from YAML
model.load_state_dict(checkpoint['model'].state_dict()) # Load weights
except Exception as inner_e:
raise RuntimeError(
f"Failed to load YOLO model from {MODEL_PATH}: {str(e)}. "
f"Manual loading also failed: {str(inner_e)}. "
"The model file may be corrupted or not a valid YOLOv8 .pt file. "
"Please replace it with a valid model, e.g., by running scripts/download_yolov8_model.py "
"to download yolov8n.pt, or train a custom model using scripts/train_yolov8_model.py."
)
def track_ball(video_path: str) -> list:
"""
Track the ball in the video and return its trajectory as a list of (x, y) coordinates.
"""
if not os.path.exists(video_path):
raise FileNotFoundError(f"Video file not found at {video_path}")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Failed to open video file: {video_path}")
tracker = cv2.TrackerKCF_create()
trajectory = []
init = False
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if not init:
# Detect ball using YOLO for initial bounding box
results = model(frame)
for detection in results[0].boxes:
if detection.cls == 0: # Assume class 0 is the ball
x, y, w, h = detection.xywh[0]
bbox = (int(x - w/2), int(y - h/2), int(w), int(h))
tracker.init(frame, bbox)
trajectory.append((x, y))
init = True
break
else:
# Update tracker
ok, bbox = tracker.update(frame)
if ok:
x, y, w, h = [int(v) for v in bbox]
trajectory.append((x + w/2, y + h/2))
cap.release()
return trajectory
def generate_replay(video_path: str, trajectory: list, decision: str) -> str:
"""
Generate a slow-motion replay video with ball trajectory and decision overlay.
"""
if not os.path.exists(video_path):
raise FileNotFoundError(f"Video file not found at {video_path}")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Failed to open video file: {video_path}")
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) / 2 # Slow motion (half speed)
replay_path = f"replays/replay_{os.path.basename(video_path)}"
out = cv2.VideoWriter(replay_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
frame_idx = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_idx < len(trajectory):
x, y = trajectory[frame_idx]
cv2.circle(frame, (int(x), int(y)), 5, (0, 0, 255), -1)
for i in range(1, min(frame_idx + 1, len(trajectory))):
cv2.line(frame, (int(trajectory[i-1][0]), int(trajectory[i-1][1])),
(int(trajectory[i][0]), int(trajectory[i][1])), (255, 0, 0), 2)
cv2.putText(frame, f"Decision: {decision}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
out.write(frame)
frame_idx += 1
cap.release()
out.release()
return replay_path