File size: 6,105 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations


class Conv1d(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        nn.init.orthogonal_(self.weight)
        nn.init.zeros_(self.bias)


class PositionalEncoding(nn.Module):
    """Positional encoding with noise level conditioning"""

    def __init__(self, n_channels, max_len=10000):
        super().__init__()
        self.n_channels = n_channels
        self.max_len = max_len
        self.C = 5000
        self.pe = torch.zeros(0, 0)

    def forward(self, x, noise_level):
        if x.shape[2] > self.pe.shape[1]:
            self.init_pe_matrix(x.shape[1], x.shape[2], x)
        return x + noise_level[..., None, None] + self.pe[:, : x.size(2)].repeat(x.shape[0], 1, 1) / self.C

    def init_pe_matrix(self, n_channels, max_len, x):
        pe = torch.zeros(max_len, n_channels)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels)

        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)
        self.pe = pe.transpose(0, 1).to(x)


class FiLM(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.encoding = PositionalEncoding(input_size)
        self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
        self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)

        nn.init.xavier_uniform_(self.input_conv.weight)
        nn.init.xavier_uniform_(self.output_conv.weight)
        nn.init.zeros_(self.input_conv.bias)
        nn.init.zeros_(self.output_conv.bias)

    def forward(self, x, noise_scale):
        o = self.input_conv(x)
        o = F.leaky_relu(o, 0.2)
        o = self.encoding(o, noise_scale)
        shift, scale = torch.chunk(self.output_conv(o), 2, dim=1)
        return shift, scale

    def remove_weight_norm(self):
        remove_parametrizations(self.input_conv, "weight")
        remove_parametrizations(self.output_conv, "weight")

    def apply_weight_norm(self):
        self.input_conv = weight_norm(self.input_conv)
        self.output_conv = weight_norm(self.output_conv)


@torch.jit.script
def shif_and_scale(x, scale, shift):
    o = shift + scale * x
    return o


class UBlock(nn.Module):
    def __init__(self, input_size, hidden_size, factor, dilation):
        super().__init__()
        assert isinstance(dilation, (list, tuple))
        assert len(dilation) == 4

        self.factor = factor
        self.res_block = Conv1d(input_size, hidden_size, 1)
        self.main_block = nn.ModuleList(
            [
                Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]),
                Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]),
            ]
        )
        self.out_block = nn.ModuleList(
            [
                Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]),
                Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]),
            ]
        )

    def forward(self, x, shift, scale):
        x_inter = F.interpolate(x, size=x.shape[-1] * self.factor)
        res = self.res_block(x_inter)
        o = F.leaky_relu(x_inter, 0.2)
        o = F.interpolate(o, size=x.shape[-1] * self.factor)
        o = self.main_block[0](o)
        o = shif_and_scale(o, scale, shift)
        o = F.leaky_relu(o, 0.2)
        o = self.main_block[1](o)
        res2 = res + o
        o = shif_and_scale(res2, scale, shift)
        o = F.leaky_relu(o, 0.2)
        o = self.out_block[0](o)
        o = shif_and_scale(o, scale, shift)
        o = F.leaky_relu(o, 0.2)
        o = self.out_block[1](o)
        o = o + res2
        return o

    def remove_weight_norm(self):
        remove_parametrizations(self.res_block, "weight")
        for _, layer in enumerate(self.main_block):
            if len(layer.state_dict()) != 0:
                remove_parametrizations(layer, "weight")
        for _, layer in enumerate(self.out_block):
            if len(layer.state_dict()) != 0:
                remove_parametrizations(layer, "weight")

    def apply_weight_norm(self):
        self.res_block = weight_norm(self.res_block)
        for idx, layer in enumerate(self.main_block):
            if len(layer.state_dict()) != 0:
                self.main_block[idx] = weight_norm(layer)
        for idx, layer in enumerate(self.out_block):
            if len(layer.state_dict()) != 0:
                self.out_block[idx] = weight_norm(layer)


class DBlock(nn.Module):
    def __init__(self, input_size, hidden_size, factor):
        super().__init__()
        self.factor = factor
        self.res_block = Conv1d(input_size, hidden_size, 1)
        self.main_block = nn.ModuleList(
            [
                Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
                Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
                Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
            ]
        )

    def forward(self, x):
        size = x.shape[-1] // self.factor
        res = self.res_block(x)
        res = F.interpolate(res, size=size)
        o = F.interpolate(x, size=size)
        for layer in self.main_block:
            o = F.leaky_relu(o, 0.2)
            o = layer(o)
        return o + res

    def remove_weight_norm(self):
        remove_parametrizations(self.res_block, "weight")
        for _, layer in enumerate(self.main_block):
            if len(layer.state_dict()) != 0:
                remove_parametrizations(layer, "weight")

    def apply_weight_norm(self):
        self.res_block = weight_norm(self.res_block)
        for idx, layer in enumerate(self.main_block):
            if len(layer.state_dict()) != 0:
                self.main_block[idx] = weight_norm(layer)