Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from abc import ABCMeta, abstractmethod | |
| from typing import Union | |
| import torch | |
| from mmengine.model.base_model import BaseModel | |
| from mmocr.utils import (OptConfigType, OptMultiConfig, OptRecSampleList, | |
| RecForwardResults, RecSampleList) | |
| class BaseRecognizer(BaseModel, metaclass=ABCMeta): | |
| """Base class for recognizer. | |
| Args: | |
| data_preprocessor (dict or ConfigDict, optional): The pre-process | |
| config of :class:`BaseDataPreprocessor`. it usually includes, | |
| ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. | |
| init_cfg (dict or ConfigDict or List[dict], optional): the config | |
| to control the initialization. Defaults to None. | |
| """ | |
| def __init__(self, | |
| data_preprocessor: OptConfigType = None, | |
| init_cfg: OptMultiConfig = None): | |
| super().__init__( | |
| data_preprocessor=data_preprocessor, init_cfg=init_cfg) | |
| def with_backbone(self): | |
| """bool: whether the recognizer has a backbone""" | |
| return hasattr(self, 'backbone') | |
| def with_encoder(self): | |
| """bool: whether the recognizer has an encoder""" | |
| return hasattr(self, 'encoder') | |
| def with_preprocessor(self): | |
| """bool: whether the recognizer has a preprocessor""" | |
| return hasattr(self, 'preprocessor') | |
| def with_decoder(self): | |
| """bool: whether the recognizer has a decoder""" | |
| return hasattr(self, 'decoder') | |
| def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor: | |
| """Extract features from images.""" | |
| pass | |
| def forward(self, | |
| inputs: torch.Tensor, | |
| data_samples: OptRecSampleList = None, | |
| mode: str = 'tensor', | |
| **kwargs) -> RecForwardResults: | |
| """The unified entry for a forward process in both training and test. | |
| The method should accept three modes: "tensor", "predict" and "loss": | |
| - "tensor": Forward the whole network and return tensor or tuple of | |
| tensor without any post-processing, same as a common nn.Module. | |
| - "predict": Forward and return the predictions, which are fully | |
| processed to a list of :obj:`DetDataSample`. | |
| - "loss": Forward and return a dict of losses according to the given | |
| inputs and data samples. | |
| Note that this method doesn't handle neither back propagation nor | |
| optimizer updating, which are done in the :meth:`train_step`. | |
| Args: | |
| inputs (torch.Tensor): The input tensor with shape | |
| (N, C, ...) in general. | |
| data_samples (list[:obj:`DetDataSample`], optional): The | |
| annotation data of every samples. Defaults to None. | |
| mode (str): Return what kind of value. Defaults to 'tensor'. | |
| Returns: | |
| The return type depends on ``mode``. | |
| - If ``mode="tensor"``, return a tensor or a tuple of tensor. | |
| - If ``mode="predict"``, return a list of :obj:`DetDataSample`. | |
| - If ``mode="loss"``, return a dict of tensor. | |
| """ | |
| if mode == 'loss': | |
| return self.loss(inputs, data_samples, **kwargs) | |
| elif mode == 'predict': | |
| return self.predict(inputs, data_samples, **kwargs) | |
| elif mode == 'tensor': | |
| return self._forward(inputs, data_samples, **kwargs) | |
| else: | |
| raise RuntimeError(f'Invalid mode "{mode}". ' | |
| 'Only supports loss, predict and tensor mode') | |
| def loss(self, inputs: torch.Tensor, data_samples: RecSampleList, | |
| **kwargs) -> Union[dict, tuple]: | |
| """Calculate losses from a batch of inputs and data samples.""" | |
| pass | |
| def predict(self, inputs: torch.Tensor, data_samples: RecSampleList, | |
| **kwargs) -> RecSampleList: | |
| """Predict results from a batch of inputs and data samples with post- | |
| processing.""" | |
| pass | |
| def _forward(self, | |
| inputs: torch.Tensor, | |
| data_samples: OptRecSampleList = None, | |
| **kwargs): | |
| """Network forward process. | |
| Usually includes backbone, neck and head forward without any post- | |
| processing. | |
| """ | |
| pass | |