Spaces:
Running
Running
import torch.utils.data | |
from dataset import Offline_Dataset | |
import yaml | |
from sgmnet.match_model import matcher as SGM_Model | |
from superglue.match_model import matcher as SG_Model | |
import torch.distributed as dist | |
import torch | |
import os | |
from collections import namedtuple | |
from train import train | |
from config import get_config, print_usage | |
def main(config, model_config): | |
"""The main function.""" | |
# Initialize network | |
if config.model_name == "SGM": | |
model = SGM_Model(model_config) | |
elif config.model_name == "SG": | |
model = SG_Model(model_config) | |
else: | |
raise NotImplementedError | |
# initialize ddp | |
torch.cuda.set_device(config.local_rank) | |
device = torch.device(f"cuda:{config.local_rank}") | |
model.to(device) | |
dist.init_process_group(backend="nccl", init_method="env://") | |
model = torch.nn.parallel.DistributedDataParallel( | |
model, device_ids=[config.local_rank] | |
) | |
if config.local_rank == 0: | |
os.system("nvidia-smi") | |
# initialize dataset | |
train_dataset = Offline_Dataset(config, "train") | |
train_sampler = torch.utils.data.distributed.DistributedSampler( | |
train_dataset, shuffle=True | |
) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=config.train_batch_size // torch.distributed.get_world_size(), | |
num_workers=8 // dist.get_world_size(), | |
pin_memory=False, | |
sampler=train_sampler, | |
collate_fn=train_dataset.collate_fn, | |
) | |
valid_dataset = Offline_Dataset(config, "valid") | |
valid_sampler = torch.utils.data.distributed.DistributedSampler( | |
valid_dataset, shuffle=False | |
) | |
valid_loader = torch.utils.data.DataLoader( | |
valid_dataset, | |
batch_size=config.train_batch_size, | |
num_workers=8 // dist.get_world_size(), | |
pin_memory=False, | |
collate_fn=valid_dataset.collate_fn, | |
sampler=valid_sampler, | |
) | |
if config.local_rank == 0: | |
print("start training .....") | |
train(model, train_loader, valid_loader, config, model_config) | |
if __name__ == "__main__": | |
# ---------------------------------------- | |
# Parse configuration | |
config, unparsed = get_config() | |
with open(config.config_path, "r") as f: | |
model_config = yaml.load(f) | |
model_config = namedtuple("model_config", model_config.keys())( | |
*model_config.values() | |
) | |
# If we have unparsed arguments, print usage and exit | |
if len(unparsed) > 0: | |
print_usage() | |
exit(1) | |
main(config, model_config) | |