# Copyright (c) OpenMMLab. All rights reserved. # Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa import torch import torch.nn as nn from mmcv.cnn import ConvModule try: from mmcv.ops import point_sample except ModuleNotFoundError: point_sample = None from typing import List from mmseg.registry import MODELS from mmseg.utils import SampleList from ..losses import accuracy from ..utils import resize from .cascade_decode_head import BaseCascadeDecodeHead def calculate_uncertainty(seg_logits): """Estimate uncertainty based on seg logits. For each location of the prediction ``seg_logits`` we estimate uncertainty as the difference between top first and top second predicted logits. Args: seg_logits (Tensor): Semantic segmentation logits, shape (batch_size, num_classes, height, width). Returns: scores (Tensor): T uncertainty scores with the most uncertain locations having the highest uncertainty score, shape ( batch_size, 1, height, width) """ top2_scores = torch.topk(seg_logits, k=2, dim=1)[0] return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) @MODELS.register_module() class PointHead(BaseCascadeDecodeHead): """A mask point head use in PointRend. This head is implemented of `PointRend: Image Segmentation as Rendering `_. ``PointHead`` use shared multi-layer perceptron (equivalent to nn.Conv1d) to predict the logit of input points. The fine-grained feature and coarse feature will be concatenate together for predication. Args: num_fcs (int): Number of fc layers in the head. Default: 3. in_channels (int): Number of input channels. Default: 256. fc_channels (int): Number of fc channels. Default: 256. num_classes (int): Number of classes for logits. Default: 80. class_agnostic (bool): Whether use class agnostic classification. If so, the output channels of logits will be 1. Default: False. coarse_pred_each_layer (bool): Whether concatenate coarse feature with the output of each fc layer. Default: True. conv_cfg (dict|None): Dictionary to construct and config conv layer. Default: dict(type='Conv1d')) norm_cfg (dict|None): Dictionary to construct and config norm layer. Default: None. loss_point (dict): Dictionary to construct and config loss layer of point head. Default: dict(type='CrossEntropyLoss', use_mask=True, loss_weight=1.0). """ def __init__(self, num_fcs=3, coarse_pred_each_layer=True, conv_cfg=dict(type='Conv1d'), norm_cfg=None, act_cfg=dict(type='ReLU', inplace=False), **kwargs): super().__init__( input_transform='multiple_select', conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, init_cfg=dict( type='Normal', std=0.01, override=dict(name='fc_seg')), **kwargs) if point_sample is None: raise RuntimeError('Please install mmcv-full for ' 'point_sample ops') self.num_fcs = num_fcs self.coarse_pred_each_layer = coarse_pred_each_layer fc_in_channels = sum(self.in_channels) + self.num_classes fc_channels = self.channels self.fcs = nn.ModuleList() for k in range(num_fcs): fc = ConvModule( fc_in_channels, fc_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.fcs.append(fc) fc_in_channels = fc_channels fc_in_channels += self.num_classes if self.coarse_pred_each_layer \ else 0 self.fc_seg = nn.Conv1d( fc_in_channels, self.num_classes, kernel_size=1, stride=1, padding=0) if self.dropout_ratio > 0: self.dropout = nn.Dropout(self.dropout_ratio) delattr(self, 'conv_seg') def cls_seg(self, feat): """Classify each pixel with fc.""" if self.dropout is not None: feat = self.dropout(feat) output = self.fc_seg(feat) return output def forward(self, fine_grained_point_feats, coarse_point_feats): x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1) for fc in self.fcs: x = fc(x) if self.coarse_pred_each_layer: x = torch.cat((x, coarse_point_feats), dim=1) return self.cls_seg(x) def _get_fine_grained_point_feats(self, x, points): """Sample from fine grained features. Args: x (list[Tensor]): Feature pyramid from by neck or backbone. points (Tensor): Point coordinates, shape (batch_size, num_points, 2). Returns: fine_grained_feats (Tensor): Sampled fine grained feature, shape (batch_size, sum(channels of x), num_points). """ fine_grained_feats_list = [ point_sample(_, points, align_corners=self.align_corners) for _ in x ] if len(fine_grained_feats_list) > 1: fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1) else: fine_grained_feats = fine_grained_feats_list[0] return fine_grained_feats def _get_coarse_point_feats(self, prev_output, points): """Sample from fine grained features. Args: prev_output (list[Tensor]): Prediction of previous decode head. points (Tensor): Point coordinates, shape (batch_size, num_points, 2). Returns: coarse_feats (Tensor): Sampled coarse feature, shape (batch_size, num_classes, num_points). """ coarse_feats = point_sample( prev_output, points, align_corners=self.align_corners) return coarse_feats def loss(self, inputs, prev_output, batch_data_samples: SampleList, train_cfg, **kwargs): """Forward function for training. Args: inputs (list[Tensor]): List of multi-level img features. prev_output (Tensor): The output of previous decode head. batch_data_samples (list[:obj:`SegDataSample`]): The seg data samples. It usually includes information such as `img_metas` or `gt_semantic_seg`. train_cfg (dict): The training config. Returns: dict[str, Tensor]: a dictionary of loss components """ x = self._transform_inputs(inputs) with torch.no_grad(): points = self.get_points_train( prev_output, calculate_uncertainty, cfg=train_cfg) fine_grained_point_feats = self._get_fine_grained_point_feats( x, points) coarse_point_feats = self._get_coarse_point_feats(prev_output, points) point_logits = self.forward(fine_grained_point_feats, coarse_point_feats) losses = self.loss_by_feat(point_logits, points, batch_data_samples) return losses def predict(self, inputs, prev_output, batch_img_metas: List[dict], test_cfg, **kwargs): """Forward function for testing. Args: inputs (list[Tensor]): List of multi-level img features. prev_output (Tensor): The output of previous decode head. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. test_cfg (dict): The testing config. Returns: Tensor: Output segmentation map. """ x = self._transform_inputs(inputs) refined_seg_logits = prev_output.clone() for _ in range(test_cfg.subdivision_steps): refined_seg_logits = resize( refined_seg_logits, scale_factor=test_cfg.scale_factor, mode='bilinear', align_corners=self.align_corners) batch_size, channels, height, width = refined_seg_logits.shape point_indices, points = self.get_points_test( refined_seg_logits, calculate_uncertainty, cfg=test_cfg) fine_grained_point_feats = self._get_fine_grained_point_feats( x, points) coarse_point_feats = self._get_coarse_point_feats( prev_output, points) point_logits = self.forward(fine_grained_point_feats, coarse_point_feats) point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) refined_seg_logits = refined_seg_logits.reshape( batch_size, channels, height * width) refined_seg_logits = refined_seg_logits.scatter_( 2, point_indices, point_logits) refined_seg_logits = refined_seg_logits.view( batch_size, channels, height, width) return self.predict_by_feat(refined_seg_logits, batch_img_metas, **kwargs) def loss_by_feat(self, point_logits, points, batch_data_samples, **kwargs): """Compute segmentation loss.""" gt_semantic_seg = self._stack_batch_gt(batch_data_samples) point_label = point_sample( gt_semantic_seg.float(), points, mode='nearest', align_corners=self.align_corners) point_label = point_label.squeeze(1).long() loss = dict() if not isinstance(self.loss_decode, nn.ModuleList): losses_decode = [self.loss_decode] else: losses_decode = self.loss_decode for loss_module in losses_decode: loss['point' + loss_module.loss_name] = loss_module( point_logits, point_label, ignore_index=self.ignore_index) loss['acc_point'] = accuracy( point_logits, point_label, ignore_index=self.ignore_index) return loss def get_points_train(self, seg_logits, uncertainty_func, cfg): """Sample points for training. Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The uncertainties are calculated for each point using 'uncertainty_func' function that takes point's logit prediction as input. Args: seg_logits (Tensor): Semantic segmentation logits, shape ( batch_size, num_classes, height, width). uncertainty_func (func): uncertainty calculation function. cfg (dict): Training config of point head. Returns: point_coords (Tensor): A tensor of shape (batch_size, num_points, 2) that contains the coordinates of ``num_points`` sampled points. """ num_points = cfg.num_points oversample_ratio = cfg.oversample_ratio importance_sample_ratio = cfg.importance_sample_ratio assert oversample_ratio >= 1 assert 0 <= importance_sample_ratio <= 1 batch_size = seg_logits.shape[0] num_sampled = int(num_points * oversample_ratio) point_coords = torch.rand( batch_size, num_sampled, 2, device=seg_logits.device) point_logits = point_sample(seg_logits, point_coords) # It is crucial to calculate uncertainty based on the sampled # prediction value for the points. Calculating uncertainties of the # coarse predictions first and sampling them for points leads to # incorrect results. To illustrate this: assume uncertainty func( # logits)=-abs(logits), a sampled point between two coarse # predictions with -1 and 1 logits has 0 logits, and therefore 0 # uncertainty value. However, if we calculate uncertainties for the # coarse predictions first, both will have -1 uncertainty, # and sampled point will get -1 uncertainty. point_uncertainties = uncertainty_func(point_logits) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk( point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_sampled * torch.arange( batch_size, dtype=torch.long, device=seg_logits.device) idx += shift[:, None] point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( batch_size, num_uncertain_points, 2) if num_random_points > 0: rand_point_coords = torch.rand( batch_size, num_random_points, 2, device=seg_logits.device) point_coords = torch.cat((point_coords, rand_point_coords), dim=1) return point_coords def get_points_test(self, seg_logits, uncertainty_func, cfg): """Sample points for testing. Find ``num_points`` most uncertain points from ``uncertainty_map``. Args: seg_logits (Tensor): A tensor of shape (batch_size, num_classes, height, width) for class-specific or class-agnostic prediction. uncertainty_func (func): uncertainty calculation function. cfg (dict): Testing config of point head. Returns: point_indices (Tensor): A tensor of shape (batch_size, num_points) that contains indices from [0, height x width) of the most uncertain points. point_coords (Tensor): A tensor of shape (batch_size, num_points, 2) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the ``height x width`` grid . """ num_points = cfg.subdivision_num_points uncertainty_map = uncertainty_func(seg_logits) batch_size, _, height, width = uncertainty_map.shape h_step = 1.0 / height w_step = 1.0 / width uncertainty_map = uncertainty_map.view(batch_size, height * width) num_points = min(height * width, num_points) point_indices = uncertainty_map.topk(num_points, dim=1)[1] point_coords = torch.zeros( batch_size, num_points, 2, dtype=torch.float, device=seg_logits.device) point_coords[:, :, 0] = w_step / 2.0 + (point_indices % width).float() * w_step point_coords[:, :, 1] = h_step / 2.0 + (point_indices // width).float() * h_step return point_indices, point_coords