File size: 2,117 Bytes
111eec1 14c6032 111eec1 14c6032 92f9a27 14c6032 92f9a27 14c6032 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# scripts/download_yolov8_model.py
from ultralytics import YOLO
import os
import torch
import requests
import shutil
# Directory to save the model
MODEL_DIR = "models"
os.makedirs(MODEL_DIR, exist_ok=True)
# Path to save the model
MODEL_PATH = os.path.join(MODEL_DIR, "yolov8_model.pt")
# URL for yolov8n.pt from Ultralytics GitHub releases
MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt"
def download_model(url, output_path):
"""Download model file from URL and save to output_path."""
print(f"Downloading model from {url}...")
try:
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(output_path, 'wb') as f:
shutil.copyfileobj(response.raw, f)
print(f"Model downloaded to {output_path}")
else:
raise Exception(f"Failed to download model: HTTP {response.status_code}")
except Exception as e:
raise Exception(f"Download error: {str(e)}")
def verify_model(path):
"""Verify that the model file is a valid PyTorch checkpoint."""
print(f"Verifying model at {path}...")
try:
checkpoint = torch.load(path, map_location='cpu', weights_only=False)
print(f"Model verified as valid PyTorch checkpoint")
return True
except Exception as e:
raise Exception(f"Verification failed: {str(e)}")
# Download and verify model
try:
# Remove existing file to avoid using corrupted version
if os.path.exists(MODEL_PATH):
os.remove(MODEL_PATH)
print(f"Removed existing file at {MODEL_PATH}")
# Download yolov8n.pt
download_model(MODEL_URL, MODEL_PATH)
# Verify the downloaded file
verify_model(MODEL_PATH)
# Load with YOLO to ensure compatibility
model = YOLO(MODEL_PATH)
print(f"Model successfully loaded with Ultralytics YOLO")
except Exception as e:
print(f"Failed to download or verify model: {str(e)}")
print("Please manually download yolov8n.pt from https://github.com/ultralytics/assets/releases and place it in backend/models/yolov8_model.pt") |