Spaces:
Runtime error
Runtime error
File size: 4,487 Bytes
753e275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import numpy as np
import torch
from easydict import EasyDict
from .misc import BlackHole
def get_optimizer(cfg, model):
if cfg.type == 'adam':
return torch.optim.Adam(
model.parameters(),
lr=cfg.lr,
weight_decay=cfg.weight_decay,
betas=(cfg.beta1, cfg.beta2, )
)
else:
raise NotImplementedError('Optimizer not supported: %s' % cfg.type)
def get_scheduler(cfg, optimizer):
if cfg.type is None:
return BlackHole()
elif cfg.type == 'plateau':
return torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
factor=cfg.factor,
patience=cfg.patience,
min_lr=cfg.min_lr,
)
elif cfg.type == 'multistep':
return torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=cfg.milestones,
gamma=cfg.gamma,
)
elif cfg.type == 'exp':
return torch.optim.lr_scheduler.ExponentialLR(
optimizer,
gamma=cfg.gamma,
)
elif cfg.type is None:
return BlackHole()
else:
raise NotImplementedError('Scheduler not supported: %s' % cfg.type)
def get_warmup_sched(cfg, optimizer):
if cfg is None: return BlackHole()
lambdas = [lambda it : (it / cfg.max_iters) if it <= cfg.max_iters else 1 for _ in optimizer.param_groups]
warmup_sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lambdas)
return warmup_sched
def log_losses(out, it, tag, logger=BlackHole(), writer=BlackHole(), others={}):
logstr = '[%s] Iter %05d' % (tag, it)
logstr += ' | loss %.4f' % out['overall'].item()
for k, v in out.items():
if k == 'overall': continue
logstr += ' | loss(%s) %.4f' % (k, v.item())
for k, v in others.items():
logstr += ' | %s %2.4f' % (k, v)
logger.info(logstr)
for k, v in out.items():
if k == 'overall':
writer.add_scalar('%s/loss' % tag, v, it)
else:
writer.add_scalar('%s/loss_%s' % (tag, k), v, it)
for k, v in others.items():
writer.add_scalar('%s/%s' % (tag, k), v, it)
writer.flush()
class ValidationLossTape(object):
def __init__(self):
super().__init__()
self.accumulate = {}
self.others = {}
self.total = 0
def update(self, out, n, others={}):
self.total += n
for k, v in out.items():
if k not in self.accumulate:
self.accumulate[k] = v.clone().detach()
else:
self.accumulate[k] += v.clone().detach()
for k, v in others.items():
if k not in self.others:
self.others[k] = v.clone().detach()
else:
self.others[k] += v.clone().detach()
def log(self, it, logger=BlackHole(), writer=BlackHole(), tag='val'):
avg = EasyDict({k:v / self.total for k, v in self.accumulate.items()})
avg_others = EasyDict({k:v / self.total for k, v in self.others.items()})
log_losses(avg, it, tag, logger, writer, others=avg_others)
return avg['overall']
def recursive_to(obj, device):
if isinstance(obj, torch.Tensor):
if device == 'cpu':
return obj.cpu()
try:
return obj.cuda(device=device, non_blocking=True)
except RuntimeError:
return obj.to(device)
elif isinstance(obj, list):
return [recursive_to(o, device=device) for o in obj]
elif isinstance(obj, tuple):
return tuple(recursive_to(o, device=device) for o in obj)
elif isinstance(obj, dict):
return {k: recursive_to(v, device=device) for k, v in obj.items()}
else:
return obj
def reweight_loss_by_sequence_length(length, max_length, mode='sqrt'):
if mode == 'sqrt':
w = np.sqrt(length / max_length)
elif mode == 'linear':
w = length / max_length
elif mode is None:
w = 1.0
else:
raise ValueError('Unknown reweighting mode: %s' % mode)
return w
def sum_weighted_losses(losses, weights):
"""
Args:
losses: Dict of scalar tensors.
weights: Dict of weights.
"""
loss = 0
for k in losses.keys():
if weights is None:
loss = loss + losses[k]
else:
loss = loss + weights[k] * losses[k]
return loss
def count_parameters(model):
return sum(p.numel() for p in model.parameters())
|