File size: 1,689 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from copy import deepcopy
from collections import OrderedDict
import torch


class ModelEma:
    def __init__(self, model, decay=0.9999, device=""):
        self.ema = deepcopy(model)
        self.ema.eval()
        self.decay = decay
        self.device = device
        if device:
            self.ema.to(device=device)
        self.ema_is_dp = hasattr(self.ema, "module")
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def load_checkpoint(self, checkpoint):
        if isinstance(checkpoint, str):
            checkpoint = torch.load(checkpoint)

        assert isinstance(checkpoint, dict)
        if "model_ema" in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint["model_ema"].items():
                if self.ema_is_dp:
                    name = k if k.startswith("module") else "module." + k
                else:
                    name = k.replace("module.", "") if k.startswith("module") else k
                new_state_dict[name] = v
            self.ema.load_state_dict(new_state_dict)

    def state_dict(self):
        return self.ema.state_dict()

    def update(self, model):
        pre_module = hasattr(model, "module") and not self.ema_is_dp
        with torch.no_grad():
            curr_msd = model.state_dict()
            for k, ema_v in self.ema.state_dict().items():
                k = "module." + k if pre_module else k
                model_v = curr_msd[k].detach()
                if self.device:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(ema_v * self.decay + (1.0 - self.decay) * model_v)