Atualli's picture
Duplicate from Atualli/yolox2
1cdd82e
raw
history blame
3.31 kB
from yoloxdetect.utils.downloads import attempt_download_from_hub, attempt_download
from yolox.data.datasets import COCO_CLASSES
from yolox.data.data_augment import preproc
from yolox.utils import postprocess, vis
import importlib
import torch
import cv2
import os
class YoloxDetector2:
def __init__(
self,
model_path: str,
config_path: str,
device: str = "cpu",
hf_model: bool = False,
):
self.device = device
self.config_path = config_path
self.classes = COCO_CLASSES
self.conf = 0.3
self.iou = 0.45
self.show = False
self.save = True
self.torchyolo = False
if self.save:
self.save_path = 'output/result.jpg'
if hf_model:
self.model_path = attempt_download_from_hub(model_path)
else:
self.model_path = attempt_download(model_path)
self.load_model()
def load_model(self):
current_exp = importlib.import_module(self.config_path)
exp = current_exp.Exp()
model = exp.get_model()
model.to(self.device)
model.eval()
ckpt = torch.load(self.model_path, map_location=self.device)
model.load_state_dict(ckpt["model"])
self.model = model
def predict(self, image_path, image_size):
image = cv2.imread(image_path)
if image_size is not None:
ratio = min(image_size / image.shape[0], image_size / image.shape[1])
img, _ = preproc(image, input_size=(image_size, image_size))
img = torch.from_numpy(img).to(self.device).unsqueeze(0).float()
else:
manuel_size = 640
ratio = min(manuel_size / image.shape[0], manuel_size / image.shape[1])
img, _ = preproc(image, input_size=(manuel_size, manuel_size))
img = torch.from_numpy(img).to(self.device).unsqueeze(0).float()
prediction_result = self.model(img)
original_predictions = postprocess(
prediction=prediction_result,
num_classes= len(COCO_CLASSES),
conf_thre=self.conf,
nms_thre=self.iou)[0]
if original_predictions is None :
return None
output = original_predictions.cpu()
bboxes = output[:, 0:4]
bboxes /= ratio
cls = output[:, 6]
scores = output[:, 4] * output[:, 5]
if self.torchyolo is False:
vis_res = vis(
image,
bboxes,
scores,
cls,
self.conf,
COCO_CLASSES,
)
if self.show:
cv2.imshow("result", vis_res)
cv2.waitKey(0)
cv2.destroyAllWindows()
elif self.save:
save_dir = self.save_path[:self.save_path.rfind('/')]
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(self.save_path, vis_res)
return self.save_path
else:
return vis_res
else:
object_predictions_list = [bboxes, scores, cls, COCO_CLASSES]
return object_predictions_list