Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from abc import ABCMeta, abstractmethod | |
import torch.nn.functional as F | |
from mmcv.runner import BaseModule, force_fp32 | |
from ..builder import build_loss | |
from ..utils import interpolate_as | |
class BaseSemanticHead(BaseModule, metaclass=ABCMeta): | |
"""Base module of Semantic Head. | |
Args: | |
num_classes (int): the number of classes. | |
init_cfg (dict): the initialization config. | |
loss_seg (dict): the loss of the semantic head. | |
""" | |
def __init__(self, | |
num_classes, | |
init_cfg=None, | |
loss_seg=dict( | |
type='CrossEntropyLoss', | |
ignore_index=255, | |
loss_weight=1.0)): | |
super(BaseSemanticHead, self).__init__(init_cfg) | |
self.loss_seg = build_loss(loss_seg) | |
self.num_classes = num_classes | |
def loss(self, seg_preds, gt_semantic_seg): | |
"""Get the loss of semantic head. | |
Args: | |
seg_preds (Tensor): The input logits with the shape (N, C, H, W). | |
gt_semantic_seg: The ground truth of semantic segmentation with | |
the shape (N, H, W). | |
label_bias: The starting number of the semantic label. | |
Default: 1. | |
Returns: | |
dict: the loss of semantic head. | |
""" | |
if seg_preds.shape[-2:] != gt_semantic_seg.shape[-2:]: | |
seg_preds = interpolate_as(seg_preds, gt_semantic_seg) | |
seg_preds = seg_preds.permute((0, 2, 3, 1)) | |
loss_seg = self.loss_seg( | |
seg_preds.reshape(-1, self.num_classes), # => [NxHxW, C] | |
gt_semantic_seg.reshape(-1).long()) | |
return dict(loss_seg=loss_seg) | |
def forward(self, x): | |
"""Placeholder of forward function. | |
Returns: | |
dict[str, Tensor]: A dictionary, including features | |
and predicted scores. Required keys: 'seg_preds' | |
and 'feats'. | |
""" | |
pass | |
def forward_train(self, x, gt_semantic_seg): | |
output = self.forward(x) | |
seg_preds = output['seg_preds'] | |
return self.loss(seg_preds, gt_semantic_seg) | |
def simple_test(self, x, img_metas, rescale=False): | |
output = self.forward(x) | |
seg_preds = output['seg_preds'] | |
seg_preds = F.interpolate( | |
seg_preds, | |
size=img_metas[0]['pad_shape'][:2], | |
mode='bilinear', | |
align_corners=False) | |
if rescale: | |
h, w, _ = img_metas[0]['img_shape'] | |
seg_preds = seg_preds[:, :, :h, :w] | |
h, w, _ = img_metas[0]['ori_shape'] | |
seg_preds = F.interpolate( | |
seg_preds, size=(h, w), mode='bilinear', align_corners=False) | |
return seg_preds | |