import numpy as np import torch from easydict import EasyDict from .misc import BlackHole def get_optimizer(cfg, model): if cfg.type == 'adam': return torch.optim.Adam( model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay, betas=(cfg.beta1, cfg.beta2, ) ) else: raise NotImplementedError('Optimizer not supported: %s' % cfg.type) def get_scheduler(cfg, optimizer): if cfg.type is None: return BlackHole() elif cfg.type == 'plateau': return torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=cfg.factor, patience=cfg.patience, min_lr=cfg.min_lr, ) elif cfg.type == 'multistep': return torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=cfg.milestones, gamma=cfg.gamma, ) elif cfg.type == 'exp': return torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=cfg.gamma, ) elif cfg.type is None: return BlackHole() else: raise NotImplementedError('Scheduler not supported: %s' % cfg.type) def get_warmup_sched(cfg, optimizer): if cfg is None: return BlackHole() lambdas = [lambda it : (it / cfg.max_iters) if it <= cfg.max_iters else 1 for _ in optimizer.param_groups] warmup_sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lambdas) return warmup_sched def log_losses(out, it, tag, logger=BlackHole(), writer=BlackHole(), others={}): logstr = '[%s] Iter %05d' % (tag, it) logstr += ' | loss %.4f' % out['overall'].item() for k, v in out.items(): if k == 'overall': continue logstr += ' | loss(%s) %.4f' % (k, v.item()) for k, v in others.items(): logstr += ' | %s %2.4f' % (k, v) logger.info(logstr) for k, v in out.items(): if k == 'overall': writer.add_scalar('%s/loss' % tag, v, it) else: writer.add_scalar('%s/loss_%s' % (tag, k), v, it) for k, v in others.items(): writer.add_scalar('%s/%s' % (tag, k), v, it) writer.flush() class ValidationLossTape(object): def __init__(self): super().__init__() self.accumulate = {} self.others = {} self.total = 0 def update(self, out, n, others={}): self.total += n for k, v in out.items(): if k not in self.accumulate: self.accumulate[k] = v.clone().detach() else: self.accumulate[k] += v.clone().detach() for k, v in others.items(): if k not in self.others: self.others[k] = v.clone().detach() else: self.others[k] += v.clone().detach() def log(self, it, logger=BlackHole(), writer=BlackHole(), tag='val'): avg = EasyDict({k:v / self.total for k, v in self.accumulate.items()}) avg_others = EasyDict({k:v / self.total for k, v in self.others.items()}) log_losses(avg, it, tag, logger, writer, others=avg_others) return avg['overall'] def recursive_to(obj, device): if isinstance(obj, torch.Tensor): if device == 'cpu': return obj.cpu() try: return obj.cuda(device=device, non_blocking=True) except RuntimeError: return obj.to(device) elif isinstance(obj, list): return [recursive_to(o, device=device) for o in obj] elif isinstance(obj, tuple): return tuple(recursive_to(o, device=device) for o in obj) elif isinstance(obj, dict): return {k: recursive_to(v, device=device) for k, v in obj.items()} else: return obj def reweight_loss_by_sequence_length(length, max_length, mode='sqrt'): if mode == 'sqrt': w = np.sqrt(length / max_length) elif mode == 'linear': w = length / max_length elif mode is None: w = 1.0 else: raise ValueError('Unknown reweighting mode: %s' % mode) return w def sum_weighted_losses(losses, weights): """ Args: losses: Dict of scalar tensors. weights: Dict of weights. """ loss = 0 for k in losses.keys(): if weights is None: loss = loss + losses[k] else: loss = loss + weights[k] * losses[k] return loss def count_parameters(model): return sum(p.numel() for p in model.parameters())