File size: 3,794 Bytes
9667e74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import math
import functools
import sys

import torch
import torch.nn as nn

from NTED.base_function import EncoderLayer, DecoderLayer, ToRGB
from NTED.edge_attention_layer import Edge_Attn

class Encoder(nn.Module):
    def __init__(
        self, 
        size, 
        input_dim, 
        channels, 
        num_labels=None, 
        match_kernels=None, 
        blur_kernel=[1, 3, 3, 1], 
        ):
        super().__init__()
        self.first = EncoderLayer(input_dim, channels[size], 1)
        self.convs = nn.ModuleList()

        log_size = int(math.log(size, 2))
        self.log_size = log_size

        in_channel = channels[size]
        for i in range(log_size-1, 3, -1):
            out_channel = channels[2 ** i]
            num_label = num_labels[2 ** i] if num_labels is not None else None
            match_kernel = match_kernels[2 ** i] if match_kernels is not None else None
            use_extraction = num_label and match_kernel
            conv = EncoderLayer(
                in_channel, 
                out_channel, 
                kernel_size=3, 
                downsample=True, 
                blur_kernel=blur_kernel,
                use_extraction=use_extraction,
                num_label=num_label,
                match_kernel=match_kernel
                )

            self.convs.append(conv)
            in_channel = out_channel

    def forward(self, input, recoder=None):
        out = self.first(input)
        for idx, layer in enumerate(self.convs):
            out = layer(out, recoder)
        return out

class Decoder(nn.Module):
    def __init__(
        self,
        size,
        channels,
        num_labels,
        match_kernels,
        blur_kernel=[1, 3, 3, 1],
    ):
        super().__init__()


        self.convs = nn.ModuleList()
        # input at resolution 16*16
        in_channel = channels[16]
        self.log_size = int(math.log(size, 2))
        
        for i in range(4, self.log_size + 1):
            out_channel = channels[2 ** i]
            num_label, match_kernel = num_labels[2 ** i], match_kernels[2 ** i]
            use_distribution = num_label and match_kernel
            upsample = (i != 4)
            
            base_layer = functools.partial(
                DecoderLayer,
                out_channel=out_channel,
                kernel_size=3, 
                blur_kernel=blur_kernel,
                use_distribution=use_distribution,
                num_label=num_label,
                match_kernel=match_kernel
                )

            up = nn.Module()   
            up.conv0 = base_layer(in_channel=in_channel, upsample=upsample)
            up.conv1 = base_layer(in_channel=out_channel, upsample=False)
            up.to_rgb = ToRGB(out_channel, upsample=upsample)
            self.convs.append(up)
            in_channel = out_channel 
            
        self.num_labels, self.match_kernels = num_labels, match_kernels
        
        self.edge_attn_block = Edge_Attn(in_channels=3)
    
    def forward(self, input, neural_textures, recoder):
        counter = 0
        out, skip = input, None
        for i, up in enumerate(self.convs):
            if self.num_labels[2**(i+4)] and self.match_kernels[2**(i+4)]:
                neural_texture_conv0 = neural_textures[counter]
                neural_texture_conv1 = neural_textures[counter+1]
                counter += 2
            else:
                neural_texture_conv0, neural_texture_conv1 = None, None
            out = up.conv0(out, neural_texture=neural_texture_conv0, recoder=recoder)
            out = up.conv1(out, neural_texture=neural_texture_conv1, recoder=recoder)
            
            skip = up.to_rgb(out, skip)
        image = self.edge_attn_block(skip)
        # image = skip
        return image