File size: 1,245 Bytes
e34aada |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch
def get_grad_norm(model, l=2):
num_para = 0
accu_grad = 0
if isinstance(model, torch.nn.Module):
params = model.parameters()
else:
params = model
for p in params:
if p.grad is None:
continue
num_para += p.numel()
if l == 1:
accu_grad += p.grad.abs(1).sum()
elif l == 2:
accu_grad += p.grad.pow(2).sum()
else:
raise ValueError("Now we only implement l1/l2 norm !")
if l == 2:
accu_grad = accu_grad ** 0.5
if isinstance(accu_grad, float):
return accu_grad
return accu_grad.item()
class GradBuffer:
def __init__(self):
self.buffer = {}
def add(self, model):
for item in model.named_parameters():
name, param = item
if param.grad is None:
continue
self.buffer[name] = self.buffer.get(name, 0) + param.grad.data
def apply(self, model):
for item in model.named_parameters():
name, param = item
if param.grad is None:
continue
if name in self.buffer.keys():
param.grad.data += self.buffer[name]
self.buffer = {} |