|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from annotator.uniformer.mmcv.cnn import ConvModule, normal_init |
|
from annotator.uniformer.mmcv.ops import point_sample |
|
|
|
from annotator.uniformer.mmseg.models.builder import HEADS |
|
from annotator.uniformer.mmseg.ops import resize |
|
from ..losses import accuracy |
|
from .cascade_decode_head import BaseCascadeDecodeHead |
|
|
|
|
|
def calculate_uncertainty(seg_logits): |
|
"""Estimate uncertainty based on seg logits. |
|
|
|
For each location of the prediction ``seg_logits`` we estimate |
|
uncertainty as the difference between top first and top second |
|
predicted logits. |
|
|
|
Args: |
|
seg_logits (Tensor): Semantic segmentation logits, |
|
shape (batch_size, num_classes, height, width). |
|
|
|
Returns: |
|
scores (Tensor): T uncertainty scores with the most uncertain |
|
locations having the highest uncertainty score, shape ( |
|
batch_size, 1, height, width) |
|
""" |
|
top2_scores = torch.topk(seg_logits, k=2, dim=1)[0] |
|
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) |
|
|
|
|
|
@HEADS.register_module() |
|
class PointHead(BaseCascadeDecodeHead): |
|
"""A mask point head use in PointRend. |
|
|
|
``PointHead`` use shared multi-layer perceptron (equivalent to |
|
nn.Conv1d) to predict the logit of input points. The fine-grained feature |
|
and coarse feature will be concatenate together for predication. |
|
|
|
Args: |
|
num_fcs (int): Number of fc layers in the head. Default: 3. |
|
in_channels (int): Number of input channels. Default: 256. |
|
fc_channels (int): Number of fc channels. Default: 256. |
|
num_classes (int): Number of classes for logits. Default: 80. |
|
class_agnostic (bool): Whether use class agnostic classification. |
|
If so, the output channels of logits will be 1. Default: False. |
|
coarse_pred_each_layer (bool): Whether concatenate coarse feature with |
|
the output of each fc layer. Default: True. |
|
conv_cfg (dict|None): Dictionary to construct and config conv layer. |
|
Default: dict(type='Conv1d')) |
|
norm_cfg (dict|None): Dictionary to construct and config norm layer. |
|
Default: None. |
|
loss_point (dict): Dictionary to construct and config loss layer of |
|
point head. Default: dict(type='CrossEntropyLoss', use_mask=True, |
|
loss_weight=1.0). |
|
""" |
|
|
|
def __init__(self, |
|
num_fcs=3, |
|
coarse_pred_each_layer=True, |
|
conv_cfg=dict(type='Conv1d'), |
|
norm_cfg=None, |
|
act_cfg=dict(type='ReLU', inplace=False), |
|
**kwargs): |
|
super(PointHead, self).__init__( |
|
input_transform='multiple_select', |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg, |
|
**kwargs) |
|
|
|
self.num_fcs = num_fcs |
|
self.coarse_pred_each_layer = coarse_pred_each_layer |
|
|
|
fc_in_channels = sum(self.in_channels) + self.num_classes |
|
fc_channels = self.channels |
|
self.fcs = nn.ModuleList() |
|
for k in range(num_fcs): |
|
fc = ConvModule( |
|
fc_in_channels, |
|
fc_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
self.fcs.append(fc) |
|
fc_in_channels = fc_channels |
|
fc_in_channels += self.num_classes if self.coarse_pred_each_layer \ |
|
else 0 |
|
self.fc_seg = nn.Conv1d( |
|
fc_in_channels, |
|
self.num_classes, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
if self.dropout_ratio > 0: |
|
self.dropout = nn.Dropout(self.dropout_ratio) |
|
delattr(self, 'conv_seg') |
|
|
|
def init_weights(self): |
|
"""Initialize weights of classification layer.""" |
|
normal_init(self.fc_seg, std=0.001) |
|
|
|
def cls_seg(self, feat): |
|
"""Classify each pixel with fc.""" |
|
if self.dropout is not None: |
|
feat = self.dropout(feat) |
|
output = self.fc_seg(feat) |
|
return output |
|
|
|
def forward(self, fine_grained_point_feats, coarse_point_feats): |
|
x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1) |
|
for fc in self.fcs: |
|
x = fc(x) |
|
if self.coarse_pred_each_layer: |
|
x = torch.cat((x, coarse_point_feats), dim=1) |
|
return self.cls_seg(x) |
|
|
|
def _get_fine_grained_point_feats(self, x, points): |
|
"""Sample from fine grained features. |
|
|
|
Args: |
|
x (list[Tensor]): Feature pyramid from by neck or backbone. |
|
points (Tensor): Point coordinates, shape (batch_size, |
|
num_points, 2). |
|
|
|
Returns: |
|
fine_grained_feats (Tensor): Sampled fine grained feature, |
|
shape (batch_size, sum(channels of x), num_points). |
|
""" |
|
|
|
fine_grained_feats_list = [ |
|
point_sample(_, points, align_corners=self.align_corners) |
|
for _ in x |
|
] |
|
if len(fine_grained_feats_list) > 1: |
|
fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1) |
|
else: |
|
fine_grained_feats = fine_grained_feats_list[0] |
|
|
|
return fine_grained_feats |
|
|
|
def _get_coarse_point_feats(self, prev_output, points): |
|
"""Sample from fine grained features. |
|
|
|
Args: |
|
prev_output (list[Tensor]): Prediction of previous decode head. |
|
points (Tensor): Point coordinates, shape (batch_size, |
|
num_points, 2). |
|
|
|
Returns: |
|
coarse_feats (Tensor): Sampled coarse feature, shape (batch_size, |
|
num_classes, num_points). |
|
""" |
|
|
|
coarse_feats = point_sample( |
|
prev_output, points, align_corners=self.align_corners) |
|
|
|
return coarse_feats |
|
|
|
def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, |
|
train_cfg): |
|
"""Forward function for training. |
|
Args: |
|
inputs (list[Tensor]): List of multi-level img features. |
|
prev_output (Tensor): The output of previous decode head. |
|
img_metas (list[dict]): List of image info dict where each dict |
|
has: 'img_shape', 'scale_factor', 'flip', and may also contain |
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. |
|
For details on the values of these keys see |
|
`mmseg/datasets/pipelines/formatting.py:Collect`. |
|
gt_semantic_seg (Tensor): Semantic segmentation masks |
|
used if the architecture supports semantic segmentation task. |
|
train_cfg (dict): The training config. |
|
|
|
Returns: |
|
dict[str, Tensor]: a dictionary of loss components |
|
""" |
|
x = self._transform_inputs(inputs) |
|
with torch.no_grad(): |
|
points = self.get_points_train( |
|
prev_output, calculate_uncertainty, cfg=train_cfg) |
|
fine_grained_point_feats = self._get_fine_grained_point_feats( |
|
x, points) |
|
coarse_point_feats = self._get_coarse_point_feats(prev_output, points) |
|
point_logits = self.forward(fine_grained_point_feats, |
|
coarse_point_feats) |
|
point_label = point_sample( |
|
gt_semantic_seg.float(), |
|
points, |
|
mode='nearest', |
|
align_corners=self.align_corners) |
|
point_label = point_label.squeeze(1).long() |
|
|
|
losses = self.losses(point_logits, point_label) |
|
|
|
return losses |
|
|
|
def forward_test(self, inputs, prev_output, img_metas, test_cfg): |
|
"""Forward function for testing. |
|
|
|
Args: |
|
inputs (list[Tensor]): List of multi-level img features. |
|
prev_output (Tensor): The output of previous decode head. |
|
img_metas (list[dict]): List of image info dict where each dict |
|
has: 'img_shape', 'scale_factor', 'flip', and may also contain |
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. |
|
For details on the values of these keys see |
|
`mmseg/datasets/pipelines/formatting.py:Collect`. |
|
test_cfg (dict): The testing config. |
|
|
|
Returns: |
|
Tensor: Output segmentation map. |
|
""" |
|
|
|
x = self._transform_inputs(inputs) |
|
refined_seg_logits = prev_output.clone() |
|
for _ in range(test_cfg.subdivision_steps): |
|
refined_seg_logits = resize( |
|
refined_seg_logits, |
|
scale_factor=test_cfg.scale_factor, |
|
mode='bilinear', |
|
align_corners=self.align_corners) |
|
batch_size, channels, height, width = refined_seg_logits.shape |
|
point_indices, points = self.get_points_test( |
|
refined_seg_logits, calculate_uncertainty, cfg=test_cfg) |
|
fine_grained_point_feats = self._get_fine_grained_point_feats( |
|
x, points) |
|
coarse_point_feats = self._get_coarse_point_feats( |
|
prev_output, points) |
|
point_logits = self.forward(fine_grained_point_feats, |
|
coarse_point_feats) |
|
|
|
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) |
|
refined_seg_logits = refined_seg_logits.reshape( |
|
batch_size, channels, height * width) |
|
refined_seg_logits = refined_seg_logits.scatter_( |
|
2, point_indices, point_logits) |
|
refined_seg_logits = refined_seg_logits.view( |
|
batch_size, channels, height, width) |
|
|
|
return refined_seg_logits |
|
|
|
def losses(self, point_logits, point_label): |
|
"""Compute segmentation loss.""" |
|
loss = dict() |
|
loss['loss_point'] = self.loss_decode( |
|
point_logits, point_label, ignore_index=self.ignore_index) |
|
loss['acc_point'] = accuracy(point_logits, point_label) |
|
return loss |
|
|
|
def get_points_train(self, seg_logits, uncertainty_func, cfg): |
|
"""Sample points for training. |
|
|
|
Sample points in [0, 1] x [0, 1] coordinate space based on their |
|
uncertainty. The uncertainties are calculated for each point using |
|
'uncertainty_func' function that takes point's logit prediction as |
|
input. |
|
|
|
Args: |
|
seg_logits (Tensor): Semantic segmentation logits, shape ( |
|
batch_size, num_classes, height, width). |
|
uncertainty_func (func): uncertainty calculation function. |
|
cfg (dict): Training config of point head. |
|
|
|
Returns: |
|
point_coords (Tensor): A tensor of shape (batch_size, num_points, |
|
2) that contains the coordinates of ``num_points`` sampled |
|
points. |
|
""" |
|
num_points = cfg.num_points |
|
oversample_ratio = cfg.oversample_ratio |
|
importance_sample_ratio = cfg.importance_sample_ratio |
|
assert oversample_ratio >= 1 |
|
assert 0 <= importance_sample_ratio <= 1 |
|
batch_size = seg_logits.shape[0] |
|
num_sampled = int(num_points * oversample_ratio) |
|
point_coords = torch.rand( |
|
batch_size, num_sampled, 2, device=seg_logits.device) |
|
point_logits = point_sample(seg_logits, point_coords) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
point_uncertainties = uncertainty_func(point_logits) |
|
num_uncertain_points = int(importance_sample_ratio * num_points) |
|
num_random_points = num_points - num_uncertain_points |
|
idx = torch.topk( |
|
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] |
|
shift = num_sampled * torch.arange( |
|
batch_size, dtype=torch.long, device=seg_logits.device) |
|
idx += shift[:, None] |
|
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( |
|
batch_size, num_uncertain_points, 2) |
|
if num_random_points > 0: |
|
rand_point_coords = torch.rand( |
|
batch_size, num_random_points, 2, device=seg_logits.device) |
|
point_coords = torch.cat((point_coords, rand_point_coords), dim=1) |
|
return point_coords |
|
|
|
def get_points_test(self, seg_logits, uncertainty_func, cfg): |
|
"""Sample points for testing. |
|
|
|
Find ``num_points`` most uncertain points from ``uncertainty_map``. |
|
|
|
Args: |
|
seg_logits (Tensor): A tensor of shape (batch_size, num_classes, |
|
height, width) for class-specific or class-agnostic prediction. |
|
uncertainty_func (func): uncertainty calculation function. |
|
cfg (dict): Testing config of point head. |
|
|
|
Returns: |
|
point_indices (Tensor): A tensor of shape (batch_size, num_points) |
|
that contains indices from [0, height x width) of the most |
|
uncertain points. |
|
point_coords (Tensor): A tensor of shape (batch_size, num_points, |
|
2) that contains [0, 1] x [0, 1] normalized coordinates of the |
|
most uncertain points from the ``height x width`` grid . |
|
""" |
|
|
|
num_points = cfg.subdivision_num_points |
|
uncertainty_map = uncertainty_func(seg_logits) |
|
batch_size, _, height, width = uncertainty_map.shape |
|
h_step = 1.0 / height |
|
w_step = 1.0 / width |
|
|
|
uncertainty_map = uncertainty_map.view(batch_size, height * width) |
|
num_points = min(height * width, num_points) |
|
point_indices = uncertainty_map.topk(num_points, dim=1)[1] |
|
point_coords = torch.zeros( |
|
batch_size, |
|
num_points, |
|
2, |
|
dtype=torch.float, |
|
device=seg_logits.device) |
|
point_coords[:, :, 0] = w_step / 2.0 + (point_indices % |
|
width).float() * w_step |
|
point_coords[:, :, 1] = h_step / 2.0 + (point_indices // |
|
width).float() * h_step |
|
return point_indices, point_coords |
|
|