from __future__ import absolute_import import numpy as np import torch from torch import nn from collections import OrderedDict from torch.autograd import Variable from scipy.ndimage import zoom from tqdm import tqdm import lpips import os class Trainer(): def name(self): return self.model_name def initialize(self, model='lpips', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, use_gpu=True, printNet=False, spatial=False, is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): ''' INPUTS model - ['lpips'] for linearly calibrated network ['baseline'] for off-the-shelf network ['L2'] for L2 distance in Lab colorspace ['SSIM'] for ssim in RGB colorspace net - ['squeeze','alex','vgg'] model_path - if None, will look in weights/[NET_NAME].pth colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM use_gpu - bool - whether or not to use a GPU printNet - bool - whether or not to print network architecture out spatial - bool - whether to output an array containing varying distances across spatial dimensions is_train - bool - [True] for training mode lr - float - initial learning rate beta1 - float - initial momentum term for adam version - 0.1 for latest, 0.0 was original (with a bug) gpu_ids - int array - [0] by default, gpus to use ''' self.use_gpu = use_gpu self.gpu_ids = gpu_ids self.model = model self.net = net self.is_train = is_train self.spatial = spatial self.model_name = '%s [%s]'%(model,net) if(self.model == 'lpips'): # pretrained net + linear layer self.net = lpips.LPIPS(pretrained=not is_train, net=net, version=version, lpips=True, spatial=spatial, pnet_rand=pnet_rand, pnet_tune=pnet_tune, use_dropout=True, model_path=model_path, eval_mode=False) elif(self.model=='baseline'): # pretrained network self.net = lpips.LPIPS(pnet_rand=pnet_rand, net=net, lpips=False) elif(self.model in ['L2','l2']): self.net = lpips.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing self.model_name = 'L2' elif(self.model in ['DSSIM','dssim','SSIM','ssim']): self.net = lpips.DSSIM(use_gpu=use_gpu,colorspace=colorspace) self.model_name = 'SSIM' else: raise ValueError("Model [%s] not recognized." % self.model) self.parameters = list(self.net.parameters()) if self.is_train: # training mode # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) self.rankLoss = lpips.BCERankingLoss() self.parameters += list(self.rankLoss.net.parameters()) self.lr = lr self.old_lr = lr self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) else: # test mode self.net.eval() if(use_gpu): self.net.to(gpu_ids[0]) self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) if(self.is_train): self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 if(printNet): print('---------- Networks initialized -------------') networks.print_network(self.net) print('-----------------------------------------------') def forward(self, in0, in1, retPerLayer=False): ''' Function computes the distance between image patches in0 and in1 INPUTS in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] OUTPUT computed distances between in0 and in1 ''' return self.net.forward(in0, in1, retPerLayer=retPerLayer) # ***** TRAINING FUNCTIONS ***** def optimize_parameters(self): self.forward_train() self.optimizer_net.zero_grad() self.backward_train() self.optimizer_net.step() self.clamp_weights() def clamp_weights(self): for module in self.net.modules(): if(hasattr(module, 'weight') and module.kernel_size==(1,1)): module.weight.data = torch.clamp(module.weight.data,min=0) def set_input(self, data): self.input_ref = data['ref'] self.input_p0 = data['p0'] self.input_p1 = data['p1'] self.input_judge = data['judge'] if(self.use_gpu): self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) self.var_ref = Variable(self.input_ref,requires_grad=True) self.var_p0 = Variable(self.input_p0,requires_grad=True) self.var_p1 = Variable(self.input_p1,requires_grad=True) def forward_train(self): # run forward pass self.d0 = self.forward(self.var_ref, self.var_p0) self.d1 = self.forward(self.var_ref, self.var_p1) self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) return self.loss_total def backward_train(self): torch.mean(self.loss_total).backward() def compute_accuracy(self,d0,d1,judge): ''' d0, d1 are Variables, judge is a Tensor ''' d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) self.old_lr = lr def get_image_paths(self): return self.image_paths def save_done(self, flag=False): np.save(os.path.join(self.save_dir, 'done_flag'),flag) np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') def score_2afc_dataset(data_loader, func, name=''): ''' Function computes Two Alternative Forced Choice (2AFC) score using distance function 'func' in dataset 'data_loader' INPUTS data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside func - callable distance function - calling d=func(in0,in1) should take 2 pytorch tensors with shape Nx3xXxY, and return numpy array of length N OUTPUTS [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators [1] - dictionary with following elements d0s,d1s - N arrays containing distances between reference patch to perturbed patches gts - N array in [0,1], preferred patch selected by human evaluators (closer to "0" for left patch p0, "1" for right patch p1, "0.6" means 60pct people preferred right patch, 40pct preferred left) scores - N array in [0,1], corresponding to what percentage function agreed with humans CONSTS N - number of test triplets in data_loader ''' d0s = [] d1s = [] gts = [] for data in tqdm(data_loader.load_data(), desc=name): d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() gts+=data['judge'].cpu().numpy().flatten().tolist() d0s = np.array(d0s) d1s = np.array(d1s) gts = np.array(gts) scores = (d0s