oguzakif's picture
init repo
d4b77ac
raw
history blame
2.41 kB
# --------------------------------------------------------
# SiamMask
# Licensed under The MIT License
# Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
# --------------------------------------------------------
import torch.nn as nn
import logging
logger = logging.getLogger('global')
class Features(nn.Module):
def __init__(self):
super(Features, self).__init__()
self.feature_size = -1
def forward(self, x):
raise NotImplementedError
def param_groups(self, start_lr, feature_mult=1):
params = filter(lambda x:x.requires_grad, self.parameters())
params = [{'params': params, 'lr': start_lr * feature_mult}]
return params
def load_model(self, f='pretrain.model'):
with open(f) as f:
pretrained_dict = torch.load(f)
model_dict = self.state_dict()
print(pretrained_dict.keys())
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
print(pretrained_dict.keys())
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)
class MultiStageFeature(Features):
def __init__(self):
super(MultiStageFeature, self).__init__()
self.layers = []
self.train_num = -1
self.change_point = []
self.train_nums = []
def unfix(self, ratio=0.0):
if self.train_num == -1:
self.train_num = 0
self.unlock()
self.eval()
for p, t in reversed(list(zip(self.change_point, self.train_nums))):
if ratio >= p:
if self.train_num != t:
self.train_num = t
self.unlock()
return True
break
return False
def train_layers(self):
return self.layers[:self.train_num]
def unlock(self):
for p in self.parameters():
p.requires_grad = False
logger.info('Current training {} layers:\n\t'.format(self.train_num, self.train_layers()))
for m in self.train_layers():
for p in m.parameters():
p.requires_grad = True
def train(self, mode):
self.training = mode
if mode == False:
super(MultiStageFeature,self).train(False)
else:
for m in self.train_layers():
m.train(True)
return self