Spaces:
Runtime error
Runtime error
import torch | |
from mmdet.core import bbox2roi | |
from ..builder import HEADS, build_head | |
from .standard_roi_head import StandardRoIHead | |
class MaskScoringRoIHead(StandardRoIHead): | |
"""Mask Scoring RoIHead for Mask Scoring RCNN. | |
https://arxiv.org/abs/1903.00241 | |
""" | |
def __init__(self, mask_iou_head, **kwargs): | |
assert mask_iou_head is not None | |
super(MaskScoringRoIHead, self).__init__(**kwargs) | |
self.mask_iou_head = build_head(mask_iou_head) | |
def init_weights(self, pretrained): | |
"""Initialize the weights in head. | |
Args: | |
pretrained (str, optional): Path to pre-trained weights. | |
Defaults to None. | |
""" | |
super(MaskScoringRoIHead, self).init_weights(pretrained) | |
self.mask_iou_head.init_weights() | |
def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks, | |
img_metas): | |
"""Run forward function and calculate loss for Mask head in | |
training.""" | |
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) | |
mask_results = super(MaskScoringRoIHead, | |
self)._mask_forward_train(x, sampling_results, | |
bbox_feats, gt_masks, | |
img_metas) | |
if mask_results['loss_mask'] is None: | |
return mask_results | |
# mask iou head forward and loss | |
pos_mask_pred = mask_results['mask_pred'][ | |
range(mask_results['mask_pred'].size(0)), pos_labels] | |
mask_iou_pred = self.mask_iou_head(mask_results['mask_feats'], | |
pos_mask_pred) | |
pos_mask_iou_pred = mask_iou_pred[range(mask_iou_pred.size(0)), | |
pos_labels] | |
mask_iou_targets = self.mask_iou_head.get_targets( | |
sampling_results, gt_masks, pos_mask_pred, | |
mask_results['mask_targets'], self.train_cfg) | |
loss_mask_iou = self.mask_iou_head.loss(pos_mask_iou_pred, | |
mask_iou_targets) | |
mask_results['loss_mask'].update(loss_mask_iou) | |
return mask_results | |
def simple_test_mask(self, | |
x, | |
img_metas, | |
det_bboxes, | |
det_labels, | |
rescale=False): | |
"""Obtain mask prediction without augmentation.""" | |
# image shapes of images in the batch | |
ori_shapes = tuple(meta['ori_shape'] for meta in img_metas) | |
scale_factors = tuple(meta['scale_factor'] for meta in img_metas) | |
num_imgs = len(det_bboxes) | |
if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes): | |
num_classes = self.mask_head.num_classes | |
segm_results = [[[] for _ in range(num_classes)] | |
for _ in range(num_imgs)] | |
mask_scores = [[[] for _ in range(num_classes)] | |
for _ in range(num_imgs)] | |
else: | |
# if det_bboxes is rescaled to the original image size, we need to | |
# rescale it back to the testing scale to obtain RoIs. | |
if rescale and not isinstance(scale_factors[0], float): | |
scale_factors = [ | |
torch.from_numpy(scale_factor).to(det_bboxes[0].device) | |
for scale_factor in scale_factors | |
] | |
_bboxes = [ | |
det_bboxes[i][:, :4] * | |
scale_factors[i] if rescale else det_bboxes[i] | |
for i in range(num_imgs) | |
] | |
mask_rois = bbox2roi(_bboxes) | |
mask_results = self._mask_forward(x, mask_rois) | |
concat_det_labels = torch.cat(det_labels) | |
# get mask scores with mask iou head | |
mask_feats = mask_results['mask_feats'] | |
mask_pred = mask_results['mask_pred'] | |
mask_iou_pred = self.mask_iou_head( | |
mask_feats, mask_pred[range(concat_det_labels.size(0)), | |
concat_det_labels]) | |
# split batch mask prediction back to each image | |
num_bboxes_per_img = tuple(len(_bbox) for _bbox in _bboxes) | |
mask_preds = mask_pred.split(num_bboxes_per_img, 0) | |
mask_iou_preds = mask_iou_pred.split(num_bboxes_per_img, 0) | |
# apply mask post-processing to each image individually | |
segm_results = [] | |
mask_scores = [] | |
for i in range(num_imgs): | |
if det_bboxes[i].shape[0] == 0: | |
segm_results.append( | |
[[] for _ in range(self.mask_head.num_classes)]) | |
mask_scores.append( | |
[[] for _ in range(self.mask_head.num_classes)]) | |
else: | |
segm_result = self.mask_head.get_seg_masks( | |
mask_preds[i], _bboxes[i], det_labels[i], | |
self.test_cfg, ori_shapes[i], scale_factors[i], | |
rescale) | |
# get mask scores with mask iou head | |
mask_score = self.mask_iou_head.get_mask_scores( | |
mask_iou_preds[i], det_bboxes[i], det_labels[i]) | |
segm_results.append(segm_result) | |
mask_scores.append(mask_score) | |
return list(zip(segm_results, mask_scores)) | |