Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS | |
from mmcv.utils import Registry, build_from_cfg | |
OPTIMIZER_BUILDERS = Registry( | |
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS) | |
def build_optimizer_constructor(cfg): | |
constructor_type = cfg.get('type') | |
if constructor_type in OPTIMIZER_BUILDERS: | |
return build_from_cfg(cfg, OPTIMIZER_BUILDERS) | |
elif constructor_type in MMCV_OPTIMIZER_BUILDERS: | |
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS) | |
else: | |
raise KeyError(f'{constructor_type} is not registered ' | |
'in the optimizer builder registry.') | |
def build_optimizer(model, cfg): | |
optimizer_cfg = copy.deepcopy(cfg) | |
constructor_type = optimizer_cfg.pop('constructor', | |
'DefaultOptimizerConstructor') | |
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) | |
optim_constructor = build_optimizer_constructor( | |
dict( | |
type=constructor_type, | |
optimizer_cfg=optimizer_cfg, | |
paramwise_cfg=paramwise_cfg)) | |
optimizer = optim_constructor(model) | |
return optimizer | |