Spaces:
Sleeping
Sleeping
File size: 3,932 Bytes
eedca6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from ultralytics import YOLO
import numpy as np
import time
import torch
torch.set_num_threads(2)
from my_models.clip_model.data_loader import pre_process_foo
from my_models.clip_model.classification import MosquitoClassifier
IMG_SIZE = (224, 224)
USE_CHANNEL_LAST = False
DATASET = "laion"
DEVICE = "cpu"
PRESERVE_ASPECT_RATIO = False
SHIFT = 0
@torch.no_grad()
def classify_image(det: YOLO, cls: MosquitoClassifier, image: np.ndarray):
s = time.time()
labels = [
"albopictus",
"culex",
"japonicus-koreicus",
"culiseta",
"anopheles",
"aegypti",
]
results = det(image, verbose=True, device=DEVICE, max_det=1)
img_w, img_h, _ = image.shape
bbox = [0, 0, img_w, img_h]
label = "albopictus"
conf = 0.0
for result in results:
_bbox = [0, 0, img_w, img_h]
_label = "albopictus"
_conf = 0.0
bboxes_tmp = result.boxes.xyxy.tolist()
labels_tmp = result.boxes.cls.tolist()
confs_tmp = result.boxes.conf.tolist()
for bbox_tmp, label_tmp, conf_tmp in zip(bboxes_tmp, labels_tmp, confs_tmp):
if conf_tmp > _conf:
_bbox = bbox_tmp
_label = labels[int(label_tmp)]
_conf = conf_tmp
if _conf > conf:
bbox = _bbox
label = _label
conf = _conf
bbox = [int(float(mcb)) for mcb in bbox]
try:
if conf < 1e-4:
raise Exception
image_cropped = image[bbox[1] : bbox[3], bbox[0] : bbox[2], :]
bbox = [bbox[0] + SHIFT, bbox[1] + SHIFT, bbox[2] - SHIFT, bbox[3] - SHIFT]
except Exception as e:
print("Error", e)
image_cropped = image
if PRESERVE_ASPECT_RATIO:
w, h = image_cropped.shape[:2]
if w > h:
x = torch.unsqueeze(
pre_process_foo(
(IMG_SIZE[0], max(int(IMG_SIZE[1] * h / w), 32)), DATASET
)(image_cropped),
0,
)
else:
x = torch.unsqueeze(
pre_process_foo(
(max(int(IMG_SIZE[0] * w / h), 32), IMG_SIZE[1]), DATASET
)(image_cropped),
0,
)
else:
x = torch.unsqueeze(pre_process_foo(IMG_SIZE, DATASET)(image_cropped), 0)
x = x.to(device=DEVICE)
if USE_CHANNEL_LAST:
p = cls(x.to(memory_format=torch.channels_last))
else:
p = cls(x)
ind = torch.argmax(p).item()
label = labels[ind]
e = time.time()
print("Time ", 1000 * (e - s), "ms")
return {"name": label, "confidence": p.max().item(), "bbox": bbox}
# getting mosquito_class name from predicted result
def extract_predicted_mosquito_class_name(extractedInformation):
return extractedInformation.get("name", "albopictus")
def extract_predicted_mosquito_bbox(extractedInformation):
return extractedInformation.get("bbox", [0, 0, 0, 0])
class YOLOV8CLIPModel:
def __init__(self):
trained_model_path = "my_models/yolo_weights/best-yolov8-s.pt"
clip_model_path = f"my_models/clip_weights/best_clf.ckpt"
self.det = YOLO(trained_model_path, task="detect")
self.cls = MosquitoClassifier.load_from_checkpoint(
clip_model_path, head_version=7, map_location=torch.device(DEVICE)
).eval()
if USE_CHANNEL_LAST:
self.cls.to(memory_format=torch.channels_last)
def predict(self, image):
predictedInformation = classify_image(self.det, self.cls, image)
mosquito_class_name_predicted = extract_predicted_mosquito_class_name(
predictedInformation
)
mosquito_class_bbox = extract_predicted_mosquito_bbox(predictedInformation)
bbox = bbox = [int(float(mcb)) for mcb in mosquito_class_bbox]
return mosquito_class_name_predicted, predictedInformation["confidence"], bbox
|