kbarnard's picture
remove unused imports
b9ca522 unverified
raw
history blame contribute delete
No virus
1.91 kB
import glob
import numpy as np
import torch
import yolov5
from typing import Union, List, Optional
# -----------------------------------------------------------------------------
# Configs
# -----------------------------------------------------------------------------
model_path = "models/mbari-mb-benthic-33k.pt"
# -----------------------------------------------------------------------------
# YOLOv5 class
# -----------------------------------------------------------------------------
class YOLO:
"""Wrapper class for loading and running YOLO model"""
def __init__(self, model_path: str, device: Optional[str] = None):
# load model
self.model = yolov5.load(model_path, device=device)
def __call__(
self,
img: Union[str, np.ndarray],
conf_threshold: float = 0.25,
iou_threshold: float = 0.45,
image_size: int = 720,
classes: Optional[List[int]] = None) -> torch.Tensor:
self.model.conf = conf_threshold
self.model.iou = iou_threshold
if classes is not None:
self.model.classes = classes
# pylint: disable=not-callable
detections = self.model(img, size=image_size)
return detections
def run_inference(image_path):
"""Helper function to execute the inference."""
predictions = model(image_path)
return predictions
# -----------------------------------------------------------------------------
# Model Creation
# -----------------------------------------------------------------------------
model = YOLO(model_path, device='cpu')
if __name__ == "__main__":
# For demo purposes: run through a couple of test
# images and then output the predictions in a folder.
test_images = glob.glob("images/*.png")
for test_image in test_images:
predictions = run_inference(test_image)
print("Done.")