Spaces:
Build error
Build error
import copy | |
import warnings | |
from mmcv.cnn import VGG | |
from mmcv.runner.hooks import HOOKS, Hook | |
from mmdet.datasets.builder import PIPELINES | |
from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile | |
from mmdet.models.dense_heads import GARPNHead, RPNHead | |
from mmdet.models.roi_heads.mask_heads import FusedSemanticHead | |
def replace_ImageToTensor(pipelines): | |
"""Replace the ImageToTensor transform in a data pipeline to | |
DefaultFormatBundle, which is normally useful in batch inference. | |
Args: | |
pipelines (list[dict]): Data pipeline configs. | |
Returns: | |
list: The new pipeline list with all ImageToTensor replaced by | |
DefaultFormatBundle. | |
Examples: | |
>>> pipelines = [ | |
... dict(type='LoadImageFromFile'), | |
... dict( | |
... type='MultiScaleFlipAug', | |
... img_scale=(1333, 800), | |
... flip=False, | |
... transforms=[ | |
... dict(type='Resize', keep_ratio=True), | |
... dict(type='RandomFlip'), | |
... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), | |
... dict(type='Pad', size_divisor=32), | |
... dict(type='ImageToTensor', keys=['img']), | |
... dict(type='Collect', keys=['img']), | |
... ]) | |
... ] | |
>>> expected_pipelines = [ | |
... dict(type='LoadImageFromFile'), | |
... dict( | |
... type='MultiScaleFlipAug', | |
... img_scale=(1333, 800), | |
... flip=False, | |
... transforms=[ | |
... dict(type='Resize', keep_ratio=True), | |
... dict(type='RandomFlip'), | |
... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), | |
... dict(type='Pad', size_divisor=32), | |
... dict(type='DefaultFormatBundle'), | |
... dict(type='Collect', keys=['img']), | |
... ]) | |
... ] | |
>>> assert expected_pipelines == replace_ImageToTensor(pipelines) | |
""" | |
pipelines = copy.deepcopy(pipelines) | |
for i, pipeline in enumerate(pipelines): | |
if pipeline['type'] == 'MultiScaleFlipAug': | |
assert 'transforms' in pipeline | |
pipeline['transforms'] = replace_ImageToTensor( | |
pipeline['transforms']) | |
elif pipeline['type'] == 'ImageToTensor': | |
warnings.warn( | |
'"ImageToTensor" pipeline is replaced by ' | |
'"DefaultFormatBundle" for batch inference. It is ' | |
'recommended to manually replace it in the test ' | |
'data pipeline in your config file.', UserWarning) | |
pipelines[i] = {'type': 'DefaultFormatBundle'} | |
return pipelines | |
def get_loading_pipeline(pipeline): | |
"""Only keep loading image and annotations related configuration. | |
Args: | |
pipeline (list[dict]): Data pipeline configs. | |
Returns: | |
list[dict]: The new pipeline list with only keep | |
loading image and annotations related configuration. | |
Examples: | |
>>> pipelines = [ | |
... dict(type='LoadImageFromFile'), | |
... dict(type='LoadAnnotations', with_bbox=True), | |
... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), | |
... dict(type='RandomFlip', flip_ratio=0.5), | |
... dict(type='Normalize', **img_norm_cfg), | |
... dict(type='Pad', size_divisor=32), | |
... dict(type='DefaultFormatBundle'), | |
... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) | |
... ] | |
>>> expected_pipelines = [ | |
... dict(type='LoadImageFromFile'), | |
... dict(type='LoadAnnotations', with_bbox=True) | |
... ] | |
>>> assert expected_pipelines ==\ | |
... get_loading_pipeline(pipelines) | |
""" | |
loading_pipeline_cfg = [] | |
for cfg in pipeline: | |
obj_cls = PIPELINES.get(cfg['type']) | |
# TODO:use more elegant way to distinguish loading modules | |
if obj_cls is not None and obj_cls in (LoadImageFromFile, | |
LoadAnnotations): | |
loading_pipeline_cfg.append(cfg) | |
assert len(loading_pipeline_cfg) == 2, \ | |
'The data pipeline in your config file must include ' \ | |
'loading image and annotations related pipeline.' | |
return loading_pipeline_cfg | |
class NumClassCheckHook(Hook): | |
def _check_head(self, runner): | |
"""Check whether the `num_classes` in head matches the length of | |
`CLASSSES` in `dataset`. | |
Args: | |
runner (obj:`EpochBasedRunner`): Epoch based Runner. | |
""" | |
model = runner.model | |
dataset = runner.data_loader.dataset | |
if dataset.CLASSES is None: | |
runner.logger.warning( | |
f'Please set `CLASSES` ' | |
f'in the {dataset.__class__.__name__} and' | |
f'check if it is consistent with the `num_classes` ' | |
f'of head') | |
else: | |
for name, module in model.named_modules(): | |
if hasattr(module, 'num_classes') and not isinstance( | |
module, (RPNHead, VGG, FusedSemanticHead, GARPNHead)): | |
assert module.num_classes == len(dataset.CLASSES), \ | |
(f'The `num_classes` ({module.num_classes}) in ' | |
f'{module.__class__.__name__} of ' | |
f'{model.__class__.__name__} does not matches ' | |
f'the length of `CLASSES` ' | |
f'{len(dataset.CLASSES)}) in ' | |
f'{dataset.__class__.__name__}') | |
def before_train_epoch(self, runner): | |
"""Check whether the training dataset is compatible with head. | |
Args: | |
runner (obj:`EpochBasedRunner`): Epoch based Runner. | |
""" | |
self._check_head(runner) | |
def before_val_epoch(self, runner): | |
"""Check whether the dataset in val epoch is compatible with head. | |
Args: | |
runner (obj:`EpochBasedRunner`): Epoch based Runner. | |
""" | |
self._check_head(runner) | |