oguzakif's picture
init repo
d4b77ac
raw
history blame
No virus
2.96 kB
from utils.dist import *
from parse import *
from utils.util import find_free_port
import torch.multiprocessing as mp
import torch.distributed
from importlib import import_module
import os
import glob
from inputs import args_parser
def main_worker(rank, opt):
if 'local_rank' not in opt:
opt['local_rank'] = opt['global_rank'] = rank
if opt['distributed']:
torch.cuda.set_device(int(opt['local_rank']))
torch.distributed.init_process_group(backend='nccl',
init_method=opt['init_method'],
world_size=opt['world_size'],
rank=opt['global_rank'],
group_name='mtorch')
print('using GPU {}-{} for training'.format(
int(opt['global_rank']), int(opt['local_rank'])))
if torch.cuda.is_available():
opt['device'] = torch.device("cuda:{}".format(opt['local_rank']))
else:
opt['device'] = 'cpu'
pkg = import_module('networks.{}'.format(opt['network']))
trainer = pkg.Network(opt, rank)
trainer.train()
def main(args_obj):
opt = parse(args_obj)
opt['world_size'] = get_world_size()
free_port = find_free_port()
master_ip = get_master_ip()
opt['init_method'] = "tcp://{}:{}".format(master_ip, free_port)
opt['distributed'] = True if opt['world_size'] > 1 else False
print(f'World size is: {opt["world_size"]}, and init_method is: {opt["init_method"]}')
print('Import network module: ', opt['network'])
checkpoint, config = glob.glob(os.path.join(opt['flow_checkPoint'], '*.tar'))[0], \
glob.glob(os.path.join(opt['flow_checkPoint'], '*.yaml'))[0]
with open(config, 'r') as f:
configs = yaml.full_load(f)
opt['flow_config'] = configs
opt['flow_checkPoint'] = checkpoint
if args.finetune == 1:
opt['finetune'] = True
else:
opt['finetune'] = False
if opt['gen_state'] != '':
opt['path']['gen_state'] = opt['gen_state']
if opt['dis_state'] != '':
opt['path']['dis_state'] = opt['dis_state']
if opt['opt_state'] != '':
opt['path']['opt_state'] = opt['opt_state']
opt['input_resolution'] = (opt['res_h'], opt['res_w'])
opt['kernel_size'] = (opt['kernel_size_h'], opt['kernel_size_w'])
opt['stride'] = (opt['stride_h'], opt['stride_w'])
opt['padding'] = (opt['pad_h'], opt['pad_w'])
print('model is: {}'.format(opt['model']))
if get_master_ip() == "127.0.0.1":
# localhost
mp.spawn(main_worker, nprocs=opt['world_size'], args=(opt,))
else:
# multiple processes should be launched by openmpi
opt['local_rank'] = get_local_rank()
opt['global_rank'] = get_global_rank()
main_worker(-1, opt)
if __name__ == '__main__':
args = args_parser()
args_obj = vars(args)
main(args_obj)