AlexZou's picture
Upload 17 files
7eb6194
raw
history blame
No virus
2.75 kB
import math
import torch
import torch.nn as nn
import numpy as np
from skimage.measure.simple_metrics import compare_psnr
from torchvision import models
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm') != -1:
# nn.init.uniform(m.weight.data, 1.0, 0.02)
m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025)
nn.init.constant(m.bias.data, 0.0)
class VGG19_PercepLoss(nn.Module):
""" Calculates perceptual loss in vgg19 space
"""
def __init__(self, _pretrained_=True):
super(VGG19_PercepLoss, self).__init__()
self.vgg = models.vgg19(pretrained=_pretrained_).features
for param in self.vgg.parameters():
param.requires_grad_(False)
def get_features(self, image, layers=None):
if layers is None:
layers = {'30': 'conv5_2'} # may add other layers
features = {}
x = image
for name, layer in self.vgg._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
def forward(self, pred, true, layer='conv5_2'):
true_f = self.get_features(true)
pred_f = self.get_features(pred)
return torch.mean((true_f[layer]-pred_f[layer])**2)
def batch_PSNR(img, imclean, data_range):
Img = img.data.cpu().numpy().astype(np.float32)
Iclean = imclean.data.cpu().numpy().astype(np.float32)
PSNR = 0
for i in range(Img.shape[0]):
PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
return (PSNR/Img.shape[0])
def data_augmentation(image, mode):
out = np.transpose(image, (1,2,0))
#out = image
if mode == 0:
# original
out = out
elif mode == 1:
# flip up and down
out = np.flipud(out)
elif mode == 2:
# rotate counterwise 90 degree
out = np.rot90(out)
elif mode == 3:
# rotate 90 degree and flip up and down
out = np.rot90(out)
out = np.flipud(out)
elif mode == 4:
# rotate 180 degree
out = np.rot90(out, k=2)
elif mode == 5:
# rotate 180 degree and flip
out = np.rot90(out, k=2)
out = np.flipud(out)
elif mode == 6:
# rotate 270 degree
out = np.rot90(out, k=3)
elif mode == 7:
# rotate 270 degree and flip
out = np.rot90(out, k=3)
out = np.flipud(out)
return np.transpose(out, (2,0,1))
#return out