backend / utils /video_processing.py
dschandra's picture
Update utils/video_processing.py
20f24a7 verified
raw
history blame
3.35 kB
# backend/utils/video_processing.py
import cv2
import numpy as np
from ultralytics import YOLO
import os
# Path to the YOLO model
MODEL_PATH = 'models/yolov8_model.pt'
# Check if model file exists
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"YOLO model file not found at {MODEL_PATH}. Please ensure 'yolov8_model.pt' is in the 'models/' directory.")
# Load YOLO model
try:
model = YOLO(MODEL_PATH)
except Exception as e:
raise RuntimeError(f"Failed to load YOLO model from {MODEL_PATH}: {str(e)}")
def track_ball(video_path: str) -> list:
"""
Track the ball in the video and return its trajectory as a list of (x, y) coordinates.
"""
if not os.path.exists(video_path):
raise FileNotFoundError(f"Video file not found at {video_path}")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Failed to open video file: {video_path}")
tracker = cv2.TrackerKCF_create()
trajectory = []
init = False
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if not init:
# Detect ball using YOLO for initial bounding box
results = model(frame)
for detection in results[0].boxes:
if detection.cls == 0: # Assume class 0 is the ball
x, y, w, h = detection.xywh[0]
bbox = (int(x - w/2), int(y - h/2), int(w), int(h))
tracker.init(frame, bbox)
trajectory.append((x, y))
init = True
break
else:
# Update tracker
ok, bbox = tracker.update(frame)
if ok:
x, y, w, h = [int(v) for v in bbox]
trajectory.append((x + w/2, y + h/2))
cap.release()
return trajectory
def generate_replay(video_path: str, trajectory: list, decision: str) -> str:
"""
Generate a slow-motion replay video with ball trajectory and decision overlay.
"""
if not os.path.exists(video_path):
raise FileNotFoundError(f"Video file not found at {video_path}")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Failed to open video file: {video_path}")
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) / 2 # Slow motion (half speed)
replay_path = f"replays/replay_{os.path.basename(video_path)}"
out = cv2.VideoWriter(replay_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
frame_idx = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_idx < len(trajectory):
x, y = trajectory[frame_idx]
cv2.circle(frame, (int(x), int(y)), 5, (0, 0, 255), -1)
for i in range(1, min(frame_idx + 1, len(trajectory))):
cv2.line(frame, (int(trajectory[i-1][0]), int(trajectory[i-1][1])),
(int(trajectory[i][0]), int(trajectory[i][1])), (255, 0, 0), 2)
cv2.putText(frame, f"Decision: {decision}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
out.write(frame)
frame_idx += 1
cap.release()
out.release()
return replay_path