import argparse import numpy as np import os import sys import shutil import torch import torch.optim as optim from torch.utils.data import DataLoader from tqdm import tqdm import warnings from lib.exceptions import NoGradientError from lib.losses.lossPhotoTourism import loss_function from lib.model import D2Net from lib.dataloaders.datasetPhotoTourism_ipr import PhotoTourismIPR # CUDA use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if use_cuda else "cpu") # Seed torch.manual_seed(1) if use_cuda: torch.cuda.manual_seed(1) np.random.seed(1) # Argument parsing parser = argparse.ArgumentParser(description='Training script') parser.add_argument( '--dataset_path', type=str, default="/scratch/udit/phototourism/", help='path to the dataset' ) parser.add_argument( '--preprocessing', type=str, default='caffe', help='image preprocessing (caffe or torch)' ) parser.add_argument( '--init_model', type=str, default='models/d2net.pth', help='path to the initial model' ) parser.add_argument( '--num_epochs', type=int, default=10, help='number of training epochs' ) parser.add_argument( '--lr', type=float, default=1e-3, help='initial learning rate' ) parser.add_argument( '--batch_size', type=int, default=1, help='batch size' ) parser.add_argument( '--num_workers', type=int, default=16, help='number of workers for data loading' ) parser.add_argument( '--log_interval', type=int, default=250, help='loss logging interval' ) parser.add_argument( '--log_file', type=str, default='log.txt', help='loss logging file' ) parser.add_argument( '--plot', dest='plot', action='store_true', help='plot training pairs' ) parser.set_defaults(plot=False) parser.add_argument( '--checkpoint_directory', type=str, default='checkpoints', help='directory for training checkpoints' ) parser.add_argument( '--checkpoint_prefix', type=str, default='rord', help='prefix for training checkpoints' ) args = parser.parse_args() print(args) # Creating CNN model model = D2Net( model_file=args.init_model, use_cuda=False ) model = model.to(device) # Optimizer optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr ) training_dataset = PhotoTourismIPR( base_path=args.dataset_path, preprocessing=args.preprocessing ) training_dataset.build_dataset() training_dataloader = DataLoader( training_dataset, batch_size=args.batch_size, num_workers=args.num_workers ) # Define epoch function def process_epoch( epoch_idx, model, loss_function, optimizer, dataloader, device, log_file, args, train=True, plot_path=None ): epoch_losses = [] torch.set_grad_enabled(train) progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) for batch_idx, batch in progress_bar: if train: optimizer.zero_grad() batch['train'] = train batch['epoch_idx'] = epoch_idx batch['batch_idx'] = batch_idx batch['batch_size'] = args.batch_size batch['preprocessing'] = args.preprocessing batch['log_interval'] = args.log_interval try: loss = loss_function(model, batch, device, plot=args.plot, plot_path=plot_path) except NoGradientError: # print("failed") continue current_loss = loss.data.cpu().numpy()[0] epoch_losses.append(current_loss) progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses))) if batch_idx % args.log_interval == 0: log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % ( 'train' if train else 'valid', epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses) )) if train: loss.backward() optimizer.step() log_file.write('[%s] epoch %d - avg_loss: %f\n' % ( 'train' if train else 'valid', epoch_idx, np.mean(epoch_losses) )) log_file.flush() return np.mean(epoch_losses) # Create the checkpoint directory checkpoint_directory = os.path.join(args.checkpoint_directory, args.checkpoint_prefix) if os.path.isdir(checkpoint_directory): print('[Warning] Checkpoint directory already exists.') else: os.makedirs(checkpoint_directory, exist_ok=True) # Open the log file for writing log_file = os.path.join(checkpoint_directory,args.log_file) if os.path.exists(log_file): print('[Warning] Log file already exists.') log_file = open(log_file, 'a+') # Create the folders for plotting if need be plot_path=None if args.plot: plot_path = os.path.join(checkpoint_directory,'train_vis') if os.path.isdir(plot_path): print('[Warning] Plotting directory already exists.') else: os.makedirs(plot_path, exist_ok=True) # Initialize the history train_loss_history = [] # Start the training for epoch_idx in range(1, args.num_epochs + 1): # Process epoch train_loss_history.append( process_epoch( epoch_idx, model, loss_function, optimizer, training_dataloader, device, log_file, args, train=True, plot_path=plot_path ) ) # Save the current checkpoint checkpoint_path = os.path.join( checkpoint_directory, '%02d.pth' % (epoch_idx) ) checkpoint = { 'args': args, 'epoch_idx': epoch_idx, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'train_loss_history': train_loss_history, } torch.save(checkpoint, checkpoint_path) # Close the log file log_file.close()