Pear-playground / detection.py
Kaori1707's picture
Upload 10 files
69ef5c2 verified
import torch
from ultralytics import YOLO
class PearDetectionModel:
def __init__(self, config) -> None:
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
self.model = YOLO(config["model_path"], task="detect")
self.names = config["classes"]
def detect(self, img):
results = self.model.predict(img)
return results[0].boxes.cpu().numpy()
def inference(self, img):
pred = self.detect(img)
# remove the box with confidence lower than 0.9 if no "burn_bbox" is detected, else 0.8
pred = (
pred[pred.conf > 0.8]
if all([pred != "burn_bbox" for pred in self.names])
else pred[pred.conf > 0.5]
)
labels = [self.names[int(cat)] for cat in pred.cls]
# if any classes rather than "normal_pear_box" is detected, return 0 else return 1
if any([label == "burn_bbox" for label in labels]):
return 1, pred.xyxy, pred.conf
else:
return 0, pred.xyxy, pred.conf
def _preporcess(self, img):
pass