from .config import set_layer_config from .helpers import load_checkpoint from .gen_efficientnet import * from .mobilenetv3 import * def create_model( model_name='mnasnet_100', pretrained=None, num_classes=1000, in_chans=3, checkpoint_path='', **kwargs): model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs) if model_name in globals(): create_fn = globals()[model_name] model = create_fn(**model_kwargs) else: raise RuntimeError('Unknown model (%s)' % model_name) if checkpoint_path and not pretrained: load_checkpoint(model, checkpoint_path) return model