Spaces:
Build error
Build error
import copy | |
import torch.nn as nn | |
class EMAHelper(object): | |
def __init__(self, mu=0.999): | |
self.mu = mu | |
self.shadow = {} | |
def register(self, module): | |
if isinstance(module, nn.DataParallel): | |
module = module.module | |
for name, param in module.named_parameters(): | |
if param.requires_grad: | |
self.shadow[name] = param.data.clone() | |
def update(self, module): | |
if isinstance(module, nn.DataParallel): | |
module = module.module | |
for name, param in module.named_parameters(): | |
if param.requires_grad: | |
self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data | |
def ema(self, module): | |
if isinstance(module, nn.DataParallel): | |
module = module.module | |
for name, param in module.named_parameters(): | |
if param.requires_grad: | |
param.data.copy_(self.shadow[name].data) | |
def ema_copy(self, module): | |
if isinstance(module, nn.DataParallel): | |
inner_module = module.module | |
module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device) | |
module_copy.load_state_dict(inner_module.state_dict()) | |
module_copy = nn.DataParallel(module_copy) | |
else: | |
module_copy = type(module)(module.config).to(module.config.device) | |
module_copy.load_state_dict(module.state_dict()) | |
# module_copy = copy.deepcopy(module) | |
self.ema(module_copy) | |
return module_copy | |
def state_dict(self): | |
return self.shadow | |
def load_state_dict(self, state_dict): | |
self.shadow = state_dict | |