|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
from functools import partial |
|
|
|
|
|
class Registry(object): |
|
def __init__(self, name): |
|
self._name = name |
|
self._module_dict = dict() |
|
|
|
def __repr__(self): |
|
format_str = self.__class__.__name__ + "(name={}, items={})".format( |
|
self._name, list(self._module_dict.keys()) |
|
) |
|
return format_str |
|
|
|
def __len__(self): |
|
return len(self._module_dict) |
|
|
|
@property |
|
def name(self): |
|
return self._name |
|
|
|
@property |
|
def module_dict(self): |
|
return self._module_dict |
|
|
|
def get(self, key): |
|
return self._module_dict.get(key, None) |
|
|
|
def registe_with_name(self, module_name=None, force=False): |
|
return partial(self.register, module_name=module_name, force=force) |
|
|
|
def register(self, module_build_function, module_name=None, force=False): |
|
"""Register a module build function. |
|
Args: |
|
module (:obj:`nn.Module`): Module to be registered. |
|
""" |
|
if not inspect.isfunction(module_build_function): |
|
raise TypeError( |
|
"module_build_function must be a function, but got {}".format( |
|
type(module_build_function) |
|
) |
|
) |
|
if module_name is None: |
|
module_name = module_build_function.__name__ |
|
if not force and module_name in self._module_dict: |
|
raise KeyError("{} is already registered in {}".format(module_name, self.name)) |
|
self._module_dict[module_name] = module_build_function |
|
|
|
return module_build_function |
|
|
|
|
|
MODULE_BUILD_FUNCS = Registry("model build functions") |
|
|