File size: 1,245 Bytes
e34aada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch

def get_grad_norm(model, l=2):
    num_para = 0
    accu_grad = 0
    if isinstance(model, torch.nn.Module):
        params = model.parameters()
    else:
        params = model
    for p in params:
        if p.grad is None:
            continue
        num_para += p.numel()
        if l == 1:
            accu_grad += p.grad.abs(1).sum()
        elif l == 2:
            accu_grad += p.grad.pow(2).sum()
        else:
            raise ValueError("Now we only implement l1/l2 norm !")
    if l == 2:
        accu_grad = accu_grad ** 0.5
    if isinstance(accu_grad, float):
        return accu_grad
    return accu_grad.item()

class GradBuffer:
    def __init__(self):
        self.buffer = {}
    
    def add(self, model):
        for item in model.named_parameters():
            name, param = item
            if param.grad is None:
                continue
            self.buffer[name] = self.buffer.get(name, 0) + param.grad.data
    
    def apply(self, model):
        for item in model.named_parameters():
            name, param = item
            if param.grad is None:
                continue
            if name in self.buffer.keys():
                param.grad.data += self.buffer[name]
        self.buffer = {}