Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): | |
"""Maintains moving averages of model parameters using an exponential decay. | |
``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` | |
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_ | |
is used to compute the EMA. | |
""" | |
def __init__(self, model, decay, device="cpu"): | |
def ema_avg(avg_model_param, model_param, num_averaged): | |
return decay * avg_model_param + (1 - decay) * model_param | |
super().__init__(model, device, ema_avg) |