Vageesh1 commited on
Commit
1f9348b
1 Parent(s): d30a2ed

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +90 -0
  2. helper.py +66 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import torchaudio.functional as F
4
+ from torchaudio.utils import download_asset
5
+
6
+ from pesq import pesq
7
+ from pystoi import stoi
8
+ import mir_eval
9
+ from pydub import AudioSegment
10
+ import matplotlib.pyplot as plt
11
+
12
+ import streamlit as st
13
+ from helper import plot_spectrogram,plot_mask,si_snr,generate_mixture,evaluate,get_irms
14
+
15
+ target_snr=3
16
+
17
+ #parameters for STFT
18
+ N_FFT = 1024
19
+ N_HOP = 256
20
+ stft = torchaudio.transforms.Spectrogram(
21
+ n_fft=N_FFT,
22
+ hop_length=N_HOP,
23
+ power=None,
24
+ )
25
+ istft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP)
26
+ #defining a psd transform
27
+ psd_transform = torchaudio.transforms.PSD()
28
+ mvdr_transform = torchaudio.transforms.SoudenMVDR()
29
+
30
+ #defining the reference microphone
31
+ REFERENCE_CHANNEL = 0
32
+
33
+ #creating a random noise for better calculations
34
+ SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
35
+ waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)
36
+ waveform_noise = waveform_noise.to(torch.double)
37
+ stft_noise = stft(waveform_noise)
38
+
39
+ def ui():
40
+ st.title("Speech Enhancer")
41
+ st.markdown("Made by Vageesh")
42
+ #making an audio developer uploader:
43
+ audio_file = st.file_uploader("Upload an audio file in wav format", type=[ "wav"])
44
+
45
+ if audio_file is not None:
46
+ waveform_clean,sr=torchaudio.load(audio_file)
47
+ waveform_clean = waveform_mix.to(torch.double)
48
+ stft_clean = stft(waveform_clean)
49
+ st.text("Your uploaded audio")
50
+ st.audio(waveform_clean)
51
+ #creating a mixture of our audio file and the noise file
52
+ waveform_mix = generate_mixture(waveform_clean, waveform_noise, target_snr)
53
+ #making the files into torch double format
54
+ waveform_mix = waveform_mix.to(torch.double)
55
+ #computing STFT
56
+ stft_mix = stft(waveform_mix)
57
+ #plotting the spectogram
58
+ spec_img=plot_spectrogram(stft_mix)
59
+ st.image(spec_img,captions='Spectrogram of Mixture Speech (dB)')
60
+ #showing mixed audio in streamlit
61
+ st.audio(waveform_mix)
62
+ #getting the irms
63
+ irm_speech, irm_noise = get_irms(stft_clean, stft_noise)
64
+ #getting the psd speech
65
+ psd_speech = psd_transform(stft_mix, irm_speech)
66
+ psd_noise = psd_transform(stft_mix, irm_noise)
67
+ stft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
68
+ waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1])
69
+ #plotting the cleaned audio and hearing it
70
+ spec_clean_img=plot_spectrogram(stft_souden)
71
+ waveform_souden = waveform_souden.reshape(1, -1)
72
+ st.image(spec_clean_img,captions='Spectrogram of Mixture Speech (dB)')
73
+ st.audio(waveform_souden)
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
helper.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def plot_spectrogram(stft, title="Spectrogram", xlim=None):
2
+ magnitude = stft.abs()
3
+ spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
4
+ # figure, axis = plt.subplots(1, 1)
5
+ # img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
6
+ # figure.suptitle(title)
7
+ # plt.colorbar(img, ax=axis)
8
+ # plt.show()
9
+
10
+
11
+ def plot_mask(mask, title="Mask", xlim=None):
12
+ mask = mask.numpy()
13
+ figure, axis = plt.subplots(1, 1)
14
+ img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
15
+ figure.suptitle(title)
16
+ plt.colorbar(img, ax=axis)
17
+ plt.show()
18
+
19
+ def si_snr(estimate, reference, epsilon=1e-8):
20
+ estimate = estimate - estimate.mean()
21
+ reference = reference - reference.mean()
22
+ reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
23
+ mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
24
+ scale = mix_pow / (reference_pow + epsilon)
25
+
26
+ reference = scale * reference
27
+ error = estimate - reference
28
+
29
+ reference_pow = reference.pow(2)
30
+ error_pow = error.pow(2)
31
+
32
+ reference_pow = reference_pow.mean(axis=1)
33
+ error_pow = error_pow.mean(axis=1)
34
+
35
+ si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
36
+ return si_snr.item()
37
+
38
+ def generate_mixture(waveform_clean, waveform_noise, target_snr):
39
+ power_clean_signal = waveform_clean.pow(2).mean()
40
+ power_noise_signal = waveform_noise.pow(2).mean()
41
+ current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
42
+ waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
43
+ return waveform_clean + waveform_noise
44
+
45
+ def evaluate(estimate, reference):
46
+ si_snr_score = si_snr(estimate, reference)
47
+ (
48
+ sdr,
49
+ _,
50
+ _,
51
+ _,
52
+ ) = mir_eval.separation.bss_eval_sources(reference.numpy(), estimate.numpy(), False)
53
+ pesq_mix = pesq(SAMPLE_RATE, estimate[0].numpy(), reference[0].numpy(), "wb")
54
+ stoi_mix = stoi(reference[0].numpy(), estimate[0].numpy(), SAMPLE_RATE, extended=False)
55
+ print(f"SDR score: {sdr[0]}")
56
+ print(f"Si-SNR score: {si_snr_score}")
57
+ print(f"PESQ score: {pesq_mix}")
58
+ print(f"STOI score: {stoi_mix}")
59
+
60
+ def get_irms(stft_clean, stft_noise):
61
+ mag_clean = stft_clean.abs() ** 2
62
+ mag_noise = stft_noise.abs() ** 2
63
+ irm_speech = mag_clean / (mag_clean + mag_noise)
64
+ irm_noise = mag_noise / (mag_clean + mag_noise)
65
+ return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL]
66
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ pesq
4
+ pystoi
5
+ mir_eval
6
+ streamlit
7
+ matplotlib