import sys import time import torch import torch.multiprocessing as mp import torch.distributed as dist import wandb from solver_ddp import Solver def train(args): print("hello") solver = Solver() ngpus_per_node = int(torch.cuda.device_count() / args.sys_params.n_nodes) print(f"use {ngpus_per_node} gpu machine") args.sys_params.world_size = ngpus_per_node * args.sys_params.n_nodes mp.spawn(worker, nprocs=ngpus_per_node, args=(solver, ngpus_per_node, args)) def worker(gpu, solver, ngpus_per_node, args): args.sys_params.rank = args.sys_params.rank * ngpus_per_node + gpu dist.init_process_group( backend="nccl", world_size=args.sys_params.world_size, init_method="env://", rank=args.sys_params.rank, ) args.gpu = gpu args.ngpus_per_node = ngpus_per_node solver.set_gpu(args) start_epoch = solver.start_epoch if args.dir_params.resume: start_epoch = start_epoch + 1 for epoch in range(start_epoch, args.hyperparams.epochs + 1): solver.train_sampler.set_epoch(epoch) solver.train(args, epoch) time.sleep(1) solver.multi_validate(args, epoch) if solver.stop == True: print("Apply Early Stopping") if args.wandb_params.use_wandb: wandb.finish() sys.exit() if args.wandb_params.use_wandb: wandb.finish()