| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import copy |
| from functools import partial |
| from typing import Any, Dict, Optional, Union |
|
|
| import hydra |
| import torch |
| import torch.optim as optim |
| from omegaconf import DictConfig, OmegaConf |
| from torch.optim import adadelta, adagrad, adamax, rmsprop, rprop |
| from torch.optim.optimizer import Optimizer |
|
|
| from nemo.core.config import OptimizerParams, get_optimizer_config, register_optimizer_params |
| from nemo.core.optim.adafactor import Adafactor |
| from nemo.core.optim.adan import Adan |
| from nemo.core.optim.novograd import Novograd |
| from nemo.utils import logging |
| from nemo.utils.model_utils import maybe_update_config_version |
|
|
| AVAILABLE_OPTIMIZERS = { |
| 'sgd': optim.SGD, |
| 'adam': optim.Adam, |
| 'adamw': optim.AdamW, |
| 'adadelta': adadelta.Adadelta, |
| 'adamax': adamax.Adamax, |
| 'adagrad': adagrad.Adagrad, |
| 'rmsprop': rmsprop.RMSprop, |
| 'rprop': rprop.Rprop, |
| 'novograd': Novograd, |
| 'adafactor': Adafactor, |
| 'adan': Adan, |
| } |
|
|
| try: |
| from apex.optimizers import FusedAdam, FusedLAMB |
|
|
| HAVE_APEX = True |
|
|
| AVAILABLE_OPTIMIZERS['lamb'] = FusedLAMB |
| AVAILABLE_OPTIMIZERS['fused_adam'] = FusedAdam |
| except ModuleNotFoundError: |
| HAVE_APEX = False |
|
|
| HAVE_APEX_DISTRIBUTED_ADAM = False |
| if HAVE_APEX: |
| try: |
| |
| from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam |
|
|
| HAVE_APEX_DISTRIBUTED_ADAM = True |
|
|
| AVAILABLE_OPTIMIZERS['distributed_fused_adam'] = MegatronDistributedFusedAdam |
| except (ImportError, ModuleNotFoundError): |
| HAVE_APEX_DISTRIBUTED_ADAM = False |
|
|
| try: |
| |
| from nemo.core.optim.megatron_fused_adam import MegatronFusedAdam |
|
|
| AVAILABLE_OPTIMIZERS['megatron_fused_adam'] = MegatronFusedAdam |
| except (ImportError, ModuleNotFoundError): |
| pass |
|
|
| __all__ = ['get_optimizer', 'register_optimizer', 'parse_optimizer_args'] |
|
|
|
|
| def parse_optimizer_args( |
| optimizer_name: str, optimizer_kwargs: Union[DictConfig, Dict[str, Any]] |
| ) -> Union[Dict[str, Any], DictConfig]: |
| """ |
| Parses a list of strings, of the format "key=value" or "key2=val1,val2,..." |
| into a dictionary of type {key=value, key2=[val1, val2], ...} |
| |
| This dictionary is then used to instantiate the chosen Optimizer. |
| |
| Args: |
| optimizer_name: string name of the optimizer, used for auto resolution of params |
| optimizer_kwargs: Either a list of strings in a specified format, |
| or a dictionary. If a dictionary is provided, it is assumed the dictionary |
| is the final parsed value, and simply returned. |
| If a list of strings is provided, each item in the list is parsed into a |
| new dictionary. |
| |
| Returns: |
| A dictionary |
| """ |
| kwargs = {} |
|
|
| if optimizer_kwargs is None: |
| return kwargs |
|
|
| optimizer_kwargs = copy.deepcopy(optimizer_kwargs) |
| optimizer_kwargs = maybe_update_config_version(optimizer_kwargs) |
|
|
| if isinstance(optimizer_kwargs, DictConfig): |
| optimizer_kwargs = OmegaConf.to_container(optimizer_kwargs, resolve=True) |
|
|
| |
| if hasattr(optimizer_kwargs, 'keys'): |
| |
| if '_target_' in optimizer_kwargs: |
| optimizer_kwargs_config = OmegaConf.create(optimizer_kwargs) |
| optimizer_instance = hydra.utils.instantiate(optimizer_kwargs_config) |
| optimizer_instance = vars(optimizer_instance) |
| return optimizer_instance |
|
|
| |
| if 'name' in optimizer_kwargs: |
| |
| |
| if optimizer_kwargs['name'] == 'auto': |
| optimizer_params_name = "{}_params".format(optimizer_name) |
| optimizer_kwargs.pop('name') |
| else: |
| optimizer_params_name = optimizer_kwargs.pop('name') |
|
|
| |
| if 'params' in optimizer_kwargs: |
| |
| optimizer_params_override = optimizer_kwargs.get('params') |
| else: |
| |
| optimizer_params_override = optimizer_kwargs |
|
|
| if isinstance(optimizer_params_override, DictConfig): |
| optimizer_params_override = OmegaConf.to_container(optimizer_params_override, resolve=True) |
|
|
| optimizer_params_cls = get_optimizer_config(optimizer_params_name, **optimizer_params_override) |
|
|
| |
| if optimizer_params_name is None: |
| optimizer_params = vars(optimizer_params_cls) |
| return optimizer_params |
|
|
| else: |
| |
| |
| optimizer_params = optimizer_params_cls() |
| optimizer_params = vars(optimizer_params) |
| return optimizer_params |
|
|
| |
| return optimizer_kwargs |
|
|
| return kwargs |
|
|
|
|
| def register_optimizer(name: str, optimizer: Optimizer, optimizer_params: OptimizerParams): |
| """ |
| Checks if the optimizer name exists in the registry, and if it doesnt, adds it. |
| |
| This allows custom optimizers to be added and called by name during instantiation. |
| |
| Args: |
| name: Name of the optimizer. Will be used as key to retrieve the optimizer. |
| optimizer: Optimizer class |
| optimizer_params: The parameters as a dataclass of the optimizer |
| """ |
| if name in AVAILABLE_OPTIMIZERS: |
| raise ValueError(f"Cannot override pre-existing optimizers. Conflicting optimizer name = {name}") |
|
|
| AVAILABLE_OPTIMIZERS[name] = optimizer |
|
|
| optim_name = "{}_params".format(optimizer.__name__) |
| register_optimizer_params(name=optim_name, optimizer_params=optimizer_params) |
|
|
|
|
| def get_optimizer(name: str, **kwargs: Optional[Dict[str, Any]]) -> Optimizer: |
| """ |
| Convenience method to obtain an Optimizer class and partially instantiate it with optimizer kwargs. |
| |
| Args: |
| name: Name of the Optimizer in the registry. |
| kwargs: Optional kwargs of the optimizer used during instantiation. |
| |
| Returns: |
| a partially instantiated Optimizer |
| """ |
| if name not in AVAILABLE_OPTIMIZERS: |
| raise ValueError( |
| f"Cannot resolve optimizer '{name}'. Available optimizers are : " f"{AVAILABLE_OPTIMIZERS.keys()}" |
| ) |
| if name == 'fused_adam': |
| if not torch.cuda.is_available(): |
| raise ValueError(f'CUDA must be available to use fused_adam.') |
|
|
| optimizer = AVAILABLE_OPTIMIZERS[name] |
| optimizer = partial(optimizer, **kwargs) |
| return optimizer |
|
|
|
|
| def init_optimizer_states(optimizer: Optimizer): |
| adam_nondist_optims = (optim.Adam, optim.AdamW) |
| if HAVE_APEX: |
| adam_nondist_optims += (FusedAdam,) |
| if isinstance(optimizer, adam_nondist_optims): |
| for group in optimizer.param_groups: |
| for p in group['params']: |
| state = optimizer.state[p] |
| if len(state) == 0: |
| state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format) |
| state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format) |
| if group.get('amsgrad'): |
| state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) |
|
|