#!/usr/bin/env python3 # -*- coding:utf-8 -*- # The code is based on # https://github.com/ultralytics/yolov5/blob/master/utils/torch_utils.py import math from copy import deepcopy import torch import torch.nn as nn class ModelEMA: """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models Keep a moving average of everything in the model state_dict (parameters and buffers). This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage A smoothed version of the weights is necessary for some training schemes to perform well. This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers. """ def __init__(self, model, decay=0.9999, updates=0): self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA self.updates = updates self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) for param in self.ema.parameters(): param.requires_grad_(False) def update(self, model): with torch.no_grad(): self.updates += 1 decay = self.decay(self.updates) state_dict = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict for k, item in self.ema.state_dict().items(): if item.dtype.is_floating_point: item *= decay item += (1 - decay) * state_dict[k].detach() def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): copy_attr(self.ema, model, include, exclude) def copy_attr(a, b, include=(), exclude=()): """Copy attributes from one instance and set them to another instance.""" for k, item in b.__dict__.items(): if (len(include) and k not in include) or k.startswith('_') or k in exclude: continue else: setattr(a, k, item) def is_parallel(model): # Return True if model's type is DP or DDP, else False. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) def de_parallel(model): # De-parallelize a model. Return single-GPU model if model's type is DP or DDP. return model.module if is_parallel(model) else model