File size: 2,385 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch.utils.data
from dataset import Offline_Dataset
import yaml
from sgmnet.match_model import matcher as SGM_Model
from superglue.match_model import matcher as SG_Model
import torch.distributed as dist
import torch
import os
from collections import namedtuple
from train import train
from config import get_config, print_usage


def main(config,model_config):
    """The main function."""
    # Initialize network
    if config.model_name=='SGM':
        model = SGM_Model(model_config)
    elif config.model_name=='SG':
        model= SG_Model(model_config)
    else:
        raise NotImplementedError

    #initialize ddp
    torch.cuda.set_device(config.local_rank)
    device = torch.device(f'cuda:{config.local_rank}')
    model.to(device)
    dist.init_process_group(backend='nccl',init_method='env://')
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.local_rank])
    
    if config.local_rank==0:
        os.system('nvidia-smi')

    #initialize dataset
    train_dataset = Offline_Dataset(config,'train')
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,shuffle=True)
    train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size//torch.distributed.get_world_size(),
            num_workers=8//dist.get_world_size(), pin_memory=False,sampler=train_sampler,collate_fn=train_dataset.collate_fn)
    
    valid_dataset = Offline_Dataset(config,'valid')
    valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,shuffle=False)
    valid_loader=torch.utils.data.DataLoader(valid_dataset, batch_size=config.train_batch_size,
                num_workers=8//dist.get_world_size(), pin_memory=False,collate_fn=valid_dataset.collate_fn,sampler=valid_sampler)
    
    if config.local_rank==0:
        print('start training .....')
    train(model,train_loader, valid_loader, config,model_config)

if __name__ == "__main__":
    # ----------------------------------------
    # Parse configuration
    config, unparsed = get_config()
    with open(config.config_path, 'r') as f:
        model_config = yaml.load(f)
    model_config=namedtuple('model_config',model_config.keys())(*model_config.values())
    # If we have unparsed arguments, print usage and exit
    if len(unparsed) > 0:
        print_usage()
        exit(1)

    main(config,model_config)