gim-online / third_party /RoRD /trainers /trainPT_combined.py
Vincentqyw
update: rord
ab830e5
raw
history blame
7.71 kB
import argparse
import numpy as np
import os
import sys
sys.path.append("../")
import shutil
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import warnings
# from lib.dataset import MegaDepthDataset
from lib.exceptions import NoGradientError
from lib.loss import loss_function as orig_loss
from lib.losses.lossPhotoTourism import loss_function as ipr_loss
from lib.model import D2Net
from lib.dataloaders.datasetPhotoTourism_combined import PhotoTourismCombined
# CUDA
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:1" if use_cuda else "cpu")
# Seed
torch.manual_seed(1)
if use_cuda:
torch.cuda.manual_seed(1)
np.random.seed(1)
# Argument parsing
parser = argparse.ArgumentParser(description='Training script')
parser.add_argument(
'--dataset_path', type=str, default="/scratch/udit/phototourism/",
help='path to the dataset'
)
# parser.add_argument(
# '--scene_info_path', type=str, required=True,
# help='path to the processed scenes'
# )
parser.add_argument(
'--preprocessing', type=str, default='caffe',
help='image preprocessing (caffe or torch)'
)
parser.add_argument(
'--model_file', type=str, default='models/d2_ots.pth',
help='path to the full model'
)
parser.add_argument(
'--num_epochs', type=int, default=10,
help='number of training epochs'
)
parser.add_argument(
'--lr', type=float, default=1e-3,
help='initial learning rate'
)
parser.add_argument(
'--batch_size', type=int, default=1,
help='batch size'
)
parser.add_argument(
'--num_workers', type=int, default=16,
help='number of workers for data loading'
)
parser.add_argument(
'--use_validation', dest='use_validation', action='store_true',
help='use the validation split'
)
parser.set_defaults(use_validation=False)
parser.add_argument(
'--log_interval', type=int, default=250,
help='loss logging interval'
)
parser.add_argument(
'--log_file', type=str, default='log.txt',
help='loss logging file'
)
parser.add_argument(
'--plot', dest='plot', action='store_true',
help='plot training pairs'
)
parser.set_defaults(plot=False)
parser.add_argument(
'--checkpoint_directory', type=str, default='checkpoints',
help='directory for training checkpoints'
)
parser.add_argument(
'--checkpoint_prefix', type=str, default='d2',
help='prefix for training checkpoints'
)
args = parser.parse_args()
print(args)
# Creating CNN model
model = D2Net(
model_file=args.model_file,
use_cuda=False
)
model = model.to(device)
# Optimizer
optimizer = optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr
)
# Dataset
if args.use_validation:
validation_dataset = PhotoTourismCombined(
# scene_list_path='megadepth_utils/valid_scenes.txt',
# scene_info_path=args.scene_info_path,
base_path=args.dataset_path,
train=False,
preprocessing=args.preprocessing,
pairs_per_scene=25
)
# validation_dataset.build_dataset()
validation_dataloader = DataLoader(
validation_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers
)
training_dataset = PhotoTourismCombined(
# scene_list_path='megadepth_utils/train_scenes.txt',
# scene_info_path=args.scene_info_path,
base_path=args.dataset_path,
preprocessing=args.preprocessing
)
# training_dataset.build_dataset()
training_dataloader = DataLoader(
training_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers
)
# Define epoch function
def process_epoch(
epoch_idx,
model, loss_function, optimizer, dataloader, device,
log_file, args, train=True, plot_path=None
):
epoch_losses = []
torch.set_grad_enabled(train)
progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
for batch_idx, (batch,method) in progress_bar:
if train:
optimizer.zero_grad()
batch['train'] = train
batch['epoch_idx'] = epoch_idx
batch['batch_idx'] = batch_idx
batch['batch_size'] = args.batch_size
batch['preprocessing'] = args.preprocessing
batch['log_interval'] = args.log_interval
try:
loss = loss_function[method](model, batch, device, plot=args.plot, plot_path=plot_path)
except NoGradientError:
# print("failed")
continue
current_loss = loss.data.cpu().numpy()[0]
epoch_losses.append(current_loss)
progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses)))
if batch_idx % args.log_interval == 0:
log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % (
'train' if train else 'valid',
epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses)
))
if train:
loss.backward()
optimizer.step()
log_file.write('[%s] epoch %d - avg_loss: %f\n' % (
'train' if train else 'valid',
epoch_idx,
np.mean(epoch_losses)
))
log_file.flush()
return np.mean(epoch_losses)
# Create the checkpoint directory
checkpoint_directory = os.path.join(args.checkpoint_directory, args.checkpoint_prefix)
if os.path.isdir(checkpoint_directory):
print('[Warning] Checkpoint directory already exists.')
else:
os.makedirs(checkpoint_directory, exist_ok=True)
# Open the log file for writing
log_file = os.path.join(checkpoint_directory,args.log_file)
if os.path.exists(log_file):
print('[Warning] Log file already exists.')
log_file = open(log_file, 'a+')
# Create the folders for plotting if need be
plot_path=None
if args.plot:
plot_path = os.path.join(checkpoint_directory,'train_vis')
if os.path.isdir(plot_path):
print('[Warning] Plotting directory already exists.')
else:
os.makedirs(plot_path, exist_ok=True)
# Initialize the history
train_loss_history = []
validation_loss_history = []
if args.use_validation:
min_validation_loss = process_epoch(
0,
model, [orig_loss, ipr_loss], optimizer, validation_dataloader, device,
log_file, args,
train=False
)
# Start the training
for epoch_idx in range(1, args.num_epochs + 1):
# Process epoch
train_loss_history.append(
process_epoch(
epoch_idx,
model, [orig_loss, ipr_loss], optimizer, training_dataloader, device,
log_file, args, train=True, plot_path=plot_path
)
)
if args.use_validation:
validation_loss_history.append(
process_epoch(
epoch_idx,
model, [orig_loss, ipr_loss], optimizer, validation_dataloader, device,
log_file, args,
train=False
)
)
# Save the current checkpoint
checkpoint_path = os.path.join(
checkpoint_directory,
'%02d.pth' % (epoch_idx)
)
checkpoint = {
'args': args,
'epoch_idx': epoch_idx,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'train_loss_history': train_loss_history,
'validation_loss_history': validation_loss_history
}
torch.save(checkpoint, checkpoint_path)
if (
args.use_validation and
validation_loss_history[-1] < min_validation_loss
):
min_validation_loss = validation_loss_history[-1]
best_checkpoint_path = os.path.join(
checkpoint_directory,
'%s.best.pth' % args.checkpoint_prefix
)
shutil.copy(checkpoint_path, best_checkpoint_path)
# Close the log file
log_file.close()