File size: 1,935 Bytes
6b38cd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import cv2
import glob
import numpy as np
import torch
import yolov5
from typing import Dict, Tuple, Union, List, Optional


# -----------------------------------------------------------------------------
# Configs
# -----------------------------------------------------------------------------

model_path = "models/mbari-mb-benthic-33k.pt"


# -----------------------------------------------------------------------------
# YOLOv5 class
# -----------------------------------------------------------------------------

class YOLO:
    """Wrapper class for loading and running YOLO model"""

    def __init__(self, model_path: str, device: Optional[str] = None):

        # load model
        self.model = yolov5.load(model_path, device=device)

    def __call__(
            self,
            img: Union[str, np.ndarray],
            conf_threshold: float = 0.25,
            iou_threshold: float = 0.45,
            image_size: int = 720,
            classes: Optional[List[int]] = None) -> torch.Tensor:
        self.model.conf = conf_threshold
        self.model.iou = iou_threshold

        if classes is not None:
            self.model.classes = classes

        # pylint: disable=not-callable
        detections = self.model(img, size=image_size)

        return detections


def run_inference(image_path):
    """Helper function to execute the inference."""

    predictions = model(image_path)

    return predictions


# -----------------------------------------------------------------------------
# Model Creation
# -----------------------------------------------------------------------------
model = YOLO(model_path, device='cpu')

if __name__ == "__main__":

    # For demo purposes: run through a couple of test
    # images and then output the predictions in a folder.
    test_images = glob.glob("images/*.png")

    for test_image in test_images:
        predictions = run_inference(test_image)

    print("Done.")