Spaces:
Runtime error
Runtime error
File size: 5,588 Bytes
92ec8d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import numpy as np
from torch import nn
from torch.nn import Conv2d, Sequential, Module
from models.hyperstyle.encoders.helpers import SeparableBlock
from models.stylegan2.model import EqualLinear
# layer_idx: [kernel_size, in_channels, out_channels]
PARAMETERS = {
0: [3, 512, 512],
1: [1, 512, 3],
2: [3, 512, 512],
3: [3, 512, 512],
4: [1, 512, 3],
5: [3, 512, 512],
6: [3, 512, 512],
7: [1, 512, 3],
8: [3, 512, 512],
9: [3, 512, 512],
10: [1, 512, 3],
11: [3, 512, 512],
12: [3, 512, 512],
13: [1, 512, 3],
14: [3, 512, 256],
15: [3, 256, 256],
16: [1, 256, 3],
17: [3, 256, 128],
18: [3, 128, 128],
19: [1, 128, 3],
20: [3, 128, 64],
21: [3, 64, 64],
22: [1, 64, 3],
23: [3, 64, 32],
24: [3, 32, 32],
25: [1, 32, 3]
}
TO_RGB_LAYERS = [1, 4, 7, 10, 13, 16, 19, 22, 25]
class RefinementBlock(Module):
def __init__(self, layer_idx, opts, n_channels=512, inner_c=256, spatial=16):
super(RefinementBlock, self).__init__()
self.layer_idx = layer_idx
self.opts = opts
self.kernel_size, self.in_channels, self.out_channels = PARAMETERS[self.layer_idx]
self.spatial = spatial
self.n_channels = n_channels
self.inner_c = inner_c
self.out_c = 512
num_pools = int(np.log2(self.spatial)) - 1
if self.kernel_size == 3:
num_pools = num_pools - 1
self.modules = []
self.modules += [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
for i in range(num_pools - 1):
self.modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
self.modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
self.convs = nn.Sequential(*self.modules)
if layer_idx in TO_RGB_LAYERS:
self.output = Sequential(
Conv2d(self.out_c, self.in_channels * self.out_channels, kernel_size=1, stride=1, padding=0))
else:
self.output = Sequential(nn.AdaptiveAvgPool2d((1, 1)),
Conv2d(self.out_c, self.in_channels * self.out_channels, kernel_size=1, stride=1,
padding=0))
def forward(self, x):
x = self.convs(x)
x = self.output(x)
if self.layer_idx in TO_RGB_LAYERS:
x = x.view(-1, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
else:
x = x.view(-1, self.out_channels, self.in_channels)
x = x.unsqueeze(3).repeat(1, 1, 1, self.kernel_size).unsqueeze(4).repeat(1, 1, 1, 1, self.kernel_size)
return x
class HyperRefinementBlock(Module):
def __init__(self, hypernet, n_channels=512, inner_c=128, spatial=16):
super(HyperRefinementBlock, self).__init__()
self.n_channels = n_channels
self.inner_c = inner_c
self.out_c = 512
num_pools = int(np.log2(spatial))
modules = [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=1, padding=1), nn.LeakyReLU()]
for i in range(num_pools - 1):
modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
self.convs = nn.Sequential(*modules)
self.linear = EqualLinear(self.out_c, self.out_c, lr_mul=1)
self.hypernet = hypernet
def forward(self, features):
code = self.convs(features)
code = code.view(-1, self.out_c)
code = self.linear(code)
weight_delta = self.hypernet(code)
return weight_delta
class RefinementBlockSeparable(Module):
def __init__(self, layer_idx, opts, n_channels=512, inner_c=256, spatial=16):
super(RefinementBlockSeparable, self).__init__()
self.layer_idx = layer_idx
self.kernel_size, self.in_channels, self.out_channels = PARAMETERS[self.layer_idx]
self.spatial = spatial
self.n_channels = n_channels
self.inner_c = inner_c
self.out_c = 512
num_pools = int(np.log2(self.spatial)) - 1
self.modules = []
self.modules += [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
for i in range(num_pools - 1):
self.modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
self.modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
self.convs = nn.Sequential(*self.modules)
self.opts = opts
if self.layer_idx in TO_RGB_LAYERS:
self.output = Sequential(Conv2d(self.out_c, self.in_channels * self.out_channels,
kernel_size=1, stride=1, padding=0))
else:
self.output = Sequential(SeparableBlock(input_size=self.out_c,
kernel_channels_in=self.in_channels,
kernel_channels_out=self.out_channels,
kernel_size=self.kernel_size))
def forward(self, x):
x = self.convs(x)
x = self.output(x)
if self.layer_idx in TO_RGB_LAYERS:
x = x.view(-1, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
return x |