Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- app.py +90 -0
- helper.py +66 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import torchaudio.functional as F
|
4 |
+
from torchaudio.utils import download_asset
|
5 |
+
|
6 |
+
from pesq import pesq
|
7 |
+
from pystoi import stoi
|
8 |
+
import mir_eval
|
9 |
+
from pydub import AudioSegment
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
|
12 |
+
import streamlit as st
|
13 |
+
from helper import plot_spectrogram,plot_mask,si_snr,generate_mixture,evaluate,get_irms
|
14 |
+
|
15 |
+
target_snr=3
|
16 |
+
|
17 |
+
#parameters for STFT
|
18 |
+
N_FFT = 1024
|
19 |
+
N_HOP = 256
|
20 |
+
stft = torchaudio.transforms.Spectrogram(
|
21 |
+
n_fft=N_FFT,
|
22 |
+
hop_length=N_HOP,
|
23 |
+
power=None,
|
24 |
+
)
|
25 |
+
istft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP)
|
26 |
+
#defining a psd transform
|
27 |
+
psd_transform = torchaudio.transforms.PSD()
|
28 |
+
mvdr_transform = torchaudio.transforms.SoudenMVDR()
|
29 |
+
|
30 |
+
#defining the reference microphone
|
31 |
+
REFERENCE_CHANNEL = 0
|
32 |
+
|
33 |
+
#creating a random noise for better calculations
|
34 |
+
SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
|
35 |
+
waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)
|
36 |
+
waveform_noise = waveform_noise.to(torch.double)
|
37 |
+
stft_noise = stft(waveform_noise)
|
38 |
+
|
39 |
+
def ui():
|
40 |
+
st.title("Speech Enhancer")
|
41 |
+
st.markdown("Made by Vageesh")
|
42 |
+
#making an audio developer uploader:
|
43 |
+
audio_file = st.file_uploader("Upload an audio file in wav format", type=[ "wav"])
|
44 |
+
|
45 |
+
if audio_file is not None:
|
46 |
+
waveform_clean,sr=torchaudio.load(audio_file)
|
47 |
+
waveform_clean = waveform_mix.to(torch.double)
|
48 |
+
stft_clean = stft(waveform_clean)
|
49 |
+
st.text("Your uploaded audio")
|
50 |
+
st.audio(waveform_clean)
|
51 |
+
#creating a mixture of our audio file and the noise file
|
52 |
+
waveform_mix = generate_mixture(waveform_clean, waveform_noise, target_snr)
|
53 |
+
#making the files into torch double format
|
54 |
+
waveform_mix = waveform_mix.to(torch.double)
|
55 |
+
#computing STFT
|
56 |
+
stft_mix = stft(waveform_mix)
|
57 |
+
#plotting the spectogram
|
58 |
+
spec_img=plot_spectrogram(stft_mix)
|
59 |
+
st.image(spec_img,captions='Spectrogram of Mixture Speech (dB)')
|
60 |
+
#showing mixed audio in streamlit
|
61 |
+
st.audio(waveform_mix)
|
62 |
+
#getting the irms
|
63 |
+
irm_speech, irm_noise = get_irms(stft_clean, stft_noise)
|
64 |
+
#getting the psd speech
|
65 |
+
psd_speech = psd_transform(stft_mix, irm_speech)
|
66 |
+
psd_noise = psd_transform(stft_mix, irm_noise)
|
67 |
+
stft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
|
68 |
+
waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1])
|
69 |
+
#plotting the cleaned audio and hearing it
|
70 |
+
spec_clean_img=plot_spectrogram(stft_souden)
|
71 |
+
waveform_souden = waveform_souden.reshape(1, -1)
|
72 |
+
st.image(spec_clean_img,captions='Spectrogram of Mixture Speech (dB)')
|
73 |
+
st.audio(waveform_souden)
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
helper.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
|
2 |
+
magnitude = stft.abs()
|
3 |
+
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
|
4 |
+
# figure, axis = plt.subplots(1, 1)
|
5 |
+
# img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
|
6 |
+
# figure.suptitle(title)
|
7 |
+
# plt.colorbar(img, ax=axis)
|
8 |
+
# plt.show()
|
9 |
+
|
10 |
+
|
11 |
+
def plot_mask(mask, title="Mask", xlim=None):
|
12 |
+
mask = mask.numpy()
|
13 |
+
figure, axis = plt.subplots(1, 1)
|
14 |
+
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
|
15 |
+
figure.suptitle(title)
|
16 |
+
plt.colorbar(img, ax=axis)
|
17 |
+
plt.show()
|
18 |
+
|
19 |
+
def si_snr(estimate, reference, epsilon=1e-8):
|
20 |
+
estimate = estimate - estimate.mean()
|
21 |
+
reference = reference - reference.mean()
|
22 |
+
reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
|
23 |
+
mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
|
24 |
+
scale = mix_pow / (reference_pow + epsilon)
|
25 |
+
|
26 |
+
reference = scale * reference
|
27 |
+
error = estimate - reference
|
28 |
+
|
29 |
+
reference_pow = reference.pow(2)
|
30 |
+
error_pow = error.pow(2)
|
31 |
+
|
32 |
+
reference_pow = reference_pow.mean(axis=1)
|
33 |
+
error_pow = error_pow.mean(axis=1)
|
34 |
+
|
35 |
+
si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
|
36 |
+
return si_snr.item()
|
37 |
+
|
38 |
+
def generate_mixture(waveform_clean, waveform_noise, target_snr):
|
39 |
+
power_clean_signal = waveform_clean.pow(2).mean()
|
40 |
+
power_noise_signal = waveform_noise.pow(2).mean()
|
41 |
+
current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
|
42 |
+
waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
|
43 |
+
return waveform_clean + waveform_noise
|
44 |
+
|
45 |
+
def evaluate(estimate, reference):
|
46 |
+
si_snr_score = si_snr(estimate, reference)
|
47 |
+
(
|
48 |
+
sdr,
|
49 |
+
_,
|
50 |
+
_,
|
51 |
+
_,
|
52 |
+
) = mir_eval.separation.bss_eval_sources(reference.numpy(), estimate.numpy(), False)
|
53 |
+
pesq_mix = pesq(SAMPLE_RATE, estimate[0].numpy(), reference[0].numpy(), "wb")
|
54 |
+
stoi_mix = stoi(reference[0].numpy(), estimate[0].numpy(), SAMPLE_RATE, extended=False)
|
55 |
+
print(f"SDR score: {sdr[0]}")
|
56 |
+
print(f"Si-SNR score: {si_snr_score}")
|
57 |
+
print(f"PESQ score: {pesq_mix}")
|
58 |
+
print(f"STOI score: {stoi_mix}")
|
59 |
+
|
60 |
+
def get_irms(stft_clean, stft_noise):
|
61 |
+
mag_clean = stft_clean.abs() ** 2
|
62 |
+
mag_noise = stft_noise.abs() ** 2
|
63 |
+
irm_speech = mag_clean / (mag_clean + mag_noise)
|
64 |
+
irm_noise = mag_noise / (mag_clean + mag_noise)
|
65 |
+
return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL]
|
66 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchaudio
|
3 |
+
pesq
|
4 |
+
pystoi
|
5 |
+
mir_eval
|
6 |
+
streamlit
|
7 |
+
matplotlib
|