Added stereo to mono conversion
Browse filesThe model does not support stereo audio. Separate channel enhancement can be experimented on in the future.
- functions.py +91 -82
functions.py
CHANGED
@@ -1,83 +1,92 @@
|
|
1 |
-
import torchaudio
|
2 |
-
from torch import
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
return
|
31 |
-
|
32 |
-
def
|
33 |
-
|
34 |
-
|
35 |
-
return
|
36 |
-
|
37 |
-
def
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
return
|
45 |
-
|
46 |
-
def
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
spectrogram =
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
enhanced_audio
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
return enhanced_audio_path
|
|
|
1 |
+
import torchaudio
|
2 |
+
from torch import mean as _mean
|
3 |
+
from torch import hamming_window, log10, no_grad, exp
|
4 |
+
|
5 |
+
|
6 |
+
def return_input(user_input):
|
7 |
+
if user_input is None:
|
8 |
+
return None
|
9 |
+
return user_input
|
10 |
+
|
11 |
+
|
12 |
+
def stereo_to_mono_convertion(waveform):
|
13 |
+
if waveform.shape[0] > 1:
|
14 |
+
waveform = _mean(waveform, dim=0, keepdims=True)
|
15 |
+
return waveform
|
16 |
+
else:
|
17 |
+
return waveform
|
18 |
+
|
19 |
+
def load_audio(audio_path):
|
20 |
+
|
21 |
+
audio_tensor, sr = torchaudio.load(audio_path)
|
22 |
+
audio_tensor = stereo_to_mono_convertion(audio_tensor)
|
23 |
+
audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000)
|
24 |
+
return audio_tensor
|
25 |
+
|
26 |
+
def load_audio_numpy(audio_path):
|
27 |
+
audio_tensor, sr = torchaudio.load(audio_path)
|
28 |
+
audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000)
|
29 |
+
audio_array = audio_tensor.numpy()
|
30 |
+
return (16000, audio_array.ravel())
|
31 |
+
|
32 |
+
def audio_to_spectrogram(audio):
|
33 |
+
transform_fn = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=512//4, power=None, window_fn=hamming_window)
|
34 |
+
spectrogram = transform_fn(audio)
|
35 |
+
return spectrogram
|
36 |
+
|
37 |
+
def extract_magnitude_and_phase(spectrogram):
|
38 |
+
magnitude, phase = spectrogram.abs(), spectrogram.angle()
|
39 |
+
return magnitude, phase
|
40 |
+
|
41 |
+
def amplitude_to_db(magnitude_spec):
|
42 |
+
max_amplitude = magnitude_spec.max()
|
43 |
+
db_spectrogram = torchaudio.functional.amplitude_to_DB(magnitude_spec, 20, 10e-10, log10(max_amplitude), 100.0)
|
44 |
+
return db_spectrogram, max_amplitude
|
45 |
+
|
46 |
+
def min_max_scaling(spectrogram, scaler):
|
47 |
+
# Min-Max scaling (soundness of the math is questionable due to the use of each spectrograms' max value during decibel-scaling)
|
48 |
+
spectrogram = scaler.transform(spectrogram)
|
49 |
+
return spectrogram
|
50 |
+
|
51 |
+
def inverse_min_max(spectrogram, scaler):
|
52 |
+
spectrogram = scaler.inverse_transform(spectrogram)
|
53 |
+
return spectrogram
|
54 |
+
|
55 |
+
def db_to_amplitude(db_spectrogram, max_amplitude):
|
56 |
+
return max_amplitude * 10**(db_spectrogram/20)
|
57 |
+
|
58 |
+
def reconstruct_complex_spectrogram(magnitude, phase):
|
59 |
+
return magnitude * exp(1j*phase)
|
60 |
+
|
61 |
+
def inverse_fft(spectrogram):
|
62 |
+
inverse_fn = torchaudio.transforms.InverseSpectrogram(n_fft=512, hop_length=512//4, window_fn=hamming_window)
|
63 |
+
return inverse_fn(spectrogram)
|
64 |
+
|
65 |
+
def transform_audio(audio, scaler):
|
66 |
+
spectrogram = audio_to_spectrogram(audio)
|
67 |
+
magnitude, phase = extract_magnitude_and_phase(spectrogram)
|
68 |
+
db_spectrogram, max_amplitude = amplitude_to_db(magnitude)
|
69 |
+
db_spectrogram = min_max_scaling(db_spectrogram, scaler)
|
70 |
+
return db_spectrogram.unsqueeze(0), phase, max_amplitude
|
71 |
+
|
72 |
+
def spectrogram_to_audio(db_spectrogram, scaler, phase, max_amplitude):
|
73 |
+
db_spectrogram = db_spectrogram.squeeze(0)
|
74 |
+
db_spectrogram = inverse_min_max(db_spectrogram, scaler)
|
75 |
+
spectrogram = db_to_amplitude(db_spectrogram, max_amplitude)
|
76 |
+
complex_spec = reconstruct_complex_spectrogram(spectrogram, phase)
|
77 |
+
audio = inverse_fft(complex_spec)
|
78 |
+
return audio
|
79 |
+
|
80 |
+
def save_audio(audio):
|
81 |
+
torchaudio.save(r"enhanced_audio.wav", audio, 16000)
|
82 |
+
return r"enhanced_audio.wav"
|
83 |
+
|
84 |
+
def predict(user_input, model, scaler):
|
85 |
+
audio = load_audio(user_input)
|
86 |
+
spectrogram, phase, max_amplitude = transform_audio(audio, scaler)
|
87 |
+
|
88 |
+
with no_grad():
|
89 |
+
enhanced_spectrogram = model.forward(spectrogram)
|
90 |
+
enhanced_audio = spectrogram_to_audio(enhanced_spectrogram, scaler, phase, max_amplitude)
|
91 |
+
enhanced_audio_path = save_audio(enhanced_audio)
|
92 |
return enhanced_audio_path
|