Spaces:
Sleeping
Sleeping
import torch | |
from ultralytics import YOLO | |
class PearDetectionModel: | |
def __init__(self, config) -> None: | |
self.device = ( | |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
) | |
self.model = YOLO(config["model_path"], task="detect") | |
self.names = config["classes"] | |
def detect(self, img): | |
results = self.model.predict(img) | |
return results[0].boxes.cpu().numpy() | |
def inference(self, img): | |
pred = self.detect(img) | |
# remove the box with confidence lower than 0.9 if no "burn_bbox" is detected, else 0.8 | |
pred = ( | |
pred[pred.conf > 0.8] | |
if all([pred != "burn_bbox" for pred in self.names]) | |
else pred[pred.conf > 0.5] | |
) | |
labels = [self.names[int(cat)] for cat in pred.cls] | |
# if any classes rather than "normal_pear_box" is detected, return 0 else return 1 | |
if any([label == "burn_bbox" for label in labels]): | |
return 1, pred.xyxy, pred.conf | |
else: | |
return 0, pred.xyxy, pred.conf | |
def _preporcess(self, img): | |
pass |