|  |  | 
					
						
						|  | from typing import List | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  | from torch.nn import functional as F | 
					
						
						|  |  | 
					
						
						|  | from detectron2.config import configurable | 
					
						
						|  | from detectron2.layers import Conv2d, ConvTranspose2d, cat, interpolate | 
					
						
						|  | from detectron2.structures import Instances, heatmaps_to_keypoints | 
					
						
						|  | from detectron2.utils.events import get_event_storage | 
					
						
						|  | from detectron2.utils.registry import Registry | 
					
						
						|  |  | 
					
						
						|  | _TOTAL_SKIPPED = 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | __all__ = [ | 
					
						
						|  | "ROI_KEYPOINT_HEAD_REGISTRY", | 
					
						
						|  | "build_keypoint_head", | 
					
						
						|  | "BaseKeypointRCNNHead", | 
					
						
						|  | "KRCNNConvDeconvUpsampleHead", | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ROI_KEYPOINT_HEAD_REGISTRY = Registry("ROI_KEYPOINT_HEAD") | 
					
						
						|  | ROI_KEYPOINT_HEAD_REGISTRY.__doc__ = """ | 
					
						
						|  | Registry for keypoint heads, which make keypoint predictions from per-region features. | 
					
						
						|  |  | 
					
						
						|  | The registered object will be called with `obj(cfg, input_shape)`. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def build_keypoint_head(cfg, input_shape): | 
					
						
						|  | """ | 
					
						
						|  | Build a keypoint head from `cfg.MODEL.ROI_KEYPOINT_HEAD.NAME`. | 
					
						
						|  | """ | 
					
						
						|  | name = cfg.MODEL.ROI_KEYPOINT_HEAD.NAME | 
					
						
						|  | return ROI_KEYPOINT_HEAD_REGISTRY.get(name)(cfg, input_shape) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def keypoint_rcnn_loss(pred_keypoint_logits, instances, normalizer): | 
					
						
						|  | """ | 
					
						
						|  | Arguments: | 
					
						
						|  | pred_keypoint_logits (Tensor): A tensor of shape (N, K, S, S) where N is the total number | 
					
						
						|  | of instances in the batch, K is the number of keypoints, and S is the side length | 
					
						
						|  | of the keypoint heatmap. The values are spatial logits. | 
					
						
						|  | instances (list[Instances]): A list of M Instances, where M is the batch size. | 
					
						
						|  | These instances are predictions from the model | 
					
						
						|  | that are in 1:1 correspondence with pred_keypoint_logits. | 
					
						
						|  | Each Instances should contain a `gt_keypoints` field containing a `structures.Keypoint` | 
					
						
						|  | instance. | 
					
						
						|  | normalizer (float): Normalize the loss by this amount. | 
					
						
						|  | If not specified, we normalize by the number of visible keypoints in the minibatch. | 
					
						
						|  |  | 
					
						
						|  | Returns a scalar tensor containing the loss. | 
					
						
						|  | """ | 
					
						
						|  | heatmaps = [] | 
					
						
						|  | valid = [] | 
					
						
						|  |  | 
					
						
						|  | keypoint_side_len = pred_keypoint_logits.shape[2] | 
					
						
						|  | for instances_per_image in instances: | 
					
						
						|  | if len(instances_per_image) == 0: | 
					
						
						|  | continue | 
					
						
						|  | keypoints = instances_per_image.gt_keypoints | 
					
						
						|  | heatmaps_per_image, valid_per_image = keypoints.to_heatmap( | 
					
						
						|  | instances_per_image.proposal_boxes.tensor, keypoint_side_len | 
					
						
						|  | ) | 
					
						
						|  | heatmaps.append(heatmaps_per_image.view(-1)) | 
					
						
						|  | valid.append(valid_per_image.view(-1)) | 
					
						
						|  |  | 
					
						
						|  | if len(heatmaps): | 
					
						
						|  | keypoint_targets = cat(heatmaps, dim=0) | 
					
						
						|  | valid = cat(valid, dim=0).to(dtype=torch.uint8) | 
					
						
						|  | valid = torch.nonzero(valid).squeeze(1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(heatmaps) == 0 or valid.numel() == 0: | 
					
						
						|  | global _TOTAL_SKIPPED | 
					
						
						|  | _TOTAL_SKIPPED += 1 | 
					
						
						|  | storage = get_event_storage() | 
					
						
						|  | storage.put_scalar("kpts_num_skipped_batches", _TOTAL_SKIPPED, smoothing_hint=False) | 
					
						
						|  | return pred_keypoint_logits.sum() * 0 | 
					
						
						|  |  | 
					
						
						|  | N, K, H, W = pred_keypoint_logits.shape | 
					
						
						|  | pred_keypoint_logits = pred_keypoint_logits.view(N * K, H * W) | 
					
						
						|  |  | 
					
						
						|  | keypoint_loss = F.cross_entropy( | 
					
						
						|  | pred_keypoint_logits[valid], keypoint_targets[valid], reduction="sum" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if normalizer is None: | 
					
						
						|  | normalizer = valid.numel() | 
					
						
						|  | keypoint_loss /= normalizer | 
					
						
						|  |  | 
					
						
						|  | return keypoint_loss | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def keypoint_rcnn_inference(pred_keypoint_logits: torch.Tensor, pred_instances: List[Instances]): | 
					
						
						|  | """ | 
					
						
						|  | Post process each predicted keypoint heatmap in `pred_keypoint_logits` into (x, y, score) | 
					
						
						|  | and add it to the `pred_instances` as a `pred_keypoints` field. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | pred_keypoint_logits (Tensor): A tensor of shape (R, K, S, S) where R is the total number | 
					
						
						|  | of instances in the batch, K is the number of keypoints, and S is the side length of | 
					
						
						|  | the keypoint heatmap. The values are spatial logits. | 
					
						
						|  | pred_instances (list[Instances]): A list of N Instances, where N is the number of images. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | None. Each element in pred_instances will contain extra "pred_keypoints" and | 
					
						
						|  | "pred_keypoint_heatmaps" fields. "pred_keypoints" is a tensor of shape | 
					
						
						|  | (#instance, K, 3) where the last dimension corresponds to (x, y, score). | 
					
						
						|  | The scores are larger than 0. "pred_keypoint_heatmaps" contains the raw | 
					
						
						|  | keypoint logits as passed to this function. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | bboxes_flat = cat([b.pred_boxes.tensor for b in pred_instances], dim=0) | 
					
						
						|  |  | 
					
						
						|  | pred_keypoint_logits = pred_keypoint_logits.detach() | 
					
						
						|  | keypoint_results = heatmaps_to_keypoints(pred_keypoint_logits, bboxes_flat.detach()) | 
					
						
						|  | num_instances_per_image = [len(i) for i in pred_instances] | 
					
						
						|  | keypoint_results = keypoint_results[:, :, [0, 1, 3]].split(num_instances_per_image, dim=0) | 
					
						
						|  | heatmap_results = pred_keypoint_logits.split(num_instances_per_image, dim=0) | 
					
						
						|  |  | 
					
						
						|  | for keypoint_results_per_image, heatmap_results_per_image, instances_per_image in zip( | 
					
						
						|  | keypoint_results, heatmap_results, pred_instances | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | instances_per_image.pred_keypoints = keypoint_results_per_image | 
					
						
						|  | instances_per_image.pred_keypoint_heatmaps = heatmap_results_per_image | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BaseKeypointRCNNHead(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | Implement the basic Keypoint R-CNN losses and inference logic described in | 
					
						
						|  | Sec. 5 of :paper:`Mask R-CNN`. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | @configurable | 
					
						
						|  | def __init__(self, *, num_keypoints, loss_weight=1.0, loss_normalizer=1.0): | 
					
						
						|  | """ | 
					
						
						|  | NOTE: this interface is experimental. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | num_keypoints (int): number of keypoints to predict | 
					
						
						|  | loss_weight (float): weight to multiple on the keypoint loss | 
					
						
						|  | loss_normalizer (float or str): | 
					
						
						|  | If float, divide the loss by `loss_normalizer * #images`. | 
					
						
						|  | If 'visible', the loss is normalized by the total number of | 
					
						
						|  | visible keypoints across images. | 
					
						
						|  | """ | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.num_keypoints = num_keypoints | 
					
						
						|  | self.loss_weight = loss_weight | 
					
						
						|  | assert loss_normalizer == "visible" or isinstance(loss_normalizer, float), loss_normalizer | 
					
						
						|  | self.loss_normalizer = loss_normalizer | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def from_config(cls, cfg, input_shape): | 
					
						
						|  | ret = { | 
					
						
						|  | "loss_weight": cfg.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT, | 
					
						
						|  | "num_keypoints": cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS, | 
					
						
						|  | } | 
					
						
						|  | normalize_by_visible = ( | 
					
						
						|  | cfg.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS | 
					
						
						|  | ) | 
					
						
						|  | if not normalize_by_visible: | 
					
						
						|  | batch_size_per_image = cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE | 
					
						
						|  | positive_sample_fraction = cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION | 
					
						
						|  | ret["loss_normalizer"] = ( | 
					
						
						|  | ret["num_keypoints"] * batch_size_per_image * positive_sample_fraction | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | ret["loss_normalizer"] = "visible" | 
					
						
						|  | return ret | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, instances: List[Instances]): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | x: input 4D region feature(s) provided by :class:`ROIHeads`. | 
					
						
						|  | instances (list[Instances]): contains the boxes & labels corresponding | 
					
						
						|  | to the input features. | 
					
						
						|  | Exact format is up to its caller to decide. | 
					
						
						|  | Typically, this is the foreground instances in training, with | 
					
						
						|  | "proposal_boxes" field and other gt annotations. | 
					
						
						|  | In inference, it contains boxes that are already predicted. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | A dict of losses if in training. The predicted "instances" if in inference. | 
					
						
						|  | """ | 
					
						
						|  | x = self.layers(x) | 
					
						
						|  | if self.training: | 
					
						
						|  | num_images = len(instances) | 
					
						
						|  | normalizer = ( | 
					
						
						|  | None if self.loss_normalizer == "visible" else num_images * self.loss_normalizer | 
					
						
						|  | ) | 
					
						
						|  | return { | 
					
						
						|  | "loss_keypoint": keypoint_rcnn_loss(x, instances, normalizer=normalizer) | 
					
						
						|  | * self.loss_weight | 
					
						
						|  | } | 
					
						
						|  | else: | 
					
						
						|  | keypoint_rcnn_inference(x, instances) | 
					
						
						|  | return instances | 
					
						
						|  |  | 
					
						
						|  | def layers(self, x): | 
					
						
						|  | """ | 
					
						
						|  | Neural network layers that makes predictions from regional input features. | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @ROI_KEYPOINT_HEAD_REGISTRY.register() | 
					
						
						|  | class KRCNNConvDeconvUpsampleHead(BaseKeypointRCNNHead, nn.Sequential): | 
					
						
						|  | """ | 
					
						
						|  | A standard keypoint head containing a series of 3x3 convs, followed by | 
					
						
						|  | a transpose convolution and bilinear interpolation for upsampling. | 
					
						
						|  | It is described in Sec. 5 of :paper:`Mask R-CNN`. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | @configurable | 
					
						
						|  | def __init__(self, input_shape, *, num_keypoints, conv_dims, **kwargs): | 
					
						
						|  | """ | 
					
						
						|  | NOTE: this interface is experimental. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | input_shape (ShapeSpec): shape of the input feature | 
					
						
						|  | conv_dims: an iterable of output channel counts for each conv in the head | 
					
						
						|  | e.g. (512, 512, 512) for three convs outputting 512 channels. | 
					
						
						|  | """ | 
					
						
						|  | super().__init__(num_keypoints=num_keypoints, **kwargs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | up_scale = 2.0 | 
					
						
						|  | in_channels = input_shape.channels | 
					
						
						|  |  | 
					
						
						|  | for idx, layer_channels in enumerate(conv_dims, 1): | 
					
						
						|  | module = Conv2d(in_channels, layer_channels, 3, stride=1, padding=1) | 
					
						
						|  | self.add_module("conv_fcn{}".format(idx), module) | 
					
						
						|  | self.add_module("conv_fcn_relu{}".format(idx), nn.ReLU()) | 
					
						
						|  | in_channels = layer_channels | 
					
						
						|  |  | 
					
						
						|  | deconv_kernel = 4 | 
					
						
						|  | self.score_lowres = ConvTranspose2d( | 
					
						
						|  | in_channels, num_keypoints, deconv_kernel, stride=2, padding=deconv_kernel // 2 - 1 | 
					
						
						|  | ) | 
					
						
						|  | self.up_scale = up_scale | 
					
						
						|  |  | 
					
						
						|  | for name, param in self.named_parameters(): | 
					
						
						|  | if "bias" in name: | 
					
						
						|  | nn.init.constant_(param, 0) | 
					
						
						|  | elif "weight" in name: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def from_config(cls, cfg, input_shape): | 
					
						
						|  | ret = super().from_config(cfg, input_shape) | 
					
						
						|  | ret["input_shape"] = input_shape | 
					
						
						|  | ret["conv_dims"] = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS | 
					
						
						|  | return ret | 
					
						
						|  |  | 
					
						
						|  | def layers(self, x): | 
					
						
						|  | for layer in self: | 
					
						
						|  | x = layer(x) | 
					
						
						|  | x = interpolate(x, scale_factor=self.up_scale, mode="bilinear", align_corners=False) | 
					
						
						|  | return x | 
					
						
						|  |  |