Spaces:
Runtime error
Runtime error
"""Network Modules | |
- encoder3: vgg encoder up to relu31 | |
- decoder3: mirror decoder to encoder3 | |
- encoder4: vgg encoder up to relu41 | |
- decoder4: mirror decoder to encoder4 | |
- encoder5: vgg encoder up to relu51 | |
- styleLoss: gram matrix loss for all style layers | |
- styleLossMask: gram matrix loss for all style layers, compare between each part defined by a mask | |
- GramMatrix: compute gram matrix for one layer | |
- LossCriterion: style transfer loss that include both content & style losses | |
- LossCriterionMask: style transfer loss that include both content & style losses, use the styleLossMask | |
- VQEmbedding: codebook class for VQVAE | |
""" | |
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .vq_functions import vq, vq_st | |
from collections import OrderedDict | |
class MetaModule(nn.Module): | |
""" | |
Base class for PyTorch meta-learning modules. These modules accept an | |
additional argument `params` in their `forward` method. | |
Notes | |
----- | |
Objects inherited from `MetaModule` are fully compatible with PyTorch | |
modules from `torch.nn.Module`. The argument `params` is a dictionary of | |
tensors, with full support of the computation graph (for differentiation). | |
""" | |
def meta_named_parameters(self, prefix='', recurse=True): | |
gen = self._named_members( | |
lambda module: module._parameters.items() | |
if isinstance(module, MetaModule) else [], | |
prefix=prefix, recurse=recurse) | |
for elem in gen: | |
yield elem | |
def meta_parameters(self, recurse=True): | |
for name, param in self.meta_named_parameters(recurse=recurse): | |
yield param | |
class BatchLinear(nn.Linear, MetaModule): | |
'''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a | |
hypernetwork.''' | |
__doc__ = nn.Linear.__doc__ | |
def forward(self, input, params=None): | |
if params is None: | |
params = OrderedDict(self.named_parameters()) | |
bias = params.get('bias', None) | |
weight = params['weight'] | |
output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2)) | |
output += bias.unsqueeze(-2) | |
return output | |
class decoder1(nn.Module): | |
def __init__(self): | |
super(decoder1,self).__init__() | |
self.reflecPad2 = nn.ReflectionPad2d((1,1,1,1)) | |
# 226 x 226 | |
self.conv3 = nn.Conv2d(64,3,3,1,0) | |
# 224 x 224 | |
def forward(self,x): | |
out = self.reflecPad2(x) | |
out = self.conv3(out) | |
return out | |
class decoder2(nn.Module): | |
def __init__(self): | |
super(decoder2,self).__init__() | |
# decoder | |
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv5 = nn.Conv2d(128,64,3,1,0) | |
self.relu5 = nn.ReLU(inplace=True) | |
# 112 x 112 | |
self.unpool = nn.UpsamplingNearest2d(scale_factor=2) | |
# 224 x 224 | |
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv6 = nn.Conv2d(64,64,3,1,0) | |
self.relu6 = nn.ReLU(inplace=True) | |
# 224 x 224 | |
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv7 = nn.Conv2d(64,3,3,1,0) | |
def forward(self,x): | |
out = self.reflecPad5(x) | |
out = self.conv5(out) | |
out = self.relu5(out) | |
out = self.unpool(out) | |
out = self.reflecPad6(out) | |
out = self.conv6(out) | |
out = self.relu6(out) | |
out = self.reflecPad7(out) | |
out = self.conv7(out) | |
return out | |
class encoder3(nn.Module): | |
def __init__(self): | |
super(encoder3,self).__init__() | |
# vgg | |
# 224 x 224 | |
self.conv1 = nn.Conv2d(3,3,1,1,0) | |
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) | |
# 226 x 226 | |
self.conv2 = nn.Conv2d(3,64,3,1,0) | |
self.relu2 = nn.ReLU(inplace=True) | |
# 224 x 224 | |
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv3 = nn.Conv2d(64,64,3,1,0) | |
self.relu3 = nn.ReLU(inplace=True) | |
# 224 x 224 | |
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) | |
# 112 x 112 | |
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv4 = nn.Conv2d(64,128,3,1,0) | |
self.relu4 = nn.ReLU(inplace=True) | |
# 112 x 112 | |
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv5 = nn.Conv2d(128,128,3,1,0) | |
self.relu5 = nn.ReLU(inplace=True) | |
# 112 x 112 | |
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) | |
# 56 x 56 | |
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv6 = nn.Conv2d(128,256,3,1,0) | |
self.relu6 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
def forward(self,x): | |
out = self.conv1(x) | |
out = self.reflecPad1(out) | |
out = self.conv2(out) | |
out = self.relu2(out) | |
out = self.reflecPad3(out) | |
out = self.conv3(out) | |
pool1 = self.relu3(out) | |
out,pool_idx = self.maxPool(pool1) | |
out = self.reflecPad4(out) | |
out = self.conv4(out) | |
out = self.relu4(out) | |
out = self.reflecPad5(out) | |
out = self.conv5(out) | |
pool2 = self.relu5(out) | |
out,pool_idx2 = self.maxPool2(pool2) | |
out = self.reflecPad6(out) | |
out = self.conv6(out) | |
out = self.relu6(out) | |
return out | |
class decoder3(nn.Module): | |
def __init__(self): | |
super(decoder3,self).__init__() | |
# decoder | |
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv7 = nn.Conv2d(256,128,3,1,0) | |
self.relu7 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.unpool = nn.UpsamplingNearest2d(scale_factor=2) | |
# 112 x 112 | |
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv8 = nn.Conv2d(128,128,3,1,0) | |
self.relu8 = nn.ReLU(inplace=True) | |
# 112 x 112 | |
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv9 = nn.Conv2d(128,64,3,1,0) | |
self.relu9 = nn.ReLU(inplace=True) | |
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) | |
# 224 x 224 | |
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv10 = nn.Conv2d(64,64,3,1,0) | |
self.relu10 = nn.ReLU(inplace=True) | |
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv11 = nn.Conv2d(64,3,3,1,0) | |
def forward(self,x): | |
output = {} | |
out = self.reflecPad7(x) | |
out = self.conv7(out) | |
out = self.relu7(out) | |
out = self.unpool(out) | |
out = self.reflecPad8(out) | |
out = self.conv8(out) | |
out = self.relu8(out) | |
out = self.reflecPad9(out) | |
out = self.conv9(out) | |
out_relu9 = self.relu9(out) | |
out = self.unpool2(out_relu9) | |
out = self.reflecPad10(out) | |
out = self.conv10(out) | |
out = self.relu10(out) | |
out = self.reflecPad11(out) | |
out = self.conv11(out) | |
return out | |
class encoder4(nn.Module): | |
def __init__(self): | |
super(encoder4,self).__init__() | |
# vgg | |
# 224 x 224 | |
self.conv1 = nn.Conv2d(3,3,1,1,0) | |
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) | |
# 226 x 226 | |
self.conv2 = nn.Conv2d(3,64,3,1,0) | |
self.relu2 = nn.ReLU(inplace=True) | |
# 224 x 224 | |
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv3 = nn.Conv2d(64,64,3,1,0) | |
self.relu3 = nn.ReLU(inplace=True) | |
# 224 x 224 | |
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2) | |
# 112 x 112 | |
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv4 = nn.Conv2d(64,128,3,1,0) | |
self.relu4 = nn.ReLU(inplace=True) | |
# 112 x 112 | |
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv5 = nn.Conv2d(128,128,3,1,0) | |
self.relu5 = nn.ReLU(inplace=True) | |
# 112 x 112 | |
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2) | |
# 56 x 56 | |
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv6 = nn.Conv2d(128,256,3,1,0) | |
self.relu6 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv7 = nn.Conv2d(256,256,3,1,0) | |
self.relu7 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv8 = nn.Conv2d(256,256,3,1,0) | |
self.relu8 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv9 = nn.Conv2d(256,256,3,1,0) | |
self.relu9 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2) | |
# 28 x 28 | |
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv10 = nn.Conv2d(256,512,3,1,0) | |
self.relu10 = nn.ReLU(inplace=True) | |
# 28 x 28 | |
def forward(self,x,sF=None,matrix11=None,matrix21=None,matrix31=None): | |
output = {} | |
out = self.conv1(x) | |
out = self.reflecPad1(out) | |
out = self.conv2(out) | |
output['r11'] = self.relu2(out) | |
out = self.reflecPad7(output['r11']) | |
out = self.conv3(out) | |
output['r12'] = self.relu3(out) | |
output['p1'] = self.maxPool(output['r12']) | |
out = self.reflecPad4(output['p1']) | |
out = self.conv4(out) | |
output['r21'] = self.relu4(out) | |
out = self.reflecPad7(output['r21']) | |
out = self.conv5(out) | |
output['r22'] = self.relu5(out) | |
output['p2'] = self.maxPool2(output['r22']) | |
out = self.reflecPad6(output['p2']) | |
out = self.conv6(out) | |
output['r31'] = self.relu6(out) | |
if(matrix31 is not None): | |
feature3,transmatrix3 = matrix31(output['r31'],sF['r31']) | |
out = self.reflecPad7(feature3) | |
else: | |
out = self.reflecPad7(output['r31']) | |
out = self.conv7(out) | |
output['r32'] = self.relu7(out) | |
out = self.reflecPad8(output['r32']) | |
out = self.conv8(out) | |
output['r33'] = self.relu8(out) | |
out = self.reflecPad9(output['r33']) | |
out = self.conv9(out) | |
output['r34'] = self.relu9(out) | |
output['p3'] = self.maxPool3(output['r34']) | |
out = self.reflecPad10(output['p3']) | |
out = self.conv10(out) | |
output['r41'] = self.relu10(out) | |
return output | |
class decoder4(nn.Module): | |
def __init__(self): | |
super(decoder4,self).__init__() | |
# decoder | |
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv11 = nn.Conv2d(512,256,3,1,0) | |
self.relu11 = nn.ReLU(inplace=True) | |
# 28 x 28 | |
self.unpool = nn.UpsamplingNearest2d(scale_factor=2) | |
# 56 x 56 | |
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv12 = nn.Conv2d(256,256,3,1,0) | |
self.relu12 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv13 = nn.Conv2d(256,256,3,1,0) | |
self.relu13 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv14 = nn.Conv2d(256,256,3,1,0) | |
self.relu14 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv15 = nn.Conv2d(256,128,3,1,0) | |
self.relu15 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) | |
# 112 x 112 | |
self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv16 = nn.Conv2d(128,128,3,1,0) | |
self.relu16 = nn.ReLU(inplace=True) | |
# 112 x 112 | |
self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv17 = nn.Conv2d(128,64,3,1,0) | |
self.relu17 = nn.ReLU(inplace=True) | |
# 112 x 112 | |
self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2) | |
# 224 x 224 | |
self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv18 = nn.Conv2d(64,64,3,1,0) | |
self.relu18 = nn.ReLU(inplace=True) | |
# 224 x 224 | |
self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv19 = nn.Conv2d(64,3,3,1,0) | |
def forward(self,x): | |
# decoder | |
out = self.reflecPad11(x) | |
out = self.conv11(out) | |
out = self.relu11(out) | |
out = self.unpool(out) | |
out = self.reflecPad12(out) | |
out = self.conv12(out) | |
out = self.relu12(out) | |
out = self.reflecPad13(out) | |
out = self.conv13(out) | |
out = self.relu13(out) | |
out = self.reflecPad14(out) | |
out = self.conv14(out) | |
out = self.relu14(out) | |
out = self.reflecPad15(out) | |
out = self.conv15(out) | |
out = self.relu15(out) | |
out = self.unpool2(out) | |
out = self.reflecPad16(out) | |
out = self.conv16(out) | |
out = self.relu16(out) | |
out = self.reflecPad17(out) | |
out = self.conv17(out) | |
out = self.relu17(out) | |
out = self.unpool3(out) | |
out = self.reflecPad18(out) | |
out = self.conv18(out) | |
out = self.relu18(out) | |
out = self.reflecPad19(out) | |
out = self.conv19(out) | |
return out | |
class encoder5(nn.Module): | |
def __init__(self): | |
super(encoder5,self).__init__() | |
# vgg | |
# 224 x 224 | |
self.conv1 = nn.Conv2d(3,3,1,1,0) | |
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) | |
# 226 x 226 | |
self.conv2 = nn.Conv2d(3,64,3,1,0) | |
self.relu2 = nn.ReLU(inplace=True) | |
# 224 x 224 | |
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv3 = nn.Conv2d(64,64,3,1,0) | |
self.relu3 = nn.ReLU(inplace=True) | |
# 224 x 224 | |
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2) | |
# 112 x 112 | |
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv4 = nn.Conv2d(64,128,3,1,0) | |
self.relu4 = nn.ReLU(inplace=True) | |
# 112 x 112 | |
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv5 = nn.Conv2d(128,128,3,1,0) | |
self.relu5 = nn.ReLU(inplace=True) | |
# 112 x 112 | |
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2) | |
# 56 x 56 | |
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv6 = nn.Conv2d(128,256,3,1,0) | |
self.relu6 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv7 = nn.Conv2d(256,256,3,1,0) | |
self.relu7 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv8 = nn.Conv2d(256,256,3,1,0) | |
self.relu8 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv9 = nn.Conv2d(256,256,3,1,0) | |
self.relu9 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2) | |
# 28 x 28 | |
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv10 = nn.Conv2d(256,512,3,1,0) | |
self.relu10 = nn.ReLU(inplace=True) | |
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv11 = nn.Conv2d(512,512,3,1,0) | |
self.relu11 = nn.ReLU(inplace=True) | |
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv12 = nn.Conv2d(512,512,3,1,0) | |
self.relu12 = nn.ReLU(inplace=True) | |
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv13 = nn.Conv2d(512,512,3,1,0) | |
self.relu13 = nn.ReLU(inplace=True) | |
self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2) | |
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv14 = nn.Conv2d(512,512,3,1,0) | |
self.relu14 = nn.ReLU(inplace=True) | |
def forward(self,x,sF=None,contentV256=None,styleV256=None,matrix11=None,matrix21=None,matrix31=None): | |
output = {} | |
out = self.conv1(x) | |
out = self.reflecPad1(out) | |
out = self.conv2(out) | |
output['r11'] = self.relu2(out) | |
out = self.reflecPad7(output['r11']) | |
#out = self.reflecPad3(output['r11']) | |
out = self.conv3(out) | |
output['r12'] = self.relu3(out) | |
output['p1'] = self.maxPool(output['r12']) | |
out = self.reflecPad4(output['p1']) | |
out = self.conv4(out) | |
output['r21'] = self.relu4(out) | |
out = self.reflecPad7(output['r21']) | |
#out = self.reflecPad5(output['r21']) | |
out = self.conv5(out) | |
output['r22'] = self.relu5(out) | |
output['p2'] = self.maxPool2(output['r22']) | |
out = self.reflecPad6(output['p2']) | |
out = self.conv6(out) | |
output['r31'] = self.relu6(out) | |
if(styleV256 is not None): | |
feature = matrix31(output['r31'],sF['r31'],contentV256,styleV256) | |
out = self.reflecPad7(feature) | |
else: | |
out = self.reflecPad7(output['r31']) | |
out = self.conv7(out) | |
output['r32'] = self.relu7(out) | |
out = self.reflecPad8(output['r32']) | |
out = self.conv8(out) | |
output['r33'] = self.relu8(out) | |
out = self.reflecPad9(output['r33']) | |
out = self.conv9(out) | |
output['r34'] = self.relu9(out) | |
output['p3'] = self.maxPool3(output['r34']) | |
out = self.reflecPad10(output['p3']) | |
out = self.conv10(out) | |
output['r41'] = self.relu10(out) | |
out = self.reflecPad11(out) | |
out = self.conv11(out) | |
out = self.relu11(out) | |
out = self.reflecPad12(out) | |
out = self.conv12(out) | |
out = self.relu12(out) | |
out = self.reflecPad13(out) | |
out = self.conv13(out) | |
out = self.relu13(out) | |
out = self.maxPool4(out) | |
out = self.reflecPad14(out) | |
out = self.conv14(out) | |
out = self.relu14(out) | |
output['r51'] = out | |
return output | |
class styleLoss(nn.Module): | |
def forward(self, input, target): | |
ib,ic,ih,iw = input.size() | |
iF = input.view(ib,ic,-1) | |
iMean = torch.mean(iF,dim=2) | |
iCov = GramMatrix()(input) | |
tb,tc,th,tw = target.size() | |
tF = target.view(tb,tc,-1) | |
tMean = torch.mean(tF,dim=2) | |
tCov = GramMatrix()(target) | |
loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov) | |
return loss/tb | |
class GramMatrix(nn.Module): | |
def forward(self, input): | |
b, c, h, w = input.size() | |
f = input.view(b,c,h*w) # bxcx(hxw) | |
# torch.bmm(batch1, batch2, out=None) # | |
# batch1: bxmxp, batch2: bxpxn -> bxmxn # | |
G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
return G.div_(c*h*w) | |
class LossCriterion(nn.Module): | |
def __init__(self, style_layers, content_layers, style_weight, content_weight, | |
model_path = '/home/xtli/Documents/GITHUB/LinearStyleTransfer/models/'): | |
super(LossCriterion,self).__init__() | |
self.style_layers = style_layers | |
self.content_layers = content_layers | |
self.style_weight = style_weight | |
self.content_weight = content_weight | |
self.styleLosses = [styleLoss()] * len(style_layers) | |
self.contentLosses = [nn.MSELoss()] * len(content_layers) | |
self.vgg5 = encoder5() | |
self.vgg5.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth'))) | |
for param in self.vgg5.parameters(): | |
param.requires_grad = True | |
def forward(self, transfer, image, content=True, style=True): | |
cF = self.vgg5(image) | |
sF = self.vgg5(image) | |
tF = self.vgg5(transfer) | |
losses = {} | |
# content loss | |
if content: | |
totalContentLoss = 0 | |
for i,layer in enumerate(self.content_layers): | |
cf_i = cF[layer] | |
cf_i = cf_i.detach() | |
tf_i = tF[layer] | |
loss_i = self.contentLosses[i] | |
totalContentLoss += loss_i(tf_i,cf_i) | |
totalContentLoss = totalContentLoss * self.content_weight | |
losses['content'] = totalContentLoss | |
# style loss | |
if style: | |
totalStyleLoss = 0 | |
for i,layer in enumerate(self.style_layers): | |
sf_i = sF[layer] | |
sf_i = sf_i.detach() | |
tf_i = tF[layer] | |
loss_i = self.styleLosses[i] | |
totalStyleLoss += loss_i(tf_i,sf_i) | |
totalStyleLoss = totalStyleLoss * self.style_weight | |
losses['style'] = totalStyleLoss | |
return losses | |
class styleLossMask(nn.Module): | |
def forward(self, input, target, mask): | |
ib,ic,ih,iw = input.size() | |
iF = input.view(ib,ic,-1) | |
tb,tc,th,tw = target.size() | |
tF = target.view(tb,tc,-1) | |
loss = 0 | |
mb, mc, mh, mw = mask.shape | |
for i in range(mb): | |
# resize mask to have the same size of the feature | |
maski = F.interpolate(mask[i:i+1], size = (ih, iw), mode = 'nearest') | |
mask_flat = maski.view(mc, -1) | |
for j in range(mc): | |
# get features for each part | |
idx = torch.nonzero(mask_flat[j]).squeeze() | |
if len(idx.shape) == 0 or idx.shape[0] == 0: | |
continue | |
ipart = torch.index_select(iF, 2, idx) | |
tpart = torch.index_select(tF, 2, idx) | |
iMean = torch.mean(ipart,dim=2) | |
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ic*ih*iw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
tMean = torch.mean(tpart,dim=2) | |
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tc*th*tw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
loss += nn.MSELoss()(iMean,tMean) + nn.MSELoss()(iGram,tGram) | |
return loss/tb | |
class LossCriterionMask(nn.Module): | |
def __init__(self, style_layers, content_layers, style_weight, content_weight, | |
model_path = '/home/xtli/Documents/GITHUB/LinearStyleTransfer/models/'): | |
super(LossCriterionMask,self).__init__() | |
self.style_layers = style_layers | |
self.content_layers = content_layers | |
self.style_weight = style_weight | |
self.content_weight = content_weight | |
self.styleLosses = [styleLossMask()] * len(style_layers) | |
self.contentLosses = [nn.MSELoss()] * len(content_layers) | |
self.vgg5 = encoder5() | |
self.vgg5.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth'))) | |
for param in self.vgg5.parameters(): | |
param.requires_grad = True | |
def forward(self, transfer, image, mask, content=True, style=True): | |
# mask: B, N, H, W | |
cF = self.vgg5(image) | |
sF = self.vgg5(image) | |
tF = self.vgg5(transfer) | |
losses = {} | |
# content loss | |
if content: | |
totalContentLoss = 0 | |
for i,layer in enumerate(self.content_layers): | |
cf_i = cF[layer] | |
cf_i = cf_i.detach() | |
tf_i = tF[layer] | |
loss_i = self.contentLosses[i] | |
totalContentLoss += loss_i(tf_i,cf_i) | |
totalContentLoss = totalContentLoss * self.content_weight | |
losses['content'] = totalContentLoss | |
# style loss | |
if style: | |
totalStyleLoss = 0 | |
for i,layer in enumerate(self.style_layers): | |
sf_i = sF[layer] | |
sf_i = sf_i.detach() | |
tf_i = tF[layer] | |
loss_i = self.styleLosses[i] | |
totalStyleLoss += loss_i(tf_i,sf_i, mask) | |
totalStyleLoss = totalStyleLoss * self.style_weight | |
losses['style'] = totalStyleLoss | |
return losses | |
class VQEmbedding(nn.Module): | |
def __init__(self, K, D): | |
super().__init__() | |
self.embedding = nn.Embedding(K, D) | |
self.embedding.weight.data.uniform_(-1./K, 1./K) | |
def forward(self, z_e_x): | |
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous() | |
latents = vq(z_e_x_, self.embedding.weight) | |
return latents | |
def straight_through(self, z_e_x, return_index=False): | |
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous() | |
z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach()) | |
z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous() | |
z_q_x_bar_flatten = torch.index_select(self.embedding.weight, | |
dim=0, index=indices) | |
z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_) | |
z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous() | |
if return_index: | |
return z_q_x, z_q_x_bar, indices | |
else: | |
return z_q_x, z_q_x_bar |