Hecheng0625 commited on
Commit
a63132d
1 Parent(s): 54d9af8

Create melspec.py

Browse files
Files changed (1) hide show
  1. Amphion/models/ns3_codec/melspec.py +102 -0
Amphion/models/ns3_codec/melspec.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pyworld as pw
3
+ import numpy as np
4
+ import soundfile as sf
5
+ import os
6
+ from torchaudio.functional import pitch_shift
7
+ import librosa
8
+ from librosa.filters import mel as librosa_mel_fn
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
14
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
15
+
16
+
17
+ def dynamic_range_decompression(x, C=1):
18
+ return np.exp(x) / C
19
+
20
+
21
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
22
+ return torch.log(torch.clamp(x, min=clip_val) * C)
23
+
24
+
25
+ def dynamic_range_decompression_torch(x, C=1):
26
+ return torch.exp(x) / C
27
+
28
+
29
+ def spectral_normalize_torch(magnitudes):
30
+ output = dynamic_range_compression_torch(magnitudes)
31
+ return output
32
+
33
+
34
+ def spectral_de_normalize_torch(magnitudes):
35
+ output = dynamic_range_decompression_torch(magnitudes)
36
+ return output
37
+
38
+
39
+ class MelSpectrogram(nn.Module):
40
+ def __init__(
41
+ self,
42
+ n_fft,
43
+ num_mels,
44
+ sampling_rate,
45
+ hop_size,
46
+ win_size,
47
+ fmin,
48
+ fmax,
49
+ center=False,
50
+ ):
51
+ super(MelSpectrogram, self).__init__()
52
+ self.n_fft = n_fft
53
+ self.hop_size = hop_size
54
+ self.win_size = win_size
55
+ self.sampling_rate = sampling_rate
56
+ self.num_mels = num_mels
57
+ self.fmin = fmin
58
+ self.fmax = fmax
59
+ self.center = center
60
+
61
+ mel_basis = {}
62
+ hann_window = {}
63
+
64
+ mel = librosa_mel_fn(
65
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
66
+ )
67
+ mel_basis = torch.from_numpy(mel).float()
68
+ hann_window = torch.hann_window(win_size)
69
+
70
+ self.register_buffer("mel_basis", mel_basis)
71
+ self.register_buffer("hann_window", hann_window)
72
+
73
+ def forward(self, y):
74
+ y = torch.nn.functional.pad(
75
+ y.unsqueeze(1),
76
+ (
77
+ int((self.n_fft - self.hop_size) / 2),
78
+ int((self.n_fft - self.hop_size) / 2),
79
+ ),
80
+ mode="reflect",
81
+ )
82
+ y = y.squeeze(1)
83
+ spec = torch.stft(
84
+ y,
85
+ self.n_fft,
86
+ hop_length=self.hop_size,
87
+ win_length=self.win_size,
88
+ window=self.hann_window,
89
+ center=self.center,
90
+ pad_mode="reflect",
91
+ normalized=False,
92
+ onesided=True,
93
+ return_complex=True,
94
+ )
95
+ spec = torch.view_as_real(spec)
96
+
97
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
98
+
99
+ spec = torch.matmul(self.mel_basis, spec)
100
+ spec = spectral_normalize_torch(spec)
101
+
102
+ return spec