Xingqian Xu
New app first commit
2fbcf51
import torch
import torch.optim as optim
import numpy as np
import itertools
def singleton(class_):
instances = {}
def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
class get_optimizer(object):
def __init__(self):
self.optimizer = {}
self.register(optim.SGD, 'sgd')
self.register(optim.Adam, 'adam')
self.register(optim.AdamW, 'adamw')
def register(self, optim, name):
self.optimizer[name] = optim
def __call__(self, net, cfg):
if cfg is None:
return None
t = cfg.type
if isinstance(net, (torch.nn.DataParallel,
torch.nn.parallel.DistributedDataParallel)):
netm = net.module
else:
netm = net
pg = getattr(netm, 'parameter_group', None)
if pg is not None:
params = []
for group_name, module_or_para in pg.items():
if not isinstance(module_or_para, list):
module_or_para = [module_or_para]
grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para]
grouped_params = itertools.chain(*grouped_params)
pg_dict = {'params':grouped_params, 'name':group_name}
params.append(pg_dict)
else:
params = net.parameters()
return self.optimizer[t](params, lr=0, **cfg.args)