Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): | |
| """Maintains moving averages of model parameters using an exponential decay. | |
| ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` | |
| `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_ | |
| is used to compute the EMA. | |
| """ | |
| def __init__(self, model, decay, device="cpu"): | |
| def ema_avg(avg_model_param, model_param, num_averaged): | |
| return decay * avg_model_param + (1 - decay) * model_param | |
| super().__init__(model, device, ema_avg) |