# Copyright (c) Facebook, Inc. and its affiliates. import numpy as np from typing import Dict import torch from torch import nn from torch.nn import functional as F from detectron2.layers import ShapeSpec, cat from detectron2.modeling import SEM_SEG_HEADS_REGISTRY from .point_features import ( get_uncertain_point_coords_on_grid, get_uncertain_point_coords_with_randomness, point_sample, ) from .point_head import build_point_head def calculate_uncertainty(sem_seg_logits): """ For each location of the prediction `sem_seg_logits` we estimate uncerainty as the difference between top first and top second predicted logits. Args: mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and C is the number of foreground classes. The values are logits. Returns: scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with the most uncertain locations having the highest uncertainty score. """ top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0] return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) @SEM_SEG_HEADS_REGISTRY.register() class PointRendSemSegHead(nn.Module): """ A semantic segmentation head that combines a head set in `POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME` and a point head set in `MODEL.POINT_HEAD.NAME`. """ def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): super().__init__() self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE self.coarse_sem_seg_head = SEM_SEG_HEADS_REGISTRY.get( cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME )(cfg, input_shape) self._init_point_head(cfg, input_shape) def _init_point_head(self, cfg, input_shape: Dict[str, ShapeSpec]): # fmt: off assert cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES feature_channels = {k: v.channels for k, v in input_shape.items()} self.in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES self.train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS self.oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO self.importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO self.subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS self.subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS # fmt: on in_channels = int(np.sum([feature_channels[f] for f in self.in_features])) self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) def forward(self, features, targets=None): coarse_sem_seg_logits = self.coarse_sem_seg_head.layers(features) if self.training: losses = self.coarse_sem_seg_head.losses(coarse_sem_seg_logits, targets) with torch.no_grad(): point_coords = get_uncertain_point_coords_with_randomness( coarse_sem_seg_logits, calculate_uncertainty, self.train_num_points, self.oversample_ratio, self.importance_sample_ratio, ) coarse_features = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False) fine_grained_features = cat( [ point_sample(features[in_feature], point_coords, align_corners=False) for in_feature in self.in_features ], dim=1, ) point_logits = self.point_head(fine_grained_features, coarse_features) point_targets = ( point_sample( targets.unsqueeze(1).to(torch.float), point_coords, mode="nearest", align_corners=False, ) .squeeze(1) .to(torch.long) ) losses["loss_sem_seg_point"] = F.cross_entropy( point_logits, point_targets, reduction="mean", ignore_index=self.ignore_value ) return None, losses else: sem_seg_logits = coarse_sem_seg_logits.clone() for _ in range(self.subdivision_steps): sem_seg_logits = F.interpolate( sem_seg_logits, scale_factor=2, mode="bilinear", align_corners=False ) uncertainty_map = calculate_uncertainty(sem_seg_logits) point_indices, point_coords = get_uncertain_point_coords_on_grid( uncertainty_map, self.subdivision_num_points ) fine_grained_features = cat( [ point_sample(features[in_feature], point_coords, align_corners=False) for in_feature in self.in_features ] ) coarse_features = point_sample( coarse_sem_seg_logits, point_coords, align_corners=False ) point_logits = self.point_head(fine_grained_features, coarse_features) # put sem seg point predictions to the right places on the upsampled grid. N, C, H, W = sem_seg_logits.shape point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) sem_seg_logits = ( sem_seg_logits.reshape(N, C, H * W) .scatter_(2, point_indices, point_logits) .view(N, C, H, W) ) return sem_seg_logits, {}