Spaces:
Runtime error
Runtime error
| import inspect | |
| import platform | |
| from .registry import PLUGIN_LAYERS | |
| if platform.system() == 'Windows': | |
| import regex as re | |
| else: | |
| import re | |
| def infer_abbr(class_type): | |
| """Infer abbreviation from the class name. | |
| This method will infer the abbreviation to map class types to | |
| abbreviations. | |
| Rule 1: If the class has the property "abbr", return the property. | |
| Rule 2: Otherwise, the abbreviation falls back to snake case of class | |
| name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``. | |
| Args: | |
| class_type (type): The norm layer type. | |
| Returns: | |
| str: The inferred abbreviation. | |
| """ | |
| def camel2snack(word): | |
| """Convert camel case word into snack case. | |
| Modified from `inflection lib | |
| <https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_. | |
| Example:: | |
| >>> camel2snack("FancyBlock") | |
| 'fancy_block' | |
| """ | |
| word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word) | |
| word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word) | |
| word = word.replace('-', '_') | |
| return word.lower() | |
| if not inspect.isclass(class_type): | |
| raise TypeError( | |
| f'class_type must be a type, but got {type(class_type)}') | |
| if hasattr(class_type, '_abbr_'): | |
| return class_type._abbr_ | |
| else: | |
| return camel2snack(class_type.__name__) | |
| def build_plugin_layer(cfg, postfix='', **kwargs): | |
| """Build plugin layer. | |
| Args: | |
| cfg (None or dict): cfg should contain: | |
| type (str): identify plugin layer type. | |
| layer args: args needed to instantiate a plugin layer. | |
| postfix (int, str): appended into norm abbreviation to | |
| create named layer. Default: ''. | |
| Returns: | |
| tuple[str, nn.Module]: | |
| name (str): abbreviation + postfix | |
| layer (nn.Module): created plugin layer | |
| """ | |
| if not isinstance(cfg, dict): | |
| raise TypeError('cfg must be a dict') | |
| if 'type' not in cfg: | |
| raise KeyError('the cfg dict must contain the key "type"') | |
| cfg_ = cfg.copy() | |
| layer_type = cfg_.pop('type') | |
| if layer_type not in PLUGIN_LAYERS: | |
| raise KeyError(f'Unrecognized plugin type {layer_type}') | |
| plugin_layer = PLUGIN_LAYERS.get(layer_type) | |
| abbr = infer_abbr(plugin_layer) | |
| assert isinstance(postfix, (int, str)) | |
| name = abbr + str(postfix) | |
| layer = plugin_layer(**kwargs, **cfg_) | |
| return name, layer | |