Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Callable, Dict, Optional | |
import torch | |
from mmdeploy.codebase.base import CODEBASE, MMCodebase | |
from mmdeploy.codebase.mmdet.deploy import ObjectDetection | |
from mmdeploy.utils import Codebase, Task | |
from mmengine import Config | |
from mmengine.registry import Registry | |
MMYOLO_TASK = Registry('mmyolo_tasks') | |
class MMYOLO(MMCodebase): | |
"""MMYOLO codebase class.""" | |
task_registry = MMYOLO_TASK | |
def register_deploy_modules(cls): | |
"""register all rewriters for mmdet.""" | |
import mmdeploy.codebase.mmdet.models # noqa: F401 | |
import mmdeploy.codebase.mmdet.ops # noqa: F401 | |
import mmdeploy.codebase.mmdet.structures # noqa: F401 | |
def register_all_modules(cls): | |
"""register all modules.""" | |
from mmdet.utils.setup_env import \ | |
register_all_modules as register_all_modules_mmdet | |
from mmyolo.utils.setup_env import \ | |
register_all_modules as register_all_modules_mmyolo | |
cls.register_deploy_modules() | |
register_all_modules_mmyolo(True) | |
register_all_modules_mmdet(False) | |
def _get_dataset_metainfo(model_cfg: Config): | |
"""Get metainfo of dataset. | |
Args: | |
model_cfg Config: Input model Config object. | |
Returns: | |
list[str]: A list of string specifying names of different class. | |
""" | |
from mmyolo import datasets # noqa | |
from mmyolo.registry import DATASETS | |
module_dict = DATASETS.module_dict | |
for dataloader_name in [ | |
'test_dataloader', 'val_dataloader', 'train_dataloader' | |
]: | |
if dataloader_name not in model_cfg: | |
continue | |
dataloader_cfg = model_cfg[dataloader_name] | |
dataset_cfg = dataloader_cfg.dataset | |
dataset_cls = module_dict.get(dataset_cfg.type, None) | |
if dataset_cls is None: | |
continue | |
if hasattr(dataset_cls, '_load_metainfo') and isinstance( | |
dataset_cls._load_metainfo, Callable): | |
meta = dataset_cls._load_metainfo( | |
dataset_cfg.get('metainfo', None)) | |
if meta is not None: | |
return meta | |
if hasattr(dataset_cls, 'METAINFO'): | |
return dataset_cls.METAINFO | |
return None | |
class YOLOObjectDetection(ObjectDetection): | |
"""YOLO Object Detection task.""" | |
def get_visualizer(self, name: str, save_dir: str): | |
"""Get visualizer. | |
Args: | |
name (str): Name of visualizer. | |
save_dir (str): Directory to save visualization results. | |
Returns: | |
Visualizer: A visualizer instance. | |
""" | |
from mmdet.visualization import DetLocalVisualizer # noqa: F401,F403 | |
metainfo = _get_dataset_metainfo(self.model_cfg) | |
visualizer = super().get_visualizer(name, save_dir) | |
if metainfo is not None: | |
visualizer.dataset_meta = metainfo | |
return visualizer | |
def build_pytorch_model(self, | |
model_checkpoint: Optional[str] = None, | |
cfg_options: Optional[Dict] = None, | |
**kwargs) -> torch.nn.Module: | |
"""Initialize torch model. | |
Args: | |
model_checkpoint (str): The checkpoint file of torch model, | |
defaults to `None`. | |
cfg_options (dict): Optional config key-pair parameters. | |
Returns: | |
nn.Module: An initialized torch model generated by other OpenMMLab | |
codebases. | |
""" | |
from copy import deepcopy | |
from mmengine.model import revert_sync_batchnorm | |
from mmengine.registry import MODELS | |
from mmyolo.utils import switch_to_deploy | |
model = deepcopy(self.model_cfg.model) | |
preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {})) | |
preprocess_cfg.update( | |
deepcopy(self.model_cfg.get('data_preprocessor', {}))) | |
model.setdefault('data_preprocessor', preprocess_cfg) | |
model = MODELS.build(model) | |
if model_checkpoint is not None: | |
from mmengine.runner.checkpoint import load_checkpoint | |
load_checkpoint(model, model_checkpoint, map_location=self.device) | |
model = revert_sync_batchnorm(model) | |
switch_to_deploy(model) | |
model = model.to(self.device) | |
model.eval() | |
return model | |