Spaces:
Runtime error
Runtime error
from __future__ import absolute_import | |
import torch | |
from .pcb import * | |
from .mlfn import * | |
from .hacnn import * | |
from .osnet import * | |
from .senet import * | |
from .mudeep import * | |
from .nasnet import * | |
from .resnet import * | |
from .densenet import * | |
from .xception import * | |
from .osnet_ain import * | |
from .resnetmid import * | |
from .shufflenet import * | |
from .squeezenet import * | |
from .inceptionv4 import * | |
from .mobilenetv2 import * | |
from .resnet_ibn_a import * | |
from .resnet_ibn_b import * | |
from .shufflenetv2 import * | |
from .inceptionresnetv2 import * | |
__model_factory = { | |
# image classification models | |
'resnet18': resnet18, | |
'resnet34': resnet34, | |
'resnet50': resnet50, | |
'resnet101': resnet101, | |
'resnet152': resnet152, | |
'resnext50_32x4d': resnext50_32x4d, | |
'resnext101_32x8d': resnext101_32x8d, | |
'resnet50_fc512': resnet50_fc512, | |
'se_resnet50': se_resnet50, | |
'se_resnet50_fc512': se_resnet50_fc512, | |
'se_resnet101': se_resnet101, | |
'se_resnext50_32x4d': se_resnext50_32x4d, | |
'se_resnext101_32x4d': se_resnext101_32x4d, | |
'densenet121': densenet121, | |
'densenet169': densenet169, | |
'densenet201': densenet201, | |
'densenet161': densenet161, | |
'densenet121_fc512': densenet121_fc512, | |
'inceptionresnetv2': inceptionresnetv2, | |
'inceptionv4': inceptionv4, | |
'xception': xception, | |
'resnet50_ibn_a': resnet50_ibn_a, | |
'resnet50_ibn_b': resnet50_ibn_b, | |
# lightweight models | |
'nasnsetmobile': nasnetamobile, | |
'mobilenetv2_x1_0': mobilenetv2_x1_0, | |
'mobilenetv2_x1_4': mobilenetv2_x1_4, | |
'shufflenet': shufflenet, | |
'squeezenet1_0': squeezenet1_0, | |
'squeezenet1_0_fc512': squeezenet1_0_fc512, | |
'squeezenet1_1': squeezenet1_1, | |
'shufflenet_v2_x0_5': shufflenet_v2_x0_5, | |
'shufflenet_v2_x1_0': shufflenet_v2_x1_0, | |
'shufflenet_v2_x1_5': shufflenet_v2_x1_5, | |
'shufflenet_v2_x2_0': shufflenet_v2_x2_0, | |
# reid-specific models | |
'mudeep': MuDeep, | |
'resnet50mid': resnet50mid, | |
'hacnn': HACNN, | |
'pcb_p6': pcb_p6, | |
'pcb_p4': pcb_p4, | |
'mlfn': mlfn, | |
'osnet_x1_0': osnet_x1_0, | |
'osnet_x0_75': osnet_x0_75, | |
'osnet_x0_5': osnet_x0_5, | |
'osnet_x0_25': osnet_x0_25, | |
'osnet_ibn_x1_0': osnet_ibn_x1_0, | |
'osnet_ain_x1_0': osnet_ain_x1_0, | |
'osnet_ain_x0_75': osnet_ain_x0_75, | |
'osnet_ain_x0_5': osnet_ain_x0_5, | |
'osnet_ain_x0_25': osnet_ain_x0_25 | |
} | |
def show_avai_models(): | |
"""Displays available models. | |
Examples:: | |
>>> from torchreid import models | |
>>> models.show_avai_models() | |
""" | |
print(list(__model_factory.keys())) | |
def build_model( | |
name, num_classes, loss='softmax', pretrained=True, use_gpu=True | |
): | |
"""A function wrapper for building a model. | |
Args: | |
name (str): model name. | |
num_classes (int): number of training identities. | |
loss (str, optional): loss function to optimize the model. Currently | |
supports "softmax" and "triplet". Default is "softmax". | |
pretrained (bool, optional): whether to load ImageNet-pretrained weights. | |
Default is True. | |
use_gpu (bool, optional): whether to use gpu. Default is True. | |
Returns: | |
nn.Module | |
Examples:: | |
>>> from torchreid import models | |
>>> model = models.build_model('resnet50', 751, loss='softmax') | |
""" | |
avai_models = list(__model_factory.keys()) | |
if name not in avai_models: | |
raise KeyError( | |
'Unknown model: {}. Must be one of {}'.format(name, avai_models) | |
) | |
return __model_factory[name]( | |
num_classes=num_classes, | |
loss=loss, | |
pretrained=pretrained, | |
use_gpu=use_gpu | |
) | |