Mosquito-Detection / my_models /yolov8_clip_model.py
hca97's picture
adding yolov8s model as well
eedca6c
from ultralytics import YOLO
import numpy as np
import time
import torch
torch.set_num_threads(2)
from my_models.clip_model.data_loader import pre_process_foo
from my_models.clip_model.classification import MosquitoClassifier
IMG_SIZE = (224, 224)
USE_CHANNEL_LAST = False
DATASET = "laion"
DEVICE = "cpu"
PRESERVE_ASPECT_RATIO = False
SHIFT = 0
@torch.no_grad()
def classify_image(det: YOLO, cls: MosquitoClassifier, image: np.ndarray):
s = time.time()
labels = [
"albopictus",
"culex",
"japonicus-koreicus",
"culiseta",
"anopheles",
"aegypti",
]
results = det(image, verbose=True, device=DEVICE, max_det=1)
img_w, img_h, _ = image.shape
bbox = [0, 0, img_w, img_h]
label = "albopictus"
conf = 0.0
for result in results:
_bbox = [0, 0, img_w, img_h]
_label = "albopictus"
_conf = 0.0
bboxes_tmp = result.boxes.xyxy.tolist()
labels_tmp = result.boxes.cls.tolist()
confs_tmp = result.boxes.conf.tolist()
for bbox_tmp, label_tmp, conf_tmp in zip(bboxes_tmp, labels_tmp, confs_tmp):
if conf_tmp > _conf:
_bbox = bbox_tmp
_label = labels[int(label_tmp)]
_conf = conf_tmp
if _conf > conf:
bbox = _bbox
label = _label
conf = _conf
bbox = [int(float(mcb)) for mcb in bbox]
try:
if conf < 1e-4:
raise Exception
image_cropped = image[bbox[1] : bbox[3], bbox[0] : bbox[2], :]
bbox = [bbox[0] + SHIFT, bbox[1] + SHIFT, bbox[2] - SHIFT, bbox[3] - SHIFT]
except Exception as e:
print("Error", e)
image_cropped = image
if PRESERVE_ASPECT_RATIO:
w, h = image_cropped.shape[:2]
if w > h:
x = torch.unsqueeze(
pre_process_foo(
(IMG_SIZE[0], max(int(IMG_SIZE[1] * h / w), 32)), DATASET
)(image_cropped),
0,
)
else:
x = torch.unsqueeze(
pre_process_foo(
(max(int(IMG_SIZE[0] * w / h), 32), IMG_SIZE[1]), DATASET
)(image_cropped),
0,
)
else:
x = torch.unsqueeze(pre_process_foo(IMG_SIZE, DATASET)(image_cropped), 0)
x = x.to(device=DEVICE)
if USE_CHANNEL_LAST:
p = cls(x.to(memory_format=torch.channels_last))
else:
p = cls(x)
ind = torch.argmax(p).item()
label = labels[ind]
e = time.time()
print("Time ", 1000 * (e - s), "ms")
return {"name": label, "confidence": p.max().item(), "bbox": bbox}
# getting mosquito_class name from predicted result
def extract_predicted_mosquito_class_name(extractedInformation):
return extractedInformation.get("name", "albopictus")
def extract_predicted_mosquito_bbox(extractedInformation):
return extractedInformation.get("bbox", [0, 0, 0, 0])
class YOLOV8CLIPModel:
def __init__(self):
trained_model_path = "my_models/yolo_weights/best-yolov8-s.pt"
clip_model_path = f"my_models/clip_weights/best_clf.ckpt"
self.det = YOLO(trained_model_path, task="detect")
self.cls = MosquitoClassifier.load_from_checkpoint(
clip_model_path, head_version=7, map_location=torch.device(DEVICE)
).eval()
if USE_CHANNEL_LAST:
self.cls.to(memory_format=torch.channels_last)
def predict(self, image):
predictedInformation = classify_image(self.det, self.cls, image)
mosquito_class_name_predicted = extract_predicted_mosquito_class_name(
predictedInformation
)
mosquito_class_bbox = extract_predicted_mosquito_bbox(predictedInformation)
bbox = bbox = [int(float(mcb)) for mcb in mosquito_class_bbox]
return mosquito_class_name_predicted, predictedInformation["confidence"], bbox