File size: 4,269 Bytes
a722365 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
class LossScaler:
def __init__(self, scale=1):
self.cur_scale = scale
# `params` is a list / generator of torch.Variable
def has_overflow(self, params):
return False
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
return False
# `overflow` is boolean indicating whether we overflowed in gradient
def update_scale(self, overflow):
pass
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss):
scaled_loss = loss*self.loss_scale
scaled_loss.backward()
class DynamicLossScaler:
def __init__(self,
init_scale=2**32,
scale_factor=2.,
scale_window=1000):
self.cur_scale = init_scale
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = scale_factor
self.scale_window = scale_window
# `params` is a list / generator of torch.Variable
def has_overflow(self, params):
# return False
for p in params:
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
return True
return False
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
cpu_sum = float(x.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
# `overflow` is boolean indicating whether we overflowed in gradient
def update_scale(self, overflow):
if overflow:
#self.cur_scale /= self.scale_factor
self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
self.last_overflow_iter = self.cur_iter
else:
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
self.cur_scale *= self.scale_factor
# self.cur_scale = 1
self.cur_iter += 1
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss):
scaled_loss = loss*self.loss_scale
scaled_loss.backward()
##############################################################
# Example usage below here -- assuming it's in a separate file
##############################################################
if __name__ == "__main__":
import torch
from torch.autograd import Variable
from dynamic_loss_scaler import DynamicLossScaler
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in), requires_grad=False)
y = Variable(torch.randn(N, D_out), requires_grad=False)
w1 = Variable(torch.randn(D_in, H), requires_grad=True)
w2 = Variable(torch.randn(H, D_out), requires_grad=True)
parameters = [w1, w2]
learning_rate = 1e-6
optimizer = torch.optim.SGD(parameters, lr=learning_rate)
loss_scaler = DynamicLossScaler()
for t in range(500):
y_pred = x.mm(w1).clamp(min=0).mm(w2)
loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
# Run backprop
optimizer.zero_grad()
loss.backward()
# Check for overflow
has_overflow = DynamicLossScaler.has_overflow(parameters)
# If no overflow, unscale grad and update as usual
if not has_overflow:
for param in parameters:
param.grad.data.mul_(1. / loss_scaler.loss_scale)
optimizer.step()
# Otherwise, don't do anything -- ie, skip iteration
else:
print('OVERFLOW!')
# Update loss scale for next iteration
loss_scaler.update_scale(has_overflow)
|