from ultralytics import YOLO import numpy as np import time import torch torch.set_num_threads(2) from my_models.clip_model.data_loader import pre_process_foo from my_models.clip_model.classification import MosquitoClassifier IMG_SIZE = (224, 224) USE_CHANNEL_LAST = False DATASET = "laion" DEVICE = "cpu" PRESERVE_ASPECT_RATIO = False SHIFT = 0 @torch.no_grad() def classify_image(det: YOLO, cls: MosquitoClassifier, image: np.ndarray): s = time.time() labels = [ "albopictus", "culex", "japonicus-koreicus", "culiseta", "anopheles", "aegypti", ] results = det(image, verbose=True, device=DEVICE, max_det=1) img_w, img_h, _ = image.shape bbox = [0, 0, img_w, img_h] label = "albopictus" conf = 0.0 for result in results: _bbox = [0, 0, img_w, img_h] _label = "albopictus" _conf = 0.0 bboxes_tmp = result.boxes.xyxy.tolist() labels_tmp = result.boxes.cls.tolist() confs_tmp = result.boxes.conf.tolist() for bbox_tmp, label_tmp, conf_tmp in zip(bboxes_tmp, labels_tmp, confs_tmp): if conf_tmp > _conf: _bbox = bbox_tmp _label = labels[int(label_tmp)] _conf = conf_tmp if _conf > conf: bbox = _bbox label = _label conf = _conf bbox = [int(float(mcb)) for mcb in bbox] try: if conf < 1e-4: raise Exception image_cropped = image[bbox[1] : bbox[3], bbox[0] : bbox[2], :] bbox = [bbox[0] + SHIFT, bbox[1] + SHIFT, bbox[2] - SHIFT, bbox[3] - SHIFT] except Exception as e: print("Error", e) image_cropped = image if PRESERVE_ASPECT_RATIO: w, h = image_cropped.shape[:2] if w > h: x = torch.unsqueeze( pre_process_foo( (IMG_SIZE[0], max(int(IMG_SIZE[1] * h / w), 32)), DATASET )(image_cropped), 0, ) else: x = torch.unsqueeze( pre_process_foo( (max(int(IMG_SIZE[0] * w / h), 32), IMG_SIZE[1]), DATASET )(image_cropped), 0, ) else: x = torch.unsqueeze(pre_process_foo(IMG_SIZE, DATASET)(image_cropped), 0) x = x.to(device=DEVICE) if USE_CHANNEL_LAST: p = cls(x.to(memory_format=torch.channels_last)) else: p = cls(x) ind = torch.argmax(p).item() label = labels[ind] e = time.time() print("Time ", 1000 * (e - s), "ms") return {"name": label, "confidence": p.max().item(), "bbox": bbox} # getting mosquito_class name from predicted result def extract_predicted_mosquito_class_name(extractedInformation): return extractedInformation.get("name", "albopictus") def extract_predicted_mosquito_bbox(extractedInformation): return extractedInformation.get("bbox", [0, 0, 0, 0]) class YOLOV8CLIPModel: def __init__(self): trained_model_path = "my_models/yolo_weights/best-yolov8-s.pt" clip_model_path = f"my_models/clip_weights/best_clf.ckpt" self.det = YOLO(trained_model_path, task="detect") self.cls = MosquitoClassifier.load_from_checkpoint( clip_model_path, head_version=7, map_location=torch.device(DEVICE) ).eval() if USE_CHANNEL_LAST: self.cls.to(memory_format=torch.channels_last) def predict(self, image): predictedInformation = classify_image(self.det, self.cls, image) mosquito_class_name_predicted = extract_predicted_mosquito_class_name( predictedInformation ) mosquito_class_bbox = extract_predicted_mosquito_bbox(predictedInformation) bbox = bbox = [int(float(mcb)) for mcb in mosquito_class_bbox] return mosquito_class_name_predicted, predictedInformation["confidence"], bbox