Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
from scepter.modules.utils.config import Config | |
from scepter.modules.utils.registry import Registry, build_from_config | |
def build_annotator(cfg, registry, logger=None, *args, **kwargs): | |
""" After build model, load pretrained model if exists key `pretrain`. | |
pretrain (str, dict): Describes how to load pretrained model. | |
str, treat pretrain as model path; | |
dict: should contains key `path`, and other parameters token by function load_pretrained(); | |
""" | |
if not isinstance(cfg, Config): | |
raise TypeError(f'Config must be type dict, got {type(cfg)}') | |
if cfg.have('PRETRAINED_MODEL'): | |
pretrain_cfg = cfg.PRETRAINED_MODEL | |
if pretrain_cfg is not None and not isinstance(pretrain_cfg, (str)): | |
raise TypeError('Pretrain parameter must be a string') | |
else: | |
pretrain_cfg = None | |
model = build_from_config(cfg, registry, logger=logger, *args, **kwargs) | |
if pretrain_cfg is not None: | |
if hasattr(model, 'load_pretrained_model'): | |
model.load_pretrained_model(pretrain_cfg) | |
return model | |
ANNOTATORS = Registry('ANNOTATORS', build_func=build_annotator) | |