| | |
| | import inspect |
| | import logging |
| | import sys |
| | from collections.abc import Callable |
| | from contextlib import contextmanager |
| | from importlib import import_module |
| | from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union |
| |
|
| | from rich.console import Console |
| | from rich.table import Table |
| |
|
| | from mmengine.config.utils import MODULE2PACKAGE |
| | from mmengine.utils import get_object_from_string, is_seq_of |
| | from .default_scope import DefaultScope |
| |
|
| |
|
| | class Registry: |
| | """A registry to map strings to classes or functions. |
| | |
| | Registered object could be built from registry. Meanwhile, registered |
| | functions could be called from registry. |
| | |
| | Args: |
| | name (str): Registry name. |
| | build_func (callable, optional): A function to construct instance |
| | from Registry. :func:`build_from_cfg` is used if neither ``parent`` |
| | or ``build_func`` is specified. If ``parent`` is specified and |
| | ``build_func`` is not given, ``build_func`` will be inherited |
| | from ``parent``. Defaults to None. |
| | parent (:obj:`Registry`, optional): Parent registry. The class |
| | registered in children registry could be built from parent. |
| | Defaults to None. |
| | scope (str, optional): The scope of registry. It is the key to search |
| | for children registry. If not specified, scope will be the name of |
| | the package where class is defined, e.g. mmdet, mmcls, mmseg. |
| | Defaults to None. |
| | locations (list): The locations to import the modules registered |
| | in this registry. Defaults to []. |
| | New in version 0.4.0. |
| | |
| | Examples: |
| | >>> # define a registry |
| | >>> MODELS = Registry('models') |
| | >>> # registry the `ResNet` to `MODELS` |
| | >>> @MODELS.register_module() |
| | >>> class ResNet: |
| | >>> pass |
| | >>> # build model from `MODELS` |
| | >>> resnet = MODELS.build(dict(type='ResNet')) |
| | >>> @MODELS.register_module() |
| | >>> def resnet50(): |
| | >>> pass |
| | >>> resnet = MODELS.build(dict(type='resnet50')) |
| | |
| | >>> # hierarchical registry |
| | >>> DETECTORS = Registry('detectors', parent=MODELS, scope='det') |
| | >>> @DETECTORS.register_module() |
| | >>> class FasterRCNN: |
| | >>> pass |
| | >>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN')) |
| | |
| | >>> # add locations to enable auto import |
| | >>> DETECTORS = Registry('detectors', parent=MODELS, |
| | >>> scope='det', locations=['det.models.detectors']) |
| | >>> # define this class in 'det.models.detectors' |
| | >>> @DETECTORS.register_module() |
| | >>> class MaskRCNN: |
| | >>> pass |
| | >>> # The registry will auto import det.models.detectors.MaskRCNN |
| | >>> fasterrcnn = DETECTORS.build(dict(type='det.MaskRCNN')) |
| | |
| | More advanced usages can be found at |
| | https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html. |
| | """ |
| |
|
| | def __init__(self, |
| | name: str, |
| | build_func: Optional[Callable] = None, |
| | parent: Optional['Registry'] = None, |
| | scope: Optional[str] = None, |
| | locations: List = []): |
| | from .build_functions import build_from_cfg |
| | self._name = name |
| | self._module_dict: Dict[str, Type] = dict() |
| | self._children: Dict[str, 'Registry'] = dict() |
| | self._locations = locations |
| | self._imported = False |
| |
|
| | if scope is not None: |
| | assert isinstance(scope, str) |
| | self._scope = scope |
| | else: |
| | self._scope = self.infer_scope() |
| |
|
| | |
| | |
| | self.parent: Optional['Registry'] |
| | if parent is not None: |
| | assert isinstance(parent, Registry) |
| | parent._add_child(self) |
| | self.parent = parent |
| | else: |
| | self.parent = None |
| |
|
| | |
| | |
| | |
| | |
| | self.build_func: Callable |
| | if build_func is None: |
| | if self.parent is not None: |
| | self.build_func = self.parent.build_func |
| | else: |
| | self.build_func = build_from_cfg |
| | else: |
| | self.build_func = build_func |
| |
|
| | def __len__(self): |
| | return len(self._module_dict) |
| |
|
| | def __contains__(self, key): |
| | return self.get(key) is not None |
| |
|
| | def __repr__(self): |
| | table = Table(title=f'Registry of {self._name}') |
| | table.add_column('Names', justify='left', style='cyan') |
| | table.add_column('Objects', justify='left', style='green') |
| |
|
| | for name, obj in sorted(self._module_dict.items()): |
| | table.add_row(name, str(obj)) |
| |
|
| | console = Console() |
| | with console.capture() as capture: |
| | console.print(table, end='') |
| |
|
| | return capture.get() |
| |
|
| | @staticmethod |
| | def infer_scope() -> str: |
| | """Infer the scope of registry. |
| | |
| | The name of the package where registry is defined will be returned. |
| | |
| | Returns: |
| | str: The inferred scope name. |
| | |
| | Examples: |
| | >>> # in mmdet/models/backbone/resnet.py |
| | >>> MODELS = Registry('models') |
| | >>> @MODELS.register_module() |
| | >>> class ResNet: |
| | >>> pass |
| | >>> # The scope of ``ResNet`` will be ``mmdet``. |
| | """ |
| | from ..logging import print_log |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | module = inspect.getmodule(sys._getframe(2)) |
| | if module is not None: |
| | filename = module.__name__ |
| | split_filename = filename.split('.') |
| | scope = split_filename[0] |
| | else: |
| | |
| | |
| | scope = 'mmengine' |
| | print_log( |
| | 'set scope as "mmengine" when scope can not be inferred. You ' |
| | 'can silence this warning by passing a "scope" argument to ' |
| | 'Registry like `Registry(name, scope="toy")`', |
| | logger='current', |
| | level=logging.WARNING) |
| |
|
| | return scope |
| |
|
| | @staticmethod |
| | def split_scope_key(key: str) -> Tuple[Optional[str], str]: |
| | """Split scope and key. |
| | |
| | The first scope will be split from key. |
| | |
| | Return: |
| | tuple[str | None, str]: The former element is the first scope of |
| | the key, which can be ``None``. The latter is the remaining key. |
| | |
| | Examples: |
| | >>> Registry.split_scope_key('mmdet.ResNet') |
| | 'mmdet', 'ResNet' |
| | >>> Registry.split_scope_key('ResNet') |
| | None, 'ResNet' |
| | """ |
| | split_index = key.find('.') |
| | if split_index != -1: |
| | return key[:split_index], key[split_index + 1:] |
| | else: |
| | return None, key |
| |
|
| | @property |
| | def name(self): |
| | return self._name |
| |
|
| | @property |
| | def scope(self): |
| | return self._scope |
| |
|
| | @property |
| | def module_dict(self): |
| | return self._module_dict |
| |
|
| | @property |
| | def children(self): |
| | return self._children |
| |
|
| | @property |
| | def root(self): |
| | return self._get_root_registry() |
| |
|
| | @contextmanager |
| | def switch_scope_and_registry(self, scope: Optional[str]) -> Generator: |
| | """Temporarily switch default scope to the target scope, and get the |
| | corresponding registry. |
| | |
| | If the registry of the corresponding scope exists, yield the |
| | registry, otherwise yield the current itself. |
| | |
| | Args: |
| | scope (str, optional): The target scope. |
| | |
| | Examples: |
| | >>> from mmengine.registry import Registry, DefaultScope, MODELS |
| | >>> import time |
| | >>> # External Registry |
| | >>> MMDET_MODELS = Registry('mmdet_model', scope='mmdet', |
| | >>> parent=MODELS) |
| | >>> MMCLS_MODELS = Registry('mmcls_model', scope='mmcls', |
| | >>> parent=MODELS) |
| | >>> # Local Registry |
| | >>> CUSTOM_MODELS = Registry('custom_model', scope='custom', |
| | >>> parent=MODELS) |
| | >>> |
| | >>> # Initiate DefaultScope |
| | >>> DefaultScope.get_instance(f'scope_{time.time()}', |
| | >>> scope_name='custom') |
| | >>> # Check default scope |
| | >>> DefaultScope.get_current_instance().scope_name |
| | custom |
| | >>> # Switch to mmcls scope and get `MMCLS_MODELS` registry. |
| | >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as registry: |
| | >>> DefaultScope.get_current_instance().scope_name |
| | mmcls |
| | >>> registry.scope |
| | mmcls |
| | >>> # Nested switch scope |
| | >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmdet') as mmdet_registry: |
| | >>> DefaultScope.get_current_instance().scope_name |
| | mmdet |
| | >>> mmdet_registry.scope |
| | mmdet |
| | >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as mmcls_registry: |
| | >>> DefaultScope.get_current_instance().scope_name |
| | mmcls |
| | >>> mmcls_registry.scope |
| | mmcls |
| | >>> |
| | >>> # Check switch back to original scope. |
| | >>> DefaultScope.get_current_instance().scope_name |
| | custom |
| | """ |
| | from ..logging import print_log |
| |
|
| | |
| | |
| | |
| | with DefaultScope.overwrite_default_scope(scope): |
| | |
| | default_scope = DefaultScope.get_current_instance() |
| | |
| | if default_scope is not None: |
| | scope_name = default_scope.scope_name |
| | try: |
| | import_module(f'{scope_name}.registry') |
| | except (ImportError, AttributeError, ModuleNotFoundError): |
| | if scope in MODULE2PACKAGE: |
| | print_log( |
| | f'{scope} is not installed and its ' |
| | 'modules will not be registered. If you ' |
| | 'want to use modules defined in ' |
| | f'{scope}, Please install {scope} by ' |
| | f'`pip install {MODULE2PACKAGE[scope]}.', |
| | logger='current', |
| | level=logging.WARNING) |
| | else: |
| | print_log( |
| | f'Failed to import `{scope}.registry` ' |
| | f'make sure the registry.py exists in `{scope}` ' |
| | 'package.', |
| | logger='current', |
| | level=logging.WARNING) |
| | root = self._get_root_registry() |
| | registry = root._search_child(scope_name) |
| | if registry is None: |
| | |
| | |
| | print_log( |
| | f'Failed to search registry with scope "{scope_name}" ' |
| | f'in the "{root.name}" registry tree. ' |
| | f'As a workaround, the current "{self.name}" registry ' |
| | f'in "{self.scope}" is used to build instance. This ' |
| | 'may cause unexpected failure when running the built ' |
| | f'modules. Please check whether "{scope_name}" is a ' |
| | 'correct scope, or whether the registry is ' |
| | 'initialized.', |
| | logger='current', |
| | level=logging.WARNING) |
| | registry = self |
| | |
| | else: |
| | registry = self |
| | yield registry |
| |
|
| | def _get_root_registry(self) -> 'Registry': |
| | """Return the root registry.""" |
| | root = self |
| | while root.parent is not None: |
| | root = root.parent |
| | return root |
| |
|
| | def import_from_location(self) -> None: |
| | """import modules from the pre-defined locations in self._location.""" |
| | if not self._imported: |
| | |
| | from ..logging import print_log |
| |
|
| | |
| | if len(self._locations) == 0 and self.scope in MODULE2PACKAGE: |
| | print_log( |
| | f'The "{self.name}" registry in {self.scope} did not ' |
| | 'set import location. Fallback to call ' |
| | f'`{self.scope}.utils.register_all_modules` ' |
| | 'instead.', |
| | logger='current', |
| | level=logging.DEBUG) |
| | try: |
| | module = import_module(f'{self.scope}.utils') |
| | except (ImportError, AttributeError, ModuleNotFoundError): |
| | if self.scope in MODULE2PACKAGE: |
| | print_log( |
| | f'{self.scope} is not installed and its ' |
| | 'modules will not be registered. If you ' |
| | 'want to use modules defined in ' |
| | f'{self.scope}, Please install {self.scope} by ' |
| | f'`pip install {MODULE2PACKAGE[self.scope]}.', |
| | logger='current', |
| | level=logging.WARNING) |
| | else: |
| | print_log( |
| | f'Failed to import {self.scope} and register ' |
| | 'its modules, please make sure you ' |
| | 'have registered the module manually.', |
| | logger='current', |
| | level=logging.WARNING) |
| | else: |
| | |
| | |
| | |
| | |
| | module.register_all_modules(False) |
| |
|
| | for loc in self._locations: |
| | import_module(loc) |
| | print_log( |
| | f"Modules of {self.scope}'s {self.name} registry have " |
| | f'been automatically imported from {loc}', |
| | logger='current', |
| | level=logging.DEBUG) |
| | self._imported = True |
| |
|
| | def get(self, key: str) -> Optional[Type]: |
| | """Get the registry record. |
| | |
| | If `key`` represents the whole object name with its module |
| | information, for example, `mmengine.model.BaseModel`, ``get`` |
| | will directly return the class object :class:`BaseModel`. |
| | |
| | Otherwise, it will first parse ``key`` and check whether it |
| | contains a scope name. The logic to search for ``key``: |
| | |
| | - ``key`` does not contain a scope name, i.e., it is purely a module |
| | name like "ResNet": :meth:`get` will search for ``ResNet`` from the |
| | current registry to its parent or ancestors until finding it. |
| | |
| | - ``key`` contains a scope name and it is equal to the scope of the |
| | current registry (e.g., "mmcls"), e.g., "mmcls.ResNet": :meth:`get` |
| | will only search for ``ResNet`` in the current registry. |
| | |
| | - ``key`` contains a scope name and it is not equal to the scope of |
| | the current registry (e.g., "mmdet"), e.g., "mmcls.FCNet": If the |
| | scope exists in its children, :meth:`get` will get "FCNet" from |
| | them. If not, :meth:`get` will first get the root registry and root |
| | registry call its own :meth:`get` method. |
| | |
| | Args: |
| | key (str): Name of the registered item, e.g., the class name in |
| | string format. |
| | |
| | Returns: |
| | Type or None: Return the corresponding class if ``key`` exists, |
| | otherwise return None. |
| | |
| | Examples: |
| | >>> # define a registry |
| | >>> MODELS = Registry('models') |
| | >>> # register `ResNet` to `MODELS` |
| | >>> @MODELS.register_module() |
| | >>> class ResNet: |
| | >>> pass |
| | >>> resnet_cls = MODELS.get('ResNet') |
| | |
| | >>> # hierarchical registry |
| | >>> DETECTORS = Registry('detector', parent=MODELS, scope='det') |
| | >>> # `ResNet` does not exist in `DETECTORS` but `get` method |
| | >>> # will try to search from its parents or ancestors |
| | >>> resnet_cls = DETECTORS.get('ResNet') |
| | >>> CLASSIFIER = Registry('classifier', parent=MODELS, scope='cls') |
| | >>> @CLASSIFIER.register_module() |
| | >>> class MobileNet: |
| | >>> pass |
| | >>> # `get` from its sibling registries |
| | >>> mobilenet_cls = DETECTORS.get('cls.MobileNet') |
| | """ |
| | |
| | from ..logging import print_log |
| |
|
| | if not isinstance(key, str): |
| | raise TypeError( |
| | 'The key argument of `Registry.get` must be a str, ' |
| | f'got {type(key)}') |
| |
|
| | scope, real_key = self.split_scope_key(key) |
| | obj_cls = None |
| | registry_name = self.name |
| | scope_name = self.scope |
| |
|
| | |
| | self.import_from_location() |
| |
|
| | if scope is None or scope == self._scope: |
| | |
| | if real_key in self._module_dict: |
| | obj_cls = self._module_dict[real_key] |
| | elif scope is None: |
| | |
| | parent = self.parent |
| | while parent is not None: |
| | if real_key in parent._module_dict: |
| | obj_cls = parent._module_dict[real_key] |
| | registry_name = parent.name |
| | scope_name = parent.scope |
| | break |
| | parent = parent.parent |
| | else: |
| | |
| | try: |
| | import_module(f'{scope}.registry') |
| | print_log( |
| | f'Registry node of {scope} has been automatically ' |
| | 'imported.', |
| | logger='current', |
| | level=logging.DEBUG) |
| | except (ImportError, AttributeError, ModuleNotFoundError): |
| | print_log( |
| | f'Cannot auto import {scope}.registry, please check ' |
| | f'whether the package "{scope}" is installed correctly ' |
| | 'or import the registry manually.', |
| | logger='current', |
| | level=logging.DEBUG) |
| | |
| | if scope in self._children: |
| | obj_cls = self._children[scope].get(real_key) |
| | registry_name = self._children[scope].name |
| | scope_name = scope |
| | else: |
| | root = self._get_root_registry() |
| |
|
| | if scope != root._scope and scope not in root._children: |
| | |
| | |
| | pass |
| | else: |
| | obj_cls = root.get(key) |
| |
|
| | if obj_cls is None: |
| | |
| | |
| | |
| | |
| | |
| | |
| | try: |
| | obj_cls = get_object_from_string(key) |
| | except Exception: |
| | raise RuntimeError(f'Failed to get {key}') |
| |
|
| | if obj_cls is not None: |
| | |
| | |
| | cls_name = getattr(obj_cls, '__name__', str(obj_cls)) |
| | print_log( |
| | f'Get class `{cls_name}` from "{registry_name}"' |
| | f' registry in "{scope_name}"', |
| | logger='current', |
| | level=logging.DEBUG) |
| |
|
| | return obj_cls |
| |
|
| | def _search_child(self, scope: str) -> Optional['Registry']: |
| | """Depth-first search for the corresponding registry in its children. |
| | |
| | Note that the method only search for the corresponding registry from |
| | the current registry. Therefore, if we want to search from the root |
| | registry, :meth:`_get_root_registry` should be called to get the |
| | root registry first. |
| | |
| | Args: |
| | scope (str): The scope name used for searching for its |
| | corresponding registry. |
| | |
| | Returns: |
| | Registry or None: Return the corresponding registry if ``scope`` |
| | exists, otherwise return None. |
| | """ |
| | if self._scope == scope: |
| | return self |
| |
|
| | for child in self._children.values(): |
| | registry = child._search_child(scope) |
| | if registry is not None: |
| | return registry |
| |
|
| | return None |
| |
|
| | def build(self, cfg: dict, *args, **kwargs) -> Any: |
| | """Build an instance. |
| | |
| | Build an instance by calling :attr:`build_func`. |
| | |
| | Args: |
| | cfg (dict): Config dict needs to be built. |
| | |
| | Returns: |
| | Any: The constructed object. |
| | |
| | Examples: |
| | >>> from mmengine import Registry |
| | >>> MODELS = Registry('models') |
| | >>> @MODELS.register_module() |
| | >>> class ResNet: |
| | >>> def __init__(self, depth, stages=4): |
| | >>> self.depth = depth |
| | >>> self.stages = stages |
| | >>> cfg = dict(type='ResNet', depth=50) |
| | >>> model = MODELS.build(cfg) |
| | """ |
| | return self.build_func(cfg, *args, **kwargs, registry=self) |
| |
|
| | def _add_child(self, registry: 'Registry') -> None: |
| | """Add a child for a registry. |
| | |
| | Args: |
| | registry (:obj:`Registry`): The ``registry`` will be added as a |
| | child of the ``self``. |
| | """ |
| |
|
| | assert isinstance(registry, Registry) |
| | assert registry.scope is not None |
| | assert registry.scope not in self.children, \ |
| | f'scope {registry.scope} exists in {self.name} registry' |
| | self.children[registry.scope] = registry |
| |
|
| | def _register_module(self, |
| | module: Type, |
| | module_name: Optional[Union[str, List[str]]] = None, |
| | force: bool = False) -> None: |
| | """Register a module. |
| | |
| | Args: |
| | module (type): Module to be registered. Typically a class or a |
| | function, but generally all ``Callable`` are acceptable. |
| | module_name (str or list of str, optional): The module name to be |
| | registered. If not specified, the class name will be used. |
| | Defaults to None. |
| | force (bool): Whether to override an existing class with the same |
| | name. Defaults to False. |
| | """ |
| | if not callable(module): |
| | raise TypeError(f'module must be Callable, but got {type(module)}') |
| |
|
| | if module_name is None: |
| | module_name = module.__name__ |
| | if isinstance(module_name, str): |
| | module_name = [module_name] |
| | for name in module_name: |
| | if not force and name in self._module_dict: |
| | existed_module = self.module_dict[name] |
| | raise KeyError(f'{name} is already registered in {self.name} ' |
| | f'at {existed_module.__module__}') |
| | self._module_dict[name] = module |
| |
|
| | def register_module( |
| | self, |
| | name: Optional[Union[str, List[str]]] = None, |
| | force: bool = False, |
| | module: Optional[Type] = None) -> Union[type, Callable]: |
| | """Register a module. |
| | |
| | A record will be added to ``self._module_dict``, whose key is the class |
| | name or the specified name, and value is the class itself. |
| | It can be used as a decorator or a normal function. |
| | |
| | Args: |
| | name (str or list of str, optional): The module name to be |
| | registered. If not specified, the class name will be used. |
| | force (bool): Whether to override an existing class with the same |
| | name. Defaults to False. |
| | module (type, optional): Module class or function to be registered. |
| | Defaults to None. |
| | |
| | Examples: |
| | >>> backbones = Registry('backbone') |
| | >>> # as a decorator |
| | >>> @backbones.register_module() |
| | >>> class ResNet: |
| | >>> pass |
| | >>> backbones = Registry('backbone') |
| | >>> @backbones.register_module(name='mnet') |
| | >>> class MobileNet: |
| | >>> pass |
| | |
| | >>> # as a normal function |
| | >>> class ResNet: |
| | >>> pass |
| | >>> backbones.register_module(module=ResNet) |
| | """ |
| | if not isinstance(force, bool): |
| | raise TypeError(f'force must be a boolean, but got {type(force)}') |
| |
|
| | |
| | if not (name is None or isinstance(name, str) or is_seq_of(name, str)): |
| | raise TypeError( |
| | 'name must be None, an instance of str, or a sequence of str, ' |
| | f'but got {type(name)}') |
| |
|
| | |
| | if module is not None: |
| | self._register_module(module=module, module_name=name, force=force) |
| | return module |
| |
|
| | |
| | def _register(module): |
| | self._register_module(module=module, module_name=name, force=force) |
| | return module |
| |
|
| | return _register |
| |
|