DurreSudoku commited on
Commit
89b26ec
·
verified ·
1 Parent(s): 6f3415e

Added stereo to mono conversion

Browse files

The model does not support stereo audio. Separate channel enhancement can be experimented on in the future.

Files changed (1) hide show
  1. functions.py +91 -82
functions.py CHANGED
@@ -1,83 +1,92 @@
1
- import torchaudio
2
- from torch import hamming_window, log10, no_grad, exp
3
-
4
-
5
- def return_input(user_input):
6
- if user_input is None:
7
- return None
8
- return user_input
9
-
10
- def load_audio(audio_path):
11
-
12
- audio_tensor, sr = torchaudio.load(audio_path)
13
- audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000)
14
- audio_tensor.type()
15
- return audio_tensor
16
-
17
- def load_audio_numpy(audio_path):
18
- audio_tensor, sr = torchaudio.load(audio_path)
19
- audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000)
20
- audio_array = audio_tensor.numpy()
21
- return (16000, audio_array.ravel())
22
-
23
- def audio_to_spectrogram(audio):
24
- transform_fn = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=512//4, power=None, window_fn=hamming_window)
25
- spectrogram = transform_fn(audio)
26
- return spectrogram
27
-
28
- def extract_magnitude_and_phase(spectrogram):
29
- magnitude, phase = spectrogram.abs(), spectrogram.angle()
30
- return magnitude, phase
31
-
32
- def amplitude_to_db(magnitude_spec):
33
- max_amplitude = magnitude_spec.max()
34
- db_spectrogram = torchaudio.functional.amplitude_to_DB(magnitude_spec, 20, 10e-10, log10(max_amplitude), 100.0)
35
- return db_spectrogram, max_amplitude
36
-
37
- def min_max_scaling(spectrogram, scaler):
38
- # Min-Max scaling (soundness of the math is questionable due to the use of each spectrograms' max value during decibel-scaling)
39
- spectrogram = scaler.transform(spectrogram)
40
- return spectrogram
41
-
42
- def inverse_min_max(spectrogram, scaler):
43
- spectrogram = scaler.inverse_transform(spectrogram)
44
- return spectrogram
45
-
46
- def db_to_amplitude(db_spectrogram, max_amplitude):
47
- return max_amplitude * 10**(db_spectrogram/20)
48
-
49
- def reconstruct_complex_spectrogram(magnitude, phase):
50
- return magnitude * exp(1j*phase)
51
-
52
- def inverse_fft(spectrogram):
53
- inverse_fn = torchaudio.transforms.InverseSpectrogram(n_fft=512, hop_length=512//4, window_fn=hamming_window)
54
- return inverse_fn(spectrogram)
55
-
56
- def transform_audio(audio, scaler):
57
- spectrogram = audio_to_spectrogram(audio)
58
- magnitude, phase = extract_magnitude_and_phase(spectrogram)
59
- db_spectrogram, max_amplitude = amplitude_to_db(magnitude)
60
- db_spectrogram = min_max_scaling(db_spectrogram, scaler)
61
- return db_spectrogram.unsqueeze(0), phase, max_amplitude
62
-
63
- def spectrogram_to_audio(db_spectrogram, scaler, phase, max_amplitude):
64
- db_spectrogram = db_spectrogram.squeeze(0)
65
- db_spectrogram = inverse_min_max(db_spectrogram, scaler)
66
- spectrogram = db_to_amplitude(db_spectrogram, max_amplitude)
67
- complex_spec = reconstruct_complex_spectrogram(spectrogram, phase)
68
- audio = inverse_fft(complex_spec)
69
- return audio
70
-
71
- def save_audio(audio):
72
- torchaudio.save(r"enhanced_audio.wav", audio, 16000)
73
- return r"enhanced_audio.wav"
74
-
75
- def predict(user_input, model, scaler):
76
- audio = load_audio(user_input)
77
- spectrogram, phase, max_amplitude = transform_audio(audio, scaler)
78
-
79
- with no_grad():
80
- enhanced_spectrogram = model.forward(spectrogram)
81
- enhanced_audio = spectrogram_to_audio(enhanced_spectrogram, scaler, phase, max_amplitude)
82
- enhanced_audio_path = save_audio(enhanced_audio)
 
 
 
 
 
 
 
 
 
83
  return enhanced_audio_path
 
1
+ import torchaudio
2
+ from torch import mean as _mean
3
+ from torch import hamming_window, log10, no_grad, exp
4
+
5
+
6
+ def return_input(user_input):
7
+ if user_input is None:
8
+ return None
9
+ return user_input
10
+
11
+
12
+ def stereo_to_mono_convertion(waveform):
13
+ if waveform.shape[0] > 1:
14
+ waveform = _mean(waveform, dim=0, keepdims=True)
15
+ return waveform
16
+ else:
17
+ return waveform
18
+
19
+ def load_audio(audio_path):
20
+
21
+ audio_tensor, sr = torchaudio.load(audio_path)
22
+ audio_tensor = stereo_to_mono_convertion(audio_tensor)
23
+ audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000)
24
+ return audio_tensor
25
+
26
+ def load_audio_numpy(audio_path):
27
+ audio_tensor, sr = torchaudio.load(audio_path)
28
+ audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000)
29
+ audio_array = audio_tensor.numpy()
30
+ return (16000, audio_array.ravel())
31
+
32
+ def audio_to_spectrogram(audio):
33
+ transform_fn = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=512//4, power=None, window_fn=hamming_window)
34
+ spectrogram = transform_fn(audio)
35
+ return spectrogram
36
+
37
+ def extract_magnitude_and_phase(spectrogram):
38
+ magnitude, phase = spectrogram.abs(), spectrogram.angle()
39
+ return magnitude, phase
40
+
41
+ def amplitude_to_db(magnitude_spec):
42
+ max_amplitude = magnitude_spec.max()
43
+ db_spectrogram = torchaudio.functional.amplitude_to_DB(magnitude_spec, 20, 10e-10, log10(max_amplitude), 100.0)
44
+ return db_spectrogram, max_amplitude
45
+
46
+ def min_max_scaling(spectrogram, scaler):
47
+ # Min-Max scaling (soundness of the math is questionable due to the use of each spectrograms' max value during decibel-scaling)
48
+ spectrogram = scaler.transform(spectrogram)
49
+ return spectrogram
50
+
51
+ def inverse_min_max(spectrogram, scaler):
52
+ spectrogram = scaler.inverse_transform(spectrogram)
53
+ return spectrogram
54
+
55
+ def db_to_amplitude(db_spectrogram, max_amplitude):
56
+ return max_amplitude * 10**(db_spectrogram/20)
57
+
58
+ def reconstruct_complex_spectrogram(magnitude, phase):
59
+ return magnitude * exp(1j*phase)
60
+
61
+ def inverse_fft(spectrogram):
62
+ inverse_fn = torchaudio.transforms.InverseSpectrogram(n_fft=512, hop_length=512//4, window_fn=hamming_window)
63
+ return inverse_fn(spectrogram)
64
+
65
+ def transform_audio(audio, scaler):
66
+ spectrogram = audio_to_spectrogram(audio)
67
+ magnitude, phase = extract_magnitude_and_phase(spectrogram)
68
+ db_spectrogram, max_amplitude = amplitude_to_db(magnitude)
69
+ db_spectrogram = min_max_scaling(db_spectrogram, scaler)
70
+ return db_spectrogram.unsqueeze(0), phase, max_amplitude
71
+
72
+ def spectrogram_to_audio(db_spectrogram, scaler, phase, max_amplitude):
73
+ db_spectrogram = db_spectrogram.squeeze(0)
74
+ db_spectrogram = inverse_min_max(db_spectrogram, scaler)
75
+ spectrogram = db_to_amplitude(db_spectrogram, max_amplitude)
76
+ complex_spec = reconstruct_complex_spectrogram(spectrogram, phase)
77
+ audio = inverse_fft(complex_spec)
78
+ return audio
79
+
80
+ def save_audio(audio):
81
+ torchaudio.save(r"enhanced_audio.wav", audio, 16000)
82
+ return r"enhanced_audio.wav"
83
+
84
+ def predict(user_input, model, scaler):
85
+ audio = load_audio(user_input)
86
+ spectrogram, phase, max_amplitude = transform_audio(audio, scaler)
87
+
88
+ with no_grad():
89
+ enhanced_spectrogram = model.forward(spectrogram)
90
+ enhanced_audio = spectrogram_to_audio(enhanced_spectrogram, scaler, phase, max_amplitude)
91
+ enhanced_audio_path = save_audio(enhanced_audio)
92
  return enhanced_audio_path