| |
| import inspect |
| import logging |
| import sys |
| from collections.abc import Callable |
| from contextlib import contextmanager |
| from importlib import import_module |
| from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union |
|
|
| from rich.console import Console |
| from rich.table import Table |
|
|
| from mmengine.config.utils import MODULE2PACKAGE |
| from mmengine.utils import get_object_from_string, is_seq_of |
| from .default_scope import DefaultScope |
|
|
|
|
| class Registry: |
| """A registry to map strings to classes or functions. |
| |
| Registered object could be built from registry. Meanwhile, registered |
| functions could be called from registry. |
| |
| Args: |
| name (str): Registry name. |
| build_func (callable, optional): A function to construct instance |
| from Registry. :func:`build_from_cfg` is used if neither ``parent`` |
| or ``build_func`` is specified. If ``parent`` is specified and |
| ``build_func`` is not given, ``build_func`` will be inherited |
| from ``parent``. Defaults to None. |
| parent (:obj:`Registry`, optional): Parent registry. The class |
| registered in children registry could be built from parent. |
| Defaults to None. |
| scope (str, optional): The scope of registry. It is the key to search |
| for children registry. If not specified, scope will be the name of |
| the package where class is defined, e.g. mmdet, mmcls, mmseg. |
| Defaults to None. |
| locations (list): The locations to import the modules registered |
| in this registry. Defaults to []. |
| New in version 0.4.0. |
| |
| Examples: |
| >>> # define a registry |
| >>> MODELS = Registry('models') |
| >>> # registry the `ResNet` to `MODELS` |
| >>> @MODELS.register_module() |
| >>> class ResNet: |
| >>> pass |
| >>> # build model from `MODELS` |
| >>> resnet = MODELS.build(dict(type='ResNet')) |
| >>> @MODELS.register_module() |
| >>> def resnet50(): |
| >>> pass |
| >>> resnet = MODELS.build(dict(type='resnet50')) |
| |
| >>> # hierarchical registry |
| >>> DETECTORS = Registry('detectors', parent=MODELS, scope='det') |
| >>> @DETECTORS.register_module() |
| >>> class FasterRCNN: |
| >>> pass |
| >>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN')) |
| |
| >>> # add locations to enable auto import |
| >>> DETECTORS = Registry('detectors', parent=MODELS, |
| >>> scope='det', locations=['det.models.detectors']) |
| >>> # define this class in 'det.models.detectors' |
| >>> @DETECTORS.register_module() |
| >>> class MaskRCNN: |
| >>> pass |
| >>> # The registry will auto import det.models.detectors.MaskRCNN |
| >>> fasterrcnn = DETECTORS.build(dict(type='det.MaskRCNN')) |
| |
| More advanced usages can be found at |
| https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html. |
| """ |
|
|
| def __init__(self, |
| name: str, |
| build_func: Optional[Callable] = None, |
| parent: Optional['Registry'] = None, |
| scope: Optional[str] = None, |
| locations: List = []): |
| from .build_functions import build_from_cfg |
| self._name = name |
| self._module_dict: Dict[str, Type] = dict() |
| self._children: Dict[str, 'Registry'] = dict() |
| self._locations = locations |
| self._imported = False |
|
|
| if scope is not None: |
| assert isinstance(scope, str) |
| self._scope = scope |
| else: |
| self._scope = self.infer_scope() |
|
|
| |
| |
| self.parent: Optional['Registry'] |
| if parent is not None: |
| assert isinstance(parent, Registry) |
| parent._add_child(self) |
| self.parent = parent |
| else: |
| self.parent = None |
|
|
| |
| |
| |
| |
| self.build_func: Callable |
| if build_func is None: |
| if self.parent is not None: |
| self.build_func = self.parent.build_func |
| else: |
| self.build_func = build_from_cfg |
| else: |
| self.build_func = build_func |
|
|
| def __len__(self): |
| return len(self._module_dict) |
|
|
| def __contains__(self, key): |
| return self.get(key) is not None |
|
|
| def __repr__(self): |
| table = Table(title=f'Registry of {self._name}') |
| table.add_column('Names', justify='left', style='cyan') |
| table.add_column('Objects', justify='left', style='green') |
|
|
| for name, obj in sorted(self._module_dict.items()): |
| table.add_row(name, str(obj)) |
|
|
| console = Console() |
| with console.capture() as capture: |
| console.print(table, end='') |
|
|
| return capture.get() |
|
|
| @staticmethod |
| def infer_scope() -> str: |
| """Infer the scope of registry. |
| |
| The name of the package where registry is defined will be returned. |
| |
| Returns: |
| str: The inferred scope name. |
| |
| Examples: |
| >>> # in mmdet/models/backbone/resnet.py |
| >>> MODELS = Registry('models') |
| >>> @MODELS.register_module() |
| >>> class ResNet: |
| >>> pass |
| >>> # The scope of ``ResNet`` will be ``mmdet``. |
| """ |
| from ..logging import print_log |
|
|
| |
| |
| |
| |
| |
| |
| module = inspect.getmodule(sys._getframe(2)) |
| if module is not None: |
| filename = module.__name__ |
| split_filename = filename.split('.') |
| scope = split_filename[0] |
| else: |
| |
| |
| scope = 'mmengine' |
| print_log( |
| 'set scope as "mmengine" when scope can not be inferred. You ' |
| 'can silence this warning by passing a "scope" argument to ' |
| 'Registry like `Registry(name, scope="toy")`', |
| logger='current', |
| level=logging.WARNING) |
|
|
| return scope |
|
|
| @staticmethod |
| def split_scope_key(key: str) -> Tuple[Optional[str], str]: |
| """Split scope and key. |
| |
| The first scope will be split from key. |
| |
| Return: |
| tuple[str | None, str]: The former element is the first scope of |
| the key, which can be ``None``. The latter is the remaining key. |
| |
| Examples: |
| >>> Registry.split_scope_key('mmdet.ResNet') |
| 'mmdet', 'ResNet' |
| >>> Registry.split_scope_key('ResNet') |
| None, 'ResNet' |
| """ |
| split_index = key.find('.') |
| if split_index != -1: |
| return key[:split_index], key[split_index + 1:] |
| else: |
| return None, key |
|
|
| @property |
| def name(self): |
| return self._name |
|
|
| @property |
| def scope(self): |
| return self._scope |
|
|
| @property |
| def module_dict(self): |
| return self._module_dict |
|
|
| @property |
| def children(self): |
| return self._children |
|
|
| @property |
| def root(self): |
| return self._get_root_registry() |
|
|
| @contextmanager |
| def switch_scope_and_registry(self, scope: Optional[str]) -> Generator: |
| """Temporarily switch default scope to the target scope, and get the |
| corresponding registry. |
| |
| If the registry of the corresponding scope exists, yield the |
| registry, otherwise yield the current itself. |
| |
| Args: |
| scope (str, optional): The target scope. |
| |
| Examples: |
| >>> from mmengine.registry import Registry, DefaultScope, MODELS |
| >>> import time |
| >>> # External Registry |
| >>> MMDET_MODELS = Registry('mmdet_model', scope='mmdet', |
| >>> parent=MODELS) |
| >>> MMCLS_MODELS = Registry('mmcls_model', scope='mmcls', |
| >>> parent=MODELS) |
| >>> # Local Registry |
| >>> CUSTOM_MODELS = Registry('custom_model', scope='custom', |
| >>> parent=MODELS) |
| >>> |
| >>> # Initiate DefaultScope |
| >>> DefaultScope.get_instance(f'scope_{time.time()}', |
| >>> scope_name='custom') |
| >>> # Check default scope |
| >>> DefaultScope.get_current_instance().scope_name |
| custom |
| >>> # Switch to mmcls scope and get `MMCLS_MODELS` registry. |
| >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as registry: |
| >>> DefaultScope.get_current_instance().scope_name |
| mmcls |
| >>> registry.scope |
| mmcls |
| >>> # Nested switch scope |
| >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmdet') as mmdet_registry: |
| >>> DefaultScope.get_current_instance().scope_name |
| mmdet |
| >>> mmdet_registry.scope |
| mmdet |
| >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as mmcls_registry: |
| >>> DefaultScope.get_current_instance().scope_name |
| mmcls |
| >>> mmcls_registry.scope |
| mmcls |
| >>> |
| >>> # Check switch back to original scope. |
| >>> DefaultScope.get_current_instance().scope_name |
| custom |
| """ |
| from ..logging import print_log |
|
|
| |
| |
| |
| with DefaultScope.overwrite_default_scope(scope): |
| |
| default_scope = DefaultScope.get_current_instance() |
| |
| if default_scope is not None: |
| scope_name = default_scope.scope_name |
| try: |
| import_module(f'{scope_name}.registry') |
| except (ImportError, AttributeError, ModuleNotFoundError): |
| if scope in MODULE2PACKAGE: |
| print_log( |
| f'{scope} is not installed and its ' |
| 'modules will not be registered. If you ' |
| 'want to use modules defined in ' |
| f'{scope}, Please install {scope} by ' |
| f'`pip install {MODULE2PACKAGE[scope]}.', |
| logger='current', |
| level=logging.WARNING) |
| else: |
| print_log( |
| f'Failed to import `{scope}.registry` ' |
| f'make sure the registry.py exists in `{scope}` ' |
| 'package.', |
| logger='current', |
| level=logging.WARNING) |
| root = self._get_root_registry() |
| registry = root._search_child(scope_name) |
| if registry is None: |
| |
| |
| print_log( |
| f'Failed to search registry with scope "{scope_name}" ' |
| f'in the "{root.name}" registry tree. ' |
| f'As a workaround, the current "{self.name}" registry ' |
| f'in "{self.scope}" is used to build instance. This ' |
| 'may cause unexpected failure when running the built ' |
| f'modules. Please check whether "{scope_name}" is a ' |
| 'correct scope, or whether the registry is ' |
| 'initialized.', |
| logger='current', |
| level=logging.WARNING) |
| registry = self |
| |
| else: |
| registry = self |
| yield registry |
|
|
| def _get_root_registry(self) -> 'Registry': |
| """Return the root registry.""" |
| root = self |
| while root.parent is not None: |
| root = root.parent |
| return root |
|
|
| def import_from_location(self) -> None: |
| """import modules from the pre-defined locations in self._location.""" |
| if not self._imported: |
| |
| from ..logging import print_log |
|
|
| |
| if len(self._locations) == 0 and self.scope in MODULE2PACKAGE: |
| print_log( |
| f'The "{self.name}" registry in {self.scope} did not ' |
| 'set import location. Fallback to call ' |
| f'`{self.scope}.utils.register_all_modules` ' |
| 'instead.', |
| logger='current', |
| level=logging.DEBUG) |
| try: |
| module = import_module(f'{self.scope}.utils') |
| except (ImportError, AttributeError, ModuleNotFoundError): |
| if self.scope in MODULE2PACKAGE: |
| print_log( |
| f'{self.scope} is not installed and its ' |
| 'modules will not be registered. If you ' |
| 'want to use modules defined in ' |
| f'{self.scope}, Please install {self.scope} by ' |
| f'`pip install {MODULE2PACKAGE[self.scope]}.', |
| logger='current', |
| level=logging.WARNING) |
| else: |
| print_log( |
| f'Failed to import {self.scope} and register ' |
| 'its modules, please make sure you ' |
| 'have registered the module manually.', |
| logger='current', |
| level=logging.WARNING) |
| else: |
| |
| |
| |
| |
| module.register_all_modules(False) |
|
|
| for loc in self._locations: |
| import_module(loc) |
| print_log( |
| f"Modules of {self.scope}'s {self.name} registry have " |
| f'been automatically imported from {loc}', |
| logger='current', |
| level=logging.DEBUG) |
| self._imported = True |
|
|
| def get(self, key: str) -> Optional[Type]: |
| """Get the registry record. |
| |
| If `key`` represents the whole object name with its module |
| information, for example, `mmengine.model.BaseModel`, ``get`` |
| will directly return the class object :class:`BaseModel`. |
| |
| Otherwise, it will first parse ``key`` and check whether it |
| contains a scope name. The logic to search for ``key``: |
| |
| - ``key`` does not contain a scope name, i.e., it is purely a module |
| name like "ResNet": :meth:`get` will search for ``ResNet`` from the |
| current registry to its parent or ancestors until finding it. |
| |
| - ``key`` contains a scope name and it is equal to the scope of the |
| current registry (e.g., "mmcls"), e.g., "mmcls.ResNet": :meth:`get` |
| will only search for ``ResNet`` in the current registry. |
| |
| - ``key`` contains a scope name and it is not equal to the scope of |
| the current registry (e.g., "mmdet"), e.g., "mmcls.FCNet": If the |
| scope exists in its children, :meth:`get` will get "FCNet" from |
| them. If not, :meth:`get` will first get the root registry and root |
| registry call its own :meth:`get` method. |
| |
| Args: |
| key (str): Name of the registered item, e.g., the class name in |
| string format. |
| |
| Returns: |
| Type or None: Return the corresponding class if ``key`` exists, |
| otherwise return None. |
| |
| Examples: |
| >>> # define a registry |
| >>> MODELS = Registry('models') |
| >>> # register `ResNet` to `MODELS` |
| >>> @MODELS.register_module() |
| >>> class ResNet: |
| >>> pass |
| >>> resnet_cls = MODELS.get('ResNet') |
| |
| >>> # hierarchical registry |
| >>> DETECTORS = Registry('detector', parent=MODELS, scope='det') |
| >>> # `ResNet` does not exist in `DETECTORS` but `get` method |
| >>> # will try to search from its parents or ancestors |
| >>> resnet_cls = DETECTORS.get('ResNet') |
| >>> CLASSIFIER = Registry('classifier', parent=MODELS, scope='cls') |
| >>> @CLASSIFIER.register_module() |
| >>> class MobileNet: |
| >>> pass |
| >>> # `get` from its sibling registries |
| >>> mobilenet_cls = DETECTORS.get('cls.MobileNet') |
| """ |
| |
| from ..logging import print_log |
|
|
| if not isinstance(key, str): |
| raise TypeError( |
| 'The key argument of `Registry.get` must be a str, ' |
| f'got {type(key)}') |
|
|
| scope, real_key = self.split_scope_key(key) |
| obj_cls = None |
| registry_name = self.name |
| scope_name = self.scope |
|
|
| |
| self.import_from_location() |
|
|
| if scope is None or scope == self._scope: |
| |
| if real_key in self._module_dict: |
| obj_cls = self._module_dict[real_key] |
| elif scope is None: |
| |
| parent = self.parent |
| while parent is not None: |
| if real_key in parent._module_dict: |
| obj_cls = parent._module_dict[real_key] |
| registry_name = parent.name |
| scope_name = parent.scope |
| break |
| parent = parent.parent |
| else: |
| |
| try: |
| import_module(f'{scope}.registry') |
| print_log( |
| f'Registry node of {scope} has been automatically ' |
| 'imported.', |
| logger='current', |
| level=logging.DEBUG) |
| except (ImportError, AttributeError, ModuleNotFoundError): |
| print_log( |
| f'Cannot auto import {scope}.registry, please check ' |
| f'whether the package "{scope}" is installed correctly ' |
| 'or import the registry manually.', |
| logger='current', |
| level=logging.DEBUG) |
| |
| if scope in self._children: |
| obj_cls = self._children[scope].get(real_key) |
| registry_name = self._children[scope].name |
| scope_name = scope |
| else: |
| root = self._get_root_registry() |
|
|
| if scope != root._scope and scope not in root._children: |
| |
| |
| pass |
| else: |
| obj_cls = root.get(key) |
|
|
| if obj_cls is None: |
| |
| |
| |
| |
| |
| |
| try: |
| obj_cls = get_object_from_string(key) |
| except Exception: |
| raise RuntimeError(f'Failed to get {key}') |
|
|
| if obj_cls is not None: |
| |
| |
| cls_name = getattr(obj_cls, '__name__', str(obj_cls)) |
| print_log( |
| f'Get class `{cls_name}` from "{registry_name}"' |
| f' registry in "{scope_name}"', |
| logger='current', |
| level=logging.DEBUG) |
|
|
| return obj_cls |
|
|
| def _search_child(self, scope: str) -> Optional['Registry']: |
| """Depth-first search for the corresponding registry in its children. |
| |
| Note that the method only search for the corresponding registry from |
| the current registry. Therefore, if we want to search from the root |
| registry, :meth:`_get_root_registry` should be called to get the |
| root registry first. |
| |
| Args: |
| scope (str): The scope name used for searching for its |
| corresponding registry. |
| |
| Returns: |
| Registry or None: Return the corresponding registry if ``scope`` |
| exists, otherwise return None. |
| """ |
| if self._scope == scope: |
| return self |
|
|
| for child in self._children.values(): |
| registry = child._search_child(scope) |
| if registry is not None: |
| return registry |
|
|
| return None |
|
|
| def build(self, cfg: dict, *args, **kwargs) -> Any: |
| """Build an instance. |
| |
| Build an instance by calling :attr:`build_func`. |
| |
| Args: |
| cfg (dict): Config dict needs to be built. |
| |
| Returns: |
| Any: The constructed object. |
| |
| Examples: |
| >>> from mmengine import Registry |
| >>> MODELS = Registry('models') |
| >>> @MODELS.register_module() |
| >>> class ResNet: |
| >>> def __init__(self, depth, stages=4): |
| >>> self.depth = depth |
| >>> self.stages = stages |
| >>> cfg = dict(type='ResNet', depth=50) |
| >>> model = MODELS.build(cfg) |
| """ |
| return self.build_func(cfg, *args, **kwargs, registry=self) |
|
|
| def _add_child(self, registry: 'Registry') -> None: |
| """Add a child for a registry. |
| |
| Args: |
| registry (:obj:`Registry`): The ``registry`` will be added as a |
| child of the ``self``. |
| """ |
|
|
| assert isinstance(registry, Registry) |
| assert registry.scope is not None |
| assert registry.scope not in self.children, \ |
| f'scope {registry.scope} exists in {self.name} registry' |
| self.children[registry.scope] = registry |
|
|
| def _register_module(self, |
| module: Type, |
| module_name: Optional[Union[str, List[str]]] = None, |
| force: bool = False) -> None: |
| """Register a module. |
| |
| Args: |
| module (type): Module to be registered. Typically a class or a |
| function, but generally all ``Callable`` are acceptable. |
| module_name (str or list of str, optional): The module name to be |
| registered. If not specified, the class name will be used. |
| Defaults to None. |
| force (bool): Whether to override an existing class with the same |
| name. Defaults to False. |
| """ |
| if not callable(module): |
| raise TypeError(f'module must be Callable, but got {type(module)}') |
|
|
| if module_name is None: |
| module_name = module.__name__ |
| if isinstance(module_name, str): |
| module_name = [module_name] |
| for name in module_name: |
| if not force and name in self._module_dict: |
| existed_module = self.module_dict[name] |
| raise KeyError(f'{name} is already registered in {self.name} ' |
| f'at {existed_module.__module__}') |
| self._module_dict[name] = module |
|
|
| def register_module( |
| self, |
| name: Optional[Union[str, List[str]]] = None, |
| force: bool = False, |
| module: Optional[Type] = None) -> Union[type, Callable]: |
| """Register a module. |
| |
| A record will be added to ``self._module_dict``, whose key is the class |
| name or the specified name, and value is the class itself. |
| It can be used as a decorator or a normal function. |
| |
| Args: |
| name (str or list of str, optional): The module name to be |
| registered. If not specified, the class name will be used. |
| force (bool): Whether to override an existing class with the same |
| name. Defaults to False. |
| module (type, optional): Module class or function to be registered. |
| Defaults to None. |
| |
| Examples: |
| >>> backbones = Registry('backbone') |
| >>> # as a decorator |
| >>> @backbones.register_module() |
| >>> class ResNet: |
| >>> pass |
| >>> backbones = Registry('backbone') |
| >>> @backbones.register_module(name='mnet') |
| >>> class MobileNet: |
| >>> pass |
| |
| >>> # as a normal function |
| >>> class ResNet: |
| >>> pass |
| >>> backbones.register_module(module=ResNet) |
| """ |
| if not isinstance(force, bool): |
| raise TypeError(f'force must be a boolean, but got {type(force)}') |
|
|
| |
| if not (name is None or isinstance(name, str) or is_seq_of(name, str)): |
| raise TypeError( |
| 'name must be None, an instance of str, or a sequence of str, ' |
| f'but got {type(name)}') |
|
|
| |
| if module is not None: |
| self._register_module(module=module, module_name=name, force=force) |
| return module |
|
|
| |
| def _register(module): |
| self._register_module(module=module, module_name=name, force=force) |
| return module |
|
|
| return _register |
|
|