Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Tuple | |
from mmengine.structures import LabelData | |
from torch import Tensor | |
from mmocr.registry import MODELS, TASK_UTILS | |
from mmocr.structures import TextRecogDataSample # noqa F401 | |
from mmocr.utils import DetSampleList, OptMultiConfig, RecSampleList | |
from .base_roi_head import BaseRoIHead | |
class RecRoIHead(BaseRoIHead): | |
"""Simplest base roi head including one bbox head and one mask head.""" | |
def __init__(self, | |
neck=None, | |
sampler: OptMultiConfig = None, | |
roi_extractor: OptMultiConfig = None, | |
rec_head: OptMultiConfig = None, | |
init_cfg=None): | |
super().__init__(init_cfg) | |
if sampler is not None: | |
self.sampler = TASK_UTILS.build(sampler) | |
if neck is not None: | |
self.neck = MODELS.build(neck) | |
self.roi_extractor = MODELS.build(roi_extractor) | |
self.rec_head = MODELS.build(rec_head) | |
def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict: | |
"""Perform forward propagation and loss calculation of the detection | |
roi on the features of the upstream network. | |
Args: | |
x (tuple[Tensor]): List of multi-level img features. | |
rpn_results_list (list[:obj:`InstanceData`]): List of region | |
proposals. | |
DetSampleList (list[:obj:`DetDataSample`]): The batch | |
data samples. It usually includes information such | |
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
Returns: | |
dict[str, Tensor]: A dictionary of loss components | |
""" | |
proposals = [ | |
ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples | |
] | |
proposals = [p for p in proposals if len(p) > 0] | |
bbox_feats = self.roi_extractor(inputs, proposals) | |
rec_data_samples = [ | |
TextRecogDataSample(gt_text=LabelData(item=text)) | |
for proposal in proposals for text in proposal.texts | |
] | |
return self.rec_head.loss(bbox_feats, rec_data_samples) | |
def predict(self, inputs: Tuple[Tensor], | |
data_samples: DetSampleList) -> RecSampleList: | |
if hasattr(self, 'neck') and self.neck is not None: | |
inputs = self.neck(inputs) | |
pred_instances = [ds.pred_instances for ds in data_samples] | |
bbox_feats = self.roi_extractor(inputs, pred_instances) | |
if bbox_feats.size(0) == 0: | |
return [] | |
len_instance = sum( | |
[len(instance_data) for instance_data in pred_instances]) | |
rec_data_samples = [TextRecogDataSample() for _ in range(len_instance)] | |
rec_data_samples = self.rec_head.predict(bbox_feats, rec_data_samples) | |
return rec_data_samples | |