"""Network Modules - encoder3: vgg encoder up to relu31 - decoder3: mirror decoder to encoder3 - encoder4: vgg encoder up to relu41 - decoder4: mirror decoder to encoder4 - encoder5: vgg encoder up to relu51 - styleLoss: gram matrix loss for all style layers - styleLossMask: gram matrix loss for all style layers, compare between each part defined by a mask - GramMatrix: compute gram matrix for one layer - LossCriterion: style transfer loss that include both content & style losses - LossCriterionMask: style transfer loss that include both content & style losses, use the styleLossMask - VQEmbedding: codebook class for VQVAE """ import os import torch import torch.nn as nn import torch.nn.functional as F from .vq_functions import vq, vq_st from collections import OrderedDict class MetaModule(nn.Module): """ Base class for PyTorch meta-learning modules. These modules accept an additional argument `params` in their `forward` method. Notes ----- Objects inherited from `MetaModule` are fully compatible with PyTorch modules from `torch.nn.Module`. The argument `params` is a dictionary of tensors, with full support of the computation graph (for differentiation). """ def meta_named_parameters(self, prefix='', recurse=True): gen = self._named_members( lambda module: module._parameters.items() if isinstance(module, MetaModule) else [], prefix=prefix, recurse=recurse) for elem in gen: yield elem def meta_parameters(self, recurse=True): for name, param in self.meta_named_parameters(recurse=recurse): yield param class BatchLinear(nn.Linear, MetaModule): '''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a hypernetwork.''' __doc__ = nn.Linear.__doc__ def forward(self, input, params=None): if params is None: params = OrderedDict(self.named_parameters()) bias = params.get('bias', None) weight = params['weight'] output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2)) output += bias.unsqueeze(-2) return output class decoder1(nn.Module): def __init__(self): super(decoder1,self).__init__() self.reflecPad2 = nn.ReflectionPad2d((1,1,1,1)) # 226 x 226 self.conv3 = nn.Conv2d(64,3,3,1,0) # 224 x 224 def forward(self,x): out = self.reflecPad2(x) out = self.conv3(out) return out class decoder2(nn.Module): def __init__(self): super(decoder2,self).__init__() # decoder self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) self.conv5 = nn.Conv2d(128,64,3,1,0) self.relu5 = nn.ReLU(inplace=True) # 112 x 112 self.unpool = nn.UpsamplingNearest2d(scale_factor=2) # 224 x 224 self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) self.conv6 = nn.Conv2d(64,64,3,1,0) self.relu6 = nn.ReLU(inplace=True) # 224 x 224 self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) self.conv7 = nn.Conv2d(64,3,3,1,0) def forward(self,x): out = self.reflecPad5(x) out = self.conv5(out) out = self.relu5(out) out = self.unpool(out) out = self.reflecPad6(out) out = self.conv6(out) out = self.relu6(out) out = self.reflecPad7(out) out = self.conv7(out) return out class encoder3(nn.Module): def __init__(self): super(encoder3,self).__init__() # vgg # 224 x 224 self.conv1 = nn.Conv2d(3,3,1,1,0) self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) # 226 x 226 self.conv2 = nn.Conv2d(3,64,3,1,0) self.relu2 = nn.ReLU(inplace=True) # 224 x 224 self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) self.conv3 = nn.Conv2d(64,64,3,1,0) self.relu3 = nn.ReLU(inplace=True) # 224 x 224 self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) # 112 x 112 self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) self.conv4 = nn.Conv2d(64,128,3,1,0) self.relu4 = nn.ReLU(inplace=True) # 112 x 112 self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) self.conv5 = nn.Conv2d(128,128,3,1,0) self.relu5 = nn.ReLU(inplace=True) # 112 x 112 self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) # 56 x 56 self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) self.conv6 = nn.Conv2d(128,256,3,1,0) self.relu6 = nn.ReLU(inplace=True) # 56 x 56 def forward(self,x): out = self.conv1(x) out = self.reflecPad1(out) out = self.conv2(out) out = self.relu2(out) out = self.reflecPad3(out) out = self.conv3(out) pool1 = self.relu3(out) out,pool_idx = self.maxPool(pool1) out = self.reflecPad4(out) out = self.conv4(out) out = self.relu4(out) out = self.reflecPad5(out) out = self.conv5(out) pool2 = self.relu5(out) out,pool_idx2 = self.maxPool2(pool2) out = self.reflecPad6(out) out = self.conv6(out) out = self.relu6(out) return out class decoder3(nn.Module): def __init__(self): super(decoder3,self).__init__() # decoder self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) self.conv7 = nn.Conv2d(256,128,3,1,0) self.relu7 = nn.ReLU(inplace=True) # 56 x 56 self.unpool = nn.UpsamplingNearest2d(scale_factor=2) # 112 x 112 self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) self.conv8 = nn.Conv2d(128,128,3,1,0) self.relu8 = nn.ReLU(inplace=True) # 112 x 112 self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) self.conv9 = nn.Conv2d(128,64,3,1,0) self.relu9 = nn.ReLU(inplace=True) self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) # 224 x 224 self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) self.conv10 = nn.Conv2d(64,64,3,1,0) self.relu10 = nn.ReLU(inplace=True) self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) self.conv11 = nn.Conv2d(64,3,3,1,0) def forward(self,x): output = {} out = self.reflecPad7(x) out = self.conv7(out) out = self.relu7(out) out = self.unpool(out) out = self.reflecPad8(out) out = self.conv8(out) out = self.relu8(out) out = self.reflecPad9(out) out = self.conv9(out) out_relu9 = self.relu9(out) out = self.unpool2(out_relu9) out = self.reflecPad10(out) out = self.conv10(out) out = self.relu10(out) out = self.reflecPad11(out) out = self.conv11(out) return out class encoder4(nn.Module): def __init__(self): super(encoder4,self).__init__() # vgg # 224 x 224 self.conv1 = nn.Conv2d(3,3,1,1,0) self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) # 226 x 226 self.conv2 = nn.Conv2d(3,64,3,1,0) self.relu2 = nn.ReLU(inplace=True) # 224 x 224 self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) self.conv3 = nn.Conv2d(64,64,3,1,0) self.relu3 = nn.ReLU(inplace=True) # 224 x 224 self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2) # 112 x 112 self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) self.conv4 = nn.Conv2d(64,128,3,1,0) self.relu4 = nn.ReLU(inplace=True) # 112 x 112 self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) self.conv5 = nn.Conv2d(128,128,3,1,0) self.relu5 = nn.ReLU(inplace=True) # 112 x 112 self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2) # 56 x 56 self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) self.conv6 = nn.Conv2d(128,256,3,1,0) self.relu6 = nn.ReLU(inplace=True) # 56 x 56 self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) self.conv7 = nn.Conv2d(256,256,3,1,0) self.relu7 = nn.ReLU(inplace=True) # 56 x 56 self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) self.conv8 = nn.Conv2d(256,256,3,1,0) self.relu8 = nn.ReLU(inplace=True) # 56 x 56 self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) self.conv9 = nn.Conv2d(256,256,3,1,0) self.relu9 = nn.ReLU(inplace=True) # 56 x 56 self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2) # 28 x 28 self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) self.conv10 = nn.Conv2d(256,512,3,1,0) self.relu10 = nn.ReLU(inplace=True) # 28 x 28 def forward(self,x,sF=None,matrix11=None,matrix21=None,matrix31=None): output = {} out = self.conv1(x) out = self.reflecPad1(out) out = self.conv2(out) output['r11'] = self.relu2(out) out = self.reflecPad7(output['r11']) out = self.conv3(out) output['r12'] = self.relu3(out) output['p1'] = self.maxPool(output['r12']) out = self.reflecPad4(output['p1']) out = self.conv4(out) output['r21'] = self.relu4(out) out = self.reflecPad7(output['r21']) out = self.conv5(out) output['r22'] = self.relu5(out) output['p2'] = self.maxPool2(output['r22']) out = self.reflecPad6(output['p2']) out = self.conv6(out) output['r31'] = self.relu6(out) if(matrix31 is not None): feature3,transmatrix3 = matrix31(output['r31'],sF['r31']) out = self.reflecPad7(feature3) else: out = self.reflecPad7(output['r31']) out = self.conv7(out) output['r32'] = self.relu7(out) out = self.reflecPad8(output['r32']) out = self.conv8(out) output['r33'] = self.relu8(out) out = self.reflecPad9(output['r33']) out = self.conv9(out) output['r34'] = self.relu9(out) output['p3'] = self.maxPool3(output['r34']) out = self.reflecPad10(output['p3']) out = self.conv10(out) output['r41'] = self.relu10(out) return output class decoder4(nn.Module): def __init__(self): super(decoder4,self).__init__() # decoder self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) self.conv11 = nn.Conv2d(512,256,3,1,0) self.relu11 = nn.ReLU(inplace=True) # 28 x 28 self.unpool = nn.UpsamplingNearest2d(scale_factor=2) # 56 x 56 self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) self.conv12 = nn.Conv2d(256,256,3,1,0) self.relu12 = nn.ReLU(inplace=True) # 56 x 56 self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) self.conv13 = nn.Conv2d(256,256,3,1,0) self.relu13 = nn.ReLU(inplace=True) # 56 x 56 self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) self.conv14 = nn.Conv2d(256,256,3,1,0) self.relu14 = nn.ReLU(inplace=True) # 56 x 56 self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1)) self.conv15 = nn.Conv2d(256,128,3,1,0) self.relu15 = nn.ReLU(inplace=True) # 56 x 56 self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) # 112 x 112 self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1)) self.conv16 = nn.Conv2d(128,128,3,1,0) self.relu16 = nn.ReLU(inplace=True) # 112 x 112 self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1)) self.conv17 = nn.Conv2d(128,64,3,1,0) self.relu17 = nn.ReLU(inplace=True) # 112 x 112 self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2) # 224 x 224 self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1)) self.conv18 = nn.Conv2d(64,64,3,1,0) self.relu18 = nn.ReLU(inplace=True) # 224 x 224 self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1)) self.conv19 = nn.Conv2d(64,3,3,1,0) def forward(self,x): # decoder out = self.reflecPad11(x) out = self.conv11(out) out = self.relu11(out) out = self.unpool(out) out = self.reflecPad12(out) out = self.conv12(out) out = self.relu12(out) out = self.reflecPad13(out) out = self.conv13(out) out = self.relu13(out) out = self.reflecPad14(out) out = self.conv14(out) out = self.relu14(out) out = self.reflecPad15(out) out = self.conv15(out) out = self.relu15(out) out = self.unpool2(out) out = self.reflecPad16(out) out = self.conv16(out) out = self.relu16(out) out = self.reflecPad17(out) out = self.conv17(out) out = self.relu17(out) out = self.unpool3(out) out = self.reflecPad18(out) out = self.conv18(out) out = self.relu18(out) out = self.reflecPad19(out) out = self.conv19(out) return out class encoder5(nn.Module): def __init__(self): super(encoder5,self).__init__() # vgg # 224 x 224 self.conv1 = nn.Conv2d(3,3,1,1,0) self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) # 226 x 226 self.conv2 = nn.Conv2d(3,64,3,1,0) self.relu2 = nn.ReLU(inplace=True) # 224 x 224 self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) self.conv3 = nn.Conv2d(64,64,3,1,0) self.relu3 = nn.ReLU(inplace=True) # 224 x 224 self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2) # 112 x 112 self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) self.conv4 = nn.Conv2d(64,128,3,1,0) self.relu4 = nn.ReLU(inplace=True) # 112 x 112 self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) self.conv5 = nn.Conv2d(128,128,3,1,0) self.relu5 = nn.ReLU(inplace=True) # 112 x 112 self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2) # 56 x 56 self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) self.conv6 = nn.Conv2d(128,256,3,1,0) self.relu6 = nn.ReLU(inplace=True) # 56 x 56 self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) self.conv7 = nn.Conv2d(256,256,3,1,0) self.relu7 = nn.ReLU(inplace=True) # 56 x 56 self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) self.conv8 = nn.Conv2d(256,256,3,1,0) self.relu8 = nn.ReLU(inplace=True) # 56 x 56 self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) self.conv9 = nn.Conv2d(256,256,3,1,0) self.relu9 = nn.ReLU(inplace=True) # 56 x 56 self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2) # 28 x 28 self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) self.conv10 = nn.Conv2d(256,512,3,1,0) self.relu10 = nn.ReLU(inplace=True) self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) self.conv11 = nn.Conv2d(512,512,3,1,0) self.relu11 = nn.ReLU(inplace=True) self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) self.conv12 = nn.Conv2d(512,512,3,1,0) self.relu12 = nn.ReLU(inplace=True) self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) self.conv13 = nn.Conv2d(512,512,3,1,0) self.relu13 = nn.ReLU(inplace=True) self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2) self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) self.conv14 = nn.Conv2d(512,512,3,1,0) self.relu14 = nn.ReLU(inplace=True) def forward(self,x,sF=None,contentV256=None,styleV256=None,matrix11=None,matrix21=None,matrix31=None): output = {} out = self.conv1(x) out = self.reflecPad1(out) out = self.conv2(out) output['r11'] = self.relu2(out) out = self.reflecPad7(output['r11']) #out = self.reflecPad3(output['r11']) out = self.conv3(out) output['r12'] = self.relu3(out) output['p1'] = self.maxPool(output['r12']) out = self.reflecPad4(output['p1']) out = self.conv4(out) output['r21'] = self.relu4(out) out = self.reflecPad7(output['r21']) #out = self.reflecPad5(output['r21']) out = self.conv5(out) output['r22'] = self.relu5(out) output['p2'] = self.maxPool2(output['r22']) out = self.reflecPad6(output['p2']) out = self.conv6(out) output['r31'] = self.relu6(out) if(styleV256 is not None): feature = matrix31(output['r31'],sF['r31'],contentV256,styleV256) out = self.reflecPad7(feature) else: out = self.reflecPad7(output['r31']) out = self.conv7(out) output['r32'] = self.relu7(out) out = self.reflecPad8(output['r32']) out = self.conv8(out) output['r33'] = self.relu8(out) out = self.reflecPad9(output['r33']) out = self.conv9(out) output['r34'] = self.relu9(out) output['p3'] = self.maxPool3(output['r34']) out = self.reflecPad10(output['p3']) out = self.conv10(out) output['r41'] = self.relu10(out) out = self.reflecPad11(out) out = self.conv11(out) out = self.relu11(out) out = self.reflecPad12(out) out = self.conv12(out) out = self.relu12(out) out = self.reflecPad13(out) out = self.conv13(out) out = self.relu13(out) out = self.maxPool4(out) out = self.reflecPad14(out) out = self.conv14(out) out = self.relu14(out) output['r51'] = out return output class styleLoss(nn.Module): def forward(self, input, target): ib,ic,ih,iw = input.size() iF = input.view(ib,ic,-1) iMean = torch.mean(iF,dim=2) iCov = GramMatrix()(input) tb,tc,th,tw = target.size() tF = target.view(tb,tc,-1) tMean = torch.mean(tF,dim=2) tCov = GramMatrix()(target) loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov) return loss/tb class GramMatrix(nn.Module): def forward(self, input): b, c, h, w = input.size() f = input.view(b,c,h*w) # bxcx(hxw) # torch.bmm(batch1, batch2, out=None) # # batch1: bxmxp, batch2: bxpxn -> bxmxn # G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc return G.div_(c*h*w) class LossCriterion(nn.Module): def __init__(self, style_layers, content_layers, style_weight, content_weight, model_path = '/home/xtli/Documents/GITHUB/LinearStyleTransfer/models/'): super(LossCriterion,self).__init__() self.style_layers = style_layers self.content_layers = content_layers self.style_weight = style_weight self.content_weight = content_weight self.styleLosses = [styleLoss()] * len(style_layers) self.contentLosses = [nn.MSELoss()] * len(content_layers) self.vgg5 = encoder5() self.vgg5.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth'))) for param in self.vgg5.parameters(): param.requires_grad = True def forward(self, transfer, image, content=True, style=True): cF = self.vgg5(image) sF = self.vgg5(image) tF = self.vgg5(transfer) losses = {} # content loss if content: totalContentLoss = 0 for i,layer in enumerate(self.content_layers): cf_i = cF[layer] cf_i = cf_i.detach() tf_i = tF[layer] loss_i = self.contentLosses[i] totalContentLoss += loss_i(tf_i,cf_i) totalContentLoss = totalContentLoss * self.content_weight losses['content'] = totalContentLoss # style loss if style: totalStyleLoss = 0 for i,layer in enumerate(self.style_layers): sf_i = sF[layer] sf_i = sf_i.detach() tf_i = tF[layer] loss_i = self.styleLosses[i] totalStyleLoss += loss_i(tf_i,sf_i) totalStyleLoss = totalStyleLoss * self.style_weight losses['style'] = totalStyleLoss return losses class styleLossMask(nn.Module): def forward(self, input, target, mask): ib,ic,ih,iw = input.size() iF = input.view(ib,ic,-1) tb,tc,th,tw = target.size() tF = target.view(tb,tc,-1) loss = 0 mb, mc, mh, mw = mask.shape for i in range(mb): # resize mask to have the same size of the feature maski = F.interpolate(mask[i:i+1], size = (ih, iw), mode = 'nearest') mask_flat = maski.view(mc, -1) for j in range(mc): # get features for each part idx = torch.nonzero(mask_flat[j]).squeeze() if len(idx.shape) == 0 or idx.shape[0] == 0: continue ipart = torch.index_select(iF, 2, idx) tpart = torch.index_select(tF, 2, idx) iMean = torch.mean(ipart,dim=2) iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ic*ih*iw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc tMean = torch.mean(tpart,dim=2) tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tc*th*tw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc loss += nn.MSELoss()(iMean,tMean) + nn.MSELoss()(iGram,tGram) return loss/tb class LossCriterionMask(nn.Module): def __init__(self, style_layers, content_layers, style_weight, content_weight, model_path = '/home/xtli/Documents/GITHUB/LinearStyleTransfer/models/'): super(LossCriterionMask,self).__init__() self.style_layers = style_layers self.content_layers = content_layers self.style_weight = style_weight self.content_weight = content_weight self.styleLosses = [styleLossMask()] * len(style_layers) self.contentLosses = [nn.MSELoss()] * len(content_layers) self.vgg5 = encoder5() self.vgg5.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth'))) for param in self.vgg5.parameters(): param.requires_grad = True def forward(self, transfer, image, mask, content=True, style=True): # mask: B, N, H, W cF = self.vgg5(image) sF = self.vgg5(image) tF = self.vgg5(transfer) losses = {} # content loss if content: totalContentLoss = 0 for i,layer in enumerate(self.content_layers): cf_i = cF[layer] cf_i = cf_i.detach() tf_i = tF[layer] loss_i = self.contentLosses[i] totalContentLoss += loss_i(tf_i,cf_i) totalContentLoss = totalContentLoss * self.content_weight losses['content'] = totalContentLoss # style loss if style: totalStyleLoss = 0 for i,layer in enumerate(self.style_layers): sf_i = sF[layer] sf_i = sf_i.detach() tf_i = tF[layer] loss_i = self.styleLosses[i] totalStyleLoss += loss_i(tf_i,sf_i, mask) totalStyleLoss = totalStyleLoss * self.style_weight losses['style'] = totalStyleLoss return losses class VQEmbedding(nn.Module): def __init__(self, K, D): super().__init__() self.embedding = nn.Embedding(K, D) self.embedding.weight.data.uniform_(-1./K, 1./K) def forward(self, z_e_x): z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous() latents = vq(z_e_x_, self.embedding.weight) return latents def straight_through(self, z_e_x, return_index=False): z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous() z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach()) z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous() z_q_x_bar_flatten = torch.index_select(self.embedding.weight, dim=0, index=indices) z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_) z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous() if return_index: return z_q_x, z_q_x_bar, indices else: return z_q_x, z_q_x_bar