Spaces:
Runtime error
Runtime error
import numpy as np | |
from torch import nn | |
from torch.nn import Conv2d, Sequential, Module | |
from models.hyperstyle.encoders.helpers import SeparableBlock | |
from models.stylegan2.model import EqualLinear | |
# layer_idx: [kernel_size, in_channels, out_channels] | |
PARAMETERS = { | |
0: [3, 512, 512], | |
1: [1, 512, 3], | |
2: [3, 512, 512], | |
3: [3, 512, 512], | |
4: [1, 512, 3], | |
5: [3, 512, 512], | |
6: [3, 512, 512], | |
7: [1, 512, 3], | |
8: [3, 512, 512], | |
9: [3, 512, 512], | |
10: [1, 512, 3], | |
11: [3, 512, 512], | |
12: [3, 512, 512], | |
13: [1, 512, 3], | |
14: [3, 512, 256], | |
15: [3, 256, 256], | |
16: [1, 256, 3], | |
17: [3, 256, 128], | |
18: [3, 128, 128], | |
19: [1, 128, 3], | |
20: [3, 128, 64], | |
21: [3, 64, 64], | |
22: [1, 64, 3], | |
23: [3, 64, 32], | |
24: [3, 32, 32], | |
25: [1, 32, 3] | |
} | |
TO_RGB_LAYERS = [1, 4, 7, 10, 13, 16, 19, 22, 25] | |
class RefinementBlock(Module): | |
def __init__(self, layer_idx, opts, n_channels=512, inner_c=256, spatial=16): | |
super(RefinementBlock, self).__init__() | |
self.layer_idx = layer_idx | |
self.opts = opts | |
self.kernel_size, self.in_channels, self.out_channels = PARAMETERS[self.layer_idx] | |
self.spatial = spatial | |
self.n_channels = n_channels | |
self.inner_c = inner_c | |
self.out_c = 512 | |
num_pools = int(np.log2(self.spatial)) - 1 | |
if self.kernel_size == 3: | |
num_pools = num_pools - 1 | |
self.modules = [] | |
self.modules += [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] | |
for i in range(num_pools - 1): | |
self.modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] | |
self.modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] | |
self.convs = nn.Sequential(*self.modules) | |
if layer_idx in TO_RGB_LAYERS: | |
self.output = Sequential( | |
Conv2d(self.out_c, self.in_channels * self.out_channels, kernel_size=1, stride=1, padding=0)) | |
else: | |
self.output = Sequential(nn.AdaptiveAvgPool2d((1, 1)), | |
Conv2d(self.out_c, self.in_channels * self.out_channels, kernel_size=1, stride=1, | |
padding=0)) | |
def forward(self, x): | |
x = self.convs(x) | |
x = self.output(x) | |
if self.layer_idx in TO_RGB_LAYERS: | |
x = x.view(-1, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size) | |
else: | |
x = x.view(-1, self.out_channels, self.in_channels) | |
x = x.unsqueeze(3).repeat(1, 1, 1, self.kernel_size).unsqueeze(4).repeat(1, 1, 1, 1, self.kernel_size) | |
return x | |
class HyperRefinementBlock(Module): | |
def __init__(self, hypernet, n_channels=512, inner_c=128, spatial=16): | |
super(HyperRefinementBlock, self).__init__() | |
self.n_channels = n_channels | |
self.inner_c = inner_c | |
self.out_c = 512 | |
num_pools = int(np.log2(spatial)) | |
modules = [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=1, padding=1), nn.LeakyReLU()] | |
for i in range(num_pools - 1): | |
modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] | |
modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] | |
self.convs = nn.Sequential(*modules) | |
self.linear = EqualLinear(self.out_c, self.out_c, lr_mul=1) | |
self.hypernet = hypernet | |
def forward(self, features): | |
code = self.convs(features) | |
code = code.view(-1, self.out_c) | |
code = self.linear(code) | |
weight_delta = self.hypernet(code) | |
return weight_delta | |
class RefinementBlockSeparable(Module): | |
def __init__(self, layer_idx, opts, n_channels=512, inner_c=256, spatial=16): | |
super(RefinementBlockSeparable, self).__init__() | |
self.layer_idx = layer_idx | |
self.kernel_size, self.in_channels, self.out_channels = PARAMETERS[self.layer_idx] | |
self.spatial = spatial | |
self.n_channels = n_channels | |
self.inner_c = inner_c | |
self.out_c = 512 | |
num_pools = int(np.log2(self.spatial)) - 1 | |
self.modules = [] | |
self.modules += [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] | |
for i in range(num_pools - 1): | |
self.modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] | |
self.modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] | |
self.convs = nn.Sequential(*self.modules) | |
self.opts = opts | |
if self.layer_idx in TO_RGB_LAYERS: | |
self.output = Sequential(Conv2d(self.out_c, self.in_channels * self.out_channels, | |
kernel_size=1, stride=1, padding=0)) | |
else: | |
self.output = Sequential(SeparableBlock(input_size=self.out_c, | |
kernel_channels_in=self.in_channels, | |
kernel_channels_out=self.out_channels, | |
kernel_size=self.kernel_size)) | |
def forward(self, x): | |
x = self.convs(x) | |
x = self.output(x) | |
if self.layer_idx in TO_RGB_LAYERS: | |
x = x.view(-1, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size) | |
return x |