File size: 3,932 Bytes
eedca6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from ultralytics import YOLO
import numpy as np
import time

import torch

torch.set_num_threads(2)

from my_models.clip_model.data_loader import pre_process_foo
from my_models.clip_model.classification import MosquitoClassifier

IMG_SIZE = (224, 224)
USE_CHANNEL_LAST = False
DATASET = "laion"
DEVICE = "cpu"
PRESERVE_ASPECT_RATIO = False
SHIFT = 0


@torch.no_grad()
def classify_image(det: YOLO, cls: MosquitoClassifier, image: np.ndarray):
    s = time.time()
    labels = [
        "albopictus",
        "culex",
        "japonicus-koreicus",
        "culiseta",
        "anopheles",
        "aegypti",
    ]

    results = det(image, verbose=True, device=DEVICE, max_det=1)
    img_w, img_h, _ = image.shape
    bbox = [0, 0, img_w, img_h]
    label = "albopictus"
    conf = 0.0

    for result in results:
        _bbox = [0, 0, img_w, img_h]
        _label = "albopictus"
        _conf = 0.0

        bboxes_tmp = result.boxes.xyxy.tolist()
        labels_tmp = result.boxes.cls.tolist()
        confs_tmp = result.boxes.conf.tolist()

        for bbox_tmp, label_tmp, conf_tmp in zip(bboxes_tmp, labels_tmp, confs_tmp):
            if conf_tmp > _conf:
                _bbox = bbox_tmp
                _label = labels[int(label_tmp)]
                _conf = conf_tmp

        if _conf > conf:
            bbox = _bbox
            label = _label
            conf = _conf

    bbox = [int(float(mcb)) for mcb in bbox]

    try:
        if conf < 1e-4:
            raise Exception
        image_cropped = image[bbox[1] : bbox[3], bbox[0] : bbox[2], :]
        bbox = [bbox[0] + SHIFT, bbox[1] + SHIFT, bbox[2] - SHIFT, bbox[3] - SHIFT]
    except Exception as e:
        print("Error", e)
        image_cropped = image

    if PRESERVE_ASPECT_RATIO:
        w, h = image_cropped.shape[:2]
        if w > h:
            x = torch.unsqueeze(
                pre_process_foo(
                    (IMG_SIZE[0], max(int(IMG_SIZE[1] * h / w), 32)), DATASET
                )(image_cropped),
                0,
            )
        else:
            x = torch.unsqueeze(
                pre_process_foo(
                    (max(int(IMG_SIZE[0] * w / h), 32), IMG_SIZE[1]), DATASET
                )(image_cropped),
                0,
            )
    else:
        x = torch.unsqueeze(pre_process_foo(IMG_SIZE, DATASET)(image_cropped), 0)

    x = x.to(device=DEVICE)

    if USE_CHANNEL_LAST:
        p = cls(x.to(memory_format=torch.channels_last))
    else:
        p = cls(x)
    ind = torch.argmax(p).item()
    label = labels[ind]

    e = time.time()

    print("Time ", 1000 * (e - s), "ms")
    return {"name": label, "confidence": p.max().item(), "bbox": bbox}


# getting mosquito_class name from predicted result
def extract_predicted_mosquito_class_name(extractedInformation):
    return extractedInformation.get("name", "albopictus")


def extract_predicted_mosquito_bbox(extractedInformation):
    return extractedInformation.get("bbox", [0, 0, 0, 0])


class YOLOV8CLIPModel:
    def __init__(self):
        trained_model_path = "my_models/yolo_weights/best-yolov8-s.pt"
        clip_model_path = f"my_models/clip_weights/best_clf.ckpt"
        self.det = YOLO(trained_model_path, task="detect")

        self.cls = MosquitoClassifier.load_from_checkpoint(
            clip_model_path, head_version=7, map_location=torch.device(DEVICE)
        ).eval()

        if USE_CHANNEL_LAST:
            self.cls.to(memory_format=torch.channels_last)

    def predict(self, image):
        predictedInformation = classify_image(self.det, self.cls, image)

        mosquito_class_name_predicted = extract_predicted_mosquito_class_name(
            predictedInformation
        )
        mosquito_class_bbox = extract_predicted_mosquito_bbox(predictedInformation)

        bbox = bbox = [int(float(mcb)) for mcb in mosquito_class_bbox]

        return mosquito_class_name_predicted, predictedInformation["confidence"], bbox