from copy import deepcopy import torch import json from collections import OrderedDict import math class ModelEma(torch.nn.Module): """EMA Model""" def __init__(self, model, decay=0.9997, tau=0, device=None): super(ModelEma, self).__init__() # make a copy of the model for accumulating moving average of weights self.module = deepcopy(model) self.module.eval() self.decay = decay self.tau = tau self.updates = 1 self.device = device # perform ema on different device from model if set if self.device is not None: self.module.to(device=device) def _get_decay(self): if self.tau == 0: decay = self.decay else: decay = self.decay * (1 - math.exp(-self.updates / self.tau)) return decay def _update(self, model, update_fn): with torch.no_grad(): for ema_v, model_v in zip( self.module.state_dict().values(), model.state_dict().values()): if self.device is not None: model_v = model_v.to(device=self.device) ema_v.copy_(update_fn(ema_v, model_v)) def update(self, model): decay = self._get_decay() self._update(model, update_fn=lambda e, m: decay * e + (1. - decay) * m) self.updates += 1 def set(self, model): self._update(model, update_fn=lambda e, m: m) class BestMetricSingle(): def __init__(self, init_res=0.0, better='large') -> None: self.init_res = init_res self.best_res = init_res self.best_ep = -1 self.better = better assert better in ['large', 'small'] def isbetter(self, new_res, old_res): if self.better == 'large': return new_res > old_res if self.better == 'small': return new_res < old_res def update(self, new_res, ep): if self.isbetter(new_res, self.best_res): self.best_res = new_res self.best_ep = ep return True return False def __str__(self) -> str: return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep) def __repr__(self) -> str: return self.__str__() def summary(self) -> dict: return { 'best_res': self.best_res, 'best_ep': self.best_ep, } class BestMetricHolder(): def __init__(self, init_res=0.0, better='large', use_ema=False) -> None: self.best_all = BestMetricSingle(init_res, better) self.use_ema = use_ema if use_ema: self.best_ema = BestMetricSingle(init_res, better) self.best_regular = BestMetricSingle(init_res, better) def update(self, new_res, epoch, is_ema=False): """ return if the results is the best. """ if not self.use_ema: return self.best_all.update(new_res, epoch) else: if is_ema: self.best_ema.update(new_res, epoch) return self.best_all.update(new_res, epoch) else: self.best_regular.update(new_res, epoch) return self.best_all.update(new_res, epoch) def summary(self): if not self.use_ema: return self.best_all.summary() res = {} res.update({f'all_{k}':v for k,v in self.best_all.summary().items()}) res.update({f'regular_{k}':v for k,v in self.best_regular.summary().items()}) res.update({f'ema_{k}':v for k,v in self.best_ema.summary().items()}) return res def __repr__(self) -> str: return json.dumps(self.summary(), indent=2) def __str__(self) -> str: return self.__repr__() def clean_state_dict(state_dict): new_state_dict = OrderedDict() for k, v in state_dict.items(): if k[:7] == 'module.': k = k[7:] # remove `module.` new_state_dict[k] = v return new_state_dict