# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # ModelEMA implementation is taken from # https://github.com/facebookresearch/demucs from collections import defaultdict import typing as tp import torch import torch.nn as nn def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "") -> set: names: set = set() for (name, sub_module) in module.named_modules(): if name == '': buffer_names = module._non_persistent_buffers_set buffer_names = {f"{root}.{buff_name}" if len(root) > 0 else buff_name for buff_name in buffer_names} names.update(buffer_names) else: sub_name = f"{root}.{name}" if len(root) > 0 else name sub_buffer_names = _get_all_non_persistent_buffers_set(sub_module, sub_name) names.update(sub_buffer_names) return names def _get_named_tensors(module: nn.Module): non_persistent_buffers_set = _get_all_non_persistent_buffers_set(module) named_buffers = [(name, buffer) for (name, buffer) in module.named_buffers() if name not in non_persistent_buffers_set] named_parameters = list(module.named_parameters()) return named_parameters + named_buffers class ModuleDictEMA: """Exponential Moving Average over a nn.ModuleDict. You can switch to the EMA weights temporarily. """ def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999, unbias: bool = True, device: tp.Union[torch.device, str] = 'cpu'): self.decay = decay self.module_dict = module_dict self.state: dict = defaultdict(dict) self.count = 0 self.device = device self.unbias = unbias self._init() def _init(self): for module_name, module in self.module_dict.items(): for key, val in _get_named_tensors(module): if not val.is_floating_point(): continue device = self.device or val.device if key not in self.state[module_name]: self.state[module_name][key] = val.detach().to(device, copy=True) def step(self): if self.unbias: self.count = self.count * self.decay + 1 w = 1 / self.count else: w = 1 - self.decay for module_name, module in self.module_dict.items(): for key, val in _get_named_tensors(module): if not val.is_floating_point(): continue device = self.device or val.device self.state[module_name][key].mul_(1 - w) self.state[module_name][key].add_(val.detach().to(device), alpha=w) def state_dict(self): return {'state': self.state, 'count': self.count} def load_state_dict(self, state): self.count = state['count'] for module_name, module in state['state'].items(): for key, val in module.items(): self.state[module_name][key].copy_(val)