import numpy as np import streamlit as st import librosa import soundfile as sf import librosa.display from config import CONFIG import torch from dataset import MaskGenerator import onnxruntime, onnx import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from pystoi import stoi from pesq import pesq import pandas as pd import torchaudio from torchmetrics.audio import ShortTimeObjectiveIntelligibility as STOI from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality as PESQ from PLCMOS.plc_mos import PLCMOSEstimator from speechmos import dnsmos from speechmos import plcmos import speech_recognition as speech_r from jiwer import wer import time @st.cache def load_model(model): path = 'lightning_logs/version_0/checkpoints/' + str(model) onnx_model = onnx.load(path) options = onnxruntime.SessionOptions() options.intra_op_num_threads = 2 options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL session = onnxruntime.InferenceSession(path, options) input_names = [x.name for x in session.get_inputs()] output_names = [x.name for x in session.get_outputs()] return session, onnx_model, input_names, output_names def inference(re_im, session, onnx_model, input_names, output_names): inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim], dtype=np.float32) for i, _input in enumerate(onnx_model.graph.input) } output_audio = [] for t in range(re_im.shape[0]): inputs[input_names[0]] = re_im[t] out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs) inputs[input_names[1]] = prev_mag inputs[input_names[2]] = predictor_state inputs[input_names[3]] = mlp_state output_audio.append(out) output_audio = torch.tensor(np.concatenate(output_audio, 0)) output_audio = output_audio.permute(1, 0, 2).contiguous() output_audio = torch.view_as_complex(output_audio) output_audio = torch.istft(output_audio, window, stride, window=hann) return output_audio.numpy() def visualize(hr, lr, recon, sr): sr = sr window_size = 1024 window = np.hanning(window_size) stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window) stft_hr = 2 * np.abs(stft_hr) / np.sum(window) stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window) stft_lr = 2 * np.abs(stft_lr) / np.sum(window) stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window) stft_recon = 2 * np.abs(stft_recon) / np.sum(window) fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12)) ax1.title.set_text('Оригинальный сигнал') ax2.title.set_text('Сигнал с потерями') ax3.title.set_text('Улучшенный сигнал') canvas = FigureCanvas(fig) p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='log', x_axis='time', sr=sr) p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='log', x_axis='time', sr=sr) p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='log', x_axis='time', sr=sr) ax1.set_xlabel('Время, с') ax1.set_ylabel('Частота, Гц') ax2.set_xlabel('Время, с') ax2.set_ylabel('Частота, Гц') ax3.set_xlabel('Время, с') ax3.set_ylabel('Частота, Гц') return fig def waveplot(hr, lr, recon, sr): fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12)) ax1.title.set_text('Оригинальный сигнал') ax2.title.set_text('Сигнал с потерями') ax3.title.set_text('Улучшенный сигнал') canvas = FigureCanvas(fig) p = librosa.display.waveshow(hr, ax=ax1, sr=sr) p = librosa.display.waveshow(lr, ax=ax2, sr=sr) p = librosa.display.waveshow(recon, ax=ax3, sr=sr) ax1.set_xlabel('Время, с') #ax1.set_ylabel('Частота, Гц') ax2.set_xlabel('Время, с') #ax2.set_ylabel('Частота, Гц') ax3.set_xlabel('Время, с') #ax3.set_ylabel('Частота, Гц') return fig packet_size = CONFIG.DATA.EVAL.packet_size window = CONFIG.DATA.window_size stride = CONFIG.DATA.stride title = 'Сокрытие потерь пакетов' st.set_page_config(page_title=title, page_icon=":sound:") st.title(title) st.subheader('1. Загрузка аудио') uploaded_file = st.file_uploader("Загрузите аудио формата (.wav) 48 КГц") is_file_uploaded = uploaded_file is not None if not is_file_uploaded: uploaded_file = 'sample.wav' target, sr = librosa.load(uploaded_file, sr=48000) target = target[:packet_size * (len(target) // packet_size)] st.text('Ваше аудио') st.audio(uploaded_file) model_ver = st.selectbox( 'Веса оригинальной модели выбраны по умолчанию. Выберите модель', ('frn.onnx', 'frn_modified.onnx', 'frn_out_Q.onnx', 'frn_out_QF.onnx', 'frn_out_QInt16.onnx', 'frn_out_QInt8.onnx', 'frn_out_QUInt8.onnx', 'frn_out_QUInt16.onnx', 'frn_fp16 (1).onnx')) st.write('Вы выбрали:', model_ver) lang = st.selectbox( 'Выберите язык вашего аудио для корректной работы распознавания речи', ('ru-RU', 'en-EN')) st.write('Вы выбрали:', lang) st.subheader('2. Выберите желаемый процент потерь') slider = [st.slider("Ожидаемый процент потерь для генератора потерь цепи Маркова", 0, 100, step=1)] loss_percent = float(slider[0])/100 mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)]) lossy_input = target.copy().reshape(-1, packet_size) mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis] lossy_input *= mask lossy_input = lossy_input.reshape(-1) hann = torch.sqrt(torch.hann_window(window)) lossy_input_tensor = torch.tensor(lossy_input) re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(1).numpy().astype(np.float32) session, onnx_model, input_names, output_names = load_model(model_ver) with st.sidebar: st.title('Full-band Reccurent Network', help = 'https://arxiv.org/abs/2211.04071') st.text('Авторы модели: Viet-Anh Nguyen and Anh H. T. Nguyen and Andy W. H. Khong') st.link_button("Github авторов", "https://github.com/Crystalsound/FRN", help = 'Кликни на меня') st.header("Метрики") st.subheader("PESQ", help = 'https://ieeexplore.ieee.org/document/941023') st.text('Перцептивная оценка качества речи') st.subheader("STOI", help = 'https://ieeexplore.ieee.org/document/5495701') st.text('Индекс объективной кратковременной разборчивости') st.subheader("PLCMOS_v1&2", help = 'https://arxiv.org/abs/2305.15127') st.text('Эталонная и неэталонная метрики от Microsoft') st.subheader("WER", help = 'https://deepgram.com/learn/what-is-word-error-rate') st.text('Процент нераспознанных слов') if st.button('Сгенерировать потери'): with st.spinner('Ожидайте...'): start_time = time.time() output = inference(re_im, session, onnx_model, input_names, output_names) st.text(str(time.time() - start_time)) st.subheader('3. Визуализация аудио') fig_1 = visualize(target, lossy_input, output, sr) fig_2 = waveplot(target, lossy_input, output, sr) tab1, tab2 = st.tabs(["Частотная область", "Временная область"]) with tab1: st.header("Частотная область") st.pyplot(fig_1) with tab2: st.header("Временная область") st.pyplot(fig_2) #st.pyplot(fig) st.success('Сделано!') sf.write('target.wav', target, sr) sf.write('lossy.wav', lossy_input, sr) sf.write('enhanced.wav', output, sr) st.text('Оригинальное аудио') st.audio('target.wav') st.text('Аудио с потерями') st.audio('lossy.wav') st.text('Улучшенное аудио') st.audio('enhanced.wav') #data_clean, samplerate = torchaudio.load('target.wav') #data_lossy, samplerate = torchaudio.load('lossy.wav') #data_enhanced, samplerate = torchaudio.load('enhanced.wav') #min_len = min(data_clean.shape[1], data_lossy.shape[1], data_enhanced.shape[1]) #data_clean = data_clean[:, :min_len] #data_lossy = data_lossy[:, :min_len] #data_enhanced = data_enhanced[:, :min_len] #stoi = STOI(samplerate) #stoi_orig = round(float(stoi(data_clean, data_clean)),3) #stoi_lossy = round(float(stoi(data_clean, data_lossy)),5) #stoi_enhanced = round(float(stoi(data_clean, data_enhanced)),5) #stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced] #pesq = PESQ(8000, 'nb') #data_clean = data_clean.cpu().numpy() #data_lossy = data_lossy.cpu().numpy() #data_enhanced = data_enhanced.cpu().numpy() #if samplerate != 8000: #data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=8000) #data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=8000) #data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=8000) #pesq_orig = float(pesq(torch.tensor(data_clean), torch.tensor(data_clean))) #pesq_lossy = float(pesq(torch.tensor(data_lossy), torch.tensor(data_clean))) #pesq_enhanced = float(pesq(torch.tensor(data_enhanced), torch.tensor(data_clean))) #psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced] #_____________________________________________ data_clean, samplerate = sf.read('target.wav') data_lossy, samplerate = sf.read('lossy.wav') data_enhanced, samplerate = sf.read('enhanced.wav') min_len = min(data_clean.shape[0], data_lossy.shape[0], data_enhanced.shape[0]) data_clean = data_clean[:min_len] data_lossy = data_lossy[:min_len] data_enhanced = data_enhanced[:min_len] stoi_orig = round(stoi(data_clean, data_clean, samplerate, extended=False),5) stoi_lossy = round(stoi(data_clean, data_lossy , samplerate, extended=False),5) stoi_enhanced = round(stoi(data_clean, data_enhanced, samplerate, extended=False),5) stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced] #def get_power(x, nfft): # S = librosa.stft(x, n_fft=nfft) # S = np.log(np.abs(S) ** 2 + 1e-8) # return S #def LSD(x_hr, x_pr): # S1 = get_power(x_hr, nfft=2048) # S2 = get_power(x_pr, nfft=2048) # lsd = np.mean(np.sqrt(np.mean((S1 - S2) ** 2, axis=-1)), axis=0) # return lsd #lsd_orig = LSD(data_clean,data_clean) #lsd_lossy = LSD(data_lossy,data_clean) #lsd_enhanced = LSD(data_enhanced,data_clean) #lsd_mass=[lsd_orig, lsd_lossy, lsd_enhanced] if samplerate != 16000: data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000) data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000) data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000) pesq_orig = pesq(fs = 16000, ref = data_clean, deg = data_clean, mode='wb') pesq_lossy = pesq(fs = 16000, ref = data_clean, deg = data_lossy, mode='wb') pesq_enhanced = pesq(fs = 16000, ref = data_clean, deg = data_enhanced, mode='wb') psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced] data_clean, fs = sf.read('target.wav') data_lossy, fs = sf.read('lossy.wav') data_enhanced, fs = sf.read('enhanced.wav') if fs!= 16000: data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000) data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000) data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000) PLC_example=PLCMOSEstimator() PLC_org = PLC_example.run(audio_degraded=data_clean, audio_clean=data_clean)[0] PLC_lossy = PLC_example.run(audio_degraded=data_lossy, audio_clean=data_clean)[0] PLC_enhanced = PLC_example.run(audio_degraded=data_enhanced, audio_clean=data_clean)[0] PLC_massv1 = [PLC_org, PLC_lossy, PLC_enhanced] df_1 = pd.DataFrame(columns=['Audio', 'PESQ', 'STOI', 'PLCMOSv1']) df_1['Audio'] = ['Clean', 'Lossy', 'Enhanced'] df_1['PESQ'] = psq_mas df_1['STOI'] = stoi_mass #df['LSD'] = lsd_mass df_1['PLCMOSv1'] = PLC_massv1 #new_columns = pd.MultiIndex.from_tuples([('', 'Audio'), ('Эталонные метрики', 'PESQ'), ('Эталонные метрики', 'STOI'), ('Эталонные метрики', 'PLCMOSv1')]) # Присваиваем новый мультииндекс столбцам #df_1.columns = new_columns PLC_massv2 = [plcmos.run("target.wav", sr=16000)['plcmos'], plcmos.run("lossy.wav", sr=16000)['plcmos'], plcmos.run("enhanced.wav", sr=16000)['plcmos']] #DNS = [dnsmos.run("target.wav", sr=16000)['ovrl_mos'], dnsmos.run("lossy.wav", sr=16000)['ovrl_mos'], dnsmos.run("enhanced.wav", sr=16000)['ovrl_mos']] df_1['PLCMOSv2'] = PLC_massv2 #df_1['DNSMOS'] = DNS #df_2 = pd.DataFrame(columns=['DNSMOS', 'PLCMOSv2']) #df_2['DNSMOS'] = DNS #df_2['PLCMOSv2'] = PLC_massv2 #new_columns = pd.MultiIndex.from_tuples([('Неэталонные метрики', 'DNSMOS'), ('Неэталонные метрики', 'PLCMOSv2')]) # Присваиваем новый мультииндекс столбцам #df_2.columns = new_columns #df_merged = df_1.merge(df_2, left_index=True, right_index=True) r = speech_r.Recognizer() harvard = speech_r.AudioFile('target.wav') with harvard as source: audio = r.record(source) orig = r.recognize_google(audio, language = str(lang)) harvard = speech_r.AudioFile('lossy.wav') #with harvard as source: # audio = r.record(source) #lossy = r.recognize_google(audio, language = "ru-RU") try: with harvard as source: audio = r.record(source) lossy = r.recognize_google(audio, language = str(lang)) #print("Распознанный текст:", text) except speech_r.UnknownValueError: #st.text("Система не смогла распознать аудио") lossy = '' #except speech_r.RequestError as e: #st.text("Ошибка при запросе к сервису распознавания речи; {0}".format(e)) harvard = speech_r.AudioFile('enhanced.wav') #with harvard as source: # audio = r.record(source) #enhanced = r.recognize_google(audio, language = "ru-RU") try: with harvard as source: audio = r.record(source) enhanced = r.recognize_google(audio, language = str(lang)) #print("Распознанный текст:", text) except speech_r.UnknownValueError: #st.text("Система не смогла распознать улучшенное аудио") enhanced = '' #except speech_r.RequestError as e: #st.text("Ошибка при запросе к сервису распознавания речи; {0}".format(e)) error1 = wer(orig, orig) error2 = wer(orig, lossy) error3 = wer(orig, enhanced) WER_mass=[error1*100, error2*100, error3*100] df_1['WER'] = WER_mass st.subheader('4. Метрики аудио') st.dataframe(df_1) tab1, tab2, tab3, tab4, tab5 = st.tabs(["PESQ", "STOI", "PLCMOSv1", "PLCMOSv2", "WER"]) with tab1: st.header("PESQ") st.bar_chart(df_1, x="Audio", y="PESQ") with tab2: st.header("STOI") st.bar_chart(df_1, x="Audio", y="STOI") with tab3: st.header("PLCMOSv1") st.bar_chart(df_1, x="Audio", y="PLCMOSv1") with tab4: st.header("PLCMOSv2") st.bar_chart(df_1, x="Audio", y="PLCMOSv2") with tab5: st.header("WER") st.bar_chart(df_1, x="Audio", y="WER") #st.bar_chart(df_1, x="Audio", y="PESQ") #st.bar_chart(df_1, x="Audio", y="STOI") #st.bar_chart(df_1, x="Audio", y="PLCMOSv1") #st.bar_chart(df_1, x="Audio", y="PLCMOSv2") #st.bar_chart(df_1, x="Audio", y="WER") #col1, col2, col3, col4, col5 = st.columns(5) #col1.metric("PESQ", value = psq_mas[-1], delta = psq_mas[-1] - psq_mas[-2]) #col2.metric("STOI", value = stoi_mass[-1], delta = stoi_mass[-1] - stoi_mass[-2]) #col3.metric("PLCMOSv1", value = PLC_massv1[-1], delta = PLC_massv1[-1] - PLC_massv1[-2]) #col4.metric("PLCMOSv2", value = PLC_massv2[-1], delta = PLC_massv2[-1] - PLC_massv2[-2]) #col5.metric("WER", value = WER_mass[-1], delta = WER_mass[-1] - WER_mass[-2], delta_color="inverse")