from copy import deepcopy from collections import OrderedDict import torch class ModelEma: def __init__(self, model, decay=0.9999, device=""): self.ema = deepcopy(model) self.ema.eval() self.decay = decay self.device = device if device: self.ema.to(device=device) self.ema_is_dp = hasattr(self.ema, "module") for p in self.ema.parameters(): p.requires_grad_(False) def load_checkpoint(self, checkpoint): if isinstance(checkpoint, str): checkpoint = torch.load(checkpoint) assert isinstance(checkpoint, dict) if "model_ema" in checkpoint: new_state_dict = OrderedDict() for k, v in checkpoint["model_ema"].items(): if self.ema_is_dp: name = k if k.startswith("module") else "module." + k else: name = k.replace("module.", "") if k.startswith("module") else k new_state_dict[name] = v self.ema.load_state_dict(new_state_dict) def state_dict(self): return self.ema.state_dict() def update(self, model): pre_module = hasattr(model, "module") and not self.ema_is_dp with torch.no_grad(): curr_msd = model.state_dict() for k, ema_v in self.ema.state_dict().items(): k = "module." + k if pre_module else k model_v = curr_msd[k].detach() if self.device: model_v = model_v.to(device=self.device) ema_v.copy_(ema_v * self.decay + (1.0 - self.decay) * model_v)