Spaces:
Running
Running
import argparse | |
import numpy as np | |
import os | |
import shutil | |
import torch | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
import warnings | |
from lib.dataset import MegaDepthDataset | |
from lib.exceptions import NoGradientError | |
from lib.loss import loss_function | |
from lib.model import D2Net | |
# CUDA | |
use_cuda = torch.cuda.is_available() | |
device = torch.device("cuda:0" if use_cuda else "cpu") | |
# Seed | |
torch.manual_seed(1) | |
if use_cuda: | |
torch.cuda.manual_seed(1) | |
np.random.seed(1) | |
# Argument parsing | |
parser = argparse.ArgumentParser(description="Training script") | |
parser.add_argument( | |
"--dataset_path", type=str, required=True, help="path to the dataset" | |
) | |
parser.add_argument( | |
"--scene_info_path", type=str, required=True, help="path to the processed scenes" | |
) | |
parser.add_argument( | |
"--preprocessing", | |
type=str, | |
default="caffe", | |
help="image preprocessing (caffe or torch)", | |
) | |
parser.add_argument( | |
"--model_file", type=str, default="models/d2_ots.pth", help="path to the full model" | |
) | |
parser.add_argument( | |
"--num_epochs", type=int, default=10, help="number of training epochs" | |
) | |
parser.add_argument("--lr", type=float, default=1e-3, help="initial learning rate") | |
parser.add_argument("--batch_size", type=int, default=1, help="batch size") | |
parser.add_argument( | |
"--num_workers", type=int, default=4, help="number of workers for data loading" | |
) | |
parser.add_argument( | |
"--use_validation", | |
dest="use_validation", | |
action="store_true", | |
help="use the validation split", | |
) | |
parser.set_defaults(use_validation=False) | |
parser.add_argument( | |
"--log_interval", type=int, default=250, help="loss logging interval" | |
) | |
parser.add_argument("--log_file", type=str, default="log.txt", help="loss logging file") | |
parser.add_argument( | |
"--plot", dest="plot", action="store_true", help="plot training pairs" | |
) | |
parser.set_defaults(plot=False) | |
parser.add_argument( | |
"--checkpoint_directory", | |
type=str, | |
default="checkpoints", | |
help="directory for training checkpoints", | |
) | |
parser.add_argument( | |
"--checkpoint_prefix", | |
type=str, | |
default="d2", | |
help="prefix for training checkpoints", | |
) | |
args = parser.parse_args() | |
print(args) | |
# Create the folders for plotting if need be | |
if args.plot: | |
plot_path = "train_vis" | |
if os.path.isdir(plot_path): | |
print("[Warning] Plotting directory already exists.") | |
else: | |
os.mkdir(plot_path) | |
# Creating CNN model | |
model = D2Net(model_file=args.model_file, use_cuda=use_cuda) | |
# Optimizer | |
optimizer = optim.Adam( | |
filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr | |
) | |
# Dataset | |
if args.use_validation: | |
validation_dataset = MegaDepthDataset( | |
scene_list_path="megadepth_utils/valid_scenes.txt", | |
scene_info_path=args.scene_info_path, | |
base_path=args.dataset_path, | |
train=False, | |
preprocessing=args.preprocessing, | |
pairs_per_scene=25, | |
) | |
validation_dataloader = DataLoader( | |
validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers | |
) | |
training_dataset = MegaDepthDataset( | |
scene_list_path="megadepth_utils/train_scenes.txt", | |
scene_info_path=args.scene_info_path, | |
base_path=args.dataset_path, | |
preprocessing=args.preprocessing, | |
) | |
training_dataloader = DataLoader( | |
training_dataset, batch_size=args.batch_size, num_workers=args.num_workers | |
) | |
# Define epoch function | |
def process_epoch( | |
epoch_idx, | |
model, | |
loss_function, | |
optimizer, | |
dataloader, | |
device, | |
log_file, | |
args, | |
train=True, | |
): | |
epoch_losses = [] | |
torch.set_grad_enabled(train) | |
progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) | |
for batch_idx, batch in progress_bar: | |
if train: | |
optimizer.zero_grad() | |
batch["train"] = train | |
batch["epoch_idx"] = epoch_idx | |
batch["batch_idx"] = batch_idx | |
batch["batch_size"] = args.batch_size | |
batch["preprocessing"] = args.preprocessing | |
batch["log_interval"] = args.log_interval | |
try: | |
loss = loss_function(model, batch, device, plot=args.plot) | |
except NoGradientError: | |
continue | |
current_loss = loss.data.cpu().numpy()[0] | |
epoch_losses.append(current_loss) | |
progress_bar.set_postfix(loss=("%.4f" % np.mean(epoch_losses))) | |
if batch_idx % args.log_interval == 0: | |
log_file.write( | |
"[%s] epoch %d - batch %d / %d - avg_loss: %f\n" | |
% ( | |
"train" if train else "valid", | |
epoch_idx, | |
batch_idx, | |
len(dataloader), | |
np.mean(epoch_losses), | |
) | |
) | |
if train: | |
loss.backward() | |
optimizer.step() | |
log_file.write( | |
"[%s] epoch %d - avg_loss: %f\n" | |
% ("train" if train else "valid", epoch_idx, np.mean(epoch_losses)) | |
) | |
log_file.flush() | |
return np.mean(epoch_losses) | |
# Create the checkpoint directory | |
if os.path.isdir(args.checkpoint_directory): | |
print("[Warning] Checkpoint directory already exists.") | |
else: | |
os.mkdir(args.checkpoint_directory) | |
# Open the log file for writing | |
if os.path.exists(args.log_file): | |
print("[Warning] Log file already exists.") | |
log_file = open(args.log_file, "a+") | |
# Initialize the history | |
train_loss_history = [] | |
validation_loss_history = [] | |
if args.use_validation: | |
validation_dataset.build_dataset() | |
min_validation_loss = process_epoch( | |
0, | |
model, | |
loss_function, | |
optimizer, | |
validation_dataloader, | |
device, | |
log_file, | |
args, | |
train=False, | |
) | |
# Start the training | |
for epoch_idx in range(1, args.num_epochs + 1): | |
# Process epoch | |
training_dataset.build_dataset() | |
train_loss_history.append( | |
process_epoch( | |
epoch_idx, | |
model, | |
loss_function, | |
optimizer, | |
training_dataloader, | |
device, | |
log_file, | |
args, | |
) | |
) | |
if args.use_validation: | |
validation_loss_history.append( | |
process_epoch( | |
epoch_idx, | |
model, | |
loss_function, | |
optimizer, | |
validation_dataloader, | |
device, | |
log_file, | |
args, | |
train=False, | |
) | |
) | |
# Save the current checkpoint | |
checkpoint_path = os.path.join( | |
args.checkpoint_directory, "%s.%02d.pth" % (args.checkpoint_prefix, epoch_idx) | |
) | |
checkpoint = { | |
"args": args, | |
"epoch_idx": epoch_idx, | |
"model": model.state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"train_loss_history": train_loss_history, | |
"validation_loss_history": validation_loss_history, | |
} | |
torch.save(checkpoint, checkpoint_path) | |
if args.use_validation and validation_loss_history[-1] < min_validation_loss: | |
min_validation_loss = validation_loss_history[-1] | |
best_checkpoint_path = os.path.join( | |
args.checkpoint_directory, "%s.best.pth" % args.checkpoint_prefix | |
) | |
shutil.copy(checkpoint_path, best_checkpoint_path) | |
# Close the log file | |
log_file.close() | |