|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
from skimage.metrics import structural_similarity |
|
import torch |
|
|
|
from saicinpainting.utils import get_shape |
|
|
|
|
|
class PerceptualLoss(torch.nn.Module): |
|
def __init__(self, model='net-lin', net='alex', colorspace='rgb', model_path=None, spatial=False, use_gpu=True): |
|
|
|
|
|
super(PerceptualLoss, self).__init__() |
|
self.use_gpu = use_gpu |
|
self.spatial = spatial |
|
self.model = DistModel() |
|
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, |
|
model_path=model_path, spatial=self.spatial) |
|
|
|
def forward(self, pred, target, normalize=True): |
|
""" |
|
Pred and target are Variables. |
|
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] |
|
If normalize is False, assumes the images are already between [-1,+1] |
|
Inputs pred and target are Nx3xHxW |
|
Output pytorch Variable N long |
|
""" |
|
|
|
if normalize: |
|
target = 2 * target - 1 |
|
pred = 2 * pred - 1 |
|
|
|
return self.model(target, pred) |
|
|
|
|
|
def normalize_tensor(in_feat, eps=1e-10): |
|
norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) |
|
return in_feat / (norm_factor + eps) |
|
|
|
|
|
def l2(p0, p1, range=255.): |
|
return .5 * np.mean((p0 / range - p1 / range) ** 2) |
|
|
|
|
|
def psnr(p0, p1, peak=255.): |
|
return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2)) |
|
|
|
|
|
def dssim(p0, p1, range=255.): |
|
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. |
|
|
|
|
|
def rgb2lab(in_img, mean_cent=False): |
|
from skimage import color |
|
img_lab = color.rgb2lab(in_img) |
|
if (mean_cent): |
|
img_lab[:, :, 0] = img_lab[:, :, 0] - 50 |
|
return img_lab |
|
|
|
|
|
def tensor2np(tensor_obj): |
|
|
|
return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) |
|
|
|
|
|
def np2tensor(np_obj): |
|
|
|
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) |
|
|
|
|
|
def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): |
|
|
|
from skimage import color |
|
|
|
img = tensor2im(image_tensor) |
|
img_lab = color.rgb2lab(img) |
|
if (mc_only): |
|
img_lab[:, :, 0] = img_lab[:, :, 0] - 50 |
|
if (to_norm and not mc_only): |
|
img_lab[:, :, 0] = img_lab[:, :, 0] - 50 |
|
img_lab = img_lab / 100. |
|
|
|
return np2tensor(img_lab) |
|
|
|
|
|
def tensorlab2tensor(lab_tensor, return_inbnd=False): |
|
from skimage import color |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
lab = tensor2np(lab_tensor) * 100. |
|
lab[:, :, 0] = lab[:, :, 0] + 50 |
|
|
|
rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1) |
|
if (return_inbnd): |
|
|
|
lab_back = color.rgb2lab(rgb_back.astype('uint8')) |
|
mask = 1. * np.isclose(lab_back, lab, atol=2.) |
|
mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) |
|
return (im2tensor(rgb_back), mask) |
|
else: |
|
return im2tensor(rgb_back) |
|
|
|
|
|
def rgb2lab(input): |
|
from skimage import color |
|
return color.rgb2lab(input / 255.) |
|
|
|
|
|
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): |
|
image_numpy = image_tensor[0].cpu().float().numpy() |
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor |
|
return image_numpy.astype(imtype) |
|
|
|
|
|
def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): |
|
return torch.Tensor((image / factor - cent) |
|
[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) |
|
|
|
|
|
def tensor2vec(vector_tensor): |
|
return vector_tensor.data.cpu().numpy()[:, :, 0, 0] |
|
|
|
|
|
def voc_ap(rec, prec, use_07_metric=False): |
|
""" ap = voc_ap(rec, prec, [use_07_metric]) |
|
Compute VOC AP given precision and recall. |
|
If use_07_metric is true, uses the |
|
VOC 07 11 point method (default:False). |
|
""" |
|
if use_07_metric: |
|
|
|
ap = 0. |
|
for t in np.arange(0., 1.1, 0.1): |
|
if np.sum(rec >= t) == 0: |
|
p = 0 |
|
else: |
|
p = np.max(prec[rec >= t]) |
|
ap = ap + p / 11. |
|
else: |
|
|
|
|
|
mrec = np.concatenate(([0.], rec, [1.])) |
|
mpre = np.concatenate(([0.], prec, [0.])) |
|
|
|
|
|
for i in range(mpre.size - 1, 0, -1): |
|
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) |
|
|
|
|
|
|
|
i = np.where(mrec[1:] != mrec[:-1])[0] |
|
|
|
|
|
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) |
|
return ap |
|
|
|
|
|
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): |
|
|
|
image_numpy = image_tensor[0].cpu().float().numpy() |
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor |
|
return image_numpy.astype(imtype) |
|
|
|
|
|
def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): |
|
|
|
return torch.Tensor((image / factor - cent) |
|
[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseModel(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def name(self): |
|
return 'BaseModel' |
|
|
|
def initialize(self, use_gpu=True): |
|
self.use_gpu = use_gpu |
|
|
|
def forward(self): |
|
pass |
|
|
|
def get_image_paths(self): |
|
pass |
|
|
|
def optimize_parameters(self): |
|
pass |
|
|
|
def get_current_visuals(self): |
|
return self.input |
|
|
|
def get_current_errors(self): |
|
return {} |
|
|
|
def save(self, label): |
|
pass |
|
|
|
|
|
def save_network(self, network, path, network_label, epoch_label): |
|
save_filename = '%s_net_%s.pth' % (epoch_label, network_label) |
|
save_path = os.path.join(path, save_filename) |
|
torch.save(network.state_dict(), save_path) |
|
|
|
|
|
def load_network(self, network, network_label, epoch_label): |
|
save_filename = '%s_net_%s.pth' % (epoch_label, network_label) |
|
save_path = os.path.join(self.save_dir, save_filename) |
|
print('Loading network from %s' % save_path) |
|
network.load_state_dict(torch.load(save_path, map_location='cpu')) |
|
|
|
def update_learning_rate(): |
|
pass |
|
|
|
def get_image_paths(self): |
|
return self.image_paths |
|
|
|
def save_done(self, flag=False): |
|
np.save(os.path.join(self.save_dir, 'done_flag'), flag) |
|
np.savetxt(os.path.join(self.save_dir, 'done_flag'), [flag, ], fmt='%i') |
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from collections import OrderedDict |
|
from scipy.ndimage import zoom |
|
from tqdm import tqdm |
|
|
|
|
|
class DistModel(BaseModel): |
|
def name(self): |
|
return self.model_name |
|
|
|
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, |
|
model_path=None, |
|
use_gpu=True, printNet=False, spatial=False, |
|
is_train=False, lr=.0001, beta1=0.5, version='0.1'): |
|
''' |
|
INPUTS |
|
model - ['net-lin'] for linearly calibrated network |
|
['net'] for off-the-shelf network |
|
['L2'] for L2 distance in Lab colorspace |
|
['SSIM'] for ssim in RGB colorspace |
|
net - ['squeeze','alex','vgg'] |
|
model_path - if None, will look in weights/[NET_NAME].pth |
|
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM |
|
use_gpu - bool - whether or not to use a GPU |
|
printNet - bool - whether or not to print network architecture out |
|
spatial - bool - whether to output an array containing varying distances across spatial dimensions |
|
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). |
|
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. |
|
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). |
|
is_train - bool - [True] for training mode |
|
lr - float - initial learning rate |
|
beta1 - float - initial momentum term for adam |
|
version - 0.1 for latest, 0.0 was original (with a bug) |
|
''' |
|
BaseModel.initialize(self, use_gpu=use_gpu) |
|
|
|
self.model = model |
|
self.net = net |
|
self.is_train = is_train |
|
self.spatial = spatial |
|
self.model_name = '%s [%s]' % (model, net) |
|
|
|
if (self.model == 'net-lin'): |
|
self.net = PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, |
|
use_dropout=True, spatial=spatial, version=version, lpips=True) |
|
kw = dict(map_location='cpu') |
|
if (model_path is None): |
|
import inspect |
|
model_path = os.path.abspath( |
|
os.path.join(os.path.dirname(__file__), '..', '..', '..', 'models', 'lpips_models', f'{net}.pth')) |
|
|
|
if (not is_train): |
|
self.net.load_state_dict(torch.load(model_path, **kw), strict=False) |
|
|
|
elif (self.model == 'net'): |
|
self.net = PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) |
|
elif (self.model in ['L2', 'l2']): |
|
self.net = L2(use_gpu=use_gpu, colorspace=colorspace) |
|
self.model_name = 'L2' |
|
elif (self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']): |
|
self.net = DSSIM(use_gpu=use_gpu, colorspace=colorspace) |
|
self.model_name = 'SSIM' |
|
else: |
|
raise ValueError("Model [%s] not recognized." % self.model) |
|
|
|
self.trainable_parameters = list(self.net.parameters()) |
|
|
|
if self.is_train: |
|
|
|
self.rankLoss = BCERankingLoss() |
|
self.trainable_parameters += list(self.rankLoss.net.parameters()) |
|
self.lr = lr |
|
self.old_lr = lr |
|
self.optimizer_net = torch.optim.Adam(self.trainable_parameters, lr=lr, betas=(beta1, 0.999)) |
|
else: |
|
self.net.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (printNet): |
|
print('---------- Networks initialized -------------') |
|
print_network(self.net) |
|
print('-----------------------------------------------') |
|
|
|
def forward(self, in0, in1, retPerLayer=False): |
|
''' Function computes the distance between image patches in0 and in1 |
|
INPUTS |
|
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] |
|
OUTPUT |
|
computed distances between in0 and in1 |
|
''' |
|
|
|
return self.net(in0, in1, retPerLayer=retPerLayer) |
|
|
|
|
|
def optimize_parameters(self): |
|
self.forward_train() |
|
self.optimizer_net.zero_grad() |
|
self.backward_train() |
|
self.optimizer_net.step() |
|
self.clamp_weights() |
|
|
|
def clamp_weights(self): |
|
for module in self.net.modules(): |
|
if (hasattr(module, 'weight') and module.kernel_size == (1, 1)): |
|
module.weight.data = torch.clamp(module.weight.data, min=0) |
|
|
|
def set_input(self, data): |
|
self.input_ref = data['ref'] |
|
self.input_p0 = data['p0'] |
|
self.input_p1 = data['p1'] |
|
self.input_judge = data['judge'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_train(self): |
|
|
|
|
|
|
|
assert False, "We shoud've not get here when using LPIPS as a metric" |
|
|
|
self.d0 = self(self.var_ref, self.var_p0) |
|
self.d1 = self(self.var_ref, self.var_p1) |
|
self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge) |
|
|
|
self.var_judge = Variable(1. * self.input_judge).view(self.d0.size()) |
|
|
|
self.loss_total = self.rankLoss(self.d0, self.d1, self.var_judge * 2. - 1.) |
|
|
|
return self.loss_total |
|
|
|
def backward_train(self): |
|
torch.mean(self.loss_total).backward() |
|
|
|
def compute_accuracy(self, d0, d1, judge): |
|
''' d0, d1 are Variables, judge is a Tensor ''' |
|
d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten() |
|
judge_per = judge.cpu().numpy().flatten() |
|
return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per) |
|
|
|
def get_current_errors(self): |
|
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()), |
|
('acc_r', self.acc_r)]) |
|
|
|
for key in retDict.keys(): |
|
retDict[key] = np.mean(retDict[key]) |
|
|
|
return retDict |
|
|
|
def get_current_visuals(self): |
|
zoom_factor = 256 / self.var_ref.data.size()[2] |
|
|
|
ref_img = tensor2im(self.var_ref.data) |
|
p0_img = tensor2im(self.var_p0.data) |
|
p1_img = tensor2im(self.var_p1.data) |
|
|
|
ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0) |
|
p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0) |
|
p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0) |
|
|
|
return OrderedDict([('ref', ref_img_vis), |
|
('p0', p0_img_vis), |
|
('p1', p1_img_vis)]) |
|
|
|
def save(self, path, label): |
|
if (self.use_gpu): |
|
self.save_network(self.net.module, path, '', label) |
|
else: |
|
self.save_network(self.net, path, '', label) |
|
self.save_network(self.rankLoss.net, path, 'rank', label) |
|
|
|
def update_learning_rate(self, nepoch_decay): |
|
lrd = self.lr / nepoch_decay |
|
lr = self.old_lr - lrd |
|
|
|
for param_group in self.optimizer_net.param_groups: |
|
param_group['lr'] = lr |
|
|
|
print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr)) |
|
self.old_lr = lr |
|
|
|
|
|
def score_2afc_dataset(data_loader, func, name=''): |
|
''' Function computes Two Alternative Forced Choice (2AFC) score using |
|
distance function 'func' in dataset 'data_loader' |
|
INPUTS |
|
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside |
|
func - callable distance function - calling d=func(in0,in1) should take 2 |
|
pytorch tensors with shape Nx3xXxY, and return numpy array of length N |
|
OUTPUTS |
|
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators |
|
[1] - dictionary with following elements |
|
d0s,d1s - N arrays containing distances between reference patch to perturbed patches |
|
gts - N array in [0,1], preferred patch selected by human evaluators |
|
(closer to "0" for left patch p0, "1" for right patch p1, |
|
"0.6" means 60pct people preferred right patch, 40pct preferred left) |
|
scores - N array in [0,1], corresponding to what percentage function agreed with humans |
|
CONSTS |
|
N - number of test triplets in data_loader |
|
''' |
|
|
|
d0s = [] |
|
d1s = [] |
|
gts = [] |
|
|
|
for data in tqdm(data_loader.load_data(), desc=name): |
|
d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist() |
|
d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist() |
|
gts += data['judge'].cpu().numpy().flatten().tolist() |
|
|
|
d0s = np.array(d0s) |
|
d1s = np.array(d1s) |
|
gts = np.array(gts) |
|
scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5 |
|
|
|
return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores)) |
|
|
|
|
|
def score_jnd_dataset(data_loader, func, name=''): |
|
''' Function computes JND score using distance function 'func' in dataset 'data_loader' |
|
INPUTS |
|
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside |
|
func - callable distance function - calling d=func(in0,in1) should take 2 |
|
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N |
|
OUTPUTS |
|
[0] - JND score in [0,1], mAP score (area under precision-recall curve) |
|
[1] - dictionary with following elements |
|
ds - N array containing distances between two patches shown to human evaluator |
|
sames - N array containing fraction of people who thought the two patches were identical |
|
CONSTS |
|
N - number of test triplets in data_loader |
|
''' |
|
|
|
ds = [] |
|
gts = [] |
|
|
|
for data in tqdm(data_loader.load_data(), desc=name): |
|
ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist() |
|
gts += data['same'].cpu().numpy().flatten().tolist() |
|
|
|
sames = np.array(gts) |
|
ds = np.array(ds) |
|
|
|
sorted_inds = np.argsort(ds) |
|
ds_sorted = ds[sorted_inds] |
|
sames_sorted = sames[sorted_inds] |
|
|
|
TPs = np.cumsum(sames_sorted) |
|
FPs = np.cumsum(1 - sames_sorted) |
|
FNs = np.sum(sames_sorted) - TPs |
|
|
|
precs = TPs / (TPs + FPs) |
|
recs = TPs / (TPs + FNs) |
|
score = voc_ap(recs, precs) |
|
|
|
return (score, dict(ds=ds, sames=sames)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
from torch.autograd import Variable |
|
import numpy as np |
|
|
|
|
|
def spatial_average(in_tens, keepdim=True): |
|
return in_tens.mean([2, 3], keepdim=keepdim) |
|
|
|
|
|
def upsample(in_tens, out_H=64): |
|
in_H = in_tens.shape[2] |
|
scale_factor = 1. * out_H / in_H |
|
|
|
return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens) |
|
|
|
|
|
|
|
class PNetLin(nn.Module): |
|
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, |
|
version='0.1', lpips=True): |
|
super(PNetLin, self).__init__() |
|
|
|
self.pnet_type = pnet_type |
|
self.pnet_tune = pnet_tune |
|
self.pnet_rand = pnet_rand |
|
self.spatial = spatial |
|
self.lpips = lpips |
|
self.version = version |
|
self.scaling_layer = ScalingLayer() |
|
|
|
if (self.pnet_type in ['vgg', 'vgg16']): |
|
net_type = vgg16 |
|
self.chns = [64, 128, 256, 512, 512] |
|
elif (self.pnet_type == 'alex'): |
|
net_type = alexnet |
|
self.chns = [64, 192, 384, 256, 256] |
|
elif (self.pnet_type == 'squeeze'): |
|
net_type = squeezenet |
|
self.chns = [64, 128, 256, 384, 384, 512, 512] |
|
self.L = len(self.chns) |
|
|
|
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) |
|
|
|
if (lpips): |
|
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) |
|
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) |
|
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) |
|
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) |
|
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) |
|
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] |
|
if (self.pnet_type == 'squeeze'): |
|
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) |
|
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) |
|
self.lins += [self.lin5, self.lin6] |
|
|
|
def forward(self, in0, in1, retPerLayer=False): |
|
|
|
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else ( |
|
in0, in1) |
|
outs0, outs1 = self.net(in0_input), self.net(in1_input) |
|
feats0, feats1, diffs = {}, {}, {} |
|
|
|
for kk in range(self.L): |
|
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) |
|
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 |
|
|
|
if (self.lpips): |
|
if (self.spatial): |
|
res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] |
|
else: |
|
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] |
|
else: |
|
if (self.spatial): |
|
res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] |
|
else: |
|
res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] |
|
|
|
val = res[0] |
|
for l in range(1, self.L): |
|
val += res[l] |
|
|
|
if (retPerLayer): |
|
return (val, res) |
|
else: |
|
return val |
|
|
|
|
|
class ScalingLayer(nn.Module): |
|
def __init__(self): |
|
super(ScalingLayer, self).__init__() |
|
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) |
|
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) |
|
|
|
def forward(self, inp): |
|
return (inp - self.shift) / self.scale |
|
|
|
|
|
class NetLinLayer(nn.Module): |
|
''' A single linear layer which does a 1x1 conv ''' |
|
|
|
def __init__(self, chn_in, chn_out=1, use_dropout=False): |
|
super(NetLinLayer, self).__init__() |
|
|
|
layers = [nn.Dropout(), ] if (use_dropout) else [] |
|
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] |
|
self.model = nn.Sequential(*layers) |
|
|
|
|
|
class Dist2LogitLayer(nn.Module): |
|
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' |
|
|
|
def __init__(self, chn_mid=32, use_sigmoid=True): |
|
super(Dist2LogitLayer, self).__init__() |
|
|
|
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ] |
|
layers += [nn.LeakyReLU(0.2, True), ] |
|
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ] |
|
layers += [nn.LeakyReLU(0.2, True), ] |
|
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ] |
|
if (use_sigmoid): |
|
layers += [nn.Sigmoid(), ] |
|
self.model = nn.Sequential(*layers) |
|
|
|
def forward(self, d0, d1, eps=0.1): |
|
return self.model(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1)) |
|
|
|
|
|
class BCERankingLoss(nn.Module): |
|
def __init__(self, chn_mid=32): |
|
super(BCERankingLoss, self).__init__() |
|
self.net = Dist2LogitLayer(chn_mid=chn_mid) |
|
|
|
self.loss = torch.nn.BCELoss() |
|
|
|
def forward(self, d0, d1, judge): |
|
per = (judge + 1.) / 2. |
|
self.logit = self.net(d0, d1) |
|
return self.loss(self.logit, per) |
|
|
|
|
|
|
|
class FakeNet(nn.Module): |
|
def __init__(self, use_gpu=True, colorspace='Lab'): |
|
super(FakeNet, self).__init__() |
|
self.use_gpu = use_gpu |
|
self.colorspace = colorspace |
|
|
|
|
|
class L2(FakeNet): |
|
|
|
def forward(self, in0, in1, retPerLayer=None): |
|
assert (in0.size()[0] == 1) |
|
|
|
if (self.colorspace == 'RGB'): |
|
(N, C, X, Y) = in0.size() |
|
value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), |
|
dim=3).view(N) |
|
return value |
|
elif (self.colorspace == 'Lab'): |
|
value = l2(tensor2np(tensor2tensorlab(in0.data, to_norm=False)), |
|
tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float') |
|
ret_var = Variable(torch.Tensor((value,))) |
|
|
|
|
|
return ret_var |
|
|
|
|
|
class DSSIM(FakeNet): |
|
|
|
def forward(self, in0, in1, retPerLayer=None): |
|
assert (in0.size()[0] == 1) |
|
|
|
if (self.colorspace == 'RGB'): |
|
value = dssim(1. * tensor2im(in0.data), 1. * tensor2im(in1.data), range=255.).astype('float') |
|
elif (self.colorspace == 'Lab'): |
|
value = dssim(tensor2np(tensor2tensorlab(in0.data, to_norm=False)), |
|
tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float') |
|
ret_var = Variable(torch.Tensor((value,))) |
|
|
|
|
|
return ret_var |
|
|
|
|
|
def print_network(net): |
|
num_params = 0 |
|
for param in net.parameters(): |
|
num_params += param.numel() |
|
print('Network', net) |
|
print('Total number of parameters: %d' % num_params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import namedtuple |
|
import torch |
|
from torchvision import models as tv |
|
|
|
|
|
class squeezenet(torch.nn.Module): |
|
def __init__(self, requires_grad=False, pretrained=True): |
|
super(squeezenet, self).__init__() |
|
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features |
|
self.slice1 = torch.nn.Sequential() |
|
self.slice2 = torch.nn.Sequential() |
|
self.slice3 = torch.nn.Sequential() |
|
self.slice4 = torch.nn.Sequential() |
|
self.slice5 = torch.nn.Sequential() |
|
self.slice6 = torch.nn.Sequential() |
|
self.slice7 = torch.nn.Sequential() |
|
self.N_slices = 7 |
|
for x in range(2): |
|
self.slice1.add_module(str(x), pretrained_features[x]) |
|
for x in range(2, 5): |
|
self.slice2.add_module(str(x), pretrained_features[x]) |
|
for x in range(5, 8): |
|
self.slice3.add_module(str(x), pretrained_features[x]) |
|
for x in range(8, 10): |
|
self.slice4.add_module(str(x), pretrained_features[x]) |
|
for x in range(10, 11): |
|
self.slice5.add_module(str(x), pretrained_features[x]) |
|
for x in range(11, 12): |
|
self.slice6.add_module(str(x), pretrained_features[x]) |
|
for x in range(12, 13): |
|
self.slice7.add_module(str(x), pretrained_features[x]) |
|
if not requires_grad: |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, X): |
|
h = self.slice1(X) |
|
h_relu1 = h |
|
h = self.slice2(h) |
|
h_relu2 = h |
|
h = self.slice3(h) |
|
h_relu3 = h |
|
h = self.slice4(h) |
|
h_relu4 = h |
|
h = self.slice5(h) |
|
h_relu5 = h |
|
h = self.slice6(h) |
|
h_relu6 = h |
|
h = self.slice7(h) |
|
h_relu7 = h |
|
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7']) |
|
out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) |
|
|
|
return out |
|
|
|
|
|
class alexnet(torch.nn.Module): |
|
def __init__(self, requires_grad=False, pretrained=True): |
|
super(alexnet, self).__init__() |
|
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features |
|
self.slice1 = torch.nn.Sequential() |
|
self.slice2 = torch.nn.Sequential() |
|
self.slice3 = torch.nn.Sequential() |
|
self.slice4 = torch.nn.Sequential() |
|
self.slice5 = torch.nn.Sequential() |
|
self.N_slices = 5 |
|
for x in range(2): |
|
self.slice1.add_module(str(x), alexnet_pretrained_features[x]) |
|
for x in range(2, 5): |
|
self.slice2.add_module(str(x), alexnet_pretrained_features[x]) |
|
for x in range(5, 8): |
|
self.slice3.add_module(str(x), alexnet_pretrained_features[x]) |
|
for x in range(8, 10): |
|
self.slice4.add_module(str(x), alexnet_pretrained_features[x]) |
|
for x in range(10, 12): |
|
self.slice5.add_module(str(x), alexnet_pretrained_features[x]) |
|
if not requires_grad: |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, X): |
|
h = self.slice1(X) |
|
h_relu1 = h |
|
h = self.slice2(h) |
|
h_relu2 = h |
|
h = self.slice3(h) |
|
h_relu3 = h |
|
h = self.slice4(h) |
|
h_relu4 = h |
|
h = self.slice5(h) |
|
h_relu5 = h |
|
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) |
|
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) |
|
|
|
return out |
|
|
|
|
|
class vgg16(torch.nn.Module): |
|
def __init__(self, requires_grad=False, pretrained=True): |
|
super(vgg16, self).__init__() |
|
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features |
|
self.slice1 = torch.nn.Sequential() |
|
self.slice2 = torch.nn.Sequential() |
|
self.slice3 = torch.nn.Sequential() |
|
self.slice4 = torch.nn.Sequential() |
|
self.slice5 = torch.nn.Sequential() |
|
self.N_slices = 5 |
|
for x in range(4): |
|
self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(4, 9): |
|
self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(9, 16): |
|
self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(16, 23): |
|
self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(23, 30): |
|
self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
|
if not requires_grad: |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, X): |
|
h = self.slice1(X) |
|
h_relu1_2 = h |
|
h = self.slice2(h) |
|
h_relu2_2 = h |
|
h = self.slice3(h) |
|
h_relu3_3 = h |
|
h = self.slice4(h) |
|
h_relu4_3 = h |
|
h = self.slice5(h) |
|
h_relu5_3 = h |
|
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) |
|
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) |
|
|
|
return out |
|
|
|
|
|
class resnet(torch.nn.Module): |
|
def __init__(self, requires_grad=False, pretrained=True, num=18): |
|
super(resnet, self).__init__() |
|
if (num == 18): |
|
self.net = tv.resnet18(pretrained=pretrained) |
|
elif (num == 34): |
|
self.net = tv.resnet34(pretrained=pretrained) |
|
elif (num == 50): |
|
self.net = tv.resnet50(pretrained=pretrained) |
|
elif (num == 101): |
|
self.net = tv.resnet101(pretrained=pretrained) |
|
elif (num == 152): |
|
self.net = tv.resnet152(pretrained=pretrained) |
|
self.N_slices = 5 |
|
|
|
self.conv1 = self.net.conv1 |
|
self.bn1 = self.net.bn1 |
|
self.relu = self.net.relu |
|
self.maxpool = self.net.maxpool |
|
self.layer1 = self.net.layer1 |
|
self.layer2 = self.net.layer2 |
|
self.layer3 = self.net.layer3 |
|
self.layer4 = self.net.layer4 |
|
|
|
def forward(self, X): |
|
h = self.conv1(X) |
|
h = self.bn1(h) |
|
h = self.relu(h) |
|
h_relu1 = h |
|
h = self.maxpool(h) |
|
h = self.layer1(h) |
|
h_conv2 = h |
|
h = self.layer2(h) |
|
h_conv3 = h |
|
h = self.layer3(h) |
|
h_conv4 = h |
|
h = self.layer4(h) |
|
h_conv5 = h |
|
|
|
outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5']) |
|
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) |
|
|
|
return out |
|
|