import numpy as np import streamlit as st import librosa import soundfile as sf import librosa.display from config import CONFIG import torch from dataset import MaskGenerator import onnxruntime, onnx import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from pystoi import stoi #from pesq import pesq import pandas as pd import torchaudio from torch_pesq import PesqLoss @st.cache def load_model(): path = 'lightning_logs/version_0/checkpoints/frn.onnx' onnx_model = onnx.load(path) options = onnxruntime.SessionOptions() options.intra_op_num_threads = 2 options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL session = onnxruntime.InferenceSession(path, options) input_names = [x.name for x in session.get_inputs()] output_names = [x.name for x in session.get_outputs()] return session, onnx_model, input_names, output_names def inference(re_im, session, onnx_model, input_names, output_names): inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim], dtype=np.float32) for i, _input in enumerate(onnx_model.graph.input) } output_audio = [] for t in range(re_im.shape[0]): inputs[input_names[0]] = re_im[t] out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs) inputs[input_names[1]] = prev_mag inputs[input_names[2]] = predictor_state inputs[input_names[3]] = mlp_state output_audio.append(out) output_audio = torch.tensor(np.concatenate(output_audio, 0)) output_audio = output_audio.permute(1, 0, 2).contiguous() output_audio = torch.view_as_complex(output_audio) output_audio = torch.istft(output_audio, window, stride, window=hann) return output_audio.np() def visualize(hr, lr, recon, sr): sr = sr window_size = 1024 window = np.hanning(window_size) stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window) stft_hr = 2 * np.abs(stft_hr) / np.sum(window) stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window) stft_lr = 2 * np.abs(stft_lr) / np.sum(window) stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window) stft_recon = 2 * np.abs(stft_recon) / np.sum(window) fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12)) ax1.title.set_text('Оригинальный сигнал') ax2.title.set_text('Сигнал с потерями') ax3.title.set_text('Улучшенный сигнал') canvas = FigureCanvas(fig) p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='log', x_axis='time', sr=sr) p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='log', x_axis='time', sr=sr) p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='log', x_axis='time', sr=sr) return fig packet_size = CONFIG.DATA.EVAL.packet_size window = CONFIG.DATA.window_size stride = CONFIG.DATA.stride title = 'Сокрытие потерь пакетов' st.set_page_config(page_title=title, page_icon=":sound:") st.title(title) st.subheader('1. Загрузка аудио') uploaded_file = st.file_uploader("Загрузите аудио формата (.wav) 48 КГц") is_file_uploaded = uploaded_file is not None if not is_file_uploaded: uploaded_file = 'sample.wav' target, sr = librosa.load(uploaded_file) target = target[:packet_size * (len(target) // packet_size)] st.text('Ваше аудио') st.audio(uploaded_file) st.subheader('2. Выберите желаемый процент потерь') slider = [st.slider("Ожидаемый процент потерь для генератора потерь цепи Маркова", 0, 100, step=1)] loss_percent = float(slider[0])/100 mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)]) lossy_input = target.copy().reshape(-1, packet_size) mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis] lossy_input *= mask lossy_input = lossy_input.reshape(-1) hann = torch.sqrt(torch.hann_window(window)) lossy_input_tensor = torch.tensor(lossy_input) re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze( 1).np().astype(np.float32) session, onnx_model, input_names, output_names = load_model() if st.button('Сгенерировать потери'): with st.spinner('Ожидайте...'): output = inference(re_im, session, onnx_model, input_names, output_names) st.subheader('3. Визуализация') fig = visualize(target, lossy_input, output, sr) st.pyplot(fig) st.success('Сделано!') sf.write('target.wav', target, sr) sf.write('lossy.wav', lossy_input, sr) sf.write('enhanced.wav', output, sr) st.text('Оригинальное аудио') st.audio('target.wav') st.text('Аудио с потерями') st.audio('lossy.wav') st.text('Улучшенное аудио') st.audio('enhanced.wav') data_clean, samplerate = sf.read('target.wav') data_lossy, samplerate = sf.read('lossy.wav') data_enhanced, samplerate = sf.read('enhanced.wav') min_len = min(data_clean.shape[0], data_lossy.shape[0], data_enhanced.shape[0]) data_clean = data_clean[:min_len] data_lossy = data_lossy[:min_len] data_enhanced = data_enhanced[:min_len] stoi_orig = round(stoi(data_clean, data_clean, samplerate, extended=False),5) stoi_lossy = round(stoi(data_clean, data_lossy , samplerate, extended=False),5) stoi_enhanced = round(stoi(data_clean, data_enhanced, samplerate, extended=False),5) stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced] #if samplerate != 16000: # data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000) # data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000) # data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000) # pesq = PesqLoss(0.5, sample_rate=48000) pesq_orig = pesq.mos(data_clean, data_clean) pesq_lossy = pesq.mos(data_clean, data_lossy) pesq_enhanced= pesq.mos(data_clean, data_enhanced) #pesq_orig = pesq(fs = 16000, ref = data_clean, deg = data_clean, mode='nb') #pesq_lossy = pesq(fs = 16000, ref = data_clean, deg = data_lossy, mode='nb') #pesq_enhanced = pesq(fs = 16000, ref = data_clean, deg = data_enhanced, mode='nb') psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced] df = pd.DataFrame(columns=['Audio', 'PESQ', 'STOI', 'PLCMOS', 'LSD']) df['Audio'] = ['Clean', 'Lossy', 'Enhanced'] df['PESQ'] = psq_mas df['STOI'] = stoi_mass st.table(df)