Yusen commited on
Commit
f8f1ec7
1 Parent(s): e218ea2

update sovits

Browse files
Files changed (2) hide show
  1. models.py +302 -235
  2. utils.py +144 -27
models.py CHANGED
@@ -1,123 +1,125 @@
1
- import copy
2
- import math
3
  import torch
4
  from torch import nn
 
5
  from torch.nn import functional as F
 
6
 
7
  import modules.attentions as attentions
8
  import modules.commons as commons
9
  import modules.modules as modules
10
-
11
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
12
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
-
14
  import utils
15
- from modules.commons import init_weights, get_padding
16
- from vdecoder.hifigan.models import Generator
17
  from utils import f0_to_coarse
18
 
 
19
  class ResidualCouplingBlock(nn.Module):
20
- def __init__(self,
21
- channels,
22
- hidden_channels,
23
- kernel_size,
24
- dilation_rate,
25
- n_layers,
26
- n_flows=4,
27
- gin_channels=0):
28
- super().__init__()
29
- self.channels = channels
30
- self.hidden_channels = hidden_channels
31
- self.kernel_size = kernel_size
32
- self.dilation_rate = dilation_rate
33
- self.n_layers = n_layers
34
- self.n_flows = n_flows
35
- self.gin_channels = gin_channels
36
-
37
- self.flows = nn.ModuleList()
38
- for i in range(n_flows):
39
- self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
40
- self.flows.append(modules.Flip())
41
-
42
- def forward(self, x, x_mask, g=None, reverse=False):
43
- if not reverse:
44
- for flow in self.flows:
45
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
46
- else:
47
- for flow in reversed(self.flows):
48
- x = flow(x, x_mask, g=g, reverse=reverse)
49
- return x
 
 
 
 
 
 
 
50
 
51
 
52
  class Encoder(nn.Module):
53
- def __init__(self,
54
- in_channels,
55
- out_channels,
56
- hidden_channels,
57
- kernel_size,
58
- dilation_rate,
59
- n_layers,
60
- gin_channels=0):
61
- super().__init__()
62
- self.in_channels = in_channels
63
- self.out_channels = out_channels
64
- self.hidden_channels = hidden_channels
65
- self.kernel_size = kernel_size
66
- self.dilation_rate = dilation_rate
67
- self.n_layers = n_layers
68
- self.gin_channels = gin_channels
69
-
70
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
71
- self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
72
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
73
-
74
- def forward(self, x, x_lengths, g=None):
75
- # print(x.shape,x_lengths.shape)
76
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
77
- x = self.pre(x) * x_mask
78
- x = self.enc(x, x_mask, g=g)
79
- stats = self.proj(x) * x_mask
80
- m, logs = torch.split(stats, self.out_channels, dim=1)
81
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
82
- return z, m, logs, x_mask
83
 
84
 
85
  class TextEncoder(nn.Module):
86
- def __init__(self,
87
- out_channels,
88
- hidden_channels,
89
- kernel_size,
90
- n_layers,
91
- gin_channels=0,
92
- filter_channels=None,
93
- n_heads=None,
94
- p_dropout=None):
95
- super().__init__()
96
- self.out_channels = out_channels
97
- self.hidden_channels = hidden_channels
98
- self.kernel_size = kernel_size
99
- self.n_layers = n_layers
100
- self.gin_channels = gin_channels
101
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
102
- self.f0_emb = nn.Embedding(256, hidden_channels)
103
-
104
- self.enc_ = attentions.Encoder(
105
- hidden_channels,
106
- filter_channels,
107
- n_heads,
108
- n_layers,
109
- kernel_size,
110
- p_dropout)
111
-
112
- def forward(self, x, x_mask, f0=None, noice_scale=1):
113
- x = x + self.f0_emb(f0).transpose(1,2)
114
- x = self.enc_(x * x_mask, x_mask)
115
- stats = self.proj(x) * x_mask
116
- m, logs = torch.split(stats, self.out_channels, dim=1)
117
- z = (m + torch.randn_like(m) * torch.exp(logs) * noice_scale) * x_mask
118
-
119
- return z, m, logs, x_mask
120
 
 
 
 
 
 
 
 
 
121
 
122
 
123
  class DiscriminatorP(torch.nn.Module):
@@ -125,7 +127,7 @@ class DiscriminatorP(torch.nn.Module):
125
  super(DiscriminatorP, self).__init__()
126
  self.period = period
127
  self.use_spectral_norm = use_spectral_norm
128
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
129
  self.convs = nn.ModuleList([
130
  norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
131
  norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
@@ -140,7 +142,7 @@ class DiscriminatorP(torch.nn.Module):
140
 
141
  # 1d to 2d
142
  b, c, t = x.shape
143
- if t % self.period != 0: # pad first
144
  n_pad = self.period - (t % self.period)
145
  x = F.pad(x, (0, n_pad), "reflect")
146
  t = t + n_pad
@@ -160,7 +162,7 @@ class DiscriminatorP(torch.nn.Module):
160
  class DiscriminatorS(torch.nn.Module):
161
  def __init__(self, use_spectral_norm=False):
162
  super(DiscriminatorS, self).__init__()
163
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
164
  self.convs = nn.ModuleList([
165
  norm_f(Conv1d(1, 16, 15, 1, padding=7)),
166
  norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
@@ -188,7 +190,7 @@ class DiscriminatorS(torch.nn.Module):
188
  class MultiPeriodDiscriminator(torch.nn.Module):
189
  def __init__(self, use_spectral_norm=False):
190
  super(MultiPeriodDiscriminator, self).__init__()
191
- periods = [2,3,5,7,11]
192
 
193
  discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
194
  discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
@@ -225,26 +227,26 @@ class SpeakerEncoder(torch.nn.Module):
225
 
226
  def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
227
  mel_slices = []
228
- for i in range(0, total_frames-partial_frames, partial_hop):
229
- mel_range = torch.arange(i, i+partial_frames)
230
  mel_slices.append(mel_range)
231
 
232
  return mel_slices
233
 
234
  def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
235
  mel_len = mel.size(1)
236
- last_mel = mel[:,-partial_frames:]
237
 
238
  if mel_len > partial_frames:
239
  mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
240
- mels = list(mel[:,s] for s in mel_slices)
241
  mels.append(last_mel)
242
  mels = torch.stack(tuple(mels), 0).squeeze(1)
243
 
244
  with torch.no_grad():
245
  partial_embeds = self(mels)
246
  embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
247
- #embed = embed / torch.linalg.norm(embed, 2)
248
  else:
249
  with torch.no_grad():
250
  embed = self(last_mel)
@@ -280,7 +282,7 @@ class F0Decoder(nn.Module):
280
  kernel_size,
281
  p_dropout)
282
  self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
283
- self.f0_prenet = nn.Conv1d(1, hidden_channels , 3, padding=1)
284
  self.cond = nn.Conv1d(spk_channels, hidden_channels, 1)
285
 
286
  def forward(self, x, norm_f0, x_mask, spk_emb=None):
@@ -295,126 +297,191 @@ class F0Decoder(nn.Module):
295
 
296
 
297
  class SynthesizerTrn(nn.Module):
298
- """
299
- Synthesizer for Training
300
- """
301
-
302
- def __init__(self,
303
- spec_channels,
304
- segment_size,
305
- inter_channels,
306
- hidden_channels,
307
- filter_channels,
308
- n_heads,
309
- n_layers,
310
- kernel_size,
311
- p_dropout,
312
- resblock,
313
- resblock_kernel_sizes,
314
- resblock_dilation_sizes,
315
- upsample_rates,
316
- upsample_initial_channel,
317
- upsample_kernel_sizes,
318
- gin_channels,
319
- ssl_dim,
320
- n_speakers,
321
- sampling_rate=44100,
322
- **kwargs):
323
-
324
- super().__init__()
325
- self.spec_channels = spec_channels
326
- self.inter_channels = inter_channels
327
- self.hidden_channels = hidden_channels
328
- self.filter_channels = filter_channels
329
- self.n_heads = n_heads
330
- self.n_layers = n_layers
331
- self.kernel_size = kernel_size
332
- self.p_dropout = p_dropout
333
- self.resblock = resblock
334
- self.resblock_kernel_sizes = resblock_kernel_sizes
335
- self.resblock_dilation_sizes = resblock_dilation_sizes
336
- self.upsample_rates = upsample_rates
337
- self.upsample_initial_channel = upsample_initial_channel
338
- self.upsample_kernel_sizes = upsample_kernel_sizes
339
- self.segment_size = segment_size
340
- self.gin_channels = gin_channels
341
- self.ssl_dim = ssl_dim
342
- self.emb_g = nn.Embedding(n_speakers, gin_channels)
343
-
344
- self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
345
-
346
- self.enc_p = TextEncoder(
347
- inter_channels,
348
- hidden_channels,
349
- filter_channels=filter_channels,
350
- n_heads=n_heads,
351
- n_layers=n_layers,
352
- kernel_size=kernel_size,
353
- p_dropout=p_dropout
354
- )
355
- hps = {
356
- "sampling_rate": sampling_rate,
357
- "inter_channels": inter_channels,
358
- "resblock": resblock,
359
- "resblock_kernel_sizes": resblock_kernel_sizes,
360
- "resblock_dilation_sizes": resblock_dilation_sizes,
361
- "upsample_rates": upsample_rates,
362
- "upsample_initial_channel": upsample_initial_channel,
363
- "upsample_kernel_sizes": upsample_kernel_sizes,
364
- "gin_channels": gin_channels,
365
- }
366
- self.dec = Generator(h=hps)
367
- self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
368
- self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
369
- self.f0_decoder = F0Decoder(
370
- 1,
371
- hidden_channels,
372
- filter_channels,
373
- n_heads,
374
- n_layers,
375
- kernel_size,
376
- p_dropout,
377
- spk_channels=gin_channels
378
- )
379
- self.emb_uv = nn.Embedding(2, hidden_channels)
380
-
381
- def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None):
382
- g = self.emb_g(g).transpose(1,2)
383
- # ssl prenet
384
- x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
385
- x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2)
386
-
387
- # f0 predict
388
- lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500
389
- norm_lf0 = utils.normalize_f0(lf0, x_mask, uv)
390
- pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
391
-
392
- # encoder
393
- z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask, f0=f0_to_coarse(f0))
394
- z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
395
-
396
- # flow
397
- z_p = self.flow(z, spec_mask, g=g)
398
- z_slice, pitch_slice, ids_slice = commons.rand_slice_segments_with_pitch(z, f0, spec_lengths, self.segment_size)
399
-
400
- # nsf decoder
401
- o = self.dec(z_slice, g=g, f0=pitch_slice)
402
-
403
- return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0
404
-
405
- def infer(self, c, f0, uv, g=None, noice_scale=0.35, predict_f0=False):
406
- c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
407
- g = self.emb_g(g).transpose(1,2)
408
- x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
409
- x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2)
410
-
411
- if predict_f0:
412
- lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500
413
- norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False)
414
- pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
415
- f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1)
416
-
417
- z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, f0=f0_to_coarse(f0), noice_scale=noice_scale)
418
- z = self.flow(z_p, c_mask, g=g, reverse=True)
419
- o = self.dec(z * c_mask, g=g, f0=f0)
420
- return o,f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from torch import nn
3
+ from torch.nn import Conv1d, Conv2d
4
  from torch.nn import functional as F
5
+ from torch.nn.utils import spectral_norm, weight_norm
6
 
7
  import modules.attentions as attentions
8
  import modules.commons as commons
9
  import modules.modules as modules
 
 
 
 
10
  import utils
11
+ from modules.commons import get_padding
 
12
  from utils import f0_to_coarse
13
 
14
+
15
  class ResidualCouplingBlock(nn.Module):
16
+ def __init__(self,
17
+ channels,
18
+ hidden_channels,
19
+ kernel_size,
20
+ dilation_rate,
21
+ n_layers,
22
+ n_flows=4,
23
+ gin_channels=0,
24
+ share_parameter=False
25
+ ):
26
+ super().__init__()
27
+ self.channels = channels
28
+ self.hidden_channels = hidden_channels
29
+ self.kernel_size = kernel_size
30
+ self.dilation_rate = dilation_rate
31
+ self.n_layers = n_layers
32
+ self.n_flows = n_flows
33
+ self.gin_channels = gin_channels
34
+
35
+ self.flows = nn.ModuleList()
36
+
37
+ self.wn = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=gin_channels) if share_parameter else None
38
+
39
+ for i in range(n_flows):
40
+ self.flows.append(
41
+ modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
42
+ gin_channels=gin_channels, mean_only=True, wn_sharing_parameter=self.wn))
43
+ self.flows.append(modules.Flip())
44
+
45
+ def forward(self, x, x_mask, g=None, reverse=False):
46
+ if not reverse:
47
+ for flow in self.flows:
48
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
49
+ else:
50
+ for flow in reversed(self.flows):
51
+ x = flow(x, x_mask, g=g, reverse=reverse)
52
+ return x
53
 
54
 
55
  class Encoder(nn.Module):
56
+ def __init__(self,
57
+ in_channels,
58
+ out_channels,
59
+ hidden_channels,
60
+ kernel_size,
61
+ dilation_rate,
62
+ n_layers,
63
+ gin_channels=0):
64
+ super().__init__()
65
+ self.in_channels = in_channels
66
+ self.out_channels = out_channels
67
+ self.hidden_channels = hidden_channels
68
+ self.kernel_size = kernel_size
69
+ self.dilation_rate = dilation_rate
70
+ self.n_layers = n_layers
71
+ self.gin_channels = gin_channels
72
+
73
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
74
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
75
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
76
+
77
+ def forward(self, x, x_lengths, g=None):
78
+ # print(x.shape,x_lengths.shape)
79
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
80
+ x = self.pre(x) * x_mask
81
+ x = self.enc(x, x_mask, g=g)
82
+ stats = self.proj(x) * x_mask
83
+ m, logs = torch.split(stats, self.out_channels, dim=1)
84
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
85
+ return z, m, logs, x_mask
86
 
87
 
88
  class TextEncoder(nn.Module):
89
+ def __init__(self,
90
+ out_channels,
91
+ hidden_channels,
92
+ kernel_size,
93
+ n_layers,
94
+ gin_channels=0,
95
+ filter_channels=None,
96
+ n_heads=None,
97
+ p_dropout=None):
98
+ super().__init__()
99
+ self.out_channels = out_channels
100
+ self.hidden_channels = hidden_channels
101
+ self.kernel_size = kernel_size
102
+ self.n_layers = n_layers
103
+ self.gin_channels = gin_channels
104
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
105
+ self.f0_emb = nn.Embedding(256, hidden_channels)
106
+
107
+ self.enc_ = attentions.Encoder(
108
+ hidden_channels,
109
+ filter_channels,
110
+ n_heads,
111
+ n_layers,
112
+ kernel_size,
113
+ p_dropout)
 
 
 
 
 
 
 
 
 
114
 
115
+ def forward(self, x, x_mask, f0=None, noice_scale=1):
116
+ x = x + self.f0_emb(f0).transpose(1, 2)
117
+ x = self.enc_(x * x_mask, x_mask)
118
+ stats = self.proj(x) * x_mask
119
+ m, logs = torch.split(stats, self.out_channels, dim=1)
120
+ z = (m + torch.randn_like(m) * torch.exp(logs) * noice_scale) * x_mask
121
+
122
+ return z, m, logs, x_mask
123
 
124
 
125
  class DiscriminatorP(torch.nn.Module):
 
127
  super(DiscriminatorP, self).__init__()
128
  self.period = period
129
  self.use_spectral_norm = use_spectral_norm
130
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
131
  self.convs = nn.ModuleList([
132
  norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
133
  norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
 
142
 
143
  # 1d to 2d
144
  b, c, t = x.shape
145
+ if t % self.period != 0: # pad first
146
  n_pad = self.period - (t % self.period)
147
  x = F.pad(x, (0, n_pad), "reflect")
148
  t = t + n_pad
 
162
  class DiscriminatorS(torch.nn.Module):
163
  def __init__(self, use_spectral_norm=False):
164
  super(DiscriminatorS, self).__init__()
165
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
166
  self.convs = nn.ModuleList([
167
  norm_f(Conv1d(1, 16, 15, 1, padding=7)),
168
  norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
 
190
  class MultiPeriodDiscriminator(torch.nn.Module):
191
  def __init__(self, use_spectral_norm=False):
192
  super(MultiPeriodDiscriminator, self).__init__()
193
+ periods = [2, 3, 5, 7, 11]
194
 
195
  discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
196
  discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
 
227
 
228
  def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
229
  mel_slices = []
230
+ for i in range(0, total_frames - partial_frames, partial_hop):
231
+ mel_range = torch.arange(i, i + partial_frames)
232
  mel_slices.append(mel_range)
233
 
234
  return mel_slices
235
 
236
  def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
237
  mel_len = mel.size(1)
238
+ last_mel = mel[:, -partial_frames:]
239
 
240
  if mel_len > partial_frames:
241
  mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
242
+ mels = list(mel[:, s] for s in mel_slices)
243
  mels.append(last_mel)
244
  mels = torch.stack(tuple(mels), 0).squeeze(1)
245
 
246
  with torch.no_grad():
247
  partial_embeds = self(mels)
248
  embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
249
+ # embed = embed / torch.linalg.norm(embed, 2)
250
  else:
251
  with torch.no_grad():
252
  embed = self(last_mel)
 
282
  kernel_size,
283
  p_dropout)
284
  self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
285
+ self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1)
286
  self.cond = nn.Conv1d(spk_channels, hidden_channels, 1)
287
 
288
  def forward(self, x, norm_f0, x_mask, spk_emb=None):
 
297
 
298
 
299
  class SynthesizerTrn(nn.Module):
300
+ """
301
+ Synthesizer for Training
302
+ """
303
+
304
+ def __init__(self,
305
+ spec_channels,
306
+ segment_size,
307
+ inter_channels,
308
+ hidden_channels,
309
+ filter_channels,
310
+ n_heads,
311
+ n_layers,
312
+ kernel_size,
313
+ p_dropout,
314
+ resblock,
315
+ resblock_kernel_sizes,
316
+ resblock_dilation_sizes,
317
+ upsample_rates,
318
+ upsample_initial_channel,
319
+ upsample_kernel_sizes,
320
+ gin_channels,
321
+ ssl_dim,
322
+ n_speakers,
323
+ sampling_rate=44100,
324
+ vol_embedding=False,
325
+ vocoder_name = "nsf-hifigan",
326
+ use_depthwise_conv = False,
327
+ use_automatic_f0_prediction = True,
328
+ flow_share_parameter = False,
329
+ n_flow_layer = 4,
330
+ **kwargs):
331
+
332
+ super().__init__()
333
+ self.spec_channels = spec_channels
334
+ self.inter_channels = inter_channels
335
+ self.hidden_channels = hidden_channels
336
+ self.filter_channels = filter_channels
337
+ self.n_heads = n_heads
338
+ self.n_layers = n_layers
339
+ self.kernel_size = kernel_size
340
+ self.p_dropout = p_dropout
341
+ self.resblock = resblock
342
+ self.resblock_kernel_sizes = resblock_kernel_sizes
343
+ self.resblock_dilation_sizes = resblock_dilation_sizes
344
+ self.upsample_rates = upsample_rates
345
+ self.upsample_initial_channel = upsample_initial_channel
346
+ self.upsample_kernel_sizes = upsample_kernel_sizes
347
+ self.segment_size = segment_size
348
+ self.gin_channels = gin_channels
349
+ self.ssl_dim = ssl_dim
350
+ self.vol_embedding = vol_embedding
351
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
352
+ self.use_depthwise_conv = use_depthwise_conv
353
+ self.use_automatic_f0_prediction = use_automatic_f0_prediction
354
+ if vol_embedding:
355
+ self.emb_vol = nn.Linear(1, hidden_channels)
356
+
357
+ self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
358
+
359
+ self.enc_p = TextEncoder(
360
+ inter_channels,
361
+ hidden_channels,
362
+ filter_channels=filter_channels,
363
+ n_heads=n_heads,
364
+ n_layers=n_layers,
365
+ kernel_size=kernel_size,
366
+ p_dropout=p_dropout
367
+ )
368
+ hps = {
369
+ "sampling_rate": sampling_rate,
370
+ "inter_channels": inter_channels,
371
+ "resblock": resblock,
372
+ "resblock_kernel_sizes": resblock_kernel_sizes,
373
+ "resblock_dilation_sizes": resblock_dilation_sizes,
374
+ "upsample_rates": upsample_rates,
375
+ "upsample_initial_channel": upsample_initial_channel,
376
+ "upsample_kernel_sizes": upsample_kernel_sizes,
377
+ "gin_channels": gin_channels,
378
+ "use_depthwise_conv":use_depthwise_conv
379
+ }
380
+
381
+ modules.set_Conv1dModel(self.use_depthwise_conv)
382
+
383
+ if vocoder_name == "nsf-hifigan":
384
+ from vdecoder.hifigan.models import Generator
385
+ self.dec = Generator(h=hps)
386
+ elif vocoder_name == "nsf-snake-hifigan":
387
+ from vdecoder.hifiganwithsnake.models import Generator
388
+ self.dec = Generator(h=hps)
389
+ else:
390
+ print("[?] Unkown vocoder: use default(nsf-hifigan)")
391
+ from vdecoder.hifigan.models import Generator
392
+ self.dec = Generator(h=hps)
393
+
394
+ self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
395
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter)
396
+ if self.use_automatic_f0_prediction:
397
+ self.f0_decoder = F0Decoder(
398
+ 1,
399
+ hidden_channels,
400
+ filter_channels,
401
+ n_heads,
402
+ n_layers,
403
+ kernel_size,
404
+ p_dropout,
405
+ spk_channels=gin_channels
406
+ )
407
+ self.emb_uv = nn.Embedding(2, hidden_channels)
408
+ self.character_mix = False
409
+
410
+ def EnableCharacterMix(self, n_speakers_map, device):
411
+ self.speaker_map = torch.zeros((n_speakers_map, 1, 1, self.gin_channels)).to(device)
412
+ for i in range(n_speakers_map):
413
+ self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]]).to(device))
414
+ self.speaker_map = self.speaker_map.unsqueeze(0).to(device)
415
+ self.character_mix = True
416
+
417
+ def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None, vol = None):
418
+ g = self.emb_g(g).transpose(1,2)
419
+
420
+ # vol proj
421
+ vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0
422
+
423
+ # ssl prenet
424
+ x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
425
+ x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2) + vol
426
+
427
+ # f0 predict
428
+ if self.use_automatic_f0_prediction:
429
+ lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500
430
+ norm_lf0 = utils.normalize_f0(lf0, x_mask, uv)
431
+ pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
432
+ else:
433
+ lf0 = 0
434
+ norm_lf0 = 0
435
+ pred_lf0 = 0
436
+ # encoder
437
+ z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask, f0=f0_to_coarse(f0))
438
+ z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
439
+
440
+ # flow
441
+ z_p = self.flow(z, spec_mask, g=g)
442
+ z_slice, pitch_slice, ids_slice = commons.rand_slice_segments_with_pitch(z, f0, spec_lengths, self.segment_size)
443
+
444
+ # nsf decoder
445
+ o = self.dec(z_slice, g=g, f0=pitch_slice)
446
+
447
+ return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0
448
+
449
+ @torch.no_grad()
450
+ def infer(self, c, f0, uv, g=None, noice_scale=0.35, seed=52468, predict_f0=False, vol = None):
451
+
452
+ if c.device == torch.device("cuda"):
453
+ torch.cuda.manual_seed_all(seed)
454
+ else:
455
+ torch.manual_seed(seed)
456
+
457
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
458
+
459
+ if self.character_mix and len(g) > 1: # [N, S] * [S, B, 1, H]
460
+ g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
461
+ g = g * self.speaker_map # [N, S, B, 1, H]
462
+ g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
463
+ g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
464
+ else:
465
+ if g.dim() == 1:
466
+ g = g.unsqueeze(0)
467
+ g = self.emb_g(g).transpose(1, 2)
468
+
469
+ x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
470
+ # vol proj
471
+
472
+ vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0
473
+
474
+ x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) + vol
475
+
476
+
477
+ if self.use_automatic_f0_prediction and predict_f0:
478
+ lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500
479
+ norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False)
480
+ pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
481
+ f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1)
482
+
483
+ z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, f0=f0_to_coarse(f0), noice_scale=noice_scale)
484
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
485
+ o = self.dec(z * c_mask, g=g, f0=f0)
486
+ return o,f0
487
+
utils.py CHANGED
@@ -1,21 +1,21 @@
1
- import os
2
- import glob
3
- import re
4
- import sys
5
  import argparse
6
- import logging
7
  import json
 
 
 
8
  import subprocess
9
- import warnings
10
- import random
11
- import functools
12
 
 
13
  import librosa
14
  import numpy as np
15
- from scipy.io.wavfile import read
16
  import torch
 
 
17
  from torch.nn import functional as F
18
- from modules.commons import sequence_mask
19
 
20
  MATPLOTLIB_FLAG = False
21
 
@@ -110,25 +110,37 @@ def get_speech_encoder(speech_encoder,device=None,**kargs):
110
  speech_encoder_object = ContentVec256L9(device = device)
111
  elif speech_encoder == "vec256l9-onnx":
112
  from vencoder.ContentVec256L9_Onnx import ContentVec256L9_Onnx
113
- speech_encoder_object = ContentVec256L9(device = device)
114
  elif speech_encoder == "vec256l12-onnx":
115
  from vencoder.ContentVec256L12_Onnx import ContentVec256L12_Onnx
116
- speech_encoder_object = ContentVec256L9(device = device)
117
  elif speech_encoder == "vec768l9-onnx":
118
  from vencoder.ContentVec768L9_Onnx import ContentVec768L9_Onnx
119
- speech_encoder_object = ContentVec256L9(device = device)
120
  elif speech_encoder == "vec768l12-onnx":
121
  from vencoder.ContentVec768L12_Onnx import ContentVec768L12_Onnx
122
- speech_encoder_object = ContentVec256L9(device = device)
123
  elif speech_encoder == "hubertsoft-onnx":
124
  from vencoder.HubertSoft_Onnx import HubertSoft_Onnx
125
- speech_encoder_object = HubertSoft(device = device)
126
  elif speech_encoder == "hubertsoft":
127
  from vencoder.HubertSoft import HubertSoft
128
  speech_encoder_object = HubertSoft(device = device)
129
  elif speech_encoder == "whisper-ppg":
130
  from vencoder.WhisperPPG import WhisperPPG
131
  speech_encoder_object = WhisperPPG(device = device)
 
 
 
 
 
 
 
 
 
 
 
 
132
  else:
133
  raise Exception("Unknown speech encoder")
134
  return speech_encoder_object
@@ -152,7 +164,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
152
  # print("load", k)
153
  new_state_dict[k] = saved_state_dict[k]
154
  assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
155
- except:
156
  print("error, %s is not in the checkpoint" % k)
157
  logger.info("%s is not in the checkpoint" % k)
158
  new_state_dict[k] = v
@@ -188,15 +200,20 @@ def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_tim
188
  False -> lexicographically delete ckpts
189
  """
190
  ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
191
- name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
192
- time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
 
 
193
  sort_key = time_key if sort_by_time else name_key
194
- x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key)
 
195
  to_del = [os.path.join(path_to_models, fn) for fn in
196
  (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
197
- del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
198
- del_routine = lambda x: [os.remove(x), del_info(x)]
199
- rs = [del_routine(fn) for fn in to_del]
 
 
200
 
201
  def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
202
  for k, v in scalars.items():
@@ -324,11 +341,11 @@ def get_hparams_from_dir(model_dir):
324
  return hparams
325
 
326
 
327
- def get_hparams_from_file(config_path):
328
  with open(config_path, "r") as f:
329
  data = f.read()
330
  config = json.loads(data)
331
- hparams =HParams(**config)
332
  return hparams
333
 
334
 
@@ -367,7 +384,13 @@ def get_logger(model_dir, filename="train.log"):
367
  return logger
368
 
369
 
370
- def repeat_expand_2d(content, target_len):
 
 
 
 
 
 
371
  # content : [h, t]
372
 
373
  src_len = content.shape[-1]
@@ -384,6 +407,14 @@ def repeat_expand_2d(content, target_len):
384
  return target
385
 
386
 
 
 
 
 
 
 
 
 
387
  def mix_model(model_paths,mix_rate,mode):
388
  mix_rate = torch.FloatTensor(mix_rate)/100
389
  model_tem = torch.load(model_paths[0])
@@ -397,6 +428,80 @@ def mix_model(model_paths,mix_rate,mode):
397
  torch.save(model_tem,os.path.join(os.path.curdir,"output.pth"))
398
  return os.path.join(os.path.curdir,"output.pth")
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  class HParams():
401
  def __init__(self, **kwargs):
402
  for k, v in kwargs.items():
@@ -431,6 +536,18 @@ class HParams():
431
  def get(self,index):
432
  return self.__dict__.get(index)
433
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  class Volume_Extractor:
435
  def __init__(self, hop_size = 512):
436
  self.hop_size = hop_size
@@ -441,6 +558,6 @@ class Volume_Extractor:
441
  n_frames = int(audio.size(-1) // self.hop_size)
442
  audio2 = audio ** 2
443
  audio2 = torch.nn.functional.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect')
444
- volume = torch.FloatTensor([torch.mean(audio2[:,int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)])
445
  volume = torch.sqrt(volume)
446
- return volume
 
 
 
 
 
1
  import argparse
2
+ import glob
3
  import json
4
+ import logging
5
+ import os
6
+ import re
7
  import subprocess
8
+ import sys
9
+ import traceback
10
+ from multiprocessing import cpu_count
11
 
12
+ import faiss
13
  import librosa
14
  import numpy as np
 
15
  import torch
16
+ from scipy.io.wavfile import read
17
+ from sklearn.cluster import MiniBatchKMeans
18
  from torch.nn import functional as F
 
19
 
20
  MATPLOTLIB_FLAG = False
21
 
 
110
  speech_encoder_object = ContentVec256L9(device = device)
111
  elif speech_encoder == "vec256l9-onnx":
112
  from vencoder.ContentVec256L9_Onnx import ContentVec256L9_Onnx
113
+ speech_encoder_object = ContentVec256L9_Onnx(device = device)
114
  elif speech_encoder == "vec256l12-onnx":
115
  from vencoder.ContentVec256L12_Onnx import ContentVec256L12_Onnx
116
+ speech_encoder_object = ContentVec256L12_Onnx(device = device)
117
  elif speech_encoder == "vec768l9-onnx":
118
  from vencoder.ContentVec768L9_Onnx import ContentVec768L9_Onnx
119
+ speech_encoder_object = ContentVec768L9_Onnx(device = device)
120
  elif speech_encoder == "vec768l12-onnx":
121
  from vencoder.ContentVec768L12_Onnx import ContentVec768L12_Onnx
122
+ speech_encoder_object = ContentVec768L12_Onnx(device = device)
123
  elif speech_encoder == "hubertsoft-onnx":
124
  from vencoder.HubertSoft_Onnx import HubertSoft_Onnx
125
+ speech_encoder_object = HubertSoft_Onnx(device = device)
126
  elif speech_encoder == "hubertsoft":
127
  from vencoder.HubertSoft import HubertSoft
128
  speech_encoder_object = HubertSoft(device = device)
129
  elif speech_encoder == "whisper-ppg":
130
  from vencoder.WhisperPPG import WhisperPPG
131
  speech_encoder_object = WhisperPPG(device = device)
132
+ elif speech_encoder == "cnhubertlarge":
133
+ from vencoder.CNHubertLarge import CNHubertLarge
134
+ speech_encoder_object = CNHubertLarge(device = device)
135
+ elif speech_encoder == "dphubert":
136
+ from vencoder.DPHubert import DPHubert
137
+ speech_encoder_object = DPHubert(device = device)
138
+ elif speech_encoder == "whisper-ppg-large":
139
+ from vencoder.WhisperPPGLarge import WhisperPPGLarge
140
+ speech_encoder_object = WhisperPPGLarge(device = device)
141
+ elif speech_encoder == "wavlmbase+":
142
+ from vencoder.WavLMBasePlus import WavLMBasePlus
143
+ speech_encoder_object = WavLMBasePlus(device = device)
144
  else:
145
  raise Exception("Unknown speech encoder")
146
  return speech_encoder_object
 
164
  # print("load", k)
165
  new_state_dict[k] = saved_state_dict[k]
166
  assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
167
+ except Exception:
168
  print("error, %s is not in the checkpoint" % k)
169
  logger.info("%s is not in the checkpoint" % k)
170
  new_state_dict[k] = v
 
200
  False -> lexicographically delete ckpts
201
  """
202
  ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
203
+ def name_key(_f):
204
+ return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
205
+ def time_key(_f):
206
+ return os.path.getmtime(os.path.join(path_to_models, _f))
207
  sort_key = time_key if sort_by_time else name_key
208
+ def x_sorted(_x):
209
+ return sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], key=sort_key)
210
  to_del = [os.path.join(path_to_models, fn) for fn in
211
  (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
212
+ def del_info(fn):
213
+ return logger.info(f".. Free up space by deleting ckpt {fn}")
214
+ def del_routine(x):
215
+ return [os.remove(x), del_info(x)]
216
+ [del_routine(fn) for fn in to_del]
217
 
218
  def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
219
  for k, v in scalars.items():
 
341
  return hparams
342
 
343
 
344
+ def get_hparams_from_file(config_path, infer_mode = False):
345
  with open(config_path, "r") as f:
346
  data = f.read()
347
  config = json.loads(data)
348
+ hparams =HParams(**config) if not infer_mode else InferHParams(**config)
349
  return hparams
350
 
351
 
 
384
  return logger
385
 
386
 
387
+ def repeat_expand_2d(content, target_len, mode = 'left'):
388
+ # content : [h, t]
389
+ return repeat_expand_2d_left(content, target_len) if mode == 'left' else repeat_expand_2d_other(content, target_len, mode)
390
+
391
+
392
+
393
+ def repeat_expand_2d_left(content, target_len):
394
  # content : [h, t]
395
 
396
  src_len = content.shape[-1]
 
407
  return target
408
 
409
 
410
+ # mode : 'nearest'| 'linear'| 'bilinear'| 'bicubic'| 'trilinear'| 'area'
411
+ def repeat_expand_2d_other(content, target_len, mode = 'nearest'):
412
+ # content : [h, t]
413
+ content = content[None,:,:]
414
+ target = F.interpolate(content,size=target_len,mode=mode)[0]
415
+ return target
416
+
417
+
418
  def mix_model(model_paths,mix_rate,mode):
419
  mix_rate = torch.FloatTensor(mix_rate)/100
420
  model_tem = torch.load(model_paths[0])
 
428
  torch.save(model_tem,os.path.join(os.path.curdir,"output.pth"))
429
  return os.path.join(os.path.curdir,"output.pth")
430
 
431
+ def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出音频,rate是2的占比 from RVC
432
+ # print(data1.max(),data2.max())
433
+ rms1 = librosa.feature.rms(
434
+ y=data1, frame_length=sr1 // 2 * 2, hop_length=sr1 // 2
435
+ ) # 每半秒一个点
436
+ rms2 = librosa.feature.rms(y=data2.detach().cpu().numpy(), frame_length=sr2 // 2 * 2, hop_length=sr2 // 2)
437
+ rms1 = torch.from_numpy(rms1).to(data2.device)
438
+ rms1 = F.interpolate(
439
+ rms1.unsqueeze(0), size=data2.shape[0], mode="linear"
440
+ ).squeeze()
441
+ rms2 = torch.from_numpy(rms2).to(data2.device)
442
+ rms2 = F.interpolate(
443
+ rms2.unsqueeze(0), size=data2.shape[0], mode="linear"
444
+ ).squeeze()
445
+ rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-6)
446
+ data2 *= (
447
+ torch.pow(rms1, torch.tensor(1 - rate))
448
+ * torch.pow(rms2, torch.tensor(rate - 1))
449
+ )
450
+ return data2
451
+
452
+ def train_index(spk_name,root_dir = "dataset/44k/"): #from: RVC https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI
453
+ n_cpu = cpu_count()
454
+ print("The feature index is constructing.")
455
+ exp_dir = os.path.join(root_dir,spk_name)
456
+ listdir_res = []
457
+ for file in os.listdir(exp_dir):
458
+ if ".wav.soft.pt" in file:
459
+ listdir_res.append(os.path.join(exp_dir,file))
460
+ if len(listdir_res) == 0:
461
+ raise Exception("You need to run preprocess_hubert_f0.py!")
462
+ npys = []
463
+ for name in sorted(listdir_res):
464
+ phone = torch.load(name)[0].transpose(-1,-2).numpy()
465
+ npys.append(phone)
466
+ big_npy = np.concatenate(npys, 0)
467
+ big_npy_idx = np.arange(big_npy.shape[0])
468
+ np.random.shuffle(big_npy_idx)
469
+ big_npy = big_npy[big_npy_idx]
470
+ if big_npy.shape[0] > 2e5:
471
+ # if(1):
472
+ info = "Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0]
473
+ print(info)
474
+ try:
475
+ big_npy = (
476
+ MiniBatchKMeans(
477
+ n_clusters=10000,
478
+ verbose=True,
479
+ batch_size=256 * n_cpu,
480
+ compute_labels=False,
481
+ init="random",
482
+ )
483
+ .fit(big_npy)
484
+ .cluster_centers_
485
+ )
486
+ except Exception:
487
+ info = traceback.format_exc()
488
+ print(info)
489
+ n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
490
+ index = faiss.index_factory(big_npy.shape[1] , "IVF%s,Flat" % n_ivf)
491
+ index_ivf = faiss.extract_index_ivf(index) #
492
+ index_ivf.nprobe = 1
493
+ index.train(big_npy)
494
+ batch_size_add = 8192
495
+ for i in range(0, big_npy.shape[0], batch_size_add):
496
+ index.add(big_npy[i : i + batch_size_add])
497
+ # faiss.write_index(
498
+ # index,
499
+ # f"added_{spk_name}.index"
500
+ # )
501
+ print("Successfully build index")
502
+ return index
503
+
504
+
505
  class HParams():
506
  def __init__(self, **kwargs):
507
  for k, v in kwargs.items():
 
536
  def get(self,index):
537
  return self.__dict__.get(index)
538
 
539
+
540
+ class InferHParams(HParams):
541
+ def __init__(self, **kwargs):
542
+ for k, v in kwargs.items():
543
+ if type(v) == dict:
544
+ v = InferHParams(**v)
545
+ self[k] = v
546
+
547
+ def __getattr__(self,index):
548
+ return self.get(index)
549
+
550
+
551
  class Volume_Extractor:
552
  def __init__(self, hop_size = 512):
553
  self.hop_size = hop_size
 
558
  n_frames = int(audio.size(-1) // self.hop_size)
559
  audio2 = audio ** 2
560
  audio2 = torch.nn.functional.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect')
561
+ volume = torch.nn.functional.unfold(audio2[:,None,None,:],(1,self.hop_size),stride=self.hop_size)[:,:,:n_frames].mean(dim=1)[0]
562
  volume = torch.sqrt(volume)
563
+ return volume