ethanNeuralImage's picture
models
92ec8d3
raw
history blame
No virus
5.59 kB
import numpy as np
from torch import nn
from torch.nn import Conv2d, Sequential, Module
from models.hyperstyle.encoders.helpers import SeparableBlock
from models.stylegan2.model import EqualLinear
# layer_idx: [kernel_size, in_channels, out_channels]
PARAMETERS = {
0: [3, 512, 512],
1: [1, 512, 3],
2: [3, 512, 512],
3: [3, 512, 512],
4: [1, 512, 3],
5: [3, 512, 512],
6: [3, 512, 512],
7: [1, 512, 3],
8: [3, 512, 512],
9: [3, 512, 512],
10: [1, 512, 3],
11: [3, 512, 512],
12: [3, 512, 512],
13: [1, 512, 3],
14: [3, 512, 256],
15: [3, 256, 256],
16: [1, 256, 3],
17: [3, 256, 128],
18: [3, 128, 128],
19: [1, 128, 3],
20: [3, 128, 64],
21: [3, 64, 64],
22: [1, 64, 3],
23: [3, 64, 32],
24: [3, 32, 32],
25: [1, 32, 3]
}
TO_RGB_LAYERS = [1, 4, 7, 10, 13, 16, 19, 22, 25]
class RefinementBlock(Module):
def __init__(self, layer_idx, opts, n_channels=512, inner_c=256, spatial=16):
super(RefinementBlock, self).__init__()
self.layer_idx = layer_idx
self.opts = opts
self.kernel_size, self.in_channels, self.out_channels = PARAMETERS[self.layer_idx]
self.spatial = spatial
self.n_channels = n_channels
self.inner_c = inner_c
self.out_c = 512
num_pools = int(np.log2(self.spatial)) - 1
if self.kernel_size == 3:
num_pools = num_pools - 1
self.modules = []
self.modules += [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
for i in range(num_pools - 1):
self.modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
self.modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
self.convs = nn.Sequential(*self.modules)
if layer_idx in TO_RGB_LAYERS:
self.output = Sequential(
Conv2d(self.out_c, self.in_channels * self.out_channels, kernel_size=1, stride=1, padding=0))
else:
self.output = Sequential(nn.AdaptiveAvgPool2d((1, 1)),
Conv2d(self.out_c, self.in_channels * self.out_channels, kernel_size=1, stride=1,
padding=0))
def forward(self, x):
x = self.convs(x)
x = self.output(x)
if self.layer_idx in TO_RGB_LAYERS:
x = x.view(-1, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
else:
x = x.view(-1, self.out_channels, self.in_channels)
x = x.unsqueeze(3).repeat(1, 1, 1, self.kernel_size).unsqueeze(4).repeat(1, 1, 1, 1, self.kernel_size)
return x
class HyperRefinementBlock(Module):
def __init__(self, hypernet, n_channels=512, inner_c=128, spatial=16):
super(HyperRefinementBlock, self).__init__()
self.n_channels = n_channels
self.inner_c = inner_c
self.out_c = 512
num_pools = int(np.log2(spatial))
modules = [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=1, padding=1), nn.LeakyReLU()]
for i in range(num_pools - 1):
modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
self.convs = nn.Sequential(*modules)
self.linear = EqualLinear(self.out_c, self.out_c, lr_mul=1)
self.hypernet = hypernet
def forward(self, features):
code = self.convs(features)
code = code.view(-1, self.out_c)
code = self.linear(code)
weight_delta = self.hypernet(code)
return weight_delta
class RefinementBlockSeparable(Module):
def __init__(self, layer_idx, opts, n_channels=512, inner_c=256, spatial=16):
super(RefinementBlockSeparable, self).__init__()
self.layer_idx = layer_idx
self.kernel_size, self.in_channels, self.out_channels = PARAMETERS[self.layer_idx]
self.spatial = spatial
self.n_channels = n_channels
self.inner_c = inner_c
self.out_c = 512
num_pools = int(np.log2(self.spatial)) - 1
self.modules = []
self.modules += [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
for i in range(num_pools - 1):
self.modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
self.modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
self.convs = nn.Sequential(*self.modules)
self.opts = opts
if self.layer_idx in TO_RGB_LAYERS:
self.output = Sequential(Conv2d(self.out_c, self.in_channels * self.out_channels,
kernel_size=1, stride=1, padding=0))
else:
self.output = Sequential(SeparableBlock(input_size=self.out_c,
kernel_channels_in=self.in_channels,
kernel_channels_out=self.out_channels,
kernel_size=self.kernel_size))
def forward(self, x):
x = self.convs(x)
x = self.output(x)
if self.layer_idx in TO_RGB_LAYERS:
x = x.view(-1, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
return x