Spaces:
Runtime error
Runtime error
File size: 4,309 Bytes
b456239 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import time
class ConvCIN(nn.Module):
def __init__(self, n_styles, C_in, C_out, kernel_size, padding, stride, activation=None):
super(ConvCIN, self).__init__()
self.reflection = nn.ReflectionPad2d(padding)
self.conv = nn.Conv2d(in_channels=C_in, out_channels=C_out, kernel_size=kernel_size, stride=stride)
nn.init.normal_(self.conv.weight, mean=0, std=1e-2)
self.instnorm = nn.InstanceNorm2d(C_out)#, affine=True)
#nn.init.normal_(self.instnorm.weight, mean=1, std=1e-2)
#nn.init.normal_(self.instnorm.bias, mean=0, std=1e-2)
self.gamma = torch.nn.Parameter(data=torch.randn(n_styles, C_out)*1e-2 + 1, requires_grad=True)
#self.gamma.data.uniform_(1.0, 1.0)
self.beta = torch.nn.Parameter(data=torch.randn(n_styles, C_out)*1e-2, requires_grad=True)
#self.beta.data.uniform_(0, 0)
self.activation = activation
def forward(self, x, style_1, style_2, alpha):
x = self.reflection(x)
x = self.conv(x)
x = self.instnorm(x)
if style_2 != None:
gamma = alpha*self.gamma[style_1] + (1-alpha)*self.gamma[style_2]
beta = alpha*self.beta[style_1] + (1-alpha)*self.beta[style_2]
else:
gamma = self.gamma[style_1]
beta = self.beta[style_1]
b,d,w,h = x.size()
x = x.view(b,d,w*h)
x = (x*gamma.unsqueeze(-1) + beta.unsqueeze(-1)).view(b,d,w,h)
if self.activation == 'relu':
x = F.relu(x)
elif self.activation == 'sigmoid':
x = torch.sigmoid(x)
return x
class ResidualBlock(nn.Module):
def __init__(self, n_styles, C_in, C_out):
super(ResidualBlock,self).__init__()
self.convcin1 = ConvCIN(n_styles, C_in, C_out, kernel_size=3, padding=1, stride=1, activation='relu')
self.convcin2 = ConvCIN(n_styles, C_in, C_out, kernel_size=3, padding=1, stride=1)
def forward(self, x, style_1, style_2, alpha):
out = self.convcin1(x, style_1, style_2, alpha)
out = self.convcin2(out, style_1, style_2, alpha)
return x + out
class UpSampling(nn.Module):
def __init__(self, n_styles, C_in, C_out):
super(UpSampling,self).__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.convcin = ConvCIN(n_styles, C_in, C_out, kernel_size=3, padding=1, stride=1, activation='relu')
def forward(self, x, style_1, style_2, alpha):
x = self.upsample(x)
x = self.convcin(x, style_1, style_2, alpha)
return x
class STModel(nn.Module):
def __init__(self, n_styles):
super(STModel,self).__init__()
self.convcin1 = ConvCIN(n_styles, C_in=3, C_out=32, kernel_size=9, padding=4, stride=1, activation='relu')
self.convcin2 = ConvCIN(n_styles, C_in=32, C_out=64, kernel_size=3, padding=1, stride=2, activation='relu')
self.convcin3 = ConvCIN(n_styles, C_in=64, C_out=128, kernel_size=3, padding=1, stride=2, activation='relu')
self.rb1 = ResidualBlock(n_styles, 128, 128)
self.rb2 = ResidualBlock(n_styles, 128, 128)
self.rb3 = ResidualBlock(n_styles, 128, 128)
self.rb4 = ResidualBlock(n_styles, 128, 128)
self.rb5 = ResidualBlock(n_styles, 128, 128)
self.upsample1 = UpSampling(n_styles, 128, 64)
self.upsample2 = UpSampling(n_styles, 64, 32)
self.convcin4 = ConvCIN(n_styles, C_in=32, C_out=3, kernel_size=9, padding=4, stride=1, activation='sigmoid')
def forward(self, x, style_1, style_2=None, alpha=0.5):
x = self.convcin1(x, style_1, style_2, alpha)
x = self.convcin2(x, style_1, style_2, alpha)
x = self.convcin3(x, style_1, style_2, alpha)
x = self.rb1(x, style_1, style_2, alpha)
x = self.rb2(x, style_1, style_2, alpha)
x = self.rb3(x, style_1, style_2, alpha)
x = self.rb4(x, style_1, style_2, alpha)
x = self.rb5(x, style_1, style_2, alpha)
x = self.upsample1(x, style_1, style_2, alpha)
x = self.upsample2(x, style_1, style_2, alpha)
x = self.convcin4(x, style_1, style_2, alpha)
return x |