incrl's picture
Initial Upload (attempt 2)
5b557cf verified
import torch
import torch.nn as nn
class encoder3(nn.Module):
def __init__(self):
super(encoder3,self).__init__()
# vgg
# 224 x 224
self.conv1 = nn.Conv2d(3,3,1,1,0)
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
# 226 x 226
self.conv2 = nn.Conv2d(3,64,3,1,0)
self.relu2 = nn.ReLU(inplace=True)
# 224 x 224
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
self.conv3 = nn.Conv2d(64,64,3,1,0)
self.relu3 = nn.ReLU(inplace=True)
# 224 x 224
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
# 112 x 112
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
self.conv4 = nn.Conv2d(64,128,3,1,0)
self.relu4 = nn.ReLU(inplace=True)
# 112 x 112
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
self.conv5 = nn.Conv2d(128,128,3,1,0)
self.relu5 = nn.ReLU(inplace=True)
# 112 x 112
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
# 56 x 56
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
self.conv6 = nn.Conv2d(128,256,3,1,0)
self.relu6 = nn.ReLU(inplace=True)
# 56 x 56
def forward(self,x):
out = self.conv1(x)
out = self.reflecPad1(out)
out = self.conv2(out)
out = self.relu2(out)
out = self.reflecPad3(out)
out = self.conv3(out)
pool1 = self.relu3(out)
out,pool_idx = self.maxPool(pool1)
out = self.reflecPad4(out)
out = self.conv4(out)
out = self.relu4(out)
out = self.reflecPad5(out)
out = self.conv5(out)
pool2 = self.relu5(out)
out,pool_idx2 = self.maxPool2(pool2)
out = self.reflecPad6(out)
out = self.conv6(out)
out = self.relu6(out)
return out
class decoder3(nn.Module):
def __init__(self):
super(decoder3,self).__init__()
# decoder
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
self.conv7 = nn.Conv2d(256,128,3,1,0)
self.relu7 = nn.ReLU(inplace=True)
# 56 x 56
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
# 112 x 112
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
self.conv8 = nn.Conv2d(128,128,3,1,0)
self.relu8 = nn.ReLU(inplace=True)
# 112 x 112
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
self.conv9 = nn.Conv2d(128,64,3,1,0)
self.relu9 = nn.ReLU(inplace=True)
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
# 224 x 224
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
self.conv10 = nn.Conv2d(64,64,3,1,0)
self.relu10 = nn.ReLU(inplace=True)
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
self.conv11 = nn.Conv2d(64,3,3,1,0)
def forward(self,x):
output = {}
out = self.reflecPad7(x)
out = self.conv7(out)
out = self.relu7(out)
out = self.unpool(out)
out = self.reflecPad8(out)
out = self.conv8(out)
out = self.relu8(out)
out = self.reflecPad9(out)
out = self.conv9(out)
out_relu9 = self.relu9(out)
out = self.unpool2(out_relu9)
out = self.reflecPad10(out)
out = self.conv10(out)
out = self.relu10(out)
out = self.reflecPad11(out)
out = self.conv11(out)
return out
class encoder4(nn.Module):
def __init__(self):
super(encoder4,self).__init__()
# vgg
# 224 x 224
self.conv1 = nn.Conv2d(3,3,1,1,0)
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
# 226 x 226
self.conv2 = nn.Conv2d(3,64,3,1,0)
self.relu2 = nn.ReLU(inplace=True)
# 224 x 224
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
self.conv3 = nn.Conv2d(64,64,3,1,0)
self.relu3 = nn.ReLU(inplace=True)
# 224 x 224
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
# 112 x 112
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
self.conv4 = nn.Conv2d(64,128,3,1,0)
self.relu4 = nn.ReLU(inplace=True)
# 112 x 112
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
self.conv5 = nn.Conv2d(128,128,3,1,0)
self.relu5 = nn.ReLU(inplace=True)
# 112 x 112
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
# 56 x 56
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
self.conv6 = nn.Conv2d(128,256,3,1,0)
self.relu6 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
self.conv7 = nn.Conv2d(256,256,3,1,0)
self.relu7 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
self.conv8 = nn.Conv2d(256,256,3,1,0)
self.relu8 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
self.conv9 = nn.Conv2d(256,256,3,1,0)
self.relu9 = nn.ReLU(inplace=True)
# 56 x 56
self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
# 28 x 28
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
self.conv10 = nn.Conv2d(256,512,3,1,0)
self.relu10 = nn.ReLU(inplace=True)
# 28 x 28
def forward(self,x,sF=None,matrix11=None,matrix21=None,matrix31=None):
output = {}
out = self.conv1(x)
out = self.reflecPad1(out)
out = self.conv2(out)
output['r11'] = self.relu2(out)
out = self.reflecPad7(output['r11'])
out = self.conv3(out)
output['r12'] = self.relu3(out)
output['p1'] = self.maxPool(output['r12'])
out = self.reflecPad4(output['p1'])
out = self.conv4(out)
output['r21'] = self.relu4(out)
out = self.reflecPad7(output['r21'])
out = self.conv5(out)
output['r22'] = self.relu5(out)
output['p2'] = self.maxPool2(output['r22'])
out = self.reflecPad6(output['p2'])
out = self.conv6(out)
output['r31'] = self.relu6(out)
if(matrix31 is not None):
feature3,transmatrix3 = matrix31(output['r31'],sF['r31'])
out = self.reflecPad7(feature3)
else:
out = self.reflecPad7(output['r31'])
out = self.conv7(out)
output['r32'] = self.relu7(out)
out = self.reflecPad8(output['r32'])
out = self.conv8(out)
output['r33'] = self.relu8(out)
out = self.reflecPad9(output['r33'])
out = self.conv9(out)
output['r34'] = self.relu9(out)
output['p3'] = self.maxPool3(output['r34'])
out = self.reflecPad10(output['p3'])
out = self.conv10(out)
output['r41'] = self.relu10(out)
return output
class decoder4(nn.Module):
def __init__(self):
super(decoder4,self).__init__()
# decoder
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
self.conv11 = nn.Conv2d(512,256,3,1,0)
self.relu11 = nn.ReLU(inplace=True)
# 28 x 28
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
# 56 x 56
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
self.conv12 = nn.Conv2d(256,256,3,1,0)
self.relu12 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
self.conv13 = nn.Conv2d(256,256,3,1,0)
self.relu13 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
self.conv14 = nn.Conv2d(256,256,3,1,0)
self.relu14 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
self.conv15 = nn.Conv2d(256,128,3,1,0)
self.relu15 = nn.ReLU(inplace=True)
# 56 x 56
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
# 112 x 112
self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
self.conv16 = nn.Conv2d(128,128,3,1,0)
self.relu16 = nn.ReLU(inplace=True)
# 112 x 112
self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
self.conv17 = nn.Conv2d(128,64,3,1,0)
self.relu17 = nn.ReLU(inplace=True)
# 112 x 112
self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
# 224 x 224
self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
self.conv18 = nn.Conv2d(64,64,3,1,0)
self.relu18 = nn.ReLU(inplace=True)
# 224 x 224
self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
self.conv19 = nn.Conv2d(64,3,3,1,0)
def forward(self,x):
# decoder
out = self.reflecPad11(x)
out = self.conv11(out)
out = self.relu11(out)
out = self.unpool(out)
out = self.reflecPad12(out)
out = self.conv12(out)
out = self.relu12(out)
out = self.reflecPad13(out)
out = self.conv13(out)
out = self.relu13(out)
out = self.reflecPad14(out)
out = self.conv14(out)
out = self.relu14(out)
out = self.reflecPad15(out)
out = self.conv15(out)
out = self.relu15(out)
out = self.unpool2(out)
out = self.reflecPad16(out)
out = self.conv16(out)
out = self.relu16(out)
out = self.reflecPad17(out)
out = self.conv17(out)
out = self.relu17(out)
out = self.unpool3(out)
out = self.reflecPad18(out)
out = self.conv18(out)
out = self.relu18(out)
out = self.reflecPad19(out)
out = self.conv19(out)
return out
class decoder4(nn.Module):
def __init__(self):
super(decoder4,self).__init__()
# decoder
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
self.conv11 = nn.Conv2d(512,256,3,1,0)
self.relu11 = nn.ReLU(inplace=True)
# 28 x 28
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
# 56 x 56
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
self.conv12 = nn.Conv2d(256,256,3,1,0)
self.relu12 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
self.conv13 = nn.Conv2d(256,256,3,1,0)
self.relu13 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
self.conv14 = nn.Conv2d(256,256,3,1,0)
self.relu14 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
self.conv15 = nn.Conv2d(256,128,3,1,0)
self.relu15 = nn.ReLU(inplace=True)
# 56 x 56
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
# 112 x 112
self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
self.conv16 = nn.Conv2d(128,128,3,1,0)
self.relu16 = nn.ReLU(inplace=True)
# 112 x 112
self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
self.conv17 = nn.Conv2d(128,64,3,1,0)
self.relu17 = nn.ReLU(inplace=True)
# 112 x 112
self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
# 224 x 224
self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
self.conv18 = nn.Conv2d(64,64,3,1,0)
self.relu18 = nn.ReLU(inplace=True)
# 224 x 224
self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
self.conv19 = nn.Conv2d(64,3,3,1,0)
def forward(self,x):
# decoder
out = self.reflecPad11(x)
out = self.conv11(out)
out = self.relu11(out)
out = self.unpool(out)
out = self.reflecPad12(out)
out = self.conv12(out)
out = self.relu12(out)
out = self.reflecPad13(out)
out = self.conv13(out)
out = self.relu13(out)
out = self.reflecPad14(out)
out = self.conv14(out)
out = self.relu14(out)
out = self.reflecPad15(out)
out = self.conv15(out)
out = self.relu15(out)
out = self.unpool2(out)
out = self.reflecPad16(out)
out = self.conv16(out)
out = self.relu16(out)
out = self.reflecPad17(out)
out = self.conv17(out)
out = self.relu17(out)
out = self.unpool3(out)
out = self.reflecPad18(out)
out = self.conv18(out)
out = self.relu18(out)
out = self.reflecPad19(out)
out = self.conv19(out)
return out
class encoder5(nn.Module):
def __init__(self):
super(encoder5,self).__init__()
# vgg
# 224 x 224
self.conv1 = nn.Conv2d(3,3,1,1,0)
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
# 226 x 226
self.conv2 = nn.Conv2d(3,64,3,1,0)
self.relu2 = nn.ReLU(inplace=True)
# 224 x 224
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
self.conv3 = nn.Conv2d(64,64,3,1,0)
self.relu3 = nn.ReLU(inplace=True)
# 224 x 224
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
# 112 x 112
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
self.conv4 = nn.Conv2d(64,128,3,1,0)
self.relu4 = nn.ReLU(inplace=True)
# 112 x 112
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
self.conv5 = nn.Conv2d(128,128,3,1,0)
self.relu5 = nn.ReLU(inplace=True)
# 112 x 112
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
# 56 x 56
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
self.conv6 = nn.Conv2d(128,256,3,1,0)
self.relu6 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
self.conv7 = nn.Conv2d(256,256,3,1,0)
self.relu7 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
self.conv8 = nn.Conv2d(256,256,3,1,0)
self.relu8 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
self.conv9 = nn.Conv2d(256,256,3,1,0)
self.relu9 = nn.ReLU(inplace=True)
# 56 x 56
self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
# 28 x 28
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
self.conv10 = nn.Conv2d(256,512,3,1,0)
self.relu10 = nn.ReLU(inplace=True)
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
self.conv11 = nn.Conv2d(512,512,3,1,0)
self.relu11 = nn.ReLU(inplace=True)
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
self.conv12 = nn.Conv2d(512,512,3,1,0)
self.relu12 = nn.ReLU(inplace=True)
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
self.conv13 = nn.Conv2d(512,512,3,1,0)
self.relu13 = nn.ReLU(inplace=True)
self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2)
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
self.conv14 = nn.Conv2d(512,512,3,1,0)
self.relu14 = nn.ReLU(inplace=True)
def forward(self,x,sF=None,contentV256=None,styleV256=None,matrix11=None,matrix21=None,matrix31=None):
output = {}
out = self.conv1(x)
out = self.reflecPad1(out)
out = self.conv2(out)
output['r11'] = self.relu2(out)
out = self.reflecPad7(output['r11'])
#out = self.reflecPad3(output['r11'])
out = self.conv3(out)
output['r12'] = self.relu3(out)
output['p1'] = self.maxPool(output['r12'])
out = self.reflecPad4(output['p1'])
out = self.conv4(out)
output['r21'] = self.relu4(out)
out = self.reflecPad7(output['r21'])
#out = self.reflecPad5(output['r21'])
out = self.conv5(out)
output['r22'] = self.relu5(out)
output['p2'] = self.maxPool2(output['r22'])
out = self.reflecPad6(output['p2'])
out = self.conv6(out)
output['r31'] = self.relu6(out)
if(styleV256 is not None):
feature = matrix31(output['r31'],sF['r31'],contentV256,styleV256)
out = self.reflecPad7(feature)
else:
out = self.reflecPad7(output['r31'])
out = self.conv7(out)
output['r32'] = self.relu7(out)
out = self.reflecPad8(output['r32'])
out = self.conv8(out)
output['r33'] = self.relu8(out)
out = self.reflecPad9(output['r33'])
out = self.conv9(out)
output['r34'] = self.relu9(out)
output['p3'] = self.maxPool3(output['r34'])
out = self.reflecPad10(output['p3'])
out = self.conv10(out)
output['r41'] = self.relu10(out)
out = self.reflecPad11(output['r41'])
out = self.conv11(out)
output['r42'] = self.relu11(out)
out = self.reflecPad12(output['r42'])
out = self.conv12(out)
output['r43'] = self.relu12(out)
out = self.reflecPad13(output['r43'])
out = self.conv13(out)
output['r44'] = self.relu13(out)
output['p4'] = self.maxPool4(output['r44'])
out = self.reflecPad14(output['p4'])
out = self.conv14(out)
output['r51'] = self.relu14(out)
return output
class decoder5(nn.Module):
def __init__(self):
super(decoder5,self).__init__()
# decoder
self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
self.conv15 = nn.Conv2d(512,512,3,1,0)
self.relu15 = nn.ReLU(inplace=True)
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
# 28 x 28
self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
self.conv16 = nn.Conv2d(512,512,3,1,0)
self.relu16 = nn.ReLU(inplace=True)
# 28 x 28
self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
self.conv17 = nn.Conv2d(512,512,3,1,0)
self.relu17 = nn.ReLU(inplace=True)
# 28 x 28
self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
self.conv18 = nn.Conv2d(512,512,3,1,0)
self.relu18 = nn.ReLU(inplace=True)
# 28 x 28
self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
self.conv19 = nn.Conv2d(512,256,3,1,0)
self.relu19 = nn.ReLU(inplace=True)
# 28 x 28
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
# 56 x 56
self.reflecPad20 = nn.ReflectionPad2d((1,1,1,1))
self.conv20 = nn.Conv2d(256,256,3,1,0)
self.relu20 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad21 = nn.ReflectionPad2d((1,1,1,1))
self.conv21 = nn.Conv2d(256,256,3,1,0)
self.relu21 = nn.ReLU(inplace=True)
self.reflecPad22 = nn.ReflectionPad2d((1,1,1,1))
self.conv22 = nn.Conv2d(256,256,3,1,0)
self.relu22 = nn.ReLU(inplace=True)
self.reflecPad23 = nn.ReflectionPad2d((1,1,1,1))
self.conv23 = nn.Conv2d(256,128,3,1,0)
self.relu23 = nn.ReLU(inplace=True)
self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
# 112 X 112
self.reflecPad24 = nn.ReflectionPad2d((1,1,1,1))
self.conv24 = nn.Conv2d(128,128,3,1,0)
self.relu24 = nn.ReLU(inplace=True)
self.reflecPad25 = nn.ReflectionPad2d((1,1,1,1))
self.conv25 = nn.Conv2d(128,64,3,1,0)
self.relu25 = nn.ReLU(inplace=True)
self.unpool4 = nn.UpsamplingNearest2d(scale_factor=2)
self.reflecPad26 = nn.ReflectionPad2d((1,1,1,1))
self.conv26 = nn.Conv2d(64,64,3,1,0)
self.relu26 = nn.ReLU(inplace=True)
self.reflecPad27 = nn.ReflectionPad2d((1,1,1,1))
self.conv27 = nn.Conv2d(64,3,3,1,0)
def forward(self,x):
# decoder
out = self.reflecPad15(x)
out = self.conv15(out)
out = self.relu15(out)
out = self.unpool(out)
out = self.reflecPad16(out)
out = self.conv16(out)
out = self.relu16(out)
out = self.reflecPad17(out)
out = self.conv17(out)
out = self.relu17(out)
out = self.reflecPad18(out)
out = self.conv18(out)
out = self.relu18(out)
out = self.reflecPad19(out)
out = self.conv19(out)
out = self.relu19(out)
out = self.unpool2(out)
out = self.reflecPad20(out)
out = self.conv20(out)
out = self.relu20(out)
out = self.reflecPad21(out)
out = self.conv21(out)
out = self.relu21(out)
out = self.reflecPad22(out)
out = self.conv22(out)
out = self.relu22(out)
out = self.reflecPad23(out)
out = self.conv23(out)
out = self.relu23(out)
out = self.unpool3(out)
out = self.reflecPad24(out)
out = self.conv24(out)
out = self.relu24(out)
out = self.reflecPad25(out)
out = self.conv25(out)
out = self.relu25(out)
out = self.unpool4(out)
out = self.reflecPad26(out)
out = self.conv26(out)
out = self.relu26(out)
out = self.reflecPad27(out)
out = self.conv27(out)
return out