import torch import torch.nn as nn import torch.nn.utils as utils import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.multiprocessing as mp import os, sys, time from telnetlib import IP import argparse import numpy as np from tqdm import tqdm from tensorboardX import SummaryWriter from utils import post_process_depth, flip_lr, silog_loss, compute_errors, eval_metrics, entropy_loss, colormap, \ block_print, enable_print, normalize_result, inv_normalize, convert_arg_line_to_args, colormap_magma from networks.NewCRFDepth import NewCRFDepth from networks.depth_update import * from datetime import datetime from sum_depth import Sum_depth parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@') parser.convert_arg_line_to_args = convert_arg_line_to_args parser.add_argument('--mode', type=str, help='train or test', default='train') parser.add_argument('--model_name', type=str, help='model name', default='iebins') parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07, tiny07', default='large07') parser.add_argument('--pretrain', type=str, help='path of pretrained encoder', default=None) # Dataset parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu') parser.add_argument('--data_path', type=str, help='path to the data', required=True) parser.add_argument('--gt_path', type=str, help='path to the groundtruth data', required=True) parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True) parser.add_argument('--input_height', type=int, help='input height', default=480) parser.add_argument('--input_width', type=int, help='input width', default=640) parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10) parser.add_argument('--min_depth', type=float, help='minimum depth in estimation', default=0.1) # Log and save parser.add_argument('--log_directory', type=str, help='directory to save checkpoints and summaries', default='') parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='') parser.add_argument('--log_freq', type=int, help='Logging frequency in global steps', default=100) parser.add_argument('--save_freq', type=int, help='Checkpoint saving frequency in global steps', default=5000) # Training parser.add_argument('--weight_decay', type=float, help='weight decay factor for optimization', default=1e-2) parser.add_argument('--retrain', help='if used with checkpoint_path, will restart training from step zero', action='store_true') parser.add_argument('--adam_eps', type=float, help='epsilon in Adam optimizer', default=1e-6) parser.add_argument('--batch_size', type=int, help='batch size', default=4) parser.add_argument('--num_epochs', type=int, help='number of epochs', default=50) parser.add_argument('--learning_rate', type=float, help='initial learning rate', default=1e-4) parser.add_argument('--end_learning_rate', type=float, help='end learning rate', default=-1) parser.add_argument('--variance_focus', type=float, help='lambda in paper: [0, 1], higher value more focus on minimizing variance of error', default=0.85) # Preprocessing parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true') parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5) parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true') parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true') # Multi-gpu training parser.add_argument('--num_threads', type=int, help='number of threads to use for data loading', default=1) parser.add_argument('--world_size', type=int, help='number of nodes for distributed training', default=1) parser.add_argument('--rank', type=int, help='node rank for distributed training', default=0) parser.add_argument('--dist_url', type=str, help='url used to set up distributed training', default='tcp://127.0.0.1:1234') parser.add_argument('--dist_backend', type=str, help='distributed backend', default='nccl') parser.add_argument('--gpu', type=int, help='GPU id to use.', default=None) parser.add_argument('--multiprocessing_distributed', help='Use multi-processing distributed training to launch ' 'N processes per node, which has N GPUs. This is the ' 'fastest way to use PyTorch for either single node or ' 'multi node data parallel training', action='store_true',) # Online eval parser.add_argument('--do_online_eval', help='if set, perform online eval in every eval_freq steps', action='store_true') parser.add_argument('--data_path_eval', type=str, help='path to the data for online evaluation', required=False) parser.add_argument('--gt_path_eval', type=str, help='path to the groundtruth data for online evaluation', required=False) parser.add_argument('--filenames_file_eval', type=str, help='path to the filenames text file for online evaluation', required=False) parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3) parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80) parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true') parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true') parser.add_argument('--eval_freq', type=int, help='Online evaluation frequency in global steps', default=500) parser.add_argument('--eval_summary_directory', type=str, help='output directory for eval summary,' 'if empty outputs to checkpoint folder', default='') if sys.argv.__len__() == 2: arg_filename_with_prefix = '@' + sys.argv[1] args = parser.parse_args([arg_filename_with_prefix]) else: args = parser.parse_args() if args.dataset == 'kitti' or args.dataset == 'nyu': from dataloaders.dataloader import NewDataLoader def online_eval(model, dataloader_eval, gpu, epoch, ngpus, group, post_process=False): eval_measures = torch.zeros(10).cuda(device=gpu) for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)): with torch.no_grad(): image = torch.autograd.Variable(eval_sample_batched['image'].cuda(gpu, non_blocking=True)) gt_depth = eval_sample_batched['depth'] has_valid_depth = eval_sample_batched['has_valid_depth'] if not has_valid_depth: # print('Invalid depth. continue.') continue pred_depths_r_list, _, _ = model(image) if post_process: image_flipped = flip_lr(image) pred_depths_r_list_flipped, _, _ = model(image_flipped) pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1]) pred_depth = pred_depth.cpu().numpy().squeeze() gt_depth = gt_depth.cpu().numpy().squeeze() if args.do_kb_crop: height, width = gt_depth.shape top_margin = int(height - 352) left_margin = int((width - 1216) / 2) pred_depth_uncropped = np.zeros((height, width), dtype=np.float32) pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth pred_depth = pred_depth_uncropped pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval pred_depth[np.isinf(pred_depth)] = args.max_depth_eval pred_depth[np.isnan(pred_depth)] = args.min_depth_eval valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval) if args.garg_crop or args.eigen_crop: gt_height, gt_width = gt_depth.shape eval_mask = np.zeros(valid_mask.shape) if args.garg_crop: eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1 elif args.eigen_crop: if args.dataset == 'kitti': eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1 elif args.dataset == 'nyu': eval_mask[45:471, 41:601] = 1 valid_mask = np.logical_and(valid_mask, eval_mask) measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask]) eval_measures[:9] += torch.tensor(measures).cuda(device=gpu) eval_measures[9] += 1 if args.multiprocessing_distributed: # group = dist.new_group([i for i in range(ngpus)]) dist.all_reduce(tensor=eval_measures, op=dist.ReduceOp.SUM, group=group) if not args.multiprocessing_distributed or gpu == 0: eval_measures_cpu = eval_measures.cpu() cnt = eval_measures_cpu[9].item() eval_measures_cpu /= cnt print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process) print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2', 'd3')) for i in range(8): print('{:7.4f}, '.format(eval_measures_cpu[i]), end='') print('{:7.4f}'.format(eval_measures_cpu[8])) return eval_measures_cpu return None def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu if args.gpu is not None: print("== Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # model model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=args.pretrain) model.train() num_params = sum([np.prod(p.size()) for p in model.parameters()]) print("== Total number of parameters: {}".format(num_params)) num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad]) print("== Total number of learning parameters: {}".format(num_params_update)) if args.distributed: if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) else: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) else: model = torch.nn.DataParallel(model) model.cuda() if args.distributed: print("== Model Initialized on GPU: {}".format(args.gpu)) else: print("== Model Initialized") global_step = 0 best_eval_measures_lower_better = torch.zeros(6).cpu() + 1e3 best_eval_measures_higher_better = torch.zeros(3).cpu() best_eval_steps = np.zeros(9, dtype=np.int32) # Training parameters optimizer = torch.optim.Adam([{'params': model.module.parameters()}], lr=args.learning_rate) model_just_loaded = False if args.checkpoint_path != '': if os.path.isfile(args.checkpoint_path): print("== Loading checkpoint '{}'".format(args.checkpoint_path)) if args.gpu is None: checkpoint = torch.load(args.checkpoint_path) else: loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.checkpoint_path, map_location=loc) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if not args.retrain: try: global_step = checkpoint['global_step'] best_eval_measures_higher_better = checkpoint['best_eval_measures_higher_better'].cpu() best_eval_measures_lower_better = checkpoint['best_eval_measures_lower_better'].cpu() best_eval_steps = checkpoint['best_eval_steps'] except KeyError: print("Could not load values for online evaluation") print("== Loaded checkpoint '{}' (global_step {})".format(args.checkpoint_path, checkpoint['global_step'])) else: print("== No checkpoint found at '{}'".format(args.checkpoint_path)) model_just_loaded = True del checkpoint cudnn.benchmark = True dataloader = NewDataLoader(args, 'train') dataloader_eval = NewDataLoader(args, 'online_eval') # ===== Evaluation before training ====== # model.eval() # with torch.no_grad(): # eval_measures = online_eval(model, dataloader_eval, gpu, ngpus_per_node, post_process=True) # Logging if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): writer = SummaryWriter(args.log_directory + '/' + args.model_name + '/summaries', flush_secs=30) if args.do_online_eval: if args.eval_summary_directory != '': eval_summary_path = os.path.join(args.eval_summary_directory, args.model_name) else: eval_summary_path = os.path.join(args.log_directory, args.model_name, 'eval') eval_summary_writer = SummaryWriter(eval_summary_path, flush_secs=30) silog_criterion = silog_loss(variance_focus=args.variance_focus) sum_localdepth = Sum_depth().cuda(args.gpu) start_time = time.time() duration = 0 num_log_images = args.batch_size end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad] var_cnt = len(var_sum) var_sum = np.sum(var_sum) print("== Initial variables' sum: {:.3f}, avg: {:.3f}".format(var_sum, var_sum/var_cnt)) steps_per_epoch = len(dataloader.data) num_total_steps = args.num_epochs * steps_per_epoch epoch = global_step // steps_per_epoch group = dist.new_group([i for i in range(ngpus_per_node)]) while epoch < args.num_epochs: if args.distributed: dataloader.train_sampler.set_epoch(epoch) for step, sample_batched in enumerate(dataloader.data): optimizer.zero_grad() before_op_time = time.time() si_loss = 0 image = torch.autograd.Variable(sample_batched['image'].cuda(args.gpu, non_blocking=True)) depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(args.gpu, non_blocking=True)) pred_depths_r_list, pred_depths_c_list, uncertainty_maps_list = model(image, epoch, step) if args.dataset == 'nyu': mask = depth_gt > 0.1 else: mask = depth_gt > 1.0 max_tree_depth = len(pred_depths_r_list) for curr_tree_depth in range(max_tree_depth): si_loss += silog_criterion.forward(pred_depths_r_list[curr_tree_depth], depth_gt, mask.to(torch.bool)) loss = si_loss loss.backward() for param_group in optimizer.param_groups: current_lr = (args.learning_rate - end_learning_rate) * (1 - global_step / num_total_steps) ** 0.9 + end_learning_rate param_group['lr'] = current_lr optimizer.step() if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): print('[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}'.format(epoch, step, steps_per_epoch, global_step, current_lr, loss)) # if np.isnan(loss.cpu().item()): # print('NaN in loss occurred. Aborting training.') # return -1 duration += time.time() - before_op_time if global_step and global_step % args.log_freq == 0 and not model_just_loaded: var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad] var_cnt = len(var_sum) var_sum = np.sum(var_sum) examples_per_sec = args.batch_size / duration * args.log_freq duration = 0 time_sofar = (time.time() - start_time) / 3600 training_time_left = (num_total_steps / global_step - 1.0) * time_sofar if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): print("{}".format(args.model_name)) print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h' print(print_string.format(args.gpu, examples_per_sec, loss, var_sum.item(), var_sum.item()/var_cnt, time_sofar, training_time_left)) if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): writer.add_scalar('silog_loss', si_loss, global_step) # writer.add_scalar('var_loss', var_loss, global_step) writer.add_scalar('learning_rate', current_lr, global_step) writer.add_scalar('var average', var_sum.item()/var_cnt, global_step) depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e-3, depth_gt) for i in range(num_log_images): if args.dataset == 'nyu': writer.add_image('depth_gt/image/{}'.format(i), colormap(depth_gt[i, :, :, :].data), global_step) writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step) writer.add_image('depth_r_est0/image/{}'.format(i), colormap(pred_depths_r_list[0][i, :, :, :].data), global_step) writer.add_image('depth_r_est1/image/{}'.format(i), colormap(pred_depths_r_list[1][i, :, :, :].data), global_step) writer.add_image('depth_r_est2/image/{}'.format(i), colormap(pred_depths_r_list[2][i, :, :, :].data), global_step) writer.add_image('depth_r_est3/image/{}'.format(i), colormap(pred_depths_r_list[3][i, :, :, :].data), global_step) writer.add_image('depth_r_est4/image/{}'.format(i), colormap(pred_depths_r_list[4][i, :, :, :].data), global_step) writer.add_image('depth_r_est5/image/{}'.format(i), colormap(pred_depths_r_list[5][i, :, :, :].data), global_step) writer.add_image('depth_c_est0/image/{}'.format(i), colormap(pred_depths_c_list[0][i, :, :, :].data), global_step) writer.add_image('depth_c_est1/image/{}'.format(i), colormap(pred_depths_c_list[1][i, :, :, :].data), global_step) writer.add_image('depth_c_est2/image/{}'.format(i), colormap(pred_depths_c_list[2][i, :, :, :].data), global_step) writer.add_image('depth_c_est3/image/{}'.format(i), colormap(pred_depths_c_list[3][i, :, :, :].data), global_step) writer.add_image('depth_c_est4/image/{}'.format(i), colormap(pred_depths_c_list[4][i, :, :, :].data), global_step) writer.add_image('depth_c_est5/image/{}'.format(i), colormap(pred_depths_c_list[5][i, :, :, :].data), global_step) else: writer.add_image('depth_gt/image/{}'.format(i), colormap_magma(torch.log10(depth_gt[i, :, :, :].data)), global_step) writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step) writer.add_image('depth_r_est0/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[0][i, :, :, :].data)), global_step) writer.add_image('depth_r_est1/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[1][i, :, :, :].data)), global_step) writer.add_image('depth_r_est2/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[2][i, :, :, :].data)), global_step) writer.add_image('depth_r_est3/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[3][i, :, :, :].data)), global_step) writer.add_image('depth_r_est4/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[4][i, :, :, :].data)), global_step) writer.add_image('depth_r_est5/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[5][i, :, :, :].data)), global_step) writer.add_image('depth_c_est0/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[0][i, :, :, :].data)), global_step) writer.add_image('depth_c_est1/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[1][i, :, :, :].data)), global_step) writer.add_image('depth_c_est2/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[2][i, :, :, :].data)), global_step) writer.add_image('depth_c_est3/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[3][i, :, :, :].data)), global_step) writer.add_image('depth_c_est4/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[4][i, :, :, :].data)), global_step) writer.add_image('depth_c_est5/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[5][i, :, :, :].data)), global_step) writer.add_image('uncer_est0/image/{}'.format(i), colormap(uncertainty_maps_list[0][i, :, :, :].data), global_step) writer.add_image('uncer_est1/image/{}'.format(i), colormap(uncertainty_maps_list[1][i, :, :, :].data), global_step) writer.add_image('uncer_est2/image/{}'.format(i), colormap(uncertainty_maps_list[2][i, :, :, :].data), global_step) writer.add_image('uncer_est3/image/{}'.format(i), colormap(uncertainty_maps_list[3][i, :, :, :].data), global_step) writer.add_image('uncer_est4/image/{}'.format(i), colormap(uncertainty_maps_list[4][i, :, :, :].data), global_step) writer.add_image('uncer_est5/image/{}'.format(i), colormap(uncertainty_maps_list[5][i, :, :, :].data), global_step) if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded: time.sleep(0.1) model.eval() with torch.no_grad(): eval_measures = online_eval(model, dataloader_eval, gpu, epoch, ngpus_per_node, group, post_process=True) if eval_measures is not None: exp_name = '%s'%(datetime.now().strftime('%m%d')) log_txt = os.path.join(args.log_directory + '/' + args.model_name, exp_name+'_logs.txt') with open(log_txt, 'a') as txtfile: txtfile.write(">>>>>>>>>>>>>>>>>>>>>>>>>Step:%d>>>>>>>>>>>>>>>>>>>>>>>>>\n"%(int(global_step))) txtfile.write("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}\n".format('silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2','d3')) txtfile.write("depth estimation\n") line = '' for i in range(9): line +='{:7.4f}, '.format(eval_measures[i]) txtfile.write(line+'\n') for i in range(9): eval_summary_writer.add_scalar(eval_metrics[i], eval_measures[i].cpu(), int(global_step)) measure = eval_measures[i] is_best = False if i < 6 and measure < best_eval_measures_lower_better[i]: old_best = best_eval_measures_lower_better[i].item() best_eval_measures_lower_better[i] = measure.item() is_best = True elif i >= 6 and measure > best_eval_measures_higher_better[i-6]: old_best = best_eval_measures_higher_better[i-6].item() best_eval_measures_higher_better[i-6] = measure.item() is_best = True if is_best: old_best_step = best_eval_steps[i] old_best_name = '/model-{}-best_{}_{:.5f}'.format(old_best_step, eval_metrics[i], old_best) model_path = args.log_directory + '/' + args.model_name + old_best_name if os.path.exists(model_path): command = 'rm {}'.format(model_path) os.system(command) best_eval_steps[i] = global_step model_save_name = '/model-{}-best_{}_{:.5f}'.format(global_step, eval_metrics[i], measure) print('New best for {}. Saving model: {}'.format(eval_metrics[i], model_save_name)) checkpoint = {'global_step': global_step, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_eval_measures_higher_better': best_eval_measures_higher_better, 'best_eval_measures_lower_better': best_eval_measures_lower_better, 'best_eval_steps': best_eval_steps } torch.save(checkpoint, args.log_directory + '/' + args.model_name + model_save_name) eval_summary_writer.flush() model.train() block_print() enable_print() model_just_loaded = False global_step += 1 epoch += 1 if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): writer.close() if args.do_online_eval: eval_summary_writer.close() def main(): if args.mode != 'train': print('train.py is only for training.') return -1 exp_name = '%s'%(datetime.now().strftime('%m%d')) args.log_directory = os.path.join(args.log_directory,exp_name) command = 'mkdir ' + os.path.join(args.log_directory, args.model_name) os.system(command) args_out_path = os.path.join(args.log_directory, args.model_name) command = 'cp ' + sys.argv[1] + ' ' + args_out_path os.system(command) save_files = True if save_files: aux_out_path = os.path.join(args.log_directory, args.model_name) networks_savepath = os.path.join(aux_out_path, 'networks') dataloaders_savepath = os.path.join(aux_out_path, 'dataloaders') command = 'cp iebins/train.py ' + aux_out_path os.system(command) command = 'mkdir -p ' + networks_savepath + ' && cp iebins/networks/*.py ' + networks_savepath os.system(command) command = 'mkdir -p ' + dataloaders_savepath + ' && cp iebins/dataloaders/*.py ' + dataloaders_savepath os.system(command) torch.cuda.empty_cache() args.distributed = args.world_size > 1 or args.multiprocessing_distributed ngpus_per_node = torch.cuda.device_count() if ngpus_per_node > 1 and not args.multiprocessing_distributed: print("This machine has more than 1 gpu. Please specify --multiprocessing_distributed, or set \'CUDA_VISIBLE_DEVICES=0\'") return -1 if args.do_online_eval: print("You have specified --do_online_eval.") print("This will evaluate the model every eval_freq {} steps and save best models for individual eval metrics." .format(args.eval_freq)) if args.multiprocessing_distributed: args.world_size = ngpus_per_node * args.world_size mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) else: main_worker(args.gpu, ngpus_per_node, args) if __name__ == '__main__': main()