Spaces:
Sleeping
Sleeping
| from itertools import count | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.parallel import DistributedDataParallel | |
| from detectron2.modeling import GeneralizedRCNNWithTTA, DatasetMapperTTA | |
| from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image | |
| from detectron2.structures import Instances, Boxes | |
| class RegionSpotWithTTA(GeneralizedRCNNWithTTA): | |
| def __init__(self, cfg, model, tta_mapper=None, batch_size=3): | |
| """ | |
| Args: | |
| cfg (CfgNode): | |
| model ( RegionSpot): a RegionSpot to apply TTA on. | |
| tta_mapper (callable): takes a dataset dict and returns a list of | |
| augmented versions of the dataset dict. Defaults to | |
| `DatasetMapperTTA(cfg)`. | |
| batch_size (int): batch the augmented images into this batch size for inference. | |
| """ | |
| # fix the issue: cannot assign module before Module.__init__() call | |
| nn.Module.__init__(self) | |
| if isinstance(model, DistributedDataParallel): | |
| model = model.module | |
| self.cfg = cfg.clone() | |
| self.model = model | |
| if tta_mapper is None: | |
| tta_mapper = DatasetMapperTTA(cfg) | |
| self.tta_mapper = tta_mapper | |
| self.batch_size = batch_size | |
| # cvpods tta. | |
| self.enable_cvpods_tta = cfg.TEST.AUG.CVPODS_TTA | |
| self.enable_scale_filter = cfg.TEST.AUG.SCALE_FILTER | |
| self.scale_ranges = cfg.TEST.AUG.SCALE_RANGES | |
| self.max_detection = cfg.MODEL.RegionSpot.NUM_PROPOSALS | |
| def _batch_inference(self, batched_inputs, detected_instances=None): | |
| """ | |
| Execute inference on a list of inputs, | |
| using batch size = self.batch_size, instead of the length of the list. | |
| """ | |
| if detected_instances is None: | |
| detected_instances = [None] * len(batched_inputs) | |
| factors = 2 if self.tta_mapper.flip else 1 | |
| if self.enable_scale_filter: | |
| assert len(batched_inputs) == len(self.scale_ranges) * factors | |
| outputs = [] | |
| inputs, instances = [], [] | |
| for idx, input, instance in zip(count(), batched_inputs, detected_instances): | |
| inputs.append(input) | |
| instances.append(instance) | |
| if self.enable_cvpods_tta: | |
| output = self.model.forward(inputs, do_postprocess=False)[0] | |
| if self.enable_scale_filter: | |
| pred_boxes = output.get("pred_boxes") | |
| keep = self.filter_boxes(pred_boxes.tensor, *self.scale_ranges[idx // factors]) | |
| output = Instances( | |
| image_size=output.image_size, | |
| pred_boxes=Boxes(pred_boxes.tensor[keep]), | |
| pred_classes=output.pred_classes[keep], | |
| scores=output.scores[keep]) | |
| outputs.extend([output]) | |
| else: | |
| if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1: | |
| outputs.extend( | |
| self.model.forward( | |
| inputs, | |
| do_postprocess=False, | |
| ) | |
| ) | |
| inputs, instances = [], [] | |
| return outputs | |
| def filter_boxes(boxes, min_scale, max_scale): | |
| """ | |
| boxes: (N, 4) shape | |
| """ | |
| # assert boxes.mode == "xyxy" | |
| w = boxes[:, 2] - boxes[:, 0] | |
| h = boxes[:, 3] - boxes[:, 1] | |
| keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale) | |
| return keep | |
| def _inference_one_image(self, input): | |
| """ | |
| Args: | |
| input (dict): one dataset dict with "image" field being a CHW tensor | |
| Returns: | |
| dict: one output dict | |
| """ | |
| orig_shape = (input["height"], input["width"]) | |
| augmented_inputs, tfms = self._get_augmented_inputs(input) | |
| # Detect boxes from all augmented versions | |
| all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms) | |
| # merge all detected boxes to obtain final predictions for boxes | |
| if self.enable_cvpods_tta: | |
| merged_instances = self._merge_detections_cvpods_tta(all_boxes, all_scores, all_classes, orig_shape) | |
| else: | |
| merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape) | |
| return {"instances": merged_instances} | |
| def _merge_detections(self, all_boxes, all_scores, all_classes, shape_hw): | |
| # select from the union of all results | |
| num_boxes = len(all_boxes) | |
| num_classes = self.cfg.MODEL. RegionSpot.NUM_CLASSES | |
| # +1 because fast_rcnn_inference expects background scores as well | |
| all_scores_2d = torch.zeros(num_boxes, num_classes + 1, device=all_boxes.device) | |
| for idx, cls, score in zip(count(), all_classes, all_scores): | |
| all_scores_2d[idx, cls] = score | |
| merged_instances, _ = fast_rcnn_inference_single_image( | |
| all_boxes, | |
| all_scores_2d, | |
| shape_hw, | |
| 1e-8, | |
| self.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST, | |
| self.cfg.TEST.DETECTIONS_PER_IMAGE, | |
| ) | |
| return merged_instances | |
| def _merge_detections_cvpods_tta(self, all_boxes, all_scores, all_classes, shape_hw): | |
| all_scores = torch.tensor(all_scores).to(all_boxes.device) | |
| all_classes = torch.tensor(all_classes).to(all_boxes.device) | |
| all_boxes, all_scores, all_classes = self.merge_result_from_multi_scales( | |
| all_boxes, all_scores, all_classes, | |
| nms_type="soft_vote", vote_thresh=0.65, | |
| max_detection=self.max_detection | |
| ) | |
| all_boxes = Boxes(all_boxes) | |
| all_boxes.clip(shape_hw) | |
| result = Instances(shape_hw) | |
| result.pred_boxes = all_boxes | |
| result.scores = all_scores | |
| result.pred_classes = all_classes.long() | |
| return result | |
| def merge_result_from_multi_scales( | |
| self, boxes, scores, labels, nms_type="soft-vote", vote_thresh=0.65, max_detection=100 | |
| ): | |
| boxes, scores, labels = self.batched_vote_nms( | |
| boxes, scores, labels, nms_type, vote_thresh | |
| ) | |
| number_of_detections = boxes.shape[0] | |
| # Limit to max_per_image detections **over all classes** | |
| if number_of_detections > max_detection > 0: | |
| boxes = boxes[:max_detection] | |
| scores = scores[:max_detection] | |
| labels = labels[:max_detection] | |
| return boxes, scores, labels | |
| def batched_vote_nms(self, boxes, scores, labels, vote_type, vote_thresh=0.65): | |
| # apply per class level nms, add max_coordinates on boxes first, then remove it. | |
| labels = labels.float() | |
| max_coordinates = boxes.max() + 1 | |
| offsets = labels.reshape(-1, 1) * max_coordinates | |
| boxes = boxes + offsets | |
| boxes, scores, labels = self.bbox_vote(boxes, scores, labels, vote_thresh, vote_type) | |
| boxes -= labels.reshape(-1, 1) * max_coordinates | |
| return boxes, scores, labels | |
| def bbox_vote(self, boxes, scores, labels, vote_thresh, vote_type="softvote"): | |
| assert boxes.shape[0] == scores.shape[0] == labels.shape[0] | |
| det = torch.cat((boxes, scores.reshape(-1, 1), labels.reshape(-1, 1)), dim=1) | |
| vote_results = torch.zeros(0, 6, device=det.device) | |
| if det.numel() == 0: | |
| return vote_results[:, :4], vote_results[:, 4], vote_results[:, 5] | |
| order = scores.argsort(descending=True) | |
| det = det[order] | |
| while det.shape[0] > 0: | |
| # IOU | |
| area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]) | |
| xx1 = torch.max(det[0, 0], det[:, 0]) | |
| yy1 = torch.max(det[0, 1], det[:, 1]) | |
| xx2 = torch.min(det[0, 2], det[:, 2]) | |
| yy2 = torch.min(det[0, 3], det[:, 3]) | |
| w = torch.clamp(xx2 - xx1, min=0.) | |
| h = torch.clamp(yy2 - yy1, min=0.) | |
| inter = w * h | |
| iou = inter / (area[0] + area[:] - inter) | |
| # get needed merge det and delete these det | |
| merge_index = torch.where(iou >= vote_thresh)[0] | |
| vote_det = det[merge_index, :] | |
| det = det[iou < vote_thresh] | |
| if merge_index.shape[0] <= 1: | |
| vote_results = torch.cat((vote_results, vote_det), dim=0) | |
| else: | |
| if vote_type == "soft_vote": | |
| vote_det_iou = iou[merge_index] | |
| det_accu_sum = self.get_soft_dets_sum(vote_det, vote_det_iou) | |
| elif vote_type == "vote": | |
| det_accu_sum = self.get_dets_sum(vote_det) | |
| vote_results = torch.cat((vote_results, det_accu_sum), dim=0) | |
| order = vote_results[:, 4].argsort(descending=True) | |
| vote_results = vote_results[order, :] | |
| return vote_results[:, :4], vote_results[:, 4], vote_results[:, 5] | |
| def get_dets_sum(vote_det): | |
| vote_det[:, :4] *= vote_det[:, 4:5].repeat(1, 4) | |
| max_score = vote_det[:, 4].max() | |
| det_accu_sum = torch.zeros((1, 6), device=vote_det.device) | |
| det_accu_sum[:, :4] = torch.sum(vote_det[:, :4], dim=0) / torch.sum(vote_det[:, 4]) | |
| det_accu_sum[:, 4] = max_score | |
| det_accu_sum[:, 5] = vote_det[0, 5] | |
| return det_accu_sum | |
| def get_soft_dets_sum(vote_det, vote_det_iou): | |
| soft_vote_det = vote_det.detach().clone() | |
| soft_vote_det[:, 4] *= (1 - vote_det_iou) | |
| INFERENCE_TH = 0.05 | |
| soft_index = torch.where(soft_vote_det[:, 4] >= INFERENCE_TH)[0] | |
| soft_vote_det = soft_vote_det[soft_index, :] | |
| vote_det[:, :4] *= vote_det[:, 4:5].repeat(1, 4) | |
| max_score = vote_det[:, 4].max() | |
| det_accu_sum = torch.zeros((1, 6), device=vote_det.device) | |
| det_accu_sum[:, :4] = torch.sum(vote_det[:, :4], dim=0) / torch.sum(vote_det[:, 4]) | |
| det_accu_sum[:, 4] = max_score | |
| det_accu_sum[:, 5] = vote_det[0, 5] | |
| if soft_vote_det.shape[0] > 0: | |
| det_accu_sum = torch.cat((det_accu_sum, soft_vote_det), dim=0) | |
| return det_accu_sum | |