Spaces:
Sleeping
Sleeping
from ultralytics import YOLO | |
import numpy as np | |
import time | |
import torch | |
torch.set_num_threads(2) | |
from my_models.clip_model.data_loader import pre_process_foo | |
from my_models.clip_model.classification import MosquitoClassifier | |
IMG_SIZE = (224, 224) | |
USE_CHANNEL_LAST = False | |
DATASET = "laion" | |
DEVICE = "cpu" | |
PRESERVE_ASPECT_RATIO = False | |
SHIFT = 0 | |
def classify_image(det: YOLO, cls: MosquitoClassifier, image: np.ndarray): | |
s = time.time() | |
labels = [ | |
"albopictus", | |
"culex", | |
"japonicus-koreicus", | |
"culiseta", | |
"anopheles", | |
"aegypti", | |
] | |
results = det(image, verbose=True, device=DEVICE, max_det=1) | |
img_w, img_h, _ = image.shape | |
bbox = [0, 0, img_w, img_h] | |
label = "albopictus" | |
conf = 0.0 | |
for result in results: | |
_bbox = [0, 0, img_w, img_h] | |
_label = "albopictus" | |
_conf = 0.0 | |
bboxes_tmp = result.boxes.xyxy.tolist() | |
labels_tmp = result.boxes.cls.tolist() | |
confs_tmp = result.boxes.conf.tolist() | |
for bbox_tmp, label_tmp, conf_tmp in zip(bboxes_tmp, labels_tmp, confs_tmp): | |
if conf_tmp > _conf: | |
_bbox = bbox_tmp | |
_label = labels[int(label_tmp)] | |
_conf = conf_tmp | |
if _conf > conf: | |
bbox = _bbox | |
label = _label | |
conf = _conf | |
bbox = [int(float(mcb)) for mcb in bbox] | |
try: | |
if conf < 1e-4: | |
raise Exception | |
image_cropped = image[bbox[1] : bbox[3], bbox[0] : bbox[2], :] | |
bbox = [bbox[0] + SHIFT, bbox[1] + SHIFT, bbox[2] - SHIFT, bbox[3] - SHIFT] | |
except Exception as e: | |
print("Error", e) | |
image_cropped = image | |
if PRESERVE_ASPECT_RATIO: | |
w, h = image_cropped.shape[:2] | |
if w > h: | |
x = torch.unsqueeze( | |
pre_process_foo( | |
(IMG_SIZE[0], max(int(IMG_SIZE[1] * h / w), 32)), DATASET | |
)(image_cropped), | |
0, | |
) | |
else: | |
x = torch.unsqueeze( | |
pre_process_foo( | |
(max(int(IMG_SIZE[0] * w / h), 32), IMG_SIZE[1]), DATASET | |
)(image_cropped), | |
0, | |
) | |
else: | |
x = torch.unsqueeze(pre_process_foo(IMG_SIZE, DATASET)(image_cropped), 0) | |
x = x.to(device=DEVICE) | |
if USE_CHANNEL_LAST: | |
p = cls(x.to(memory_format=torch.channels_last)) | |
else: | |
p = cls(x) | |
ind = torch.argmax(p).item() | |
label = labels[ind] | |
e = time.time() | |
print("Time ", 1000 * (e - s), "ms") | |
return {"name": label, "confidence": p.max().item(), "bbox": bbox} | |
# getting mosquito_class name from predicted result | |
def extract_predicted_mosquito_class_name(extractedInformation): | |
return extractedInformation.get("name", "albopictus") | |
def extract_predicted_mosquito_bbox(extractedInformation): | |
return extractedInformation.get("bbox", [0, 0, 0, 0]) | |
class YOLOV8CLIPModel: | |
def __init__(self): | |
trained_model_path = "my_models/yolo_weights/best-yolov8-s.pt" | |
clip_model_path = f"my_models/clip_weights/best_clf.ckpt" | |
self.det = YOLO(trained_model_path, task="detect") | |
self.cls = MosquitoClassifier.load_from_checkpoint( | |
clip_model_path, head_version=7, map_location=torch.device(DEVICE) | |
).eval() | |
if USE_CHANNEL_LAST: | |
self.cls.to(memory_format=torch.channels_last) | |
def predict(self, image): | |
predictedInformation = classify_image(self.det, self.cls, image) | |
mosquito_class_name_predicted = extract_predicted_mosquito_class_name( | |
predictedInformation | |
) | |
mosquito_class_bbox = extract_predicted_mosquito_bbox(predictedInformation) | |
bbox = bbox = [int(float(mcb)) for mcb in mosquito_class_bbox] | |
return mosquito_class_name_predicted, predictedInformation["confidence"], bbox | |