File size: 2,329 Bytes
d4b77ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from utils.dist import *
from parse import *
from utils.util import find_free_port
import torch.multiprocessing as mp
import torch.distributed
from importlib import import_module

from flow_inputs import args_parser


def main_worker(rank, opt):
    if 'local_rank' not in opt:
        opt['local_rank'] = opt['global_rank'] = rank
    if opt['distributed']:
        torch.cuda.set_device(int(opt['local_rank']))
        torch.distributed.init_process_group(backend='nccl',
                                             init_method=opt['init_method'],
                                             world_size=opt['world_size'],
                                             rank=opt['global_rank'],
                                             group_name='mtorch')
        print('using GPU {}-{} for training'.format(
            int(opt['global_rank']), int(opt['local_rank'])))

    if torch.cuda.is_available():
        opt['device'] = torch.device("cuda:{}".format(opt['local_rank']))
    else:
        opt['device'] = 'cpu'

    pkg = import_module('networks.{}'.format(opt['network']))
    trainer = pkg.Network(opt, rank)
    trainer.train()


def main(args_obj):
    opt = parse(args_obj)
    opt['world_size'] = get_world_size()
    free_port = find_free_port()
    master_ip = get_master_ip()
    opt['init_method'] = "tcp://{}:{}".format(master_ip, free_port)
    opt['distributed'] = True if opt['world_size'] > 1 else False
    print(f'World size is: {opt["world_size"]}, and init_method is: {opt["init_method"]}')
    print('Import network module: ', opt['network'])

    # dataset file names
    if opt['gen_state'] != '':
        opt['path']['gen_state'] = opt['gen_state']
    if opt['opt_state'] != '':
        opt['path']['opt_state'] = opt['opt_state']

    if args.finetune == 1:
        opt['finetune'] = True
    else:
        opt['finetune'] = False

    print(f'model is: {opt["model"]}')

    if get_master_ip() == "127.0.0.1":
        # localhost
        mp.spawn(main_worker, nprocs=opt['world_size'], args=(opt,))
    else:
        # multiple processes should be launched by openmpi
        opt['local_rank'] = get_local_rank()
        opt['global_rank'] = get_global_rank()
        main_worker(-1, opt)


if __name__ == '__main__':
    args = args_parser()
    args_obj = vars(args)
    main(args_obj)