File size: 2,962 Bytes
d4b77ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from utils.dist import *
from parse import *
from utils.util import find_free_port
import torch.multiprocessing as mp
import torch.distributed
from importlib import import_module
import os
import glob
from inputs import args_parser


def main_worker(rank, opt):
    if 'local_rank' not in opt:
        opt['local_rank'] = opt['global_rank'] = rank
    if opt['distributed']:
        torch.cuda.set_device(int(opt['local_rank']))
        torch.distributed.init_process_group(backend='nccl',
                                             init_method=opt['init_method'],
                                             world_size=opt['world_size'],
                                             rank=opt['global_rank'],
                                             group_name='mtorch')
        print('using GPU {}-{} for training'.format(
            int(opt['global_rank']), int(opt['local_rank'])))

    if torch.cuda.is_available():
        opt['device'] = torch.device("cuda:{}".format(opt['local_rank']))
    else:
        opt['device'] = 'cpu'

    pkg = import_module('networks.{}'.format(opt['network']))
    trainer = pkg.Network(opt, rank)
    trainer.train()


def main(args_obj):
    opt = parse(args_obj)
    opt['world_size'] = get_world_size()
    free_port = find_free_port()
    master_ip = get_master_ip()
    opt['init_method'] = "tcp://{}:{}".format(master_ip, free_port)
    opt['distributed'] = True if opt['world_size'] > 1 else False
    print(f'World size is: {opt["world_size"]}, and init_method is: {opt["init_method"]}')
    print('Import network module: ', opt['network'])

    checkpoint, config = glob.glob(os.path.join(opt['flow_checkPoint'], '*.tar'))[0], \
                         glob.glob(os.path.join(opt['flow_checkPoint'], '*.yaml'))[0]
    with open(config, 'r') as f:
        configs = yaml.full_load(f)
    opt['flow_config'] = configs
    opt['flow_checkPoint'] = checkpoint

    if args.finetune == 1:
        opt['finetune'] = True
    else:
        opt['finetune'] = False
    if opt['gen_state'] != '':
        opt['path']['gen_state'] = opt['gen_state']
    if opt['dis_state'] != '':
        opt['path']['dis_state'] = opt['dis_state']
    if opt['opt_state'] != '':
        opt['path']['opt_state'] = opt['opt_state']

    opt['input_resolution'] = (opt['res_h'], opt['res_w'])
    opt['kernel_size'] = (opt['kernel_size_h'], opt['kernel_size_w'])
    opt['stride'] = (opt['stride_h'], opt['stride_w'])
    opt['padding'] = (opt['pad_h'], opt['pad_w'])

    print('model is: {}'.format(opt['model']))

    if get_master_ip() == "127.0.0.1":
        # localhost
        mp.spawn(main_worker, nprocs=opt['world_size'], args=(opt,))
    else:
        # multiple processes should be launched by openmpi
        opt['local_rank'] = get_local_rank()
        opt['global_rank'] = get_global_rank()
        main_worker(-1, opt)


if __name__ == '__main__':
    args = args_parser()
    args_obj = vars(args)
    main(args_obj)