Vincentqyw
fix: roma
8b973ee
raw
history blame
2.55 kB
import torch.utils.data
from dataset import Offline_Dataset
import yaml
from sgmnet.match_model import matcher as SGM_Model
from superglue.match_model import matcher as SG_Model
import torch.distributed as dist
import torch
import os
from collections import namedtuple
from train import train
from config import get_config, print_usage
def main(config, model_config):
"""The main function."""
# Initialize network
if config.model_name == "SGM":
model = SGM_Model(model_config)
elif config.model_name == "SG":
model = SG_Model(model_config)
else:
raise NotImplementedError
# initialize ddp
torch.cuda.set_device(config.local_rank)
device = torch.device(f"cuda:{config.local_rank}")
model.to(device)
dist.init_process_group(backend="nccl", init_method="env://")
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[config.local_rank]
)
if config.local_rank == 0:
os.system("nvidia-smi")
# initialize dataset
train_dataset = Offline_Dataset(config, "train")
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, shuffle=True
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.train_batch_size // torch.distributed.get_world_size(),
num_workers=8 // dist.get_world_size(),
pin_memory=False,
sampler=train_sampler,
collate_fn=train_dataset.collate_fn,
)
valid_dataset = Offline_Dataset(config, "valid")
valid_sampler = torch.utils.data.distributed.DistributedSampler(
valid_dataset, shuffle=False
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=config.train_batch_size,
num_workers=8 // dist.get_world_size(),
pin_memory=False,
collate_fn=valid_dataset.collate_fn,
sampler=valid_sampler,
)
if config.local_rank == 0:
print("start training .....")
train(model, train_loader, valid_loader, config, model_config)
if __name__ == "__main__":
# ----------------------------------------
# Parse configuration
config, unparsed = get_config()
with open(config.config_path, "r") as f:
model_config = yaml.load(f)
model_config = namedtuple("model_config", model_config.keys())(
*model_config.values()
)
# If we have unparsed arguments, print usage and exit
if len(unparsed) > 0:
print_usage()
exit(1)
main(config, model_config)