Robert001's picture
first commit
b334e29
raw
history blame
No virus
39.7 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, xavier_init
from mmcv.runner import force_fp32
from mmdet.core import build_sampler, fast_nms, images_to_levels, multi_apply
from ..builder import HEADS, build_loss
from .anchor_head import AnchorHead
@HEADS.register_module()
class YOLACTHead(AnchorHead):
"""YOLACT box head used in https://arxiv.org/abs/1904.02689.
Note that YOLACT head is a light version of RetinaNet head.
Four differences are described as follows:
1. YOLACT box head has three-times fewer anchors.
2. YOLACT box head shares the convs for box and cls branches.
3. YOLACT box head uses OHEM instead of Focal loss.
4. YOLACT box head predicts a set of mask coefficients for each box.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
anchor_generator (dict): Config dict for anchor generator
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
num_head_convs (int): Number of the conv layers shared by
box and cls branches.
num_protos (int): Number of the mask coefficients.
use_ohem (bool): If true, ``loss_single_OHEM`` will be used for
cls loss calculation. If false, ``loss_single`` will be used.
conv_cfg (dict): Dictionary to construct and config conv layer.
norm_cfg (dict): Dictionary to construct and config norm layer.
"""
def __init__(self,
num_classes,
in_channels,
anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=3,
scales_per_octave=1,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
reduction='none',
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0, loss_weight=1.5),
num_head_convs=1,
num_protos=32,
use_ohem=True,
conv_cfg=None,
norm_cfg=None,
**kwargs):
self.num_head_convs = num_head_convs
self.num_protos = num_protos
self.use_ohem = use_ohem
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
super(YOLACTHead, self).__init__(
num_classes,
in_channels,
loss_cls=loss_cls,
loss_bbox=loss_bbox,
anchor_generator=anchor_generator,
**kwargs)
if self.use_ohem:
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.sampling = False
def _init_layers(self):
"""Initialize layers of the head."""
self.relu = nn.ReLU(inplace=True)
self.head_convs = nn.ModuleList()
for i in range(self.num_head_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.head_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.conv_cls = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
3,
padding=1)
self.conv_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 4, 3, padding=1)
self.conv_coeff = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.num_protos,
3,
padding=1)
def init_weights(self):
"""Initialize weights of the head."""
for m in self.head_convs:
xavier_init(m.conv, distribution='uniform', bias=0)
xavier_init(self.conv_cls, distribution='uniform', bias=0)
xavier_init(self.conv_reg, distribution='uniform', bias=0)
xavier_init(self.conv_coeff, distribution='uniform', bias=0)
def forward_single(self, x):
"""Forward feature of a single scale level.
Args:
x (Tensor): Features of a single scale level.
Returns:
tuple:
cls_score (Tensor): Cls scores for a single scale level \
the channels number is num_anchors * num_classes.
bbox_pred (Tensor): Box energies / deltas for a single scale \
level, the channels number is num_anchors * 4.
coeff_pred (Tensor): Mask coefficients for a single scale \
level, the channels number is num_anchors * num_protos.
"""
for head_conv in self.head_convs:
x = head_conv(x)
cls_score = self.conv_cls(x)
bbox_pred = self.conv_reg(x)
coeff_pred = self.conv_coeff(x).tanh()
return cls_score, bbox_pred, coeff_pred
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""A combination of the func:``AnchorHead.loss`` and
func:``SSDHead.loss``.
When ``self.use_ohem == True``, it functions like ``SSDHead.loss``,
otherwise, it follows ``AnchorHead.loss``. Besides, it additionally
returns ``sampling_results``.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss. Default: None
Returns:
tuple:
dict[str, Tensor]: A dictionary of loss components.
List[:obj:``SamplingResult``]: Sampler results for each image.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.anchor_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
unmap_outputs=not self.use_ohem,
return_sampling_results=True)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg, sampling_results) = cls_reg_targets
if self.use_ohem:
num_images = len(img_metas)
all_cls_scores = torch.cat([
s.permute(0, 2, 3, 1).reshape(
num_images, -1, self.cls_out_channels) for s in cls_scores
], 1)
all_labels = torch.cat(labels_list, -1).view(num_images, -1)
all_label_weights = torch.cat(label_weights_list,
-1).view(num_images, -1)
all_bbox_preds = torch.cat([
b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
for b in bbox_preds
], -2)
all_bbox_targets = torch.cat(bbox_targets_list,
-2).view(num_images, -1, 4)
all_bbox_weights = torch.cat(bbox_weights_list,
-2).view(num_images, -1, 4)
# concat all level anchors to a single tensor
all_anchors = []
for i in range(num_images):
all_anchors.append(torch.cat(anchor_list[i]))
# check NaN and Inf
assert torch.isfinite(all_cls_scores).all().item(), \
'classification scores become infinite or NaN!'
assert torch.isfinite(all_bbox_preds).all().item(), \
'bbox predications become infinite or NaN!'
losses_cls, losses_bbox = multi_apply(
self.loss_single_OHEM,
all_cls_scores,
all_bbox_preds,
all_anchors,
all_labels,
all_label_weights,
all_bbox_targets,
all_bbox_weights,
num_total_samples=num_total_pos)
else:
num_total_samples = (
num_total_pos +
num_total_neg if self.sampling else num_total_pos)
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# concat all level anchors and flags to a single tensor
concat_anchor_list = []
for i in range(len(anchor_list)):
concat_anchor_list.append(torch.cat(anchor_list[i]))
all_anchor_list = images_to_levels(concat_anchor_list,
num_level_anchors)
losses_cls, losses_bbox = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
all_anchor_list,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_samples)
return dict(
loss_cls=losses_cls, loss_bbox=losses_bbox), sampling_results
def loss_single_OHEM(self, cls_score, bbox_pred, anchors, labels,
label_weights, bbox_targets, bbox_weights,
num_total_samples):
""""See func:``SSDHead.loss``."""
loss_cls_all = self.loss_cls(cls_score, labels, label_weights)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero(
as_tuple=False).reshape(-1)
neg_inds = (labels == self.num_classes).nonzero(
as_tuple=False).view(-1)
num_pos_samples = pos_inds.size(0)
if num_pos_samples == 0:
num_neg_samples = neg_inds.size(0)
else:
num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples
if num_neg_samples > neg_inds.size(0):
num_neg_samples = neg_inds.size(0)
topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
loss_cls_pos = loss_cls_all[pos_inds].sum()
loss_cls_neg = topk_loss_cls_neg.sum()
loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
if self.reg_decoded_bbox:
# When the regression loss (e.g. `IouLoss`, `GIouLoss`)
# is applied directly on the decoded bounding boxes, it
# decodes the already encoded coordinates to absolute format.
bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
loss_bbox = self.loss_bbox(
bbox_pred,
bbox_targets,
bbox_weights,
avg_factor=num_total_samples)
return loss_cls[None], loss_bbox
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'coeff_preds'))
def get_bboxes(self,
cls_scores,
bbox_preds,
coeff_preds,
img_metas,
cfg=None,
rescale=False):
""""Similiar to func:``AnchorHead.get_bboxes``, but additionally
processes coeff_preds.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
with shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
coeff_preds (list[Tensor]): Mask coefficients for each scale
level with shape (N, num_anchors * num_protos, H, W)
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used
rescale (bool): If True, return boxes in original image space.
Default: False.
Returns:
list[tuple[Tensor, Tensor, Tensor]]: Each item in result_list is
a 3-tuple. The first item is an (n, 5) tensor, where the
first 4 columns are bounding box positions
(tl_x, tl_y, br_x, br_y) and the 5-th column is a score
between 0 and 1. The second item is an (n,) tensor where each
item is the predicted class label of the corresponding box.
The third item is an (n, num_protos) tensor where each item
is the predicted mask coefficients of instance inside the
corresponding box.
"""
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
det_bboxes = []
det_labels = []
det_coeffs = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
coeff_pred_list = [
coeff_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
bbox_res = self._get_bboxes_single(cls_score_list, bbox_pred_list,
coeff_pred_list, mlvl_anchors,
img_shape, scale_factor, cfg,
rescale)
det_bboxes.append(bbox_res[0])
det_labels.append(bbox_res[1])
det_coeffs.append(bbox_res[2])
return det_bboxes, det_labels, det_coeffs
def _get_bboxes_single(self,
cls_score_list,
bbox_pred_list,
coeff_preds_list,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
""""Similiar to func:``AnchorHead._get_bboxes_single``, but
additionally processes coeff_preds_list and uses fast NMS instead of
traditional NMS.
Args:
cls_score_list (list[Tensor]): Box scores for a single scale level
Has shape (num_anchors * num_classes, H, W).
bbox_pred_list (list[Tensor]): Box energies / deltas for a single
scale level with shape (num_anchors * 4, H, W).
coeff_preds_list (list[Tensor]): Mask coefficients for a single
scale level with shape (num_anchors * num_protos, H, W).
mlvl_anchors (list[Tensor]): Box reference for a single scale level
with shape (num_total_anchors, 4).
img_shape (tuple[int]): Shape of the input image,
(height, width, 3).
scale_factor (ndarray): Scale factor of the image arange as
(w_scale, h_scale, w_scale, h_scale).
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Returns:
tuple[Tensor, Tensor, Tensor]: The first item is an (n, 5) tensor,
where the first 4 columns are bounding box positions
(tl_x, tl_y, br_x, br_y) and the 5-th column is a score between
0 and 1. The second item is an (n,) tensor where each item is
the predicted class label of the corresponding box. The third
item is an (n, num_protos) tensor where each item is the
predicted mask coefficients of instance inside the
corresponding box.
"""
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
mlvl_bboxes = []
mlvl_scores = []
mlvl_coeffs = []
for cls_score, bbox_pred, coeff_pred, anchors in \
zip(cls_score_list, bbox_pred_list,
coeff_preds_list, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
coeff_pred = coeff_pred.permute(1, 2,
0).reshape(-1, self.num_protos)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
else:
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
max_scores, _ = scores[:, :-1].max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
coeff_pred = coeff_pred[topk_inds, :]
bboxes = self.bbox_coder.decode(
anchors, bbox_pred, max_shape=img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_coeffs.append(coeff_pred)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
mlvl_coeffs = torch.cat(mlvl_coeffs)
if self.use_sigmoid_cls:
# Add a dummy background class to the backend when using sigmoid
# remind that we set FG labels to [0, num_class-1] since mmdet v2.0
# BG cat_id: num_class
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
det_bboxes, det_labels, det_coeffs = fast_nms(mlvl_bboxes, mlvl_scores,
mlvl_coeffs,
cfg.score_thr,
cfg.iou_thr, cfg.top_k,
cfg.max_per_img)
return det_bboxes, det_labels, det_coeffs
@HEADS.register_module()
class YOLACTSegmHead(nn.Module):
"""YOLACT segmentation head used in https://arxiv.org/abs/1904.02689.
Apply a semantic segmentation loss on feature space using layers that are
only evaluated during training to increase performance with no speed
penalty.
Args:
in_channels (int): Number of channels in the input feature map.
num_classes (int): Number of categories excluding the background
category.
loss_segm (dict): Config of semantic segmentation loss.
"""
def __init__(self,
num_classes,
in_channels=256,
loss_segm=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0)):
super(YOLACTSegmHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.loss_segm = build_loss(loss_segm)
self._init_layers()
self.fp16_enabled = False
def _init_layers(self):
"""Initialize layers of the head."""
self.segm_conv = nn.Conv2d(
self.in_channels, self.num_classes, kernel_size=1)
def init_weights(self):
"""Initialize weights of the head."""
xavier_init(self.segm_conv, distribution='uniform')
def forward(self, x):
"""Forward feature from the upstream network.
Args:
x (Tensor): Feature from the upstream network, which is
a 4D-tensor.
Returns:
Tensor: Predicted semantic segmentation map with shape
(N, num_classes, H, W).
"""
return self.segm_conv(x)
@force_fp32(apply_to=('segm_pred', ))
def loss(self, segm_pred, gt_masks, gt_labels):
"""Compute loss of the head.
Args:
segm_pred (list[Tensor]): Predicted semantic segmentation map
with shape (N, num_classes, H, W).
gt_masks (list[Tensor]): Ground truth masks for each image with
the same shape of the input image.
gt_labels (list[Tensor]): Class indices corresponding to each box.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
loss_segm = []
num_imgs, num_classes, mask_h, mask_w = segm_pred.size()
for idx in range(num_imgs):
cur_segm_pred = segm_pred[idx]
cur_gt_masks = gt_masks[idx].float()
cur_gt_labels = gt_labels[idx]
segm_targets = self.get_targets(cur_segm_pred, cur_gt_masks,
cur_gt_labels)
if segm_targets is None:
loss = self.loss_segm(cur_segm_pred,
torch.zeros_like(cur_segm_pred),
torch.zeros_like(cur_segm_pred))
else:
loss = self.loss_segm(
cur_segm_pred,
segm_targets,
avg_factor=num_imgs * mask_h * mask_w)
loss_segm.append(loss)
return dict(loss_segm=loss_segm)
def get_targets(self, segm_pred, gt_masks, gt_labels):
"""Compute semantic segmentation targets for each image.
Args:
segm_pred (Tensor): Predicted semantic segmentation map
with shape (num_classes, H, W).
gt_masks (Tensor): Ground truth masks for each image with
the same shape of the input image.
gt_labels (Tensor): Class indices corresponding to each box.
Returns:
Tensor: Semantic segmentation targets with shape
(num_classes, H, W).
"""
if gt_masks.size(0) == 0:
return None
num_classes, mask_h, mask_w = segm_pred.size()
with torch.no_grad():
downsampled_masks = F.interpolate(
gt_masks.unsqueeze(0), (mask_h, mask_w),
mode='bilinear',
align_corners=False).squeeze(0)
downsampled_masks = downsampled_masks.gt(0.5).float()
segm_targets = torch.zeros_like(segm_pred, requires_grad=False)
for obj_idx in range(downsampled_masks.size(0)):
segm_targets[gt_labels[obj_idx] - 1] = torch.max(
segm_targets[gt_labels[obj_idx] - 1],
downsampled_masks[obj_idx])
return segm_targets
@HEADS.register_module()
class YOLACTProtonet(nn.Module):
"""YOLACT mask head used in https://arxiv.org/abs/1904.02689.
This head outputs the mask prototypes for YOLACT.
Args:
in_channels (int): Number of channels in the input feature map.
proto_channels (tuple[int]): Output channels of protonet convs.
proto_kernel_sizes (tuple[int]): Kernel sizes of protonet convs.
include_last_relu (Bool): If keep the last relu of protonet.
num_protos (int): Number of prototypes.
num_classes (int): Number of categories excluding the background
category.
loss_mask_weight (float): Reweight the mask loss by this factor.
max_masks_to_train (int): Maximum number of masks to train for
each image.
"""
def __init__(self,
num_classes,
in_channels=256,
proto_channels=(256, 256, 256, None, 256, 32),
proto_kernel_sizes=(3, 3, 3, -2, 3, 1),
include_last_relu=True,
num_protos=32,
loss_mask_weight=1.0,
max_masks_to_train=100):
super(YOLACTProtonet, self).__init__()
self.in_channels = in_channels
self.proto_channels = proto_channels
self.proto_kernel_sizes = proto_kernel_sizes
self.include_last_relu = include_last_relu
self.protonet = self._init_layers()
self.loss_mask_weight = loss_mask_weight
self.num_protos = num_protos
self.num_classes = num_classes
self.max_masks_to_train = max_masks_to_train
self.fp16_enabled = False
def _init_layers(self):
"""A helper function to take a config setting and turn it into a
network."""
# Possible patterns:
# ( 256, 3) -> conv
# ( 256,-2) -> deconv
# (None,-2) -> bilinear interpolate
in_channels = self.in_channels
protonets = nn.ModuleList()
for num_channels, kernel_size in zip(self.proto_channels,
self.proto_kernel_sizes):
if kernel_size > 0:
layer = nn.Conv2d(
in_channels,
num_channels,
kernel_size,
padding=kernel_size // 2)
else:
if num_channels is None:
layer = InterpolateModule(
scale_factor=-kernel_size,
mode='bilinear',
align_corners=False)
else:
layer = nn.ConvTranspose2d(
in_channels,
num_channels,
-kernel_size,
padding=kernel_size // 2)
protonets.append(layer)
protonets.append(nn.ReLU(inplace=True))
in_channels = num_channels if num_channels is not None \
else in_channels
if not self.include_last_relu:
protonets = protonets[:-1]
return nn.Sequential(*protonets)
def init_weights(self):
"""Initialize weights of the head."""
for m in self.protonet:
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
def forward(self, x, coeff_pred, bboxes, img_meta, sampling_results=None):
"""Forward feature from the upstream network to get prototypes and
linearly combine the prototypes, using masks coefficients, into
instance masks. Finally, crop the instance masks with given bboxes.
Args:
x (Tensor): Feature from the upstream network, which is
a 4D-tensor.
coeff_pred (list[Tensor]): Mask coefficients for each scale
level with shape (N, num_anchors * num_protos, H, W).
bboxes (list[Tensor]): Box used for cropping with shape
(N, num_anchors * 4, H, W). During training, they are
ground truth boxes. During testing, they are predicted
boxes.
img_meta (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
sampling_results (List[:obj:``SamplingResult``]): Sampler results
for each image.
Returns:
list[Tensor]: Predicted instance segmentation masks.
"""
prototypes = self.protonet(x)
prototypes = prototypes.permute(0, 2, 3, 1).contiguous()
num_imgs = x.size(0)
# Training state
if self.training:
coeff_pred_list = []
for coeff_pred_per_level in coeff_pred:
coeff_pred_per_level = \
coeff_pred_per_level.permute(0, 2, 3, 1)\
.reshape(num_imgs, -1, self.num_protos)
coeff_pred_list.append(coeff_pred_per_level)
coeff_pred = torch.cat(coeff_pred_list, dim=1)
mask_pred_list = []
for idx in range(num_imgs):
cur_prototypes = prototypes[idx]
cur_coeff_pred = coeff_pred[idx]
cur_bboxes = bboxes[idx]
cur_img_meta = img_meta[idx]
# Testing state
if not self.training:
bboxes_for_cropping = cur_bboxes
else:
cur_sampling_results = sampling_results[idx]
pos_assigned_gt_inds = \
cur_sampling_results.pos_assigned_gt_inds
bboxes_for_cropping = cur_bboxes[pos_assigned_gt_inds].clone()
pos_inds = cur_sampling_results.pos_inds
cur_coeff_pred = cur_coeff_pred[pos_inds]
# Linearly combine the prototypes with the mask coefficients
mask_pred = cur_prototypes @ cur_coeff_pred.t()
mask_pred = torch.sigmoid(mask_pred)
h, w = cur_img_meta['img_shape'][:2]
bboxes_for_cropping[:, 0] /= w
bboxes_for_cropping[:, 1] /= h
bboxes_for_cropping[:, 2] /= w
bboxes_for_cropping[:, 3] /= h
mask_pred = self.crop(mask_pred, bboxes_for_cropping)
mask_pred = mask_pred.permute(2, 0, 1).contiguous()
mask_pred_list.append(mask_pred)
return mask_pred_list
@force_fp32(apply_to=('mask_pred', ))
def loss(self, mask_pred, gt_masks, gt_bboxes, img_meta, sampling_results):
"""Compute loss of the head.
Args:
mask_pred (list[Tensor]): Predicted prototypes with shape
(num_classes, H, W).
gt_masks (list[Tensor]): Ground truth masks for each image with
the same shape of the input image.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
img_meta (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
sampling_results (List[:obj:``SamplingResult``]): Sampler results
for each image.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
loss_mask = []
num_imgs = len(mask_pred)
total_pos = 0
for idx in range(num_imgs):
cur_mask_pred = mask_pred[idx]
cur_gt_masks = gt_masks[idx].float()
cur_gt_bboxes = gt_bboxes[idx]
cur_img_meta = img_meta[idx]
cur_sampling_results = sampling_results[idx]
pos_assigned_gt_inds = cur_sampling_results.pos_assigned_gt_inds
num_pos = pos_assigned_gt_inds.size(0)
# Since we're producing (near) full image masks,
# it'd take too much vram to backprop on every single mask.
# Thus we select only a subset.
if num_pos > self.max_masks_to_train:
perm = torch.randperm(num_pos)
select = perm[:self.max_masks_to_train]
cur_mask_pred = cur_mask_pred[select]
pos_assigned_gt_inds = pos_assigned_gt_inds[select]
num_pos = self.max_masks_to_train
total_pos += num_pos
gt_bboxes_for_reweight = cur_gt_bboxes[pos_assigned_gt_inds]
mask_targets = self.get_targets(cur_mask_pred, cur_gt_masks,
pos_assigned_gt_inds)
if num_pos == 0:
loss = cur_mask_pred.sum() * 0.
elif mask_targets is None:
loss = F.binary_cross_entropy(cur_mask_pred,
torch.zeros_like(cur_mask_pred),
torch.zeros_like(cur_mask_pred))
else:
cur_mask_pred = torch.clamp(cur_mask_pred, 0, 1)
loss = F.binary_cross_entropy(
cur_mask_pred, mask_targets,
reduction='none') * self.loss_mask_weight
h, w = cur_img_meta['img_shape'][:2]
gt_bboxes_width = (gt_bboxes_for_reweight[:, 2] -
gt_bboxes_for_reweight[:, 0]) / w
gt_bboxes_height = (gt_bboxes_for_reweight[:, 3] -
gt_bboxes_for_reweight[:, 1]) / h
loss = loss.mean(dim=(1,
2)) / gt_bboxes_width / gt_bboxes_height
loss = torch.sum(loss)
loss_mask.append(loss)
if total_pos == 0:
total_pos += 1 # avoid nan
loss_mask = [x / total_pos for x in loss_mask]
return dict(loss_mask=loss_mask)
def get_targets(self, mask_pred, gt_masks, pos_assigned_gt_inds):
"""Compute instance segmentation targets for each image.
Args:
mask_pred (Tensor): Predicted prototypes with shape
(num_classes, H, W).
gt_masks (Tensor): Ground truth masks for each image with
the same shape of the input image.
pos_assigned_gt_inds (Tensor): GT indices of the corresponding
positive samples.
Returns:
Tensor: Instance segmentation targets with shape
(num_instances, H, W).
"""
if gt_masks.size(0) == 0:
return None
mask_h, mask_w = mask_pred.shape[-2:]
gt_masks = F.interpolate(
gt_masks.unsqueeze(0), (mask_h, mask_w),
mode='bilinear',
align_corners=False).squeeze(0)
gt_masks = gt_masks.gt(0.5).float()
mask_targets = gt_masks[pos_assigned_gt_inds]
return mask_targets
def get_seg_masks(self, mask_pred, label_pred, img_meta, rescale):
"""Resize, binarize, and format the instance mask predictions.
Args:
mask_pred (Tensor): shape (N, H, W).
label_pred (Tensor): shape (N, ).
img_meta (dict): Meta information of each image, e.g.,
image size, scaling factor, etc.
rescale (bool): If rescale is False, then returned masks will
fit the scale of imgs[0].
Returns:
list[ndarray]: Mask predictions grouped by their predicted classes.
"""
ori_shape = img_meta['ori_shape']
scale_factor = img_meta['scale_factor']
if rescale:
img_h, img_w = ori_shape[:2]
else:
img_h = np.round(ori_shape[0] * scale_factor[1]).astype(np.int32)
img_w = np.round(ori_shape[1] * scale_factor[0]).astype(np.int32)
cls_segms = [[] for _ in range(self.num_classes)]
if mask_pred.size(0) == 0:
return cls_segms
mask_pred = F.interpolate(
mask_pred.unsqueeze(0), (img_h, img_w),
mode='bilinear',
align_corners=False).squeeze(0) > 0.5
mask_pred = mask_pred.cpu().numpy().astype(np.uint8)
for m, l in zip(mask_pred, label_pred):
cls_segms[l].append(m)
return cls_segms
def crop(self, masks, boxes, padding=1):
"""Crop predicted masks by zeroing out everything not in the predicted
bbox.
Args:
masks (Tensor): shape [H, W, N].
boxes (Tensor): bbox coords in relative point form with
shape [N, 4].
Return:
Tensor: The cropped masks.
"""
h, w, n = masks.size()
x1, x2 = self.sanitize_coordinates(
boxes[:, 0], boxes[:, 2], w, padding, cast=False)
y1, y2 = self.sanitize_coordinates(
boxes[:, 1], boxes[:, 3], h, padding, cast=False)
rows = torch.arange(
w, device=masks.device, dtype=x1.dtype).view(1, -1,
1).expand(h, w, n)
cols = torch.arange(
h, device=masks.device, dtype=x1.dtype).view(-1, 1,
1).expand(h, w, n)
masks_left = rows >= x1.view(1, 1, -1)
masks_right = rows < x2.view(1, 1, -1)
masks_up = cols >= y1.view(1, 1, -1)
masks_down = cols < y2.view(1, 1, -1)
crop_mask = masks_left * masks_right * masks_up * masks_down
return masks * crop_mask.float()
def sanitize_coordinates(self, x1, x2, img_size, padding=0, cast=True):
"""Sanitizes the input coordinates so that x1 < x2, x1 != x2, x1 >= 0,
and x2 <= image_size. Also converts from relative to absolute
coordinates and casts the results to long tensors.
Warning: this does things in-place behind the scenes so
copy if necessary.
Args:
_x1 (Tensor): shape (N, ).
_x2 (Tensor): shape (N, ).
img_size (int): Size of the input image.
padding (int): x1 >= padding, x2 <= image_size-padding.
cast (bool): If cast is false, the result won't be cast to longs.
Returns:
tuple:
x1 (Tensor): Sanitized _x1.
x2 (Tensor): Sanitized _x2.
"""
x1 = x1 * img_size
x2 = x2 * img_size
if cast:
x1 = x1.long()
x2 = x2.long()
x1 = torch.min(x1, x2)
x2 = torch.max(x1, x2)
x1 = torch.clamp(x1 - padding, min=0)
x2 = torch.clamp(x2 + padding, max=img_size)
return x1, x2
class InterpolateModule(nn.Module):
"""This is a module version of F.interpolate.
Any arguments you give it just get passed along for the ride.
"""
def __init__(self, *args, **kwargs):
super().__init__()
self.args = args
self.kwargs = kwargs
def forward(self, x):
"""Forward features from the upstream network."""
return F.interpolate(x, *self.args, **self.kwargs)