RSPrompter / mmyolo /deploy /object_detection.py
KyanChen's picture
Upload 89 files
3094730
raw
history blame
4.52 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, Optional
import torch
from mmdeploy.codebase.base import CODEBASE, MMCodebase
from mmdeploy.codebase.mmdet.deploy import ObjectDetection
from mmdeploy.utils import Codebase, Task
from mmengine import Config
from mmengine.registry import Registry
MMYOLO_TASK = Registry('mmyolo_tasks')
@CODEBASE.register_module(Codebase.MMYOLO.value)
class MMYOLO(MMCodebase):
"""MMYOLO codebase class."""
task_registry = MMYOLO_TASK
@classmethod
def register_deploy_modules(cls):
"""register all rewriters for mmdet."""
import mmdeploy.codebase.mmdet.models # noqa: F401
import mmdeploy.codebase.mmdet.ops # noqa: F401
import mmdeploy.codebase.mmdet.structures # noqa: F401
@classmethod
def register_all_modules(cls):
"""register all modules."""
from mmdet.utils.setup_env import \
register_all_modules as register_all_modules_mmdet
from mmyolo.utils.setup_env import \
register_all_modules as register_all_modules_mmyolo
cls.register_deploy_modules()
register_all_modules_mmyolo(True)
register_all_modules_mmdet(False)
def _get_dataset_metainfo(model_cfg: Config):
"""Get metainfo of dataset.
Args:
model_cfg Config: Input model Config object.
Returns:
list[str]: A list of string specifying names of different class.
"""
from mmyolo import datasets # noqa
from mmyolo.registry import DATASETS
module_dict = DATASETS.module_dict
for dataloader_name in [
'test_dataloader', 'val_dataloader', 'train_dataloader'
]:
if dataloader_name not in model_cfg:
continue
dataloader_cfg = model_cfg[dataloader_name]
dataset_cfg = dataloader_cfg.dataset
dataset_cls = module_dict.get(dataset_cfg.type, None)
if dataset_cls is None:
continue
if hasattr(dataset_cls, '_load_metainfo') and isinstance(
dataset_cls._load_metainfo, Callable):
meta = dataset_cls._load_metainfo(
dataset_cfg.get('metainfo', None))
if meta is not None:
return meta
if hasattr(dataset_cls, 'METAINFO'):
return dataset_cls.METAINFO
return None
@MMYOLO_TASK.register_module(Task.OBJECT_DETECTION.value)
class YOLOObjectDetection(ObjectDetection):
"""YOLO Object Detection task."""
def get_visualizer(self, name: str, save_dir: str):
"""Get visualizer.
Args:
name (str): Name of visualizer.
save_dir (str): Directory to save visualization results.
Returns:
Visualizer: A visualizer instance.
"""
from mmdet.visualization import DetLocalVisualizer # noqa: F401,F403
metainfo = _get_dataset_metainfo(self.model_cfg)
visualizer = super().get_visualizer(name, save_dir)
if metainfo is not None:
visualizer.dataset_meta = metainfo
return visualizer
def build_pytorch_model(self,
model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module:
"""Initialize torch model.
Args:
model_checkpoint (str): The checkpoint file of torch model,
defaults to `None`.
cfg_options (dict): Optional config key-pair parameters.
Returns:
nn.Module: An initialized torch model generated by other OpenMMLab
codebases.
"""
from copy import deepcopy
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import MODELS
from mmyolo.utils import switch_to_deploy
model = deepcopy(self.model_cfg.model)
preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {}))
preprocess_cfg.update(
deepcopy(self.model_cfg.get('data_preprocessor', {})))
model.setdefault('data_preprocessor', preprocess_cfg)
model = MODELS.build(model)
if model_checkpoint is not None:
from mmengine.runner.checkpoint import load_checkpoint
load_checkpoint(model, model_checkpoint, map_location=self.device)
model = revert_sync_batchnorm(model)
switch_to_deploy(model)
model = model.to(self.device)
model.eval()
return model