Spaces:
Sleeping
Sleeping
File size: 4,496 Bytes
749745d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import cv2
import numpy as np
import torch
from torch import nn
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.keypoint import PersonKeypoints
class KeypointPostProcessor(nn.Module):
def __init__(self, keypointer=None):
super(KeypointPostProcessor, self).__init__()
self.keypointer = keypointer
def forward(self, x, boxes):
mask_prob = x
scores = None
if self.keypointer:
mask_prob, scores = self.keypointer(x, boxes)
assert len(boxes) == 1, "Only non-batched inference supported for now"
boxes_per_image = [box.bbox.size(0) for box in boxes]
mask_prob = mask_prob.split(boxes_per_image, dim=0)
scores = scores.split(boxes_per_image, dim=0)
results = []
for prob, box, score in zip(mask_prob, boxes, scores):
bbox = BoxList(box.bbox, box.size, mode="xyxy")
for field in box.fields():
bbox.add_field(field, box.get_field(field))
prob = PersonKeypoints(prob, box.size)
prob.add_field("logits", score)
bbox.add_field("keypoints", prob)
results.append(bbox)
return results
def heatmaps_to_keypoints(maps, rois):
"""Extract predicted keypoint locations from heatmaps. Output has shape
(#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
for each keypoint.
"""
# This function converts a discrete image coordinate in a HEATMAP_SIZE x
# HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
# consistency with keypoints_to_heatmap_labels by using the conversion from
# Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
# continuous coordinate.
offset_x = rois[:, 0]
offset_y = rois[:, 1]
widths = rois[:, 2] - rois[:, 0]
heights = rois[:, 3] - rois[:, 1]
widths = np.maximum(widths, 1)
heights = np.maximum(heights, 1)
widths_ceil = np.ceil(widths)
heights_ceil = np.ceil(heights)
# NCHW to NHWC for use with OpenCV
maps = np.transpose(maps, [0, 2, 3, 1])
min_size = 0 # cfg.KRCNN.INFERENCE_MIN_SIZE
num_keypoints = maps.shape[3]
xy_preds = np.zeros((len(rois), 3, num_keypoints), dtype=np.float32)
end_scores = np.zeros((len(rois), num_keypoints), dtype=np.float32)
for i in range(len(rois)):
if min_size > 0:
roi_map_width = int(np.maximum(widths_ceil[i], min_size))
roi_map_height = int(np.maximum(heights_ceil[i], min_size))
else:
roi_map_width = widths_ceil[i]
roi_map_height = heights_ceil[i]
width_correction = widths[i] / roi_map_width
height_correction = heights[i] / roi_map_height
roi_map = cv2.resize(maps[i], (roi_map_width, roi_map_height), interpolation=cv2.INTER_CUBIC)
# Bring back to CHW
roi_map = np.transpose(roi_map, [2, 0, 1])
# roi_map_probs = scores_to_probs(roi_map.copy())
w = roi_map.shape[2]
pos = roi_map.reshape(num_keypoints, -1).argmax(axis=1)
x_int = pos % w
y_int = (pos - x_int) // w
# assert (roi_map_probs[k, y_int, x_int] ==
# roi_map_probs[k, :, :].max())
x = (x_int + 0.5) * width_correction
y = (y_int + 0.5) * height_correction
xy_preds[i, 0, :] = x + offset_x[i]
xy_preds[i, 1, :] = y + offset_y[i]
xy_preds[i, 2, :] = 1
end_scores[i, :] = roi_map[np.arange(num_keypoints), y_int, x_int]
return np.transpose(xy_preds, [0, 2, 1]), end_scores
class Keypointer(object):
"""
Projects a set of masks in an image on the locations
specified by the bounding boxes
"""
def __init__(self, padding=0):
self.padding = padding
def __call__(self, masks, boxes):
# TODO do this properly
if isinstance(boxes, BoxList):
boxes = [boxes]
assert len(boxes) == 1
result, scores = heatmaps_to_keypoints(masks.detach().cpu().numpy(), boxes[0].bbox.cpu().numpy())
return torch.from_numpy(result).to(masks.device), torch.as_tensor(scores, device=masks.device)
def make_roi_keypoint_post_processor(cfg):
keypointer = Keypointer()
keypoint_post_processor = KeypointPostProcessor(keypointer)
return keypoint_post_processor
|