TextureScraping / libs /blocks.py
sunshineatnoon
Add application file
1b2a9b1
raw history blame
No virus
24.6 kB
"""Network Modules
- encoder3: vgg encoder up to relu31
- decoder3: mirror decoder to encoder3
- encoder4: vgg encoder up to relu41
- decoder4: mirror decoder to encoder4
- encoder5: vgg encoder up to relu51
- styleLoss: gram matrix loss for all style layers
- styleLossMask: gram matrix loss for all style layers, compare between each part defined by a mask
- GramMatrix: compute gram matrix for one layer
- LossCriterion: style transfer loss that include both content & style losses
- LossCriterionMask: style transfer loss that include both content & style losses, use the styleLossMask
- VQEmbedding: codebook class for VQVAE
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from .vq_functions import vq, vq_st
from collections import OrderedDict
class MetaModule(nn.Module):
"""
Base class for PyTorch meta-learning modules. These modules accept an
additional argument `params` in their `forward` method.
Notes
-----
Objects inherited from `MetaModule` are fully compatible with PyTorch
modules from `torch.nn.Module`. The argument `params` is a dictionary of
tensors, with full support of the computation graph (for differentiation).
"""
def meta_named_parameters(self, prefix='', recurse=True):
gen = self._named_members(
lambda module: module._parameters.items()
if isinstance(module, MetaModule) else [],
prefix=prefix, recurse=recurse)
for elem in gen:
yield elem
def meta_parameters(self, recurse=True):
for name, param in self.meta_named_parameters(recurse=recurse):
yield param
class BatchLinear(nn.Linear, MetaModule):
'''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a
hypernetwork.'''
__doc__ = nn.Linear.__doc__
def forward(self, input, params=None):
if params is None:
params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
weight = params['weight']
output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
output += bias.unsqueeze(-2)
return output
class decoder1(nn.Module):
def __init__(self):
super(decoder1,self).__init__()
self.reflecPad2 = nn.ReflectionPad2d((1,1,1,1))
# 226 x 226
self.conv3 = nn.Conv2d(64,3,3,1,0)
# 224 x 224
def forward(self,x):
out = self.reflecPad2(x)
out = self.conv3(out)
return out
class decoder2(nn.Module):
def __init__(self):
super(decoder2,self).__init__()
# decoder
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
self.conv5 = nn.Conv2d(128,64,3,1,0)
self.relu5 = nn.ReLU(inplace=True)
# 112 x 112
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
# 224 x 224
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
self.conv6 = nn.Conv2d(64,64,3,1,0)
self.relu6 = nn.ReLU(inplace=True)
# 224 x 224
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
self.conv7 = nn.Conv2d(64,3,3,1,0)
def forward(self,x):
out = self.reflecPad5(x)
out = self.conv5(out)
out = self.relu5(out)
out = self.unpool(out)
out = self.reflecPad6(out)
out = self.conv6(out)
out = self.relu6(out)
out = self.reflecPad7(out)
out = self.conv7(out)
return out
class encoder3(nn.Module):
def __init__(self):
super(encoder3,self).__init__()
# vgg
# 224 x 224
self.conv1 = nn.Conv2d(3,3,1,1,0)
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
# 226 x 226
self.conv2 = nn.Conv2d(3,64,3,1,0)
self.relu2 = nn.ReLU(inplace=True)
# 224 x 224
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
self.conv3 = nn.Conv2d(64,64,3,1,0)
self.relu3 = nn.ReLU(inplace=True)
# 224 x 224
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
# 112 x 112
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
self.conv4 = nn.Conv2d(64,128,3,1,0)
self.relu4 = nn.ReLU(inplace=True)
# 112 x 112
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
self.conv5 = nn.Conv2d(128,128,3,1,0)
self.relu5 = nn.ReLU(inplace=True)
# 112 x 112
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
# 56 x 56
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
self.conv6 = nn.Conv2d(128,256,3,1,0)
self.relu6 = nn.ReLU(inplace=True)
# 56 x 56
def forward(self,x):
out = self.conv1(x)
out = self.reflecPad1(out)
out = self.conv2(out)
out = self.relu2(out)
out = self.reflecPad3(out)
out = self.conv3(out)
pool1 = self.relu3(out)
out,pool_idx = self.maxPool(pool1)
out = self.reflecPad4(out)
out = self.conv4(out)
out = self.relu4(out)
out = self.reflecPad5(out)
out = self.conv5(out)
pool2 = self.relu5(out)
out,pool_idx2 = self.maxPool2(pool2)
out = self.reflecPad6(out)
out = self.conv6(out)
out = self.relu6(out)
return out
class decoder3(nn.Module):
def __init__(self):
super(decoder3,self).__init__()
# decoder
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
self.conv7 = nn.Conv2d(256,128,3,1,0)
self.relu7 = nn.ReLU(inplace=True)
# 56 x 56
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
# 112 x 112
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
self.conv8 = nn.Conv2d(128,128,3,1,0)
self.relu8 = nn.ReLU(inplace=True)
# 112 x 112
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
self.conv9 = nn.Conv2d(128,64,3,1,0)
self.relu9 = nn.ReLU(inplace=True)
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
# 224 x 224
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
self.conv10 = nn.Conv2d(64,64,3,1,0)
self.relu10 = nn.ReLU(inplace=True)
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
self.conv11 = nn.Conv2d(64,3,3,1,0)
def forward(self,x):
output = {}
out = self.reflecPad7(x)
out = self.conv7(out)
out = self.relu7(out)
out = self.unpool(out)
out = self.reflecPad8(out)
out = self.conv8(out)
out = self.relu8(out)
out = self.reflecPad9(out)
out = self.conv9(out)
out_relu9 = self.relu9(out)
out = self.unpool2(out_relu9)
out = self.reflecPad10(out)
out = self.conv10(out)
out = self.relu10(out)
out = self.reflecPad11(out)
out = self.conv11(out)
return out
class encoder4(nn.Module):
def __init__(self):
super(encoder4,self).__init__()
# vgg
# 224 x 224
self.conv1 = nn.Conv2d(3,3,1,1,0)
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
# 226 x 226
self.conv2 = nn.Conv2d(3,64,3,1,0)
self.relu2 = nn.ReLU(inplace=True)
# 224 x 224
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
self.conv3 = nn.Conv2d(64,64,3,1,0)
self.relu3 = nn.ReLU(inplace=True)
# 224 x 224
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
# 112 x 112
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
self.conv4 = nn.Conv2d(64,128,3,1,0)
self.relu4 = nn.ReLU(inplace=True)
# 112 x 112
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
self.conv5 = nn.Conv2d(128,128,3,1,0)
self.relu5 = nn.ReLU(inplace=True)
# 112 x 112
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
# 56 x 56
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
self.conv6 = nn.Conv2d(128,256,3,1,0)
self.relu6 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
self.conv7 = nn.Conv2d(256,256,3,1,0)
self.relu7 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
self.conv8 = nn.Conv2d(256,256,3,1,0)
self.relu8 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
self.conv9 = nn.Conv2d(256,256,3,1,0)
self.relu9 = nn.ReLU(inplace=True)
# 56 x 56
self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
# 28 x 28
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
self.conv10 = nn.Conv2d(256,512,3,1,0)
self.relu10 = nn.ReLU(inplace=True)
# 28 x 28
def forward(self,x,sF=None,matrix11=None,matrix21=None,matrix31=None):
output = {}
out = self.conv1(x)
out = self.reflecPad1(out)
out = self.conv2(out)
output['r11'] = self.relu2(out)
out = self.reflecPad7(output['r11'])
out = self.conv3(out)
output['r12'] = self.relu3(out)
output['p1'] = self.maxPool(output['r12'])
out = self.reflecPad4(output['p1'])
out = self.conv4(out)
output['r21'] = self.relu4(out)
out = self.reflecPad7(output['r21'])
out = self.conv5(out)
output['r22'] = self.relu5(out)
output['p2'] = self.maxPool2(output['r22'])
out = self.reflecPad6(output['p2'])
out = self.conv6(out)
output['r31'] = self.relu6(out)
if(matrix31 is not None):
feature3,transmatrix3 = matrix31(output['r31'],sF['r31'])
out = self.reflecPad7(feature3)
else:
out = self.reflecPad7(output['r31'])
out = self.conv7(out)
output['r32'] = self.relu7(out)
out = self.reflecPad8(output['r32'])
out = self.conv8(out)
output['r33'] = self.relu8(out)
out = self.reflecPad9(output['r33'])
out = self.conv9(out)
output['r34'] = self.relu9(out)
output['p3'] = self.maxPool3(output['r34'])
out = self.reflecPad10(output['p3'])
out = self.conv10(out)
output['r41'] = self.relu10(out)
return output
class decoder4(nn.Module):
def __init__(self):
super(decoder4,self).__init__()
# decoder
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
self.conv11 = nn.Conv2d(512,256,3,1,0)
self.relu11 = nn.ReLU(inplace=True)
# 28 x 28
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
# 56 x 56
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
self.conv12 = nn.Conv2d(256,256,3,1,0)
self.relu12 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
self.conv13 = nn.Conv2d(256,256,3,1,0)
self.relu13 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
self.conv14 = nn.Conv2d(256,256,3,1,0)
self.relu14 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
self.conv15 = nn.Conv2d(256,128,3,1,0)
self.relu15 = nn.ReLU(inplace=True)
# 56 x 56
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
# 112 x 112
self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
self.conv16 = nn.Conv2d(128,128,3,1,0)
self.relu16 = nn.ReLU(inplace=True)
# 112 x 112
self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
self.conv17 = nn.Conv2d(128,64,3,1,0)
self.relu17 = nn.ReLU(inplace=True)
# 112 x 112
self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
# 224 x 224
self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
self.conv18 = nn.Conv2d(64,64,3,1,0)
self.relu18 = nn.ReLU(inplace=True)
# 224 x 224
self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
self.conv19 = nn.Conv2d(64,3,3,1,0)
def forward(self,x):
# decoder
out = self.reflecPad11(x)
out = self.conv11(out)
out = self.relu11(out)
out = self.unpool(out)
out = self.reflecPad12(out)
out = self.conv12(out)
out = self.relu12(out)
out = self.reflecPad13(out)
out = self.conv13(out)
out = self.relu13(out)
out = self.reflecPad14(out)
out = self.conv14(out)
out = self.relu14(out)
out = self.reflecPad15(out)
out = self.conv15(out)
out = self.relu15(out)
out = self.unpool2(out)
out = self.reflecPad16(out)
out = self.conv16(out)
out = self.relu16(out)
out = self.reflecPad17(out)
out = self.conv17(out)
out = self.relu17(out)
out = self.unpool3(out)
out = self.reflecPad18(out)
out = self.conv18(out)
out = self.relu18(out)
out = self.reflecPad19(out)
out = self.conv19(out)
return out
class encoder5(nn.Module):
def __init__(self):
super(encoder5,self).__init__()
# vgg
# 224 x 224
self.conv1 = nn.Conv2d(3,3,1,1,0)
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
# 226 x 226
self.conv2 = nn.Conv2d(3,64,3,1,0)
self.relu2 = nn.ReLU(inplace=True)
# 224 x 224
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
self.conv3 = nn.Conv2d(64,64,3,1,0)
self.relu3 = nn.ReLU(inplace=True)
# 224 x 224
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
# 112 x 112
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
self.conv4 = nn.Conv2d(64,128,3,1,0)
self.relu4 = nn.ReLU(inplace=True)
# 112 x 112
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
self.conv5 = nn.Conv2d(128,128,3,1,0)
self.relu5 = nn.ReLU(inplace=True)
# 112 x 112
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
# 56 x 56
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
self.conv6 = nn.Conv2d(128,256,3,1,0)
self.relu6 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
self.conv7 = nn.Conv2d(256,256,3,1,0)
self.relu7 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
self.conv8 = nn.Conv2d(256,256,3,1,0)
self.relu8 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
self.conv9 = nn.Conv2d(256,256,3,1,0)
self.relu9 = nn.ReLU(inplace=True)
# 56 x 56
self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
# 28 x 28
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
self.conv10 = nn.Conv2d(256,512,3,1,0)
self.relu10 = nn.ReLU(inplace=True)
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
self.conv11 = nn.Conv2d(512,512,3,1,0)
self.relu11 = nn.ReLU(inplace=True)
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
self.conv12 = nn.Conv2d(512,512,3,1,0)
self.relu12 = nn.ReLU(inplace=True)
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
self.conv13 = nn.Conv2d(512,512,3,1,0)
self.relu13 = nn.ReLU(inplace=True)
self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2)
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
self.conv14 = nn.Conv2d(512,512,3,1,0)
self.relu14 = nn.ReLU(inplace=True)
def forward(self,x,sF=None,contentV256=None,styleV256=None,matrix11=None,matrix21=None,matrix31=None):
output = {}
out = self.conv1(x)
out = self.reflecPad1(out)
out = self.conv2(out)
output['r11'] = self.relu2(out)
out = self.reflecPad7(output['r11'])
#out = self.reflecPad3(output['r11'])
out = self.conv3(out)
output['r12'] = self.relu3(out)
output['p1'] = self.maxPool(output['r12'])
out = self.reflecPad4(output['p1'])
out = self.conv4(out)
output['r21'] = self.relu4(out)
out = self.reflecPad7(output['r21'])
#out = self.reflecPad5(output['r21'])
out = self.conv5(out)
output['r22'] = self.relu5(out)
output['p2'] = self.maxPool2(output['r22'])
out = self.reflecPad6(output['p2'])
out = self.conv6(out)
output['r31'] = self.relu6(out)
if(styleV256 is not None):
feature = matrix31(output['r31'],sF['r31'],contentV256,styleV256)
out = self.reflecPad7(feature)
else:
out = self.reflecPad7(output['r31'])
out = self.conv7(out)
output['r32'] = self.relu7(out)
out = self.reflecPad8(output['r32'])
out = self.conv8(out)
output['r33'] = self.relu8(out)
out = self.reflecPad9(output['r33'])
out = self.conv9(out)
output['r34'] = self.relu9(out)
output['p3'] = self.maxPool3(output['r34'])
out = self.reflecPad10(output['p3'])
out = self.conv10(out)
output['r41'] = self.relu10(out)
out = self.reflecPad11(out)
out = self.conv11(out)
out = self.relu11(out)
out = self.reflecPad12(out)
out = self.conv12(out)
out = self.relu12(out)
out = self.reflecPad13(out)
out = self.conv13(out)
out = self.relu13(out)
out = self.maxPool4(out)
out = self.reflecPad14(out)
out = self.conv14(out)
out = self.relu14(out)
output['r51'] = out
return output
class styleLoss(nn.Module):
def forward(self, input, target):
ib,ic,ih,iw = input.size()
iF = input.view(ib,ic,-1)
iMean = torch.mean(iF,dim=2)
iCov = GramMatrix()(input)
tb,tc,th,tw = target.size()
tF = target.view(tb,tc,-1)
tMean = torch.mean(tF,dim=2)
tCov = GramMatrix()(target)
loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov)
return loss/tb
class GramMatrix(nn.Module):
def forward(self, input):
b, c, h, w = input.size()
f = input.view(b,c,h*w) # bxcx(hxw)
# torch.bmm(batch1, batch2, out=None) #
# batch1: bxmxp, batch2: bxpxn -> bxmxn #
G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
return G.div_(c*h*w)
class LossCriterion(nn.Module):
def __init__(self, style_layers, content_layers, style_weight, content_weight,
model_path = '/home/xtli/Documents/GITHUB/LinearStyleTransfer/models/'):
super(LossCriterion,self).__init__()
self.style_layers = style_layers
self.content_layers = content_layers
self.style_weight = style_weight
self.content_weight = content_weight
self.styleLosses = [styleLoss()] * len(style_layers)
self.contentLosses = [nn.MSELoss()] * len(content_layers)
self.vgg5 = encoder5()
self.vgg5.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth')))
for param in self.vgg5.parameters():
param.requires_grad = True
def forward(self, transfer, image, content=True, style=True):
cF = self.vgg5(image)
sF = self.vgg5(image)
tF = self.vgg5(transfer)
losses = {}
# content loss
if content:
totalContentLoss = 0
for i,layer in enumerate(self.content_layers):
cf_i = cF[layer]
cf_i = cf_i.detach()
tf_i = tF[layer]
loss_i = self.contentLosses[i]
totalContentLoss += loss_i(tf_i,cf_i)
totalContentLoss = totalContentLoss * self.content_weight
losses['content'] = totalContentLoss
# style loss
if style:
totalStyleLoss = 0
for i,layer in enumerate(self.style_layers):
sf_i = sF[layer]
sf_i = sf_i.detach()
tf_i = tF[layer]
loss_i = self.styleLosses[i]
totalStyleLoss += loss_i(tf_i,sf_i)
totalStyleLoss = totalStyleLoss * self.style_weight
losses['style'] = totalStyleLoss
return losses
class styleLossMask(nn.Module):
def forward(self, input, target, mask):
ib,ic,ih,iw = input.size()
iF = input.view(ib,ic,-1)
tb,tc,th,tw = target.size()
tF = target.view(tb,tc,-1)
loss = 0
mb, mc, mh, mw = mask.shape
for i in range(mb):
# resize mask to have the same size of the feature
maski = F.interpolate(mask[i:i+1], size = (ih, iw), mode = 'nearest')
mask_flat = maski.view(mc, -1)
for j in range(mc):
# get features for each part
idx = torch.nonzero(mask_flat[j]).squeeze()
if len(idx.shape) == 0 or idx.shape[0] == 0:
continue
ipart = torch.index_select(iF, 2, idx)
tpart = torch.index_select(tF, 2, idx)
iMean = torch.mean(ipart,dim=2)
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ic*ih*iw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
tMean = torch.mean(tpart,dim=2)
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tc*th*tw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
loss += nn.MSELoss()(iMean,tMean) + nn.MSELoss()(iGram,tGram)
return loss/tb
class LossCriterionMask(nn.Module):
def __init__(self, style_layers, content_layers, style_weight, content_weight,
model_path = '/home/xtli/Documents/GITHUB/LinearStyleTransfer/models/'):
super(LossCriterionMask,self).__init__()
self.style_layers = style_layers
self.content_layers = content_layers
self.style_weight = style_weight
self.content_weight = content_weight
self.styleLosses = [styleLossMask()] * len(style_layers)
self.contentLosses = [nn.MSELoss()] * len(content_layers)
self.vgg5 = encoder5()
self.vgg5.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth')))
for param in self.vgg5.parameters():
param.requires_grad = True
def forward(self, transfer, image, mask, content=True, style=True):
# mask: B, N, H, W
cF = self.vgg5(image)
sF = self.vgg5(image)
tF = self.vgg5(transfer)
losses = {}
# content loss
if content:
totalContentLoss = 0
for i,layer in enumerate(self.content_layers):
cf_i = cF[layer]
cf_i = cf_i.detach()
tf_i = tF[layer]
loss_i = self.contentLosses[i]
totalContentLoss += loss_i(tf_i,cf_i)
totalContentLoss = totalContentLoss * self.content_weight
losses['content'] = totalContentLoss
# style loss
if style:
totalStyleLoss = 0
for i,layer in enumerate(self.style_layers):
sf_i = sF[layer]
sf_i = sf_i.detach()
tf_i = tF[layer]
loss_i = self.styleLosses[i]
totalStyleLoss += loss_i(tf_i,sf_i, mask)
totalStyleLoss = totalStyleLoss * self.style_weight
losses['style'] = totalStyleLoss
return losses
class VQEmbedding(nn.Module):
def __init__(self, K, D):
super().__init__()
self.embedding = nn.Embedding(K, D)
self.embedding.weight.data.uniform_(-1./K, 1./K)
def forward(self, z_e_x):
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
latents = vq(z_e_x_, self.embedding.weight)
return latents
def straight_through(self, z_e_x, return_index=False):
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach())
z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()
z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
dim=0, index=indices)
z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous()
if return_index:
return z_q_x, z_q_x_bar, indices
else:
return z_q_x, z_q_x_bar