import torch from collections import OrderedDict __model_types = [ 'resnet50', 'mlfn', 'hacnn', 'mobilenetv2_x1_0', 'mobilenetv2_x1_4', 'osnet_x1_0', 'osnet_x0_75', 'osnet_x0_5', 'osnet_x0_25', 'osnet_ibn_x1_0', 'osnet_ain_x1_0'] __trained_urls = { # market1501 models ######################################################## 'resnet50_market1501.pt': 'https://drive.google.com/uc?id=1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV', 'resnet50_dukemtmcreid.pt': 'https://drive.google.com/uc?id=17ymnLglnc64NRvGOitY3BqMRS9UWd1wg', 'resnet50_msmt17.pt': 'https://drive.google.com/uc?id=1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj', 'resnet50_fc512_market1501.pt': 'https://drive.google.com/uc?id=1kv8l5laX_YCdIGVCetjlNdzKIA3NvsSt', 'resnet50_fc512_dukemtmcreid.pt': 'https://drive.google.com/uc?id=13QN8Mp3XH81GK4BPGXobKHKyTGH50Rtx', 'resnet50_fc512_msmt17.pt': 'https://drive.google.com/uc?id=1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud', 'mlfn_market1501.pt': 'https://drive.google.com/uc?id=1wXcvhA_b1kpDfrt9s2Pma-MHxtj9pmvS', 'mlfn_dukemtmcreid.pt': 'https://drive.google.com/uc?id=1rExgrTNb0VCIcOnXfMsbwSUW1h2L1Bum', 'mlfn_msmt17.pt': 'https://drive.google.com/uc?id=18JzsZlJb3Wm7irCbZbZ07TN4IFKvR6p-', 'hacnn_market1501.pt': 'https://drive.google.com/uc?id=1LRKIQduThwGxMDQMiVkTScBwR7WidmYF', 'hacnn_dukemtmcreid.pt': 'https://drive.google.com/uc?id=1zNm6tP4ozFUCUQ7Sv1Z98EAJWXJEhtYH', 'hacnn_msmt17.pt': 'https://drive.google.com/uc?id=1MsKRtPM5WJ3_Tk2xC0aGOO7pM3VaFDNZ', 'mobilenetv2_x1_0_market1501.pt': 'https://drive.google.com/uc?id=18DgHC2ZJkjekVoqBWszD8_Xiikz-fewp', 'mobilenetv2_x1_0_dukemtmcreid.pt': 'https://drive.google.com/uc?id=1q1WU2FETRJ3BXcpVtfJUuqq4z3psetds', 'mobilenetv2_x1_0_msmt17.pt': 'https://drive.google.com/uc?id=1j50Hv14NOUAg7ZeB3frzfX-WYLi7SrhZ', 'mobilenetv2_x1_4_market1501.pt': 'https://drive.google.com/uc?id=1t6JCqphJG-fwwPVkRLmGGyEBhGOf2GO5', 'mobilenetv2_x1_4_dukemtmcreid.pt': 'https://drive.google.com/uc?id=12uD5FeVqLg9-AFDju2L7SQxjmPb4zpBN', 'mobilenetv2_x1_4_msmt17.pt': 'https://drive.google.com/uc?id=1ZY5P2Zgm-3RbDpbXM0kIBMPvspeNIbXz', 'osnet_x1_0_market1501.pt': 'https://drive.google.com/uc?id=1vduhq5DpN2q1g4fYEZfPI17MJeh9qyrA', 'osnet_x1_0_dukemtmcreid.pt': 'https://drive.google.com/uc?id=1QZO_4sNf4hdOKKKzKc-TZU9WW1v6zQbq', 'osnet_x1_0_msmt17.pt': 'https://drive.google.com/uc?id=112EMUfBPYeYg70w-syK6V6Mx8-Qb9Q1M', 'osnet_x0_75_market1501.pt': 'https://drive.google.com/uc?id=1ozRaDSQw_EQ8_93OUmjDbvLXw9TnfPer', 'osnet_x0_75_dukemtmcreid.pt': 'https://drive.google.com/uc?id=1IE3KRaTPp4OUa6PGTFL_d5_KQSJbP0Or', 'osnet_x0_75_msmt17.pt': 'https://drive.google.com/uc?id=1QEGO6WnJ-BmUzVPd3q9NoaO_GsPNlmWc', 'osnet_x0_5_market1501.pt': 'https://drive.google.com/uc?id=1PLB9rgqrUM7blWrg4QlprCuPT7ILYGKT', 'osnet_x0_5_dukemtmcreid.pt': 'https://drive.google.com/uc?id=1KoUVqmiST175hnkALg9XuTi1oYpqcyTu', 'osnet_x0_5_msmt17.pt': 'https://drive.google.com/uc?id=1UT3AxIaDvS2PdxzZmbkLmjtiqq7AIKCv', 'osnet_x0_25_market1501.pt': 'https://drive.google.com/uc?id=1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj', 'osnet_x0_25_dukemtmcreid.pt': 'https://drive.google.com/uc?id=1eumrtiXT4NOspjyEV4j8cHmlOaaCGk5l', 'osnet_x0_25_msmt17.pt': 'https://drive.google.com/uc?id=1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF', ####### market1501 models ################################################## 'resnet50_msmt17.pt': 'https://drive.google.com/uc?id=1yiBteqgIZoOeywE8AhGmEQl7FTVwrQmf', 'osnet_x1_0_msmt17.pt': 'https://drive.google.com/uc?id=1IosIFlLiulGIjwW3H8uMRmx3MzPwf86x', 'osnet_x0_75_msmt17.pt': 'https://drive.google.com/uc?id=1fhjSS_7SUGCioIf2SWXaRGPqIY9j7-uw', 'osnet_x0_5_msmt17.pt': 'https://drive.google.com/uc?id=1DHgmb6XV4fwG3n-CnCM0zdL9nMsZ9_RF', 'osnet_x0_25_msmt17.pt': 'https://drive.google.com/uc?id=1Kkx2zW89jq_NETu4u42CFZTMVD5Hwm6e', 'osnet_ibn_x1_0_msmt17.pt': 'https://drive.google.com/uc?id=1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ', 'osnet_ain_x1_0_msmt17.pt': 'https://drive.google.com/uc?id=1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal', } def show_downloadeable_models(): print('\nAvailable .pt ReID models for automatic download') print(list(__trained_urls.keys())) def get_model_url(model): if model.name in __trained_urls: return __trained_urls[model.name] else: None def is_model_in_model_types(model): if model.name in __model_types: return True else: return False def get_model_name(model): for x in __model_types: if x in model.name: return x return None def download_url(url, dst): """Downloads file from a url to a destination. Args: url (str): url to download file. dst (str): destination path. """ from six.moves import urllib print('* url="{}"'.format(url)) print('* destination="{}"'.format(dst)) def _reporthook(count, block_size, total_size): global start_time if count == 0: start_time = time.time() return duration = time.time() - start_time progress_size = int(count * block_size) speed = int(progress_size / (1024*duration)) percent = int(count * block_size * 100 / total_size) sys.stdout.write( '\r...%d%%, %d MB, %d KB/s, %d seconds passed' % (percent, progress_size / (1024*1024), speed, duration) ) sys.stdout.flush() urllib.request.urlretrieve(url, dst, _reporthook) sys.stdout.write('\n') def load_pretrained_weights(model, weight_path): r"""Loads pretrianed weights to model. Features:: - Incompatible layers (unmatched in name or size) will be ignored. - Can automatically deal with keys containing "module.". Args: model (nn.Module): network model. weight_path (str): path to pretrained weights. Examples:: >>> from torchreid.utils import load_pretrained_weights >>> weight_path = 'log/my_model/model-best.pth.tar' >>> load_pretrained_weights(model, weight_path) """ checkpoint = torch.load(weight_path) if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint model_dict = model.state_dict() new_state_dict = OrderedDict() matched_layers, discarded_layers = [], [] for k, v in state_dict.items(): if k.startswith('module.'): k = k[7:] # discard module. if k in model_dict and model_dict[k].size() == v.size(): new_state_dict[k] = v matched_layers.append(k) else: discarded_layers.append(k) model_dict.update(new_state_dict) model.load_state_dict(model_dict) if len(matched_layers) == 0: warnings.warn( 'The pretrained weights "{}" cannot be loaded, ' 'please check the key names manually ' '(** ignored and continue **)'.format(weight_path) ) else: print( 'Successfully loaded pretrained weights from "{}"'. format(weight_path) ) if len(discarded_layers) > 0: print( '** The following layers are discarded ' 'due to unmatched keys or layer size: {}'. format(discarded_layers) )