sunnychenxiwang's picture
Upload 1600 files
14c9181 verified
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from mmengine.structures import LabelData
from torch import Tensor
from mmocr.registry import MODELS, TASK_UTILS
from mmocr.structures import TextRecogDataSample # noqa F401
from mmocr.utils import DetSampleList, OptMultiConfig, RecSampleList
from .base_roi_head import BaseRoIHead
@MODELS.register_module()
class RecRoIHead(BaseRoIHead):
"""Simplest base roi head including one bbox head and one mask head."""
def __init__(self,
neck=None,
sampler: OptMultiConfig = None,
roi_extractor: OptMultiConfig = None,
rec_head: OptMultiConfig = None,
init_cfg=None):
super().__init__(init_cfg)
if sampler is not None:
self.sampler = TASK_UTILS.build(sampler)
if neck is not None:
self.neck = MODELS.build(neck)
self.roi_extractor = MODELS.build(roi_extractor)
self.rec_head = MODELS.build(rec_head)
def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
"""Perform forward propagation and loss calculation of the detection
roi on the features of the upstream network.
Args:
x (tuple[Tensor]): List of multi-level img features.
rpn_results_list (list[:obj:`InstanceData`]): List of region
proposals.
DetSampleList (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
dict[str, Tensor]: A dictionary of loss components
"""
proposals = [
ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
]
proposals = [p for p in proposals if len(p) > 0]
bbox_feats = self.roi_extractor(inputs, proposals)
rec_data_samples = [
TextRecogDataSample(gt_text=LabelData(item=text))
for proposal in proposals for text in proposal.texts
]
return self.rec_head.loss(bbox_feats, rec_data_samples)
def predict(self, inputs: Tuple[Tensor],
data_samples: DetSampleList) -> RecSampleList:
if hasattr(self, 'neck') and self.neck is not None:
inputs = self.neck(inputs)
pred_instances = [ds.pred_instances for ds in data_samples]
bbox_feats = self.roi_extractor(inputs, pred_instances)
if bbox_feats.size(0) == 0:
return []
len_instance = sum(
[len(instance_data) for instance_data in pred_instances])
rec_data_samples = [TextRecogDataSample() for _ in range(len_instance)]
rec_data_samples = self.rec_head.predict(bbox_feats, rec_data_samples)
return rec_data_samples