Mehdi Cherti
- memory efficient EMA
3dcdf92
raw
history blame
4.2 kB
# ---------------------------------------------------------------
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------
'''
Codes adapted from https://github.com/NVlabs/LSGM/blob/main/util/ema.py
'''
import warnings
import torch
from torch.optim import Optimizer
class EMA(Optimizer):
def __init__(self, opt, ema_decay, memory_efficient=False):
self.ema_decay = ema_decay
self.apply_ema = self.ema_decay > 0.
self.optimizer = opt
self.state = opt.state
self.param_groups = opt.param_groups
self.defaults = {}
self.memory_efficient = memory_efficient
def step(self, *args, **kwargs):
# for group in self.optimizer.param_groups:
# group.setdefault('amsgrad', False)
# group.setdefault('maximize', False)
# group.setdefault('foreach', None)
# group.setdefault('capturable', False)
# group.setdefault('differentiable', False)
# group.setdefault('fused', False)
retval = self.optimizer.step(*args, **kwargs)
# stop here if we are not applying EMA
if not self.apply_ema:
return retval
ema, params = {}, {}
for group in self.optimizer.param_groups:
for i, p in enumerate(group['params']):
if p.grad is None:
continue
state = self.optimizer.state[p]
# State initialization
if 'ema' not in state:
state['ema'] = p.data.clone()
if p.shape not in params:
params[p.shape] = {'idx': 0, 'data': []}
ema[p.shape] = []
params[p.shape]['data'].append(p.data)
ema[p.shape].append(state['ema'])
# def stack(d, dim=0):
# return torch.stack([di.cpu() for di in d], dim=dim).cuda()
for i in params:
if self.memory_efficient:
for j in range(len(params[i]['data'])):
ema[i][j].mul_(self.ema_decay).add_(params[i]['data'][j], alpha=1. - self.ema_decay)
ema[i] = torch.stack(ema[i], dim=0)
else:
params[i]['data'] = torch.stack(params[i]['data'], dim=0)
ema[i] = torch.stack(ema[i], dim=0)
ema[i].mul_(self.ema_decay).add_(params[i]['data'], alpha=1. - self.ema_decay)
for p in group['params']:
if p.grad is None:
continue
idx = params[p.shape]['idx']
self.optimizer.state[p]['ema'] = ema[p.shape][idx, :]
params[p.shape]['idx'] += 1
return retval
def load_state_dict(self, state_dict):
super(EMA, self).load_state_dict(state_dict)
# load_state_dict loads the data to self.state and self.param_groups. We need to pass this data to
# the underlying optimizer too.
self.optimizer.state = self.state
self.optimizer.param_groups = self.param_groups
def swap_parameters_with_ema(self, store_params_in_ema):
""" This function swaps parameters with their ema values. It records original parameters in the ema
parameters, if store_params_in_ema is true."""
# stop here if we are not applying EMA
if not self.apply_ema:
warnings.warn('swap_parameters_with_ema was called when there is no EMA weights.')
return
for group in self.optimizer.param_groups:
for i, p in enumerate(group['params']):
if not p.requires_grad:
continue
ema = self.optimizer.state[p]['ema']
if store_params_in_ema:
tmp = p.data.detach()
p.data = ema.detach()
self.optimizer.state[p]['ema'] = tmp
else:
p.data = ema.detach()