Spaces:
Sleeping
Sleeping
import logging | |
import time | |
import torch | |
import numpy as np | |
torch.set_num_threads(2) | |
from my_models.clip_model.data_loader import pre_process_foo | |
from my_models.clip_model.classification import MosquitoClassifier | |
IMG_SIZE = (224, 224) | |
USE_CHANNEL_LAST = False | |
DATASET = "laion" | |
DEVICE = "cpu" | |
PRESERVE_ASPECT_RATIO = False | |
SHIFT = 0 | |
def detect_image(model, image: np.ndarray) -> dict: | |
image_information = {} | |
result = model(image) | |
result_df = result.pandas().xyxy[0] | |
if result_df.empty: | |
print("No results from yolov5 model!") | |
else: | |
image_information = result_df.to_dict() | |
return image_information | |
def classify_image(model: MosquitoClassifier, image: np.ndarray, bbox: list) -> tuple: | |
labels = [ | |
"albopictus", | |
"culex", | |
"japonicus-koreicus", | |
"culiseta", | |
"anopheles", | |
"aegypti", | |
] | |
image_cropped = image[bbox[1] : bbox[3], bbox[0] : bbox[2], :] | |
x = torch.unsqueeze(pre_process_foo(IMG_SIZE, DATASET)(image_cropped), 0) | |
x = x.to(device=DEVICE) | |
p: torch.Tensor = model(x) | |
ind = torch.argmax(p).item() | |
label = labels[ind] | |
return label, p.max().item() | |
def extract_predicted_mosquito_bbox(extractedInformation): | |
bbox = [] | |
if extractedInformation is not None: | |
xmin = int(extractedInformation.get("xmin").get(0)) | |
ymin = int(extractedInformation.get("ymin").get(0)) | |
xmax = int(extractedInformation.get("xmax").get(0)) | |
ymax = int(extractedInformation.get("ymax").get(0)) | |
bbox = [xmin, ymin, xmax, ymax] | |
return bbox | |
class YOLOV5CLIPModel: | |
def __init__(self): | |
trained_model_path = "my_models/yolo_weights/mosquitoalert-yolov5-baseline.pt" | |
repo_path = "my_models/torch_hub_cache/yolov5" | |
self.det = torch.hub.load( | |
repo_path, | |
"custom", | |
path=trained_model_path, | |
force_reload=True, | |
source="local", | |
) | |
clip_model_path = f"my_models/clip_weights/best_clf.ckpt" | |
self.cls = MosquitoClassifier.load_from_checkpoint( | |
clip_model_path, head_version=7, map_location=torch.device(DEVICE) | |
).eval() | |
def predict(self, image: np.ndarray): | |
s = time.time() | |
predictedInformation = detect_image(self.det, image) | |
mosquito_class_name_predicted = "albopictus" | |
mosquito_class_confidence = 0.0 | |
mosquito_class_bbox = [0, 0, image.shape[0], image.shape[1]] | |
if predictedInformation: | |
mosquito_class_bbox = extract_predicted_mosquito_bbox(predictedInformation) | |
mosquito_class_name_predicted, mosquito_class_confidence = classify_image( | |
self.cls, image, mosquito_class_bbox | |
) | |
e = time.time() | |
logging.info(f"[PREDICTION] Total time passed {e - s}ms") | |
return ( | |
mosquito_class_name_predicted, | |
mosquito_class_confidence, | |
mosquito_class_bbox, | |
) | |