Spaces:
Runtime error
Runtime error
from abc import ABCMeta, abstractmethod | |
import torch.nn as nn | |
class BaseDenseHead(nn.Module, metaclass=ABCMeta): | |
"""Base class for DenseHeads.""" | |
def __init__(self): | |
super(BaseDenseHead, self).__init__() | |
def loss(self, **kwargs): | |
"""Compute losses of the head.""" | |
pass | |
def get_bboxes(self, **kwargs): | |
"""Transform network output for a batch into bbox predictions.""" | |
pass | |
def forward_train(self, | |
x, | |
img_metas, | |
gt_bboxes, | |
gt_labels=None, | |
gt_bboxes_ignore=None, | |
proposal_cfg=None, | |
**kwargs): | |
""" | |
Args: | |
x (list[Tensor]): Features from FPN. | |
img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
gt_bboxes (Tensor): Ground truth bboxes of the image, | |
shape (num_gts, 4). | |
gt_labels (Tensor): Ground truth labels of each box, | |
shape (num_gts,). | |
gt_bboxes_ignore (Tensor): Ground truth bboxes to be | |
ignored, shape (num_ignored_gts, 4). | |
proposal_cfg (mmcv.Config): Test / postprocessing configuration, | |
if None, test_cfg would be used | |
Returns: | |
tuple: | |
losses: (dict[str, Tensor]): A dictionary of loss components. | |
proposal_list (list[Tensor]): Proposals of each image. | |
""" | |
outs = self(x) | |
if gt_labels is None: | |
loss_inputs = outs + (gt_bboxes, img_metas) | |
else: | |
loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) | |
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) | |
if proposal_cfg is None: | |
return losses | |
else: | |
proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg) | |
return losses, proposal_list | |