LoCoNet_ASD / builder.py
xiziwang
push files
2e36228
# -*- coding: utf-8 -*-
#================================================================
# Don't go gently into that good night.
#
# author: klaus
# description:
#
#================================================================
import warnings
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry
from mmaction.utils import import_module_error_func
MODELS = Registry('models', parent=MMCV_MODELS)
BACKBONES = MODELS
NECKS = MODELS
HEADS = MODELS
RECOGNIZERS = MODELS
LOSSES = MODELS
LOCALIZERS = MODELS
try:
from mmdet.models.builder import DETECTORS, build_detector
except (ImportError, ModuleNotFoundError):
# Define an empty registry and building func, so that can import
DETECTORS = MODELS
@import_module_error_func('mmdet')
def build_detector(cfg, train_cfg, test_cfg):
pass
def build_backbone(cfg):
"""Build backbone."""
return BACKBONES.build(cfg)
def build_head(cfg):
"""Build head."""
return HEADS.build(cfg)
def build_recognizer(cfg, train_cfg=None, test_cfg=None):
"""Build recognizer."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model. Details see this '
'PR: https://github.com/open-mmlab/mmaction2/pull/629', UserWarning)
assert cfg.get(
'train_cfg'
) is None or train_cfg is None, 'train_cfg specified in both outer field and model field' # noqa: E501
assert cfg.get(
'test_cfg'
) is None or test_cfg is None, 'test_cfg specified in both outer field and model field ' # noqa: E501
return RECOGNIZERS.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_loss(cfg):
"""Build loss."""
return LOSSES.build(cfg)
def build_localizer(cfg):
"""Build localizer."""
return LOCALIZERS.build(cfg)
def build_model(cfg, train_cfg=None, test_cfg=None):
"""Build model."""
args = cfg.copy()
obj_type = args.pop('type')
if obj_type in LOCALIZERS:
return build_localizer(cfg)
if obj_type in RECOGNIZERS:
return build_recognizer(cfg, train_cfg, test_cfg)
if obj_type in DETECTORS:
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model. Details see this '
'PR: https://github.com/open-mmlab/mmaction2/pull/629', UserWarning)
return build_detector(cfg, train_cfg, test_cfg)
model_in_mmdet = ['FastRCNN']
if obj_type in model_in_mmdet:
raise ImportError('Please install mmdet for spatial temporal detection tasks.')
raise ValueError(f'{obj_type} is not registered in ' 'LOCALIZERS, RECOGNIZERS or DETECTORS')
def build_neck(cfg):
"""Build neck."""
return NECKS.build(cfg)