# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from collections import defaultdict import logging import typing as tp import flashy import torch from ..optim import ModuleDictEMA from .utils import copy_state logger = logging.getLogger(__name__) class BestStateDictManager(flashy.state.StateDictSource): """BestStateDictManager maintains a copy of best state_dict() for registered sources. BestStateDictManager has two main attributes: states (dict): State dict of the registered StateDictSource. param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources. When registering new sources, the BestStateDictManager will ensure two conflicting sources between ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about what to consider for best state. Args: device (torch.device or str): Device on which we keep the copy. dtype (torch.dtype): Data type for the state parameters. """ def __init__(self, device: tp.Union[torch.device, str] = 'cpu', dtype: tp.Optional[torch.dtype] = None): self.device = device self.states: dict = {} self.param_ids: dict = defaultdict(dict) self.dtype = dtype def _get_parameter_ids(self, state_dict): return {id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)} def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict): for registered_name, registered_param_ids in self.param_ids.items(): if registered_name != name: overlap = set.intersection(registered_param_ids.keys(), param_ids.keys()) assert len(overlap) == 0, f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters" f" in {name} and already registered {registered_name}: {' '.join(overlap)}" def update(self, name: str, source: flashy.state.StateDictSource): if name not in self.states: raise ValueError(f"{name} missing from registered states.") self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype) def register(self, name: str, source: flashy.state.StateDictSource): if name in self.states: raise ValueError(f"{name} already present in states.") # Registering parameter ids for EMA and non-EMA states allows us to check that # there is no overlap that would create ambiguity about how to handle the best state param_ids = self._get_parameter_ids(source.state_dict()) if isinstance(source, ModuleDictEMA): logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params") self._validate_no_parameter_ids_overlap(name, param_ids) self.param_ids[name] = param_ids else: logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params") self._validate_no_parameter_ids_overlap('base', param_ids) self.param_ids['base'].update(param_ids) # Register state self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype) def state_dict(self) -> flashy.state.StateDict: return self.states def load_state_dict(self, state: flashy.state.StateDict): for name, sub_state in state.items(): for k, v in sub_state.items(): self.states[name][k].copy_(v)