zach
initial commit based on github repo
3ef1661
raw
history blame contribute delete
No virus
14 kB
import os
import torch
import torch.nn as nn
from mono.utils.comm import main_process
import copy
import inspect
import logging
import glob
class LrUpdater():
"""Refer to LR Scheduler in MMCV.
Args:
@by_epoch (bool): LR changes epoch by epoch
@warmup (string): Type of warmup used. It can be None(use no warmup),
'constant', 'linear' or 'exp'
@warmup_iters (int): The number of iterations or epochs that warmup
lasts. Note when by_epoch == True, warmup_iters means the number
of epochs that warmup lasts, otherwise means the number of
iteration that warmup lasts
@warmup_ratio (float): LR used at the beginning of warmup equals to
warmup_ratio * initial_lr
@runner (dict): Configs for running. Run by epoches or iters.
"""
def __init__(self,
by_epoch: bool=True,
warmup: str=None,
warmup_iters: int=0,
warmup_ratio: float=0.1,
runner: dict={}):
# validate the "warmup" argument
if warmup is not None:
if warmup not in ['constant', 'linear', 'exp']:
raise ValueError(
f'"{warmup}" is not a supported type for warming up, valid'
' types are "constant" and "linear"')
if warmup is not None:
assert warmup_iters > 0, \
'"warmup_iters" must be a positive integer'
assert 0 < warmup_ratio <= 1.0, \
'"warmup_ratio" must be in range (0,1]'
if runner is None:
raise RuntimeError('runner should be set.')
self.by_epoch = by_epoch
self.warmup = warmup
self.warmup_iters = warmup_iters
self.warmup_ratio = warmup_ratio
self.runner = runner
self.max_iters = None
self.max_epoches = None
if 'IterBasedRunner' in self.runner.type:
self.max_iters = self.runner.max_iters
assert self.by_epoch==False
self.warmup_by_epoch = False
elif 'EpochBasedRunner' in self.runner.type:
self.max_epoches = self.runner.max_epoches
assert self.by_epoch==True
self.warmup_by_epoch = True
else:
raise ValueError(f'{self.runner.type} is not a supported type for running.')
if self.warmup_by_epoch:
self.warmup_epochs = self.warmup_iters
self.warmup_iters = None
else:
self.warmup_epochs = None
self.base_lr = [] # initial lr for all param groups
self.regular_lr = [] # expected lr if no warming up is performed
self._step_count = 0
def _set_lr(self, optimizer, lr_groups):
if isinstance(optimizer, dict):
for k, optim in optimizer.items():
for param_group, lr in zip(optim.param_groups, lr_groups[k]):
param_group['lr'] = lr
else:
for param_group, lr in zip(optimizer.param_groups,
lr_groups):
param_group['lr'] = lr
def get_lr(self, _iter, max_iter, base_lr):
raise NotImplementedError
def get_regular_lr(self, _iter, optimizer):
max_iters = self.max_iters if not self.by_epoch else self.max_epoches
if isinstance(optimizer, dict):
lr_groups = {}
for k in optimizer.keys():
_lr_group = [
self.get_lr(_iter, max_iters, _base_lr)
for _base_lr in self.base_lr[k]
]
lr_groups.update({k: _lr_group})
return lr_groups
else:
return [self.get_lr(_iter, max_iters, _base_lr) for _base_lr in self.base_lr]
def get_warmup_lr(self, cur_iters):
def _get_warmup_lr(cur_iters, regular_lr):
if self.warmup == 'constant':
warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr]
elif self.warmup == 'linear':
k = (1 - cur_iters / self.warmup_iters) * (1 -
self.warmup_ratio)
warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_lr = [_lr * k for _lr in regular_lr]
return warmup_lr
if isinstance(self.regular_lr, dict):
lr_groups = {}
for key, regular_lr in self.regular_lr.items():
lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr)
return lr_groups
else:
return _get_warmup_lr(cur_iters, self.regular_lr)
def before_run(self, optimizer):
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
# it will be set according to the optimizer params
if isinstance(optimizer, dict):
self.base_lr = {}
for k, optim in optimizer.items():
for group in optim.param_groups:
group.setdefault('initial_lr', group['lr'])
_base_lr = [
group['initial_lr'] for group in optim.param_groups
]
self.base_lr.update({k: _base_lr})
else:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lr = [
group['initial_lr'] for group in optimizer.param_groups
]
def after_train_epoch(self, optimizer):
self._step_count += 1
curr_epoch = self._step_count
self.regular_lr = self.get_regular_lr(curr_epoch, optimizer)
if self.warmup is None or curr_epoch > self.warmup_epoches:
self._set_lr(optimizer, self.regular_lr)
else:
#self.warmup_iters = int(self.warmup_epochs * epoch_len)
warmup_lr = self.get_warmup_lr(curr_epoch)
self._set_lr(optimizer, warmup_lr)
def after_train_iter(self, optimizer):
self._step_count += 1
cur_iter = self._step_count
self.regular_lr = self.get_regular_lr(cur_iter, optimizer)
if self.warmup is None or cur_iter >= self.warmup_iters:
self._set_lr(optimizer, self.regular_lr)
else:
warmup_lr = self.get_warmup_lr(cur_iter)
self._set_lr(optimizer, warmup_lr)
def get_curr_lr(self, cur_iter):
if self.warmup is None or cur_iter >= self.warmup_iters:
return self.regular_lr
else:
return self.get_warmup_lr(cur_iter)
def state_dict(self):
"""
Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
@state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
class PolyLrUpdater(LrUpdater):
def __init__(self, power=1., min_lr=0., **kwargs):
self.power = power
self.min_lr = min_lr
super(PolyLrUpdater, self).__init__(**kwargs)
def get_lr(self, _iter, max_iters, base_lr):
progress = _iter
max_progress = max_iters
coeff = (1 - progress / max_progress)**self.power
return (base_lr - self.min_lr) * coeff + self.min_lr
def build_lr_schedule_with_cfg(cfg):
# build learning rate schedule with config.
lr_config = copy.deepcopy(cfg.lr_config)
policy = lr_config.pop('policy')
if cfg.lr_config.policy == 'poly':
schedule = PolyLrUpdater(runner=cfg.runner, **lr_config)
else:
raise RuntimeError(f'{cfg.lr_config.policy} \
is not supported in this framework.')
return schedule
#def step_learning_rate(base_lr, epoch, step_epoch, multiplier=0.1):
# """Sets the learning rate to the base LR decayed by 10 every step epochs"""
# lr = base_lr * (multiplier ** (epoch // step_epoch))
# return lr
def register_torch_optimizers():
torch_optimizers = {}
for module_name in dir(torch.optim):
if module_name.startswith('__'):
continue
_optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim,
torch.optim.Optimizer):
torch_optimizers[module_name] = _optim
return torch_optimizers
TORCH_OPTIMIZER = register_torch_optimizers()
def build_optimizer_with_cfg(cfg, model):
# encoder_parameters = []
# decoder_parameters = []
# nongrad_parameters = []
# for key, value in dict(model.named_parameters()).items():
# if value.requires_grad:
# if 'encoder' in key:
# encoder_parameters.append(value)
# else:
# decoder_parameters.append(value)
# else:
# nongrad_parameters.append(value)
#params = [{"params": filter(lambda p: p.requires_grad, model.parameters())}]
optim_cfg = copy.deepcopy(cfg.optimizer)
optim_type = optim_cfg.pop('type', None)
if optim_type is None:
raise RuntimeError(f'{optim_type} is not set')
if optim_type not in TORCH_OPTIMIZER:
raise RuntimeError(f'{optim_type} is not supported in torch {torch.__version__}')
if 'others' not in optim_cfg:
optim_cfg['others'] = optim_cfg['decoder']
def match(key1, key_list, strict_match=False):
if not strict_match:
for k in key_list:
if k in key1:
return k
else:
for k in key_list:
if k == key1.split('.')[1]:
return k
return None
optim_obj = TORCH_OPTIMIZER[optim_type]
matching_type = optim_cfg.pop('strict_match', False)
module_names = optim_cfg.keys()
model_parameters = {i: [] for i in module_names}
model_parameters['others'] = []
nongrad_parameters = []
for key, value in dict(model.named_parameters()).items():
if value.requires_grad:
match_key = match(key, module_names, matching_type)
# if optim_cfg[match_key]['lr'] == 0:
# value.requires_grad=False
# continue
if match_key is None:
model_parameters['others'].append(value)
else:
model_parameters[match_key].append(value)
else:
nongrad_parameters.append(value)
optims = [{'params':model_parameters[k], **optim_cfg[k]} for k in optim_cfg.keys()]
optimizer = optim_obj(optims)
# optim_args_encoder = optim_cfg.optimizer.encoder
# optim_args_decoder = optim_cfg.optimizer.decoder
# optimizer = optim_obj(
# [{'params':encoder_parameters, **optim_args_encoder},
# {'params':decoder_parameters, **optim_args_decoder},
# ])
return optimizer
def load_ckpt(load_path, model, optimizer=None, scheduler=None, strict_match=True, loss_scaler=None):
"""
Load the check point for resuming training or finetuning.
"""
logger = logging.getLogger()
if os.path.isfile(load_path):
if main_process():
logger.info(f"Loading weight '{load_path}'")
checkpoint = torch.load(load_path, map_location="cpu")
ckpt_state_dict = checkpoint['model_state_dict']
model.module.load_state_dict(ckpt_state_dict, strict=strict_match)
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer'])
if scheduler is not None:
scheduler.load_state_dict(checkpoint['scheduler'])
if loss_scaler is not None and 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print('Loss scaler loaded', loss_scaler)
del ckpt_state_dict
del checkpoint
if main_process():
logger.info(f"Successfully loaded weight: '{load_path}'")
if scheduler is not None and optimizer is not None:
logger.info(f"Resume training from: '{load_path}'")
else:
if main_process():
raise RuntimeError(f"No weight found at '{load_path}'")
return model, optimizer, scheduler, loss_scaler
def save_ckpt(cfg, model, optimizer, scheduler, curr_iter=0, curr_epoch=None, loss_scaler=None):
"""
Save the model, optimizer, lr scheduler.
"""
logger = logging.getLogger()
if 'IterBasedRunner' in cfg.runner.type:
max_iters = cfg.runner.max_iters
elif 'EpochBasedRunner' in cfg.runner.type:
max_iters = cfg.runner.max_epoches
else:
raise TypeError(f'{cfg.runner.type} is not supported')
ckpt = dict(model_state_dict=model.module.state_dict(),
optimizer=optimizer.state_dict(),
max_iter=cfg.runner.max_iters if 'max_iters' in cfg.runner \
else cfg.runner.max_epoches,
scheduler=scheduler.state_dict(),
# current_iter=curr_iter,
# current_epoch=curr_epoch,
)
if loss_scaler is not None:
# amp state_dict
ckpt.update(dict(scaler=loss_scaler.state_dict()))
ckpt_dir = os.path.join(cfg.work_dir, 'ckpt')
os.makedirs(ckpt_dir, exist_ok=True)
save_name = os.path.join(ckpt_dir, 'step%08d.pth' % curr_iter)
saved_ckpts = glob.glob(ckpt_dir + '/step*.pth')
torch.save(ckpt, save_name)
# keep the last 8 ckpts
if len(saved_ckpts) > 8:
saved_ckpts.sort()
os.remove(saved_ckpts.pop(0))
logger.info(f'Save model: {save_name}')
if __name__ == '__main__':
print(TORCH_OPTIMIZER)