#!/usr/bin/python # -*- encoding: utf-8 -*- import torch import logging logger = logging.getLogger() class Optimizer(object): def __init__(self, model, lr0, momentum, wd, warmup_steps, warmup_start_lr, max_iter, power, *args, **kwargs): self.warmup_steps = warmup_steps self.warmup_start_lr = warmup_start_lr self.lr0 = lr0 self.lr = self.lr0 self.max_iter = float(max_iter) self.power = power self.it = 0 wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params() param_list = [ {'params': wd_params}, {'params': nowd_params, 'weight_decay': 0}, {'params': lr_mul_wd_params, 'lr_mul': True}, {'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr_mul': True}] self.optim = torch.optim.SGD( param_list, lr = lr0, momentum = momentum, weight_decay = wd) self.warmup_factor = (self.lr0/self.warmup_start_lr)**(1./self.warmup_steps) def get_lr(self): if self.it <= self.warmup_steps: lr = self.warmup_start_lr*(self.warmup_factor**self.it) else: factor = (1-(self.it-self.warmup_steps)/(self.max_iter-self.warmup_steps))**self.power lr = self.lr0 * factor return lr def step(self): self.lr = self.get_lr() for pg in self.optim.param_groups: if pg.get('lr_mul', False): pg['lr'] = self.lr * 10 else: pg['lr'] = self.lr if self.optim.defaults.get('lr_mul', False): self.optim.defaults['lr'] = self.lr * 10 else: self.optim.defaults['lr'] = self.lr self.it += 1 self.optim.step() if self.it == self.warmup_steps+2: logger.info('==> warmup done, start to implement poly lr strategy') def zero_grad(self): self.optim.zero_grad()