File size: 3,010 Bytes
9093750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import logging
import time

import torch
import numpy as np

torch.set_num_threads(2)

from my_models.clip_model.data_loader import pre_process_foo
from my_models.clip_model.classification import MosquitoClassifier

IMG_SIZE = (224, 224)
USE_CHANNEL_LAST = False
DATASET = "laion"
DEVICE = "cpu"
PRESERVE_ASPECT_RATIO = False
SHIFT = 0


@torch.no_grad()
def detect_image(model, image: np.ndarray) -> dict:
    image_information = {}
    result = model(image)
    result_df = result.pandas().xyxy[0]
    if result_df.empty:
        print("No results from yolov5 model!")
    else:
        image_information = result_df.to_dict()
    return image_information


@torch.no_grad()
def classify_image(model: MosquitoClassifier, image: np.ndarray, bbox: list) -> tuple:
    labels = [
        "albopictus",
        "culex",
        "japonicus-koreicus",
        "culiseta",
        "anopheles",
        "aegypti",
    ]

    image_cropped = image[bbox[1] : bbox[3], bbox[0] : bbox[2], :]
    x = torch.unsqueeze(pre_process_foo(IMG_SIZE, DATASET)(image_cropped), 0)
    x = x.to(device=DEVICE)
    p: torch.Tensor = model(x)
    ind = torch.argmax(p).item()
    label = labels[ind]
    return label, p.max().item()


def extract_predicted_mosquito_bbox(extractedInformation):
    bbox = []
    if extractedInformation is not None:
        xmin = int(extractedInformation.get("xmin").get(0))
        ymin = int(extractedInformation.get("ymin").get(0))
        xmax = int(extractedInformation.get("xmax").get(0))
        ymax = int(extractedInformation.get("ymax").get(0))
        bbox = [xmin, ymin, xmax, ymax]
    return bbox


class YOLOV5CLIPModel:
    def __init__(self):
        trained_model_path = "my_models/yolo_weights/mosquitoalert-yolov5-baseline.pt"
        repo_path = "my_models/torch_hub_cache/yolov5"
        self.det = torch.hub.load(
            repo_path,
            "custom",
            path=trained_model_path,
            force_reload=True,
            source="local",
        )

        clip_model_path = f"my_models/clip_weights/best_clf.ckpt"
        self.cls = MosquitoClassifier.load_from_checkpoint(
            clip_model_path, head_version=7, map_location=torch.device(DEVICE)
        ).eval()

    def predict(self, image: np.ndarray):
        s = time.time()
        predictedInformation = detect_image(self.det, image)
        mosquito_class_name_predicted = "albopictus"
        mosquito_class_confidence = 0.0
        mosquito_class_bbox = [0, 0, image.shape[0], image.shape[1]]

        if predictedInformation:
            mosquito_class_bbox = extract_predicted_mosquito_bbox(predictedInformation)

        mosquito_class_name_predicted, mosquito_class_confidence = classify_image(
            self.cls, image, mosquito_class_bbox
        )

        e = time.time()

        logging.info(f"[PREDICTION] Total time passed {e - s}ms")
        return (
            mosquito_class_name_predicted,
            mosquito_class_confidence,
            mosquito_class_bbox,
        )