IDD-YOLO-Tracking / trackers /strongsort /deep /reid_model_factory.py
Bhaskar Saranga
Added tracker
e215925
raw
history blame contribute delete
No virus
7.53 kB
import torch
from collections import OrderedDict
__model_types = [
'resnet50', 'mlfn', 'hacnn', 'mobilenetv2_x1_0', 'mobilenetv2_x1_4',
'osnet_x1_0', 'osnet_x0_75', 'osnet_x0_5', 'osnet_x0_25',
'osnet_ibn_x1_0', 'osnet_ain_x1_0']
__trained_urls = {
# market1501 models ########################################################
'resnet50_market1501.pt':
'https://drive.google.com/uc?id=1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV',
'resnet50_dukemtmcreid.pt':
'https://drive.google.com/uc?id=17ymnLglnc64NRvGOitY3BqMRS9UWd1wg',
'resnet50_msmt17.pt':
'https://drive.google.com/uc?id=1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj',
'resnet50_fc512_market1501.pt':
'https://drive.google.com/uc?id=1kv8l5laX_YCdIGVCetjlNdzKIA3NvsSt',
'resnet50_fc512_dukemtmcreid.pt':
'https://drive.google.com/uc?id=13QN8Mp3XH81GK4BPGXobKHKyTGH50Rtx',
'resnet50_fc512_msmt17.pt':
'https://drive.google.com/uc?id=1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud',
'mlfn_market1501.pt':
'https://drive.google.com/uc?id=1wXcvhA_b1kpDfrt9s2Pma-MHxtj9pmvS',
'mlfn_dukemtmcreid.pt':
'https://drive.google.com/uc?id=1rExgrTNb0VCIcOnXfMsbwSUW1h2L1Bum',
'mlfn_msmt17.pt':
'https://drive.google.com/uc?id=18JzsZlJb3Wm7irCbZbZ07TN4IFKvR6p-',
'hacnn_market1501.pt':
'https://drive.google.com/uc?id=1LRKIQduThwGxMDQMiVkTScBwR7WidmYF',
'hacnn_dukemtmcreid.pt':
'https://drive.google.com/uc?id=1zNm6tP4ozFUCUQ7Sv1Z98EAJWXJEhtYH',
'hacnn_msmt17.pt':
'https://drive.google.com/uc?id=1MsKRtPM5WJ3_Tk2xC0aGOO7pM3VaFDNZ',
'mobilenetv2_x1_0_market1501.pt':
'https://drive.google.com/uc?id=18DgHC2ZJkjekVoqBWszD8_Xiikz-fewp',
'mobilenetv2_x1_0_dukemtmcreid.pt':
'https://drive.google.com/uc?id=1q1WU2FETRJ3BXcpVtfJUuqq4z3psetds',
'mobilenetv2_x1_0_msmt17.pt':
'https://drive.google.com/uc?id=1j50Hv14NOUAg7ZeB3frzfX-WYLi7SrhZ',
'mobilenetv2_x1_4_market1501.pt':
'https://drive.google.com/uc?id=1t6JCqphJG-fwwPVkRLmGGyEBhGOf2GO5',
'mobilenetv2_x1_4_dukemtmcreid.pt':
'https://drive.google.com/uc?id=12uD5FeVqLg9-AFDju2L7SQxjmPb4zpBN',
'mobilenetv2_x1_4_msmt17.pt':
'https://drive.google.com/uc?id=1ZY5P2Zgm-3RbDpbXM0kIBMPvspeNIbXz',
'osnet_x1_0_market1501.pt':
'https://drive.google.com/uc?id=1vduhq5DpN2q1g4fYEZfPI17MJeh9qyrA',
'osnet_x1_0_dukemtmcreid.pt':
'https://drive.google.com/uc?id=1QZO_4sNf4hdOKKKzKc-TZU9WW1v6zQbq',
'osnet_x1_0_msmt17.pt':
'https://drive.google.com/uc?id=112EMUfBPYeYg70w-syK6V6Mx8-Qb9Q1M',
'osnet_x0_75_market1501.pt':
'https://drive.google.com/uc?id=1ozRaDSQw_EQ8_93OUmjDbvLXw9TnfPer',
'osnet_x0_75_dukemtmcreid.pt':
'https://drive.google.com/uc?id=1IE3KRaTPp4OUa6PGTFL_d5_KQSJbP0Or',
'osnet_x0_75_msmt17.pt':
'https://drive.google.com/uc?id=1QEGO6WnJ-BmUzVPd3q9NoaO_GsPNlmWc',
'osnet_x0_5_market1501.pt':
'https://drive.google.com/uc?id=1PLB9rgqrUM7blWrg4QlprCuPT7ILYGKT',
'osnet_x0_5_dukemtmcreid.pt':
'https://drive.google.com/uc?id=1KoUVqmiST175hnkALg9XuTi1oYpqcyTu',
'osnet_x0_5_msmt17.pt':
'https://drive.google.com/uc?id=1UT3AxIaDvS2PdxzZmbkLmjtiqq7AIKCv',
'osnet_x0_25_market1501.pt':
'https://drive.google.com/uc?id=1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj',
'osnet_x0_25_dukemtmcreid.pt':
'https://drive.google.com/uc?id=1eumrtiXT4NOspjyEV4j8cHmlOaaCGk5l',
'osnet_x0_25_msmt17.pt':
'https://drive.google.com/uc?id=1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF',
####### market1501 models ##################################################
'resnet50_msmt17.pt':
'https://drive.google.com/uc?id=1yiBteqgIZoOeywE8AhGmEQl7FTVwrQmf',
'osnet_x1_0_msmt17.pt':
'https://drive.google.com/uc?id=1IosIFlLiulGIjwW3H8uMRmx3MzPwf86x',
'osnet_x0_75_msmt17.pt':
'https://drive.google.com/uc?id=1fhjSS_7SUGCioIf2SWXaRGPqIY9j7-uw',
'osnet_x0_5_msmt17.pt':
'https://drive.google.com/uc?id=1DHgmb6XV4fwG3n-CnCM0zdL9nMsZ9_RF',
'osnet_x0_25_msmt17.pt':
'https://drive.google.com/uc?id=1Kkx2zW89jq_NETu4u42CFZTMVD5Hwm6e',
'osnet_ibn_x1_0_msmt17.pt':
'https://drive.google.com/uc?id=1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ',
'osnet_ain_x1_0_msmt17.pt':
'https://drive.google.com/uc?id=1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal',
}
def show_downloadeable_models():
print('\nAvailable .pt ReID models for automatic download')
print(list(__trained_urls.keys()))
def get_model_url(model):
if model.name in __trained_urls:
return __trained_urls[model.name]
else:
None
def is_model_in_model_types(model):
if model.name in __model_types:
return True
else:
return False
def get_model_name(model):
for x in __model_types:
if x in model.name:
return x
return None
def download_url(url, dst):
"""Downloads file from a url to a destination.
Args:
url (str): url to download file.
dst (str): destination path.
"""
from six.moves import urllib
print('* url="{}"'.format(url))
print('* destination="{}"'.format(dst))
def _reporthook(count, block_size, total_size):
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024*duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write(
'\r...%d%%, %d MB, %d KB/s, %d seconds passed' %
(percent, progress_size / (1024*1024), speed, duration)
)
sys.stdout.flush()
urllib.request.urlretrieve(url, dst, _reporthook)
sys.stdout.write('\n')
def load_pretrained_weights(model, weight_path):
r"""Loads pretrianed weights to model.
Features::
- Incompatible layers (unmatched in name or size) will be ignored.
- Can automatically deal with keys containing "module.".
Args:
model (nn.Module): network model.
weight_path (str): path to pretrained weights.
Examples::
>>> from torchreid.utils import load_pretrained_weights
>>> weight_path = 'log/my_model/model-best.pth.tar'
>>> load_pretrained_weights(model, weight_path)
"""
checkpoint = torch.load(weight_path)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:] # discard module.
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
if len(matched_layers) == 0:
warnings.warn(
'The pretrained weights "{}" cannot be loaded, '
'please check the key names manually '
'(** ignored and continue **)'.format(weight_path)
)
else:
print(
'Successfully loaded pretrained weights from "{}"'.
format(weight_path)
)
if len(discarded_layers) > 0:
print(
'** The following layers are discarded '
'due to unmatched keys or layer size: {}'.
format(discarded_layers)
)