Spaces:
Runtime error
Runtime error
File size: 8,944 Bytes
6e7b2f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResNeXtBottleneck(nn.Module):
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
super(ResNeXtBottleneck, self).__init__()
D = out_channels // 2
self.out_channels = out_channels
self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate, groups=cardinality, bias=False)
self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
self.shortcut = nn.Sequential()
if stride != 1:
self.shortcut.add_module('shortcut', nn.AvgPool2d(2, stride=2))
def forward(self, x):
bottleneck = self.conv_reduce.forward(x)
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
bottleneck = self.conv_conv.forward(bottleneck)
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
bottleneck = self.conv_expand.forward(bottleneck)
x = self.shortcut.forward(x)
return x + bottleneck
class Generator(nn.Module):
def __init__(self, ngf=64, feat=True):
super(Generator, self).__init__()
self.feat = feat
if feat:
add_channels = 512
else:
add_channels = 0
self.toH = self._block(4, ngf, kernel_size=7, stride=1, padding=3)
self.to0 = self._block(1, ngf // 2, kernel_size=3, stride=1, padding=1)
self.to1 = self._block(ngf // 2, ngf, kernel_size=4, stride=2, padding=1)
self.to2 = self._block(ngf, ngf * 2, kernel_size=4, stride=2, padding=1)
self.to3 = self._block(ngf * 3, ngf * 4, kernel_size=4, stride=2, padding=1)
self.to4 = self._block(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1)
tunnel4 = nn.Sequential(*[ResNeXtBottleneck(ngf * 8, ngf * 8, cardinality=32, dilate=1) for _ in range(20)])
self.tunnel4 = nn.Sequential(self._block(ngf * 8 + add_channels, ngf * 8, kernel_size=3, stride=1, padding=1),
tunnel4,
nn.Conv2d(ngf * 8, ngf * 16, kernel_size = 3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(0.2, True))
depth = 2
tunnel = [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1) for _ in range(depth)]
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2) for _ in range(depth)]
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=4) for _ in range(depth)]
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2),
ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1)]
tunnel3 = nn.Sequential(*tunnel)
self.tunnel3 = nn.Sequential(self._block(ngf * 8, ngf * 4, kernel_size=3, stride=1, padding=1),
tunnel3,
nn.Conv2d(ngf * 4, ngf * 8, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(0.2, True))
tunnel = [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1) for _ in range(depth)]
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2) for _ in range(depth)]
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=4) for _ in range(depth)]
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2),
ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1)]
tunnel2 = nn.Sequential(*tunnel)
self.tunnel2 = nn.Sequential(self._block(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=1),
tunnel2,
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(0.2, True))
tunnel = [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2)]
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=4)]
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2),
ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
tunnel1 = nn.Sequential(*tunnel)
self.tunnel1 = nn.Sequential(self._block(ngf * 2, ngf, kernel_size=3, stride=1, padding=1),
tunnel1,
nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(0.2, True))
self.exit = nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.LeakyReLU(0.2, True)
)
def forward(self, sketch, hint, sketch_feat):
hint = self.toH(hint)
x0 = self.to0(sketch)
x1 = self.to1(x0)
x2 = self.to2(x1)
x3 = self.to3(torch.cat([x2, hint], 1))
x4 = self.to4(x3)
if self.feat:
x = self.tunnel4(torch.cat([x4, sketch_feat], 1))
x = self.tunnel3(torch.cat([x, x3], 1))
x = self.tunnel2(torch.cat([x, x2], 1))
x = self.tunnel1(torch.cat([x, x1], 1))
x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
else:
x = self.tunnel4(x4)
x = self.tunnel3(torch.cat([x, x3], 1))
x = self.tunnel2(torch.cat([x, x2], 1))
x = self.tunnel1(torch.cat([x, x1], 1))
x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
return x
class Discriminator(nn.Module):
def __init__(self, ndf=64, feat=True):
super(Discriminator, self).__init__()
self.feat = feat
if feat:
add_channels = ndf * 8
ks = 4
else:
add_channels = 0
ks = 3
self.feed = nn.Sequential(
self._block(3, ndf, kernel_size=7, stride=1, padding=1),
self._block(ndf, ndf, kernel_size=4, stride=2, padding=1),
ResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1),
ResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1, stride=2),
self._block(ndf, ndf * 2, kernel_size=1, stride=1, padding=0),
ResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1),
ResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1, stride=2),
self._block(ndf * 2, ndf * 4, kernel_size=1, stride=1, padding=0),
ResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1),
ResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1, stride=2)
)
self.feed2 = nn.Sequential(
self._block(ndf * 4 + add_channels, ndf * 8, kernel_size=3, stride=1, padding=1),
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2),
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2),
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2),
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
self._block(ndf * 8, ndf * 8, kernel_size=ks, stride=1, padding=0),
)
self.out = nn.Linear(512, 1)
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.Conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False),
nn.LeakyReLU(0.2, True)
)
def forward(self, color, sketch_feat=None):
x = self.feed(color)
if self.feat:
x = self.feed2(torch.cat([x, sketch_feat], 1))
else:
x = self.feed2(x)
out = self.out(x.view(color.size(0), -1))
return out |