character-360 / annotator /registry.py
aki-0421
F: add
a3a3ae4 unverified
raw
history blame
1.22 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from scepter.modules.utils.config import Config
from scepter.modules.utils.registry import Registry, build_from_config
def build_annotator(cfg, registry, logger=None, *args, **kwargs):
""" After build model, load pretrained model if exists key `pretrain`.
pretrain (str, dict): Describes how to load pretrained model.
str, treat pretrain as model path;
dict: should contains key `path`, and other parameters token by function load_pretrained();
"""
if not isinstance(cfg, Config):
raise TypeError(f'Config must be type dict, got {type(cfg)}')
if cfg.have('PRETRAINED_MODEL'):
pretrain_cfg = cfg.PRETRAINED_MODEL
if pretrain_cfg is not None and not isinstance(pretrain_cfg, (str)):
raise TypeError('Pretrain parameter must be a string')
else:
pretrain_cfg = None
model = build_from_config(cfg, registry, logger=logger, *args, **kwargs)
if pretrain_cfg is not None:
if hasattr(model, 'load_pretrained_model'):
model.load_pretrained_model(pretrain_cfg)
return model
ANNOTATORS = Registry('ANNOTATORS', build_func=build_annotator)