import cv2 import numpy as np import torch from torch import nn from maskrcnn_benchmark.structures.bounding_box import BoxList from maskrcnn_benchmark.structures.keypoint import PersonKeypoints class KeypointPostProcessor(nn.Module): def __init__(self, keypointer=None): super(KeypointPostProcessor, self).__init__() self.keypointer = keypointer def forward(self, x, boxes): mask_prob = x scores = None if self.keypointer: mask_prob, scores = self.keypointer(x, boxes) assert len(boxes) == 1, "Only non-batched inference supported for now" boxes_per_image = [box.bbox.size(0) for box in boxes] mask_prob = mask_prob.split(boxes_per_image, dim=0) scores = scores.split(boxes_per_image, dim=0) results = [] for prob, box, score in zip(mask_prob, boxes, scores): bbox = BoxList(box.bbox, box.size, mode="xyxy") for field in box.fields(): bbox.add_field(field, box.get_field(field)) prob = PersonKeypoints(prob, box.size) prob.add_field("logits", score) bbox.add_field("keypoints", prob) results.append(bbox) return results def heatmaps_to_keypoints(maps, rois): """Extract predicted keypoint locations from heatmaps. Output has shape (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob) for each keypoint. """ # This function converts a discrete image coordinate in a HEATMAP_SIZE x # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain # consistency with keypoints_to_heatmap_labels by using the conversion from # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a # continuous coordinate. offset_x = rois[:, 0] offset_y = rois[:, 1] widths = rois[:, 2] - rois[:, 0] heights = rois[:, 3] - rois[:, 1] widths = np.maximum(widths, 1) heights = np.maximum(heights, 1) widths_ceil = np.ceil(widths) heights_ceil = np.ceil(heights) # NCHW to NHWC for use with OpenCV maps = np.transpose(maps, [0, 2, 3, 1]) min_size = 0 # cfg.KRCNN.INFERENCE_MIN_SIZE num_keypoints = maps.shape[3] xy_preds = np.zeros((len(rois), 3, num_keypoints), dtype=np.float32) end_scores = np.zeros((len(rois), num_keypoints), dtype=np.float32) for i in range(len(rois)): if min_size > 0: roi_map_width = int(np.maximum(widths_ceil[i], min_size)) roi_map_height = int(np.maximum(heights_ceil[i], min_size)) else: roi_map_width = widths_ceil[i] roi_map_height = heights_ceil[i] width_correction = widths[i] / roi_map_width height_correction = heights[i] / roi_map_height roi_map = cv2.resize(maps[i], (roi_map_width, roi_map_height), interpolation=cv2.INTER_CUBIC) # Bring back to CHW roi_map = np.transpose(roi_map, [2, 0, 1]) # roi_map_probs = scores_to_probs(roi_map.copy()) w = roi_map.shape[2] pos = roi_map.reshape(num_keypoints, -1).argmax(axis=1) x_int = pos % w y_int = (pos - x_int) // w # assert (roi_map_probs[k, y_int, x_int] == # roi_map_probs[k, :, :].max()) x = (x_int + 0.5) * width_correction y = (y_int + 0.5) * height_correction xy_preds[i, 0, :] = x + offset_x[i] xy_preds[i, 1, :] = y + offset_y[i] xy_preds[i, 2, :] = 1 end_scores[i, :] = roi_map[np.arange(num_keypoints), y_int, x_int] return np.transpose(xy_preds, [0, 2, 1]), end_scores class Keypointer(object): """ Projects a set of masks in an image on the locations specified by the bounding boxes """ def __init__(self, padding=0): self.padding = padding def __call__(self, masks, boxes): # TODO do this properly if isinstance(boxes, BoxList): boxes = [boxes] assert len(boxes) == 1 result, scores = heatmaps_to_keypoints(masks.detach().cpu().numpy(), boxes[0].bbox.cpu().numpy()) return torch.from_numpy(result).to(masks.device), torch.as_tensor(scores, device=masks.device) def make_roi_keypoint_post_processor(cfg): keypointer = Keypointer() keypoint_post_processor = KeypointPostProcessor(keypointer) return keypoint_post_processor