dschandra commited on
Commit
1fc0ef6
·
verified ·
1 Parent(s): 9eeb62b

Update utils/video_processing.py

Browse files
Files changed (1) hide show
  1. utils/video_processing.py +10 -5
utils/video_processing.py CHANGED
@@ -10,25 +10,30 @@ MODEL_PATH = 'models/yolov8_model.pt'
10
 
11
  # Check if model file exists
12
  if not os.path.exists(MODEL_PATH):
13
- raise FileNotFoundError(f"YOLO model file not found at {MODEL_PATH}. Please ensure 'yolov8_model.pt' is in the 'models/' directory.")
 
 
 
 
14
 
15
  # Load YOLO model
16
  try:
17
  # Load the model using Ultralytics YOLO
18
  model = YOLO(MODEL_PATH)
19
  except Exception as e:
20
- # If loading fails due to weights_only issue, try manual loading
21
  try:
22
  # Manually load the checkpoint with weights_only=False
23
  checkpoint = torch.load(MODEL_PATH, map_location='cpu', weights_only=False)
24
- model = YOLO('yolov8n.yaml') # Load model architecture from YAML
25
  model.load_state_dict(checkpoint['model'].state_dict()) # Load weights
26
  except Exception as inner_e:
27
  raise RuntimeError(
28
  f"Failed to load YOLO model from {MODEL_PATH}: {str(e)}. "
29
  f"Manual loading also failed: {str(inner_e)}. "
30
- "Ensure the model is a valid YOLOv8 .pt file from a trusted source. "
31
- "You may need to re-save the model or use a pre-trained model like yolov8n.pt."
 
32
  )
33
 
34
  def track_ball(video_path: str) -> list:
 
10
 
11
  # Check if model file exists
12
  if not os.path.exists(MODEL_PATH):
13
+ raise FileNotFoundError(
14
+ f"YOLO model file not found at {MODEL_PATH}. "
15
+ "Please place a valid YOLOv8 .pt file (e.g., yolov8n.pt) in the 'models/' directory. "
16
+ "You can download it using scripts/download_yolov8_model.py."
17
+ )
18
 
19
  # Load YOLO model
20
  try:
21
  # Load the model using Ultralytics YOLO
22
  model = YOLO(MODEL_PATH)
23
  except Exception as e:
24
+ # If loading fails, try manual loading
25
  try:
26
  # Manually load the checkpoint with weights_only=False
27
  checkpoint = torch.load(MODEL_PATH, map_location='cpu', weights_only=False)
28
+ model = YOLO('models/yolov8n.yaml') # Load model architecture from YAML
29
  model.load_state_dict(checkpoint['model'].state_dict()) # Load weights
30
  except Exception as inner_e:
31
  raise RuntimeError(
32
  f"Failed to load YOLO model from {MODEL_PATH}: {str(e)}. "
33
  f"Manual loading also failed: {str(inner_e)}. "
34
+ "The model file may be corrupted or not a valid YOLOv8 .pt file. "
35
+ "Please replace it with a valid model, e.g., by running scripts/download_yolov8_model.py "
36
+ "to download yolov8n.pt, or train a custom model using scripts/train_yolov8_model.py."
37
  )
38
 
39
  def track_ball(video_path: str) -> list: