import torch class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): """Maintains moving averages of model parameters using an exponential decay. ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` `torch.optim.swa_utils.AveragedModel `_ is used to compute the EMA. """ def __init__(self, model, decay, device="cpu"): def ema_avg(avg_model_param, model_param, num_averaged): return decay * avg_model_param + (1 - decay) * model_param super().__init__(model, device, ema_avg)