# Copyright (c) OpenMMLab. All rights reserved. import copy import inspect import torch from ...utils import Registry, build_from_cfg OPTIMIZERS = Registry('optimizer') OPTIMIZER_BUILDERS = Registry('optimizer builder') def register_torch_optimizers(): torch_optimizers = [] for module_name in dir(torch.optim): if module_name.startswith('__'): continue _optim = getattr(torch.optim, module_name) if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): OPTIMIZERS.register_module()(_optim) torch_optimizers.append(module_name) return torch_optimizers TORCH_OPTIMIZERS = register_torch_optimizers() def build_optimizer_constructor(cfg): return build_from_cfg(cfg, OPTIMIZER_BUILDERS) def build_optimizer(model, cfg): optimizer_cfg = copy.deepcopy(cfg) constructor_type = optimizer_cfg.pop('constructor', 'DefaultOptimizerConstructor') paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) optim_constructor = build_optimizer_constructor( dict( type=constructor_type, optimizer_cfg=optimizer_cfg, paramwise_cfg=paramwise_cfg)) optimizer = optim_constructor(model) return optimizer