Spaces:
Sleeping
Sleeping
File size: 3,010 Bytes
9093750 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import logging
import time
import torch
import numpy as np
torch.set_num_threads(2)
from my_models.clip_model.data_loader import pre_process_foo
from my_models.clip_model.classification import MosquitoClassifier
IMG_SIZE = (224, 224)
USE_CHANNEL_LAST = False
DATASET = "laion"
DEVICE = "cpu"
PRESERVE_ASPECT_RATIO = False
SHIFT = 0
@torch.no_grad()
def detect_image(model, image: np.ndarray) -> dict:
image_information = {}
result = model(image)
result_df = result.pandas().xyxy[0]
if result_df.empty:
print("No results from yolov5 model!")
else:
image_information = result_df.to_dict()
return image_information
@torch.no_grad()
def classify_image(model: MosquitoClassifier, image: np.ndarray, bbox: list) -> tuple:
labels = [
"albopictus",
"culex",
"japonicus-koreicus",
"culiseta",
"anopheles",
"aegypti",
]
image_cropped = image[bbox[1] : bbox[3], bbox[0] : bbox[2], :]
x = torch.unsqueeze(pre_process_foo(IMG_SIZE, DATASET)(image_cropped), 0)
x = x.to(device=DEVICE)
p: torch.Tensor = model(x)
ind = torch.argmax(p).item()
label = labels[ind]
return label, p.max().item()
def extract_predicted_mosquito_bbox(extractedInformation):
bbox = []
if extractedInformation is not None:
xmin = int(extractedInformation.get("xmin").get(0))
ymin = int(extractedInformation.get("ymin").get(0))
xmax = int(extractedInformation.get("xmax").get(0))
ymax = int(extractedInformation.get("ymax").get(0))
bbox = [xmin, ymin, xmax, ymax]
return bbox
class YOLOV5CLIPModel:
def __init__(self):
trained_model_path = "my_models/yolo_weights/mosquitoalert-yolov5-baseline.pt"
repo_path = "my_models/torch_hub_cache/yolov5"
self.det = torch.hub.load(
repo_path,
"custom",
path=trained_model_path,
force_reload=True,
source="local",
)
clip_model_path = f"my_models/clip_weights/best_clf.ckpt"
self.cls = MosquitoClassifier.load_from_checkpoint(
clip_model_path, head_version=7, map_location=torch.device(DEVICE)
).eval()
def predict(self, image: np.ndarray):
s = time.time()
predictedInformation = detect_image(self.det, image)
mosquito_class_name_predicted = "albopictus"
mosquito_class_confidence = 0.0
mosquito_class_bbox = [0, 0, image.shape[0], image.shape[1]]
if predictedInformation:
mosquito_class_bbox = extract_predicted_mosquito_bbox(predictedInformation)
mosquito_class_name_predicted, mosquito_class_confidence = classify_image(
self.cls, image, mosquito_class_bbox
)
e = time.time()
logging.info(f"[PREDICTION] Total time passed {e - s}ms")
return (
mosquito_class_name_predicted,
mosquito_class_confidence,
mosquito_class_bbox,
)
|