Spaces:
Runtime error
Runtime error
import torch | |
from collections import OrderedDict | |
__model_types = [ | |
'resnet50', 'mlfn', 'hacnn', 'mobilenetv2_x1_0', 'mobilenetv2_x1_4', | |
'osnet_x1_0', 'osnet_x0_75', 'osnet_x0_5', 'osnet_x0_25', | |
'osnet_ibn_x1_0', 'osnet_ain_x1_0'] | |
__trained_urls = { | |
# market1501 models ######################################################## | |
'resnet50_market1501.pt': | |
'https://drive.google.com/uc?id=1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV', | |
'resnet50_dukemtmcreid.pt': | |
'https://drive.google.com/uc?id=17ymnLglnc64NRvGOitY3BqMRS9UWd1wg', | |
'resnet50_msmt17.pt': | |
'https://drive.google.com/uc?id=1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj', | |
'resnet50_fc512_market1501.pt': | |
'https://drive.google.com/uc?id=1kv8l5laX_YCdIGVCetjlNdzKIA3NvsSt', | |
'resnet50_fc512_dukemtmcreid.pt': | |
'https://drive.google.com/uc?id=13QN8Mp3XH81GK4BPGXobKHKyTGH50Rtx', | |
'resnet50_fc512_msmt17.pt': | |
'https://drive.google.com/uc?id=1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud', | |
'mlfn_market1501.pt': | |
'https://drive.google.com/uc?id=1wXcvhA_b1kpDfrt9s2Pma-MHxtj9pmvS', | |
'mlfn_dukemtmcreid.pt': | |
'https://drive.google.com/uc?id=1rExgrTNb0VCIcOnXfMsbwSUW1h2L1Bum', | |
'mlfn_msmt17.pt': | |
'https://drive.google.com/uc?id=18JzsZlJb3Wm7irCbZbZ07TN4IFKvR6p-', | |
'hacnn_market1501.pt': | |
'https://drive.google.com/uc?id=1LRKIQduThwGxMDQMiVkTScBwR7WidmYF', | |
'hacnn_dukemtmcreid.pt': | |
'https://drive.google.com/uc?id=1zNm6tP4ozFUCUQ7Sv1Z98EAJWXJEhtYH', | |
'hacnn_msmt17.pt': | |
'https://drive.google.com/uc?id=1MsKRtPM5WJ3_Tk2xC0aGOO7pM3VaFDNZ', | |
'mobilenetv2_x1_0_market1501.pt': | |
'https://drive.google.com/uc?id=18DgHC2ZJkjekVoqBWszD8_Xiikz-fewp', | |
'mobilenetv2_x1_0_dukemtmcreid.pt': | |
'https://drive.google.com/uc?id=1q1WU2FETRJ3BXcpVtfJUuqq4z3psetds', | |
'mobilenetv2_x1_0_msmt17.pt': | |
'https://drive.google.com/uc?id=1j50Hv14NOUAg7ZeB3frzfX-WYLi7SrhZ', | |
'mobilenetv2_x1_4_market1501.pt': | |
'https://drive.google.com/uc?id=1t6JCqphJG-fwwPVkRLmGGyEBhGOf2GO5', | |
'mobilenetv2_x1_4_dukemtmcreid.pt': | |
'https://drive.google.com/uc?id=12uD5FeVqLg9-AFDju2L7SQxjmPb4zpBN', | |
'mobilenetv2_x1_4_msmt17.pt': | |
'https://drive.google.com/uc?id=1ZY5P2Zgm-3RbDpbXM0kIBMPvspeNIbXz', | |
'osnet_x1_0_market1501.pt': | |
'https://drive.google.com/uc?id=1vduhq5DpN2q1g4fYEZfPI17MJeh9qyrA', | |
'osnet_x1_0_dukemtmcreid.pt': | |
'https://drive.google.com/uc?id=1QZO_4sNf4hdOKKKzKc-TZU9WW1v6zQbq', | |
'osnet_x1_0_msmt17.pt': | |
'https://drive.google.com/uc?id=112EMUfBPYeYg70w-syK6V6Mx8-Qb9Q1M', | |
'osnet_x0_75_market1501.pt': | |
'https://drive.google.com/uc?id=1ozRaDSQw_EQ8_93OUmjDbvLXw9TnfPer', | |
'osnet_x0_75_dukemtmcreid.pt': | |
'https://drive.google.com/uc?id=1IE3KRaTPp4OUa6PGTFL_d5_KQSJbP0Or', | |
'osnet_x0_75_msmt17.pt': | |
'https://drive.google.com/uc?id=1QEGO6WnJ-BmUzVPd3q9NoaO_GsPNlmWc', | |
'osnet_x0_5_market1501.pt': | |
'https://drive.google.com/uc?id=1PLB9rgqrUM7blWrg4QlprCuPT7ILYGKT', | |
'osnet_x0_5_dukemtmcreid.pt': | |
'https://drive.google.com/uc?id=1KoUVqmiST175hnkALg9XuTi1oYpqcyTu', | |
'osnet_x0_5_msmt17.pt': | |
'https://drive.google.com/uc?id=1UT3AxIaDvS2PdxzZmbkLmjtiqq7AIKCv', | |
'osnet_x0_25_market1501.pt': | |
'https://drive.google.com/uc?id=1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj', | |
'osnet_x0_25_dukemtmcreid.pt': | |
'https://drive.google.com/uc?id=1eumrtiXT4NOspjyEV4j8cHmlOaaCGk5l', | |
'osnet_x0_25_msmt17.pt': | |
'https://drive.google.com/uc?id=1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF', | |
####### market1501 models ################################################## | |
'resnet50_msmt17.pt': | |
'https://drive.google.com/uc?id=1yiBteqgIZoOeywE8AhGmEQl7FTVwrQmf', | |
'osnet_x1_0_msmt17.pt': | |
'https://drive.google.com/uc?id=1IosIFlLiulGIjwW3H8uMRmx3MzPwf86x', | |
'osnet_x0_75_msmt17.pt': | |
'https://drive.google.com/uc?id=1fhjSS_7SUGCioIf2SWXaRGPqIY9j7-uw', | |
'osnet_x0_5_msmt17.pt': | |
'https://drive.google.com/uc?id=1DHgmb6XV4fwG3n-CnCM0zdL9nMsZ9_RF', | |
'osnet_x0_25_msmt17.pt': | |
'https://drive.google.com/uc?id=1Kkx2zW89jq_NETu4u42CFZTMVD5Hwm6e', | |
'osnet_ibn_x1_0_msmt17.pt': | |
'https://drive.google.com/uc?id=1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ', | |
'osnet_ain_x1_0_msmt17.pt': | |
'https://drive.google.com/uc?id=1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal', | |
} | |
def show_downloadeable_models(): | |
print('\nAvailable .pt ReID models for automatic download') | |
print(list(__trained_urls.keys())) | |
def get_model_url(model): | |
if model.name in __trained_urls: | |
return __trained_urls[model.name] | |
else: | |
None | |
def is_model_in_model_types(model): | |
if model.name in __model_types: | |
return True | |
else: | |
return False | |
def get_model_name(model): | |
for x in __model_types: | |
if x in model.name: | |
return x | |
return None | |
def download_url(url, dst): | |
"""Downloads file from a url to a destination. | |
Args: | |
url (str): url to download file. | |
dst (str): destination path. | |
""" | |
from six.moves import urllib | |
print('* url="{}"'.format(url)) | |
print('* destination="{}"'.format(dst)) | |
def _reporthook(count, block_size, total_size): | |
global start_time | |
if count == 0: | |
start_time = time.time() | |
return | |
duration = time.time() - start_time | |
progress_size = int(count * block_size) | |
speed = int(progress_size / (1024*duration)) | |
percent = int(count * block_size * 100 / total_size) | |
sys.stdout.write( | |
'\r...%d%%, %d MB, %d KB/s, %d seconds passed' % | |
(percent, progress_size / (1024*1024), speed, duration) | |
) | |
sys.stdout.flush() | |
urllib.request.urlretrieve(url, dst, _reporthook) | |
sys.stdout.write('\n') | |
def load_pretrained_weights(model, weight_path): | |
r"""Loads pretrianed weights to model. | |
Features:: | |
- Incompatible layers (unmatched in name or size) will be ignored. | |
- Can automatically deal with keys containing "module.". | |
Args: | |
model (nn.Module): network model. | |
weight_path (str): path to pretrained weights. | |
Examples:: | |
>>> from torchreid.utils import load_pretrained_weights | |
>>> weight_path = 'log/my_model/model-best.pth.tar' | |
>>> load_pretrained_weights(model, weight_path) | |
""" | |
checkpoint = torch.load(weight_path) | |
if 'state_dict' in checkpoint: | |
state_dict = checkpoint['state_dict'] | |
else: | |
state_dict = checkpoint | |
model_dict = model.state_dict() | |
new_state_dict = OrderedDict() | |
matched_layers, discarded_layers = [], [] | |
for k, v in state_dict.items(): | |
if k.startswith('module.'): | |
k = k[7:] # discard module. | |
if k in model_dict and model_dict[k].size() == v.size(): | |
new_state_dict[k] = v | |
matched_layers.append(k) | |
else: | |
discarded_layers.append(k) | |
model_dict.update(new_state_dict) | |
model.load_state_dict(model_dict) | |
if len(matched_layers) == 0: | |
warnings.warn( | |
'The pretrained weights "{}" cannot be loaded, ' | |
'please check the key names manually ' | |
'(** ignored and continue **)'.format(weight_path) | |
) | |
else: | |
print( | |
'Successfully loaded pretrained weights from "{}"'. | |
format(weight_path) | |
) | |
if len(discarded_layers) > 0: | |
print( | |
'** The following layers are discarded ' | |
'due to unmatched keys or layer size: {}'. | |
format(discarded_layers) | |
) | |