import torch import torch.optim as optim import numpy as np import itertools def singleton(class_): instances = {} def getinstance(*args, **kwargs): if class_ not in instances: instances[class_] = class_(*args, **kwargs) return instances[class_] return getinstance class get_optimizer(object): def __init__(self): self.optimizer = {} self.register(optim.SGD, 'sgd') self.register(optim.Adam, 'adam') self.register(optim.AdamW, 'adamw') def register(self, optim, name): self.optimizer[name] = optim def __call__(self, net, cfg): if cfg is None: return None t = cfg.type if isinstance(net, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): netm = net.module else: netm = net pg = getattr(netm, 'parameter_group', None) if pg is not None: params = [] for group_name, module_or_para in pg.items(): if not isinstance(module_or_para, list): module_or_para = [module_or_para] grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para] grouped_params = itertools.chain(*grouped_params) pg_dict = {'params':grouped_params, 'name':group_name} params.append(pg_dict) else: params = net.parameters() return self.optimizer[t](params, lr=0, **cfg.args)