import torch from ..utils import box_utils from .data_preprocessing import PredictionTransform from ..utils.misc import Timer class Predictor: def __init__(self, net, size, mean=0.0, std=1.0, nms_method=None, iou_threshold=0.45, filter_threshold=0.01, candidate_size=200, sigma=0.5, device=None): self.net = net self.transform = PredictionTransform(size, mean, std) self.iou_threshold = iou_threshold self.filter_threshold = filter_threshold self.candidate_size = candidate_size self.nms_method = nms_method self.sigma = sigma if device: self.device = device else: self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.net.to(self.device) self.net.eval() self.timer = Timer() def predict(self, image, top_k=-1, prob_threshold=None): cpu_device = torch.device("cpu") height, width, _ = image.shape image = self.transform(image) #print(image) images = image.unsqueeze(0) images = images.to(self.device) with torch.no_grad(): self.timer.start() scores, boxes = self.net.forward(images) print("Inference time: ", self.timer.end()) boxes = boxes[0] scores = scores[0] if not prob_threshold: prob_threshold = self.filter_threshold # this version of nms is slower on GPU, so we move data to CPU. boxes = boxes.to(cpu_device) scores = scores.to(cpu_device) picked_box_probs = [] picked_labels = [] for class_index in range(1, scores.size(1)): probs = scores[:, class_index] mask = probs > prob_threshold probs = probs[mask] if probs.size(0) == 0: continue subset_boxes = boxes[mask, :] box_probs = torch.cat([subset_boxes, probs.reshape(-1, 1)], dim=1) box_probs = box_utils.nms(box_probs, self.nms_method, score_threshold=prob_threshold, iou_threshold=self.iou_threshold, sigma=self.sigma, top_k=top_k, candidate_size=self.candidate_size) picked_box_probs.append(box_probs) picked_labels.extend([class_index] * box_probs.size(0)) if not picked_box_probs: return torch.tensor([]), torch.tensor([]), torch.tensor([]) picked_box_probs = torch.cat(picked_box_probs) picked_box_probs[:, 0] *= width picked_box_probs[:, 1] *= height picked_box_probs[:, 2] *= width picked_box_probs[:, 3] *= height return picked_box_probs[:, :4], torch.tensor(picked_labels), picked_box_probs[:, 4]