dschandra commited on
Commit
e4cb859
·
verified ·
1 Parent(s): 57c29bb

Update utils/video_processing.py

Browse files
Files changed (1) hide show
  1. utils/video_processing.py +6 -9
utils/video_processing.py CHANGED
@@ -3,6 +3,7 @@ import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
5
  import os
 
6
 
7
  # Path to the YOLO model
8
  MODEL_PATH = 'models/yolov8_model.pt'
@@ -11,11 +12,12 @@ MODEL_PATH = 'models/yolov8_model.pt'
11
  if not os.path.exists(MODEL_PATH):
12
  raise FileNotFoundError(f"YOLO model file not found at {MODEL_PATH}. Please ensure 'yolov8_model.pt' is in the 'models/' directory.")
13
 
14
- # Load YOLO model
15
  try:
16
- model = YOLO(MODEL_PATH)
 
17
  except Exception as e:
18
- raise RuntimeError(f"Failed to load YOLO model from {MODEL_PATH}: {str(e)}")
19
 
20
  def track_ball(video_path: str) -> list:
21
  """
@@ -87,9 +89,4 @@ def generate_replay(video_path: str, trajectory: list, decision: str) -> str:
87
  cv2.line(frame, (int(trajectory[i-1][0]), int(trajectory[i-1][1])),
88
  (int(trajectory[i][0]), int(trajectory[i][1])), (255, 0, 0), 2)
89
  cv2.putText(frame, f"Decision: {decision}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
90
- out.write(frame)
91
- frame_idx += 1
92
-
93
- cap.release()
94
- out.release()
95
- return replay_path
 
3
  import numpy as np
4
  from ultralytics import YOLO
5
  import os
6
+ import torch
7
 
8
  # Path to the YOLO model
9
  MODEL_PATH = 'models/yolov8_model.pt'
 
12
  if not os.path.exists(MODEL_PATH):
13
  raise FileNotFoundError(f"YOLO model file not found at {MODEL_PATH}. Please ensure 'yolov8_model.pt' is in the 'models/' directory.")
14
 
15
+ # Load YOLO model with weights_only=False for compatibility with PyTorch 2.6
16
  try:
17
+ # Explicitly set weights_only=False to allow loading Ultralytics model metadata
18
+ model = YOLO(MODEL_PATH, weights_only=False)
19
  except Exception as e:
20
+ raise RuntimeError(f"Failed to load YOLO model from {MODEL_PATH}: {str(e)}. Ensure the model is a valid YOLOv8 .pt file from a trusted source.")
21
 
22
  def track_ball(video_path: str) -> list:
23
  """
 
89
  cv2.line(frame, (int(trajectory[i-1][0]), int(trajectory[i-1][1])),
90
  (int(trajectory[i][0]), int(trajectory[i][1])), (255, 0, 0), 2)
91
  cv2.putText(frame, f"Decision: {decision}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
92
+ out