tracert commited on
Commit
48c021d
·
verified ·
1 Parent(s): c8f8511

Upload 6 files

Browse files
hifi-gan/__pycache__/vocoder.cpython-38.pyc ADDED
Binary file (8.74 kB). View file
 
hifi-gan/__pycache__/vocoder.cpython-39.pyc ADDED
Binary file (8.76 kB). View file
 
hifi-gan/__pycache__/vocoder_utils.cpython-38.pyc ADDED
Binary file (2.01 kB). View file
 
hifi-gan/__pycache__/vocoder_utils.cpython-39.pyc ADDED
Binary file (2.02 kB). View file
 
hifi-gan/vocoder.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from vocoder_utils import init_weights, get_padding
7
+
8
+ LRELU_SLOPE = 0.1
9
+
10
+
11
+ class ResBlock1(torch.nn.Module):
12
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
13
+ super(ResBlock1, self).__init__()
14
+ self.h = h
15
+ self.convs1 = nn.ModuleList([
16
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
17
+ padding=get_padding(kernel_size, dilation[0]))),
18
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
19
+ padding=get_padding(kernel_size, dilation[1]))),
20
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
21
+ padding=get_padding(kernel_size, dilation[2])))
22
+ ])
23
+ self.convs1.apply(init_weights)
24
+
25
+ self.convs2 = nn.ModuleList([
26
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
27
+ padding=get_padding(kernel_size, 1))),
28
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
29
+ padding=get_padding(kernel_size, 1))),
30
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
31
+ padding=get_padding(kernel_size, 1)))
32
+ ])
33
+ self.convs2.apply(init_weights)
34
+
35
+ def forward(self, x):
36
+ for c1, c2 in zip(self.convs1, self.convs2):
37
+ xt = F.leaky_relu(x, LRELU_SLOPE)
38
+ xt = c1(xt)
39
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
40
+ xt = c2(xt)
41
+ x = xt + x
42
+ return x
43
+
44
+ def remove_weight_norm(self):
45
+ for l in self.convs1:
46
+ remove_weight_norm(l)
47
+ for l in self.convs2:
48
+ remove_weight_norm(l)
49
+
50
+
51
+ class ResBlock2(torch.nn.Module):
52
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
53
+ super(ResBlock2, self).__init__()
54
+ self.h = h
55
+ self.convs = nn.ModuleList([
56
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
57
+ padding=get_padding(kernel_size, dilation[0]))),
58
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
59
+ padding=get_padding(kernel_size, dilation[1])))
60
+ ])
61
+ self.convs.apply(init_weights)
62
+
63
+ def forward(self, x):
64
+ for c in self.convs:
65
+ xt = F.leaky_relu(x, LRELU_SLOPE)
66
+ xt = c(xt)
67
+ x = xt + x
68
+ return x
69
+
70
+ def remove_weight_norm(self):
71
+ for l in self.convs:
72
+ remove_weight_norm(l)
73
+
74
+
75
+ class Generator(torch.nn.Module):
76
+ def __init__(self, h):
77
+ super(Generator, self).__init__()
78
+ self.h = h
79
+ self.num_kernels = len(h.resblock_kernel_sizes)
80
+ self.num_upsamples = len(h.upsample_rates)
81
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
82
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
83
+
84
+ self.ups = nn.ModuleList()
85
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
86
+ self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i),
87
+ h.upsample_initial_channel//(2**(i+1)),
88
+ k, u, padding=(u//2 + u%2), output_padding=u%2)))
89
+
90
+ self.resblocks = nn.ModuleList()
91
+ for i in range(len(self.ups)):
92
+ ch = h.upsample_initial_channel//(2**(i+1))
93
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
94
+ self.resblocks.append(resblock(h, ch, k, d))
95
+
96
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
97
+ self.ups.apply(init_weights)
98
+ self.conv_post.apply(init_weights)
99
+
100
+ def forward(self, x):
101
+ x = self.conv_pre(x)
102
+ for i in range(self.num_upsamples):
103
+ x = F.leaky_relu(x, LRELU_SLOPE)
104
+ x = self.ups[i](x)
105
+ xs = None
106
+ for j in range(self.num_kernels):
107
+ if xs is None:
108
+ xs = self.resblocks[i*self.num_kernels+j](x)
109
+ else:
110
+ xs += self.resblocks[i*self.num_kernels+j](x)
111
+ x = xs / self.num_kernels
112
+ x = F.leaky_relu(x)
113
+ x = self.conv_post(x)
114
+ x = torch.tanh(x)
115
+
116
+ return x
117
+
118
+ def remove_weight_norm(self):
119
+ print('Removing weight norm...')
120
+ for l in self.ups:
121
+ remove_weight_norm(l)
122
+ for l in self.resblocks:
123
+ l.remove_weight_norm()
124
+ remove_weight_norm(self.conv_pre)
125
+ remove_weight_norm(self.conv_post)
126
+
127
+
128
+ class DiscriminatorP(torch.nn.Module):
129
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
130
+ super(DiscriminatorP, self).__init__()
131
+ self.period = period
132
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
133
+ self.convs = nn.ModuleList([
134
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
135
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
136
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
137
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
138
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
139
+ ])
140
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
141
+
142
+ def forward(self, x):
143
+ fmap = []
144
+
145
+ # 1d to 2d
146
+ b, c, t = x.shape
147
+ if t % self.period != 0: # pad first
148
+ n_pad = self.period - (t % self.period)
149
+ x = F.pad(x, (0, n_pad), "reflect")
150
+ t = t + n_pad
151
+ x = x.view(b, c, t // self.period, self.period)
152
+
153
+ for l in self.convs:
154
+ x = l(x)
155
+ x = F.leaky_relu(x, LRELU_SLOPE)
156
+ fmap.append(x)
157
+ x = self.conv_post(x)
158
+ fmap.append(x)
159
+ x = torch.flatten(x, 1, -1)
160
+
161
+ return x, fmap
162
+
163
+
164
+ class MultiPeriodDiscriminator(torch.nn.Module):
165
+ def __init__(self):
166
+ super(MultiPeriodDiscriminator, self).__init__()
167
+ self.discriminators = nn.ModuleList([
168
+ DiscriminatorP(2),
169
+ DiscriminatorP(3),
170
+ DiscriminatorP(5),
171
+ DiscriminatorP(7),
172
+ DiscriminatorP(11),
173
+ ])
174
+
175
+ def forward(self, y, y_hat):
176
+ y_d_rs = []
177
+ y_d_gs = []
178
+ fmap_rs = []
179
+ fmap_gs = []
180
+ for i, d in enumerate(self.discriminators):
181
+ y_d_r, fmap_r = d(y)
182
+ y_d_g, fmap_g = d(y_hat)
183
+ y_d_rs.append(y_d_r)
184
+ fmap_rs.append(fmap_r)
185
+ y_d_gs.append(y_d_g)
186
+ fmap_gs.append(fmap_g)
187
+
188
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
189
+
190
+
191
+ class DiscriminatorS(torch.nn.Module):
192
+ def __init__(self, use_spectral_norm=False):
193
+ super(DiscriminatorS, self).__init__()
194
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
195
+ self.convs = nn.ModuleList([
196
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
197
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
198
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
199
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
200
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
201
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
202
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
203
+ ])
204
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
205
+
206
+ def forward(self, x):
207
+ fmap = []
208
+ for l in self.convs:
209
+ x = l(x)
210
+ x = F.leaky_relu(x, LRELU_SLOPE)
211
+ fmap.append(x)
212
+ x = self.conv_post(x)
213
+ fmap.append(x)
214
+ x = torch.flatten(x, 1, -1)
215
+
216
+ return x, fmap
217
+
218
+
219
+ class MultiScaleDiscriminator(torch.nn.Module):
220
+ def __init__(self):
221
+ super(MultiScaleDiscriminator, self).__init__()
222
+ self.discriminators = nn.ModuleList([
223
+ DiscriminatorS(use_spectral_norm=True),
224
+ DiscriminatorS(),
225
+ DiscriminatorS(),
226
+ ])
227
+ self.meanpools = nn.ModuleList([
228
+ AvgPool1d(4, 2, padding=2),
229
+ AvgPool1d(4, 2, padding=2)
230
+ ])
231
+
232
+ def forward(self, y, y_hat):
233
+ y_d_rs = []
234
+ y_d_gs = []
235
+ fmap_rs = []
236
+ fmap_gs = []
237
+ for i, d in enumerate(self.discriminators):
238
+ if i != 0:
239
+ y = self.meanpools[i-1](y)
240
+ y_hat = self.meanpools[i-1](y_hat)
241
+ y_d_r, fmap_r = d(y)
242
+ y_d_g, fmap_g = d(y_hat)
243
+ y_d_rs.append(y_d_r)
244
+ fmap_rs.append(fmap_r)
245
+ y_d_gs.append(y_d_g)
246
+ fmap_gs.append(fmap_g)
247
+
248
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
249
+
250
+
251
+ def feature_loss(fmap_r, fmap_g):
252
+ loss = 0
253
+ for dr, dg in zip(fmap_r, fmap_g):
254
+ for rl, gl in zip(dr, dg):
255
+ loss += torch.mean(torch.abs(rl - gl))
256
+
257
+ return loss*2
258
+
259
+
260
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
261
+ loss = 0
262
+ r_losses = []
263
+ g_losses = []
264
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
265
+ r_loss = torch.mean((1-dr)**2)
266
+ g_loss = torch.mean(dg**2)
267
+ loss += (r_loss + g_loss)
268
+ r_losses.append(r_loss.item())
269
+ g_losses.append(g_loss.item())
270
+
271
+ return loss, r_losses, g_losses
272
+
273
+
274
+ def generator_loss(disc_outputs):
275
+ loss = 0
276
+ gen_losses = []
277
+ for dg in disc_outputs:
278
+ l = torch.mean((1-dg)**2)
279
+ gen_losses.append(l)
280
+ loss += l
281
+
282
+ return loss, gen_losses
283
+
hifi-gan/vocoder_utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import matplotlib
4
+ import torch
5
+ from torch.nn.utils import weight_norm
6
+ matplotlib.use("Agg")
7
+ import matplotlib.pylab as plt
8
+
9
+
10
+ def plot_spectrogram(spectrogram):
11
+ fig, ax = plt.subplots(figsize=(10, 2))
12
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13
+ interpolation='none')
14
+ plt.colorbar(im, ax=ax)
15
+
16
+ fig.canvas.draw()
17
+ plt.close()
18
+
19
+ return fig
20
+
21
+
22
+ def init_weights(m, mean=0.0, std=0.01):
23
+ classname = m.__class__.__name__
24
+ if classname.find("Conv") != -1:
25
+ m.weight.data.normal_(mean, std)
26
+
27
+
28
+ def apply_weight_norm(m):
29
+ classname = m.__class__.__name__
30
+ if classname.find("Conv") != -1:
31
+ weight_norm(m)
32
+
33
+
34
+ def get_padding(kernel_size, dilation=1):
35
+ return int((kernel_size*dilation - dilation)/2)
36
+
37
+
38
+ def load_checkpoint(filepath, device):
39
+ assert os.path.isfile(filepath)
40
+ print("Loading '{}'".format(filepath))
41
+ checkpoint_dict = torch.load(filepath, map_location=device)
42
+ print("Complete.")
43
+ return checkpoint_dict
44
+
45
+
46
+ def save_checkpoint(filepath, obj):
47
+ print("Saving checkpoint to {}".format(filepath))
48
+ torch.save(obj, filepath)
49
+ print("Complete.")
50
+
51
+
52
+ def scan_checkpoint(cp_dir, prefix):
53
+ pattern = os.path.join(cp_dir, prefix + '????????')
54
+ cp_list = glob.glob(pattern)
55
+ if len(cp_list) == 0:
56
+ return None
57
+ return sorted(cp_list)[-1]
58
+