Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: ps-license@tuebingen.mpg.de | |
from torchvision import models | |
import torch | |
from torch.nn import init | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import functools | |
from torch.autograd import grad | |
def gradient(inputs, outputs): | |
d_points = torch.ones_like(outputs, | |
requires_grad=False, | |
device=outputs.device) | |
points_grad = grad(outputs=outputs, | |
inputs=inputs, | |
grad_outputs=d_points, | |
create_graph=True, | |
retain_graph=True, | |
only_inputs=True, | |
allow_unused=True)[0] | |
return points_grad | |
# def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): | |
# "3x3 convolution with padding" | |
# return nn.Conv2d(in_planes, out_planes, kernel_size=3, | |
# stride=strd, padding=padding, bias=bias) | |
def conv3x3(in_planes, | |
out_planes, | |
kernel=3, | |
strd=1, | |
dilation=1, | |
padding=1, | |
bias=False): | |
"3x3 convolution with padding" | |
return nn.Conv2d(in_planes, | |
out_planes, | |
kernel_size=kernel, | |
dilation=dilation, | |
stride=strd, | |
padding=padding, | |
bias=bias) | |
def conv1x1(in_planes, out_planes, stride=1): | |
"""1x1 convolution""" | |
return nn.Conv2d(in_planes, | |
out_planes, | |
kernel_size=1, | |
stride=stride, | |
bias=False) | |
def init_weights(net, init_type='normal', init_gain=0.02): | |
"""Initialize network weights. | |
Parameters: | |
net (network) -- network to be initialized | |
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might | |
work better for some applications. Feel free to try yourself. | |
""" | |
def init_func(m): # define the initialization function | |
classname = m.__class__.__name__ | |
if hasattr(m, 'weight') and (classname.find('Conv') != -1 | |
or classname.find('Linear') != -1): | |
if init_type == 'normal': | |
init.normal_(m.weight.data, 0.0, init_gain) | |
elif init_type == 'xavier': | |
init.xavier_normal_(m.weight.data, gain=init_gain) | |
elif init_type == 'kaiming': | |
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
elif init_type == 'orthogonal': | |
init.orthogonal_(m.weight.data, gain=init_gain) | |
else: | |
raise NotImplementedError( | |
'initialization method [%s] is not implemented' % | |
init_type) | |
if hasattr(m, 'bias') and m.bias is not None: | |
init.constant_(m.bias.data, 0.0) | |
elif classname.find( | |
'BatchNorm2d' | |
) != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. | |
init.normal_(m.weight.data, 1.0, init_gain) | |
init.constant_(m.bias.data, 0.0) | |
# print('initialize network with %s' % init_type) | |
net.apply(init_func) # apply the initialization function <init_func> | |
def init_net(net, init_type='xavier', init_gain=0.02, gpu_ids=[]): | |
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights | |
Parameters: | |
net (network) -- the network to be initialized | |
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
gain (float) -- scaling factor for normal, xavier and orthogonal. | |
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 | |
Return an initialized network. | |
""" | |
if len(gpu_ids) > 0: | |
assert (torch.cuda.is_available()) | |
net = torch.nn.DataParallel(net) # multi-GPUs | |
init_weights(net, init_type, init_gain=init_gain) | |
return net | |
def imageSpaceRotation(xy, rot): | |
''' | |
args: | |
xy: (B, 2, N) input | |
rot: (B, 2) x,y axis rotation angles | |
rotation center will be always image center (other rotation center can be represented by additional z translation) | |
''' | |
disp = rot.unsqueeze(2).sin().expand_as(xy) | |
return (disp * xy).sum(dim=1) | |
def cal_gradient_penalty(netD, | |
real_data, | |
fake_data, | |
device, | |
type='mixed', | |
constant=1.0, | |
lambda_gp=10.0): | |
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 | |
Arguments: | |
netD (network) -- discriminator network | |
real_data (tensor array) -- real images | |
fake_data (tensor array) -- generated images from the generator | |
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') | |
type (str) -- if we mix real and fake data or not [real | fake | mixed]. | |
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 | |
lambda_gp (float) -- weight for this loss | |
Returns the gradient penalty loss | |
""" | |
if lambda_gp > 0.0: | |
# either use real images, fake images, or a linear interpolation of two. | |
if type == 'real': | |
interpolatesv = real_data | |
elif type == 'fake': | |
interpolatesv = fake_data | |
elif type == 'mixed': | |
alpha = torch.rand(real_data.shape[0], 1) | |
alpha = alpha.expand( | |
real_data.shape[0], | |
real_data.nelement() // | |
real_data.shape[0]).contiguous().view(*real_data.shape) | |
alpha = alpha.to(device) | |
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) | |
else: | |
raise NotImplementedError('{} not implemented'.format(type)) | |
interpolatesv.requires_grad_(True) | |
disc_interpolates = netD(interpolatesv) | |
gradients = torch.autograd.grad( | |
outputs=disc_interpolates, | |
inputs=interpolatesv, | |
grad_outputs=torch.ones(disc_interpolates.size()).to(device), | |
create_graph=True, | |
retain_graph=True, | |
only_inputs=True) | |
gradients = gradients[0].view(real_data.size(0), -1) # flat the data | |
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** | |
2).mean() * lambda_gp # added eps | |
return gradient_penalty, gradients | |
else: | |
return 0.0, None | |
def get_norm_layer(norm_type='instance'): | |
"""Return a normalization layer | |
Parameters: | |
norm_type (str) -- the name of the normalization layer: batch | instance | none | |
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). | |
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. | |
""" | |
if norm_type == 'batch': | |
norm_layer = functools.partial(nn.BatchNorm2d, | |
affine=True, | |
track_running_stats=True) | |
elif norm_type == 'instance': | |
norm_layer = functools.partial(nn.InstanceNorm2d, | |
affine=False, | |
track_running_stats=False) | |
elif norm_type == 'group': | |
norm_layer = functools.partial(nn.GroupNorm, 32) | |
elif norm_type == 'none': | |
norm_layer = None | |
else: | |
raise NotImplementedError('normalization layer [%s] is not found' % | |
norm_type) | |
return norm_layer | |
class Flatten(nn.Module): | |
def forward(self, input): | |
return input.view(input.size(0), -1) | |
class ConvBlock(nn.Module): | |
def __init__(self, in_planes, out_planes, opt): | |
super(ConvBlock, self).__init__() | |
[k, s, d, p] = opt.conv3x3 | |
self.conv1 = conv3x3(in_planes, int(out_planes / 2), k, s, d, p) | |
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), k, s, d, | |
p) | |
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), k, s, d, | |
p) | |
if opt.norm == 'batch': | |
self.bn1 = nn.BatchNorm2d(in_planes) | |
self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) | |
self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) | |
self.bn4 = nn.BatchNorm2d(in_planes) | |
elif opt.norm == 'group': | |
self.bn1 = nn.GroupNorm(32, in_planes) | |
self.bn2 = nn.GroupNorm(32, int(out_planes / 2)) | |
self.bn3 = nn.GroupNorm(32, int(out_planes / 4)) | |
self.bn4 = nn.GroupNorm(32, in_planes) | |
if in_planes != out_planes: | |
self.downsample = nn.Sequential( | |
self.bn4, | |
nn.ReLU(True), | |
nn.Conv2d(in_planes, | |
out_planes, | |
kernel_size=1, | |
stride=1, | |
bias=False), | |
) | |
else: | |
self.downsample = None | |
def forward(self, x): | |
residual = x | |
out1 = self.bn1(x) | |
out1 = F.relu(out1, True) | |
out1 = self.conv1(out1) | |
out2 = self.bn2(out1) | |
out2 = F.relu(out2, True) | |
out2 = self.conv2(out2) | |
out3 = self.bn3(out2) | |
out3 = F.relu(out3, True) | |
out3 = self.conv3(out3) | |
out3 = torch.cat((out1, out2, out3), 1) | |
if self.downsample is not None: | |
residual = self.downsample(residual) | |
out3 += residual | |
return out3 | |
class Vgg19(torch.nn.Module): | |
def __init__(self, requires_grad=False): | |
super(Vgg19, self).__init__() | |
vgg_pretrained_features = models.vgg19(pretrained=True).features | |
self.slice1 = torch.nn.Sequential() | |
self.slice2 = torch.nn.Sequential() | |
self.slice3 = torch.nn.Sequential() | |
self.slice4 = torch.nn.Sequential() | |
self.slice5 = torch.nn.Sequential() | |
for x in range(2): | |
self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(2, 7): | |
self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(7, 12): | |
self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(12, 21): | |
self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(21, 30): | |
self.slice5.add_module(str(x), vgg_pretrained_features[x]) | |
if not requires_grad: | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, X): | |
h_relu1 = self.slice1(X) | |
h_relu2 = self.slice2(h_relu1) | |
h_relu3 = self.slice3(h_relu2) | |
h_relu4 = self.slice4(h_relu3) | |
h_relu5 = self.slice5(h_relu4) | |
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] | |
return out | |
class VGGLoss(nn.Module): | |
def __init__(self): | |
super(VGGLoss, self).__init__() | |
self.vgg = Vgg19() | |
self.criterion = nn.L1Loss() | |
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] | |
def forward(self, x, y): | |
x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
loss = 0 | |
for i in range(len(x_vgg)): | |
loss += self.weights[i] * self.criterion(x_vgg[i], | |
y_vgg[i].detach()) | |
return loss | |