from yoloxdetect2.utils.downloads import attempt_download_from_hub, attempt_download from yolox.data.datasets import COCO_CLASSES from yolox.data.data_augment import preproc from yolox.utils import postprocess, vis import importlib import torch import cv2 import os class YoloxDetector2: def __init__( self, model_path: str, config_path: str, device: str = "cpu", hf_model: bool = False, ): self.device = device self.config_path = config_path self.classes = COCO_CLASSES self.conf = 0.3 self.iou = 0.45 self.show = False self.save = True self.torchyolo = False if self.save: self.save_path = 'output/result.jpg' if hf_model: self.model_path = attempt_download_from_hub(model_path) else: self.model_path = attempt_download(model_path) self.load_model() def load_model(self): current_exp = importlib.import_module(self.config_path) exp = current_exp.Exp() model = exp.get_model() model.to(self.device) model.eval() ckpt = torch.load(self.model_path, map_location=self.device) model.load_state_dict(ckpt["model"]) self.model = model def predict(self, image_path, image_size): image = cv2.imread(image_path) if image_size is not None: ratio = min(image_size / image.shape[0], image_size / image.shape[1]) img, _ = preproc(image, input_size=(image_size, image_size)) img = torch.from_numpy(img).to(self.device).unsqueeze(0).float() else: manuel_size = 640 ratio = min(manuel_size / image.shape[0], manuel_size / image.shape[1]) img, _ = preproc(image, input_size=(manuel_size, manuel_size)) img = torch.from_numpy(img).to(self.device).unsqueeze(0).float() prediction_result = self.model(img) original_predictions = postprocess( prediction=prediction_result, num_classes= len(COCO_CLASSES), conf_thre=self.conf, nms_thre=self.iou)[0] if original_predictions is None : return None output = original_predictions.cpu() bboxes = output[:, 0:4] bboxes /= ratio cls = output[:, 6] scores = output[:, 4] * output[:, 5] if self.torchyolo is False: vis_res = vis( image, bboxes, scores, cls, self.conf, COCO_CLASSES, ) if self.show: cv2.imshow("result", vis_res) cv2.waitKey(0) cv2.destroyAllWindows() elif self.save: save_dir = self.save_path[:self.save_path.rfind('/')] if not os.path.exists(save_dir): os.makedirs(save_dir) cv2.imwrite(self.save_path, vis_res) return self.save_path else: return vis_res else: object_predictions_list = [bboxes, scores, cls, COCO_CLASSES] return object_predictions_list