# Copyright 2019-present NAVER Corp. # CC BY-NC-SA 3.0 # Available only for non-commercial use import os, pdb import torch import torch.optim as optim from tools import common, trainer from tools.dataloader import * from nets.patchnet import * from nets.losses import * default_net = "Quad_L2Net_ConfCFS()" toy_db_debug = """SyntheticPairDataset( ImgFolder('imgs'), 'RandomScale(256,1024,can_upscale=True)', 'RandomTilting(0.5), PixelNoise(25)')""" db_web_images = """SyntheticPairDataset( web_images, 'RandomScale(256,1024,can_upscale=True)', 'RandomTilting(0.5), PixelNoise(25)')""" db_aachen_images = """SyntheticPairDataset( aachen_db_images, 'RandomScale(256,1024,can_upscale=True)', 'RandomTilting(0.5), PixelNoise(25)')""" db_aachen_style_transfer = """TransformedPairs( aachen_style_transfer_pairs, 'RandomScale(256,1024,can_upscale=True), RandomTilting(0.5), PixelNoise(25)')""" db_aachen_flow = "aachen_flow_pairs" data_sources = dict( D=toy_db_debug, W=db_web_images, A=db_aachen_images, F=db_aachen_flow, S=db_aachen_style_transfer, ) default_dataloader = """PairLoader(CatPairDataset(`data`), scale = 'RandomScale(256,1024,can_upscale=True)', distort = 'ColorJitter(0.2,0.2,0.2,0.1)', crop = 'RandomCrop(192)')""" default_sampler = """NghSampler2(ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16, subd_neg=-8,maxpool_pos=True)""" default_loss = """MultiLoss( 1, ReliabilityLoss(`sampler`, base=0.5, nq=20), 1, CosimLoss(N=`N`), 1, PeakyLoss(N=`N`))""" class MyTrainer(trainer.Trainer): """This class implements the network training. Below is the function I need to overload to explain how to do the backprop. """ def forward_backward(self, inputs): output = self.net(imgs=[inputs.pop("img1"), inputs.pop("img2")]) allvars = dict(inputs, **output) loss, details = self.loss_func(**allvars) if torch.is_grad_enabled(): loss.backward() return loss, details if __name__ == "__main__": import argparse parser = argparse.ArgumentParser("Train R2D2") parser.add_argument("--data-loader", type=str, default=default_dataloader) parser.add_argument( "--train-data", type=str, default=list("WASF"), nargs="+", choices=set(data_sources.keys()), ) parser.add_argument( "--net", type=str, default=default_net, help="network architecture" ) parser.add_argument( "--pretrained", type=str, default="", help="pretrained model path" ) parser.add_argument( "--save-path", type=str, required=True, help="model save_path path" ) parser.add_argument("--loss", type=str, default=default_loss, help="loss function") parser.add_argument( "--sampler", type=str, default=default_sampler, help="AP sampler" ) parser.add_argument( "--N", type=int, default=16, help="patch size for repeatability" ) parser.add_argument( "--epochs", type=int, default=25, help="number of training epochs" ) parser.add_argument("--batch-size", "--bs", type=int, default=8, help="batch size") parser.add_argument("--learning-rate", "--lr", type=str, default=1e-4) parser.add_argument("--weight-decay", "--wd", type=float, default=5e-4) parser.add_argument( "--threads", type=int, default=8, help="number of worker threads" ) parser.add_argument("--gpu", type=int, nargs="+", default=[0], help="-1 for CPU") args = parser.parse_args() iscuda = common.torch_set_gpu(args.gpu) common.mkdir_for(args.save_path) # Create data loader from datasets import * db = [data_sources[key] for key in args.train_data] db = eval(args.data_loader.replace("`data`", ",".join(db)).replace("\n", "")) print("Training image database =", db) loader = threaded_loader(db, iscuda, args.threads, args.batch_size, shuffle=True) # create network print("\n>> Creating net = " + args.net) net = eval(args.net) print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )") # initialization if args.pretrained: checkpoint = torch.load(args.pretrained, lambda a, b: a) net.load_pretrained(checkpoint["state_dict"]) # create losses loss = args.loss.replace("`sampler`", args.sampler).replace("`N`", str(args.N)) print("\n>> Creating loss = " + loss) loss = eval(loss.replace("\n", "")) # create optimizer optimizer = optim.Adam( [p for p in net.parameters() if p.requires_grad], lr=args.learning_rate, weight_decay=args.weight_decay, ) train = MyTrainer(net, loader, loss, optimizer) if iscuda: train = train.cuda() # Training loop # for epoch in range(args.epochs): print(f"\n>> Starting epoch {epoch}...") train() print(f"\n>> Saving model to {args.save_path}") torch.save({"net": args.net, "state_dict": net.state_dict()}, args.save_path)