File size: 6,114 Bytes
f9e4a6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import functools
import torch
import torch.nn as nn

from models.transformer import RETURNX, Transformer
from models.base_blocks import Conv2d, LayerNorm2d, FirstBlock2d, DownBlock2d, UpBlock2d, \
                               FFCADAINResBlocks, Jump, FinalBlock2d


class Visual_Encoder(nn.Module):
    def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
        super(Visual_Encoder, self).__init__()
        self.layers = layers
        self.first_inp = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
        self.first_ref = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
        for i in range(layers):
            in_channels = min(ngf*(2**i), img_f)
            out_channels = min(ngf*(2**(i+1)), img_f)
            model_ref = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
            model_inp = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
            if i < 2:
                ca_layer = RETURNX()
            else:
                ca_layer = Transformer(2**(i+1) * ngf,2,4,ngf,ngf*4)
            setattr(self, 'ca' + str(i), ca_layer)
            setattr(self, 'ref_down' + str(i), model_ref)
            setattr(self, 'inp_down' + str(i), model_inp)
        self.output_nc = out_channels * 2

    def forward(self, maskGT, ref):
        x_maskGT, x_ref = self.first_inp(maskGT), self.first_ref(ref)
        out=[x_maskGT]
        for i in range(self.layers):
            model_ref = getattr(self, 'ref_down'+str(i))
            model_inp = getattr(self, 'inp_down'+str(i))
            ca_layer = getattr(self, 'ca'+str(i))
            x_maskGT, x_ref = model_inp(x_maskGT), model_ref(x_ref)
            x_maskGT = ca_layer(x_maskGT, x_ref)
            if i < self.layers - 1:
                out.append(x_maskGT)
            else:           
                out.append(torch.cat([x_maskGT, x_ref], dim=1)) # concat ref features !
        return out


class Decoder(nn.Module):
    def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
        super(Decoder, self).__init__()
        self.layers = layers
        for i in range(layers)[::-1]:
            if  i == layers-1:
                in_channels = ngf*(2**(i+1)) * 2
            else:
                in_channels = min(ngf*(2**(i+1)), img_f)
            out_channels = min(ngf*(2**i), img_f)
            up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
            res = FFCADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
            jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)

            setattr(self, 'up' + str(i), up)
            setattr(self, 'res' + str(i), res)            
            setattr(self, 'jump' + str(i), jump)

        self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'sigmoid')
        self.output_nc = out_channels

    def forward(self, x, z):
        out = x.pop()
        for i in range(self.layers)[::-1]:
            res_model = getattr(self, 'res' + str(i))
            up_model = getattr(self, 'up' + str(i))
            jump_model = getattr(self, 'jump' + str(i))
            out = res_model(out, z)
            out = up_model(out)
            out = jump_model(x.pop()) + out
        out_image = self.final(out)
        return out_image


class LNet(nn.Module): 
    def __init__(
        self, 
        image_nc=3, 
        descriptor_nc=512, 
        layer=3, 
        base_nc=64, 
        max_nc=512, 
        num_res_blocks=9, 
        use_spect=True,
        encoder=Visual_Encoder,
        decoder=Decoder
        ):  
        super(LNet, self).__init__()

        nonlinearity = nn.LeakyReLU(0.1)
        norm_layer = functools.partial(LayerNorm2d, affine=True) 
        kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
        self.descriptor_nc = descriptor_nc

        self.encoder = encoder(image_nc, base_nc, max_nc, layer, **kwargs)
        self.decoder = decoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
        self.audio_encoder = nn.Sequential(
            Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
            Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
            Conv2d(512, descriptor_nc, kernel_size=1, stride=1, padding=0),
            )

    def forward(self, audio_sequences, face_sequences):
        B = audio_sequences.size(0)
        input_dim_size = len(face_sequences.size())
        if input_dim_size > 4:
            audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
            face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
        cropped, ref = torch.split(face_sequences, 3, dim=1)

        vis_feat = self.encoder(cropped, ref)
        audio_feat = self.audio_encoder(audio_sequences) 
        _outputs = self.decoder(vis_feat, audio_feat)

        if input_dim_size > 4:
            _outputs = torch.split(_outputs, B, dim=0)
            outputs = torch.stack(_outputs, dim=2) 
        else:
            outputs = _outputs
        return outputs