|
|
|
from ultralytics import YOLO |
|
import os |
|
import torch |
|
import requests |
|
import shutil |
|
|
|
|
|
MODEL_DIR = "models" |
|
os.makedirs(MODEL_DIR, exist_ok=True) |
|
|
|
|
|
MODEL_PATH = os.path.join(MODEL_DIR, "yolov8_model.pt") |
|
|
|
|
|
MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt" |
|
|
|
def download_model(url, output_path): |
|
"""Download model file from URL and save to output_path.""" |
|
print(f"Downloading model from {url}...") |
|
try: |
|
response = requests.get(url, stream=True) |
|
if response.status_code == 200: |
|
with open(output_path, 'wb') as f: |
|
shutil.copyfileobj(response.raw, f) |
|
print(f"Model downloaded to {output_path}") |
|
else: |
|
raise Exception(f"Failed to download model: HTTP {response.status_code}") |
|
except Exception as e: |
|
raise Exception(f"Download error: {str(e)}") |
|
|
|
def verify_model(path): |
|
"""Verify that the model file is a valid PyTorch checkpoint.""" |
|
print(f"Verifying model at {path}...") |
|
try: |
|
checkpoint = torch.load(path, map_location='cpu', weights_only=False) |
|
print(f"Model verified as valid PyTorch checkpoint") |
|
return True |
|
except Exception as e: |
|
raise Exception(f"Verification failed: {str(e)}") |
|
|
|
|
|
try: |
|
|
|
if os.path.exists(MODEL_PATH): |
|
os.remove(MODEL_PATH) |
|
print(f"Removed existing file at {MODEL_PATH}") |
|
|
|
|
|
download_model(MODEL_URL, MODEL_PATH) |
|
|
|
|
|
verify_model(MODEL_PATH) |
|
|
|
|
|
model = YOLO(MODEL_PATH) |
|
print(f"Model successfully loaded with Ultralytics YOLO") |
|
except Exception as e: |
|
print(f"Failed to download or verify model: {str(e)}") |
|
print("Please manually download yolov8n.pt from https://github.com/ultralytics/assets/releases and place it in backend/models/yolov8_model.pt") |