Spaces:
Sleeping
Sleeping
# -------------------------------------------------------- | |
# SiamMask | |
# Licensed under The MIT License | |
# Written by Qiang Wang (wangqiang2015 at ia.ac.cn) | |
# -------------------------------------------------------- | |
import torch.nn as nn | |
import logging | |
logger = logging.getLogger('global') | |
class Features(nn.Module): | |
def __init__(self): | |
super(Features, self).__init__() | |
self.feature_size = -1 | |
def forward(self, x): | |
raise NotImplementedError | |
def param_groups(self, start_lr, feature_mult=1): | |
params = filter(lambda x:x.requires_grad, self.parameters()) | |
params = [{'params': params, 'lr': start_lr * feature_mult}] | |
return params | |
def load_model(self, f='pretrain.model'): | |
with open(f) as f: | |
pretrained_dict = torch.load(f) | |
model_dict = self.state_dict() | |
print(pretrained_dict.keys()) | |
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} | |
print(pretrained_dict.keys()) | |
model_dict.update(pretrained_dict) | |
self.load_state_dict(model_dict) | |
class MultiStageFeature(Features): | |
def __init__(self): | |
super(MultiStageFeature, self).__init__() | |
self.layers = [] | |
self.train_num = -1 | |
self.change_point = [] | |
self.train_nums = [] | |
def unfix(self, ratio=0.0): | |
if self.train_num == -1: | |
self.train_num = 0 | |
self.unlock() | |
self.eval() | |
for p, t in reversed(list(zip(self.change_point, self.train_nums))): | |
if ratio >= p: | |
if self.train_num != t: | |
self.train_num = t | |
self.unlock() | |
return True | |
break | |
return False | |
def train_layers(self): | |
return self.layers[:self.train_num] | |
def unlock(self): | |
for p in self.parameters(): | |
p.requires_grad = False | |
logger.info('Current training {} layers:\n\t'.format(self.train_num, self.train_layers())) | |
for m in self.train_layers(): | |
for p in m.parameters(): | |
p.requires_grad = True | |
def train(self, mode): | |
self.training = mode | |
if mode == False: | |
super(MultiStageFeature,self).train(False) | |
else: | |
for m in self.train_layers(): | |
m.train(True) | |
return self | |