import torch.utils.data from dataset import Offline_Dataset import yaml from sgmnet.match_model import matcher as SGM_Model from superglue.match_model import matcher as SG_Model import torch.distributed as dist import torch import os from collections import namedtuple from train import train from config import get_config, print_usage def main(config,model_config): """The main function.""" # Initialize network if config.model_name=='SGM': model = SGM_Model(model_config) elif config.model_name=='SG': model= SG_Model(model_config) else: raise NotImplementedError #initialize ddp torch.cuda.set_device(config.local_rank) device = torch.device(f'cuda:{config.local_rank}') model.to(device) dist.init_process_group(backend='nccl',init_method='env://') model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.local_rank]) if config.local_rank==0: os.system('nvidia-smi') #initialize dataset train_dataset = Offline_Dataset(config,'train') train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,shuffle=True) train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size//torch.distributed.get_world_size(), num_workers=8//dist.get_world_size(), pin_memory=False,sampler=train_sampler,collate_fn=train_dataset.collate_fn) valid_dataset = Offline_Dataset(config,'valid') valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,shuffle=False) valid_loader=torch.utils.data.DataLoader(valid_dataset, batch_size=config.train_batch_size, num_workers=8//dist.get_world_size(), pin_memory=False,collate_fn=valid_dataset.collate_fn,sampler=valid_sampler) if config.local_rank==0: print('start training .....') train(model,train_loader, valid_loader, config,model_config) if __name__ == "__main__": # ---------------------------------------- # Parse configuration config, unparsed = get_config() with open(config.config_path, 'r') as f: model_config = yaml.load(f) model_config=namedtuple('model_config',model_config.keys())(*model_config.values()) # If we have unparsed arguments, print usage and exit if len(unparsed) > 0: print_usage() exit(1) main(config,model_config)