import streamlit as st import librosa import soundfile as sf import librosa.display from config import CONFIG import torch from dataset import MaskGenerator import onnxruntime, onnx import matplotlib.pyplot as plt import numpy as np from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas @st.cache def load_model(): path = 'lightning_logs/version_0/checkpoints/frn.onnx' onnx_model = onnx.load(path) options = onnxruntime.SessionOptions() options.intra_op_num_threads = 2 options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL session = onnxruntime.InferenceSession(path, options) input_names = [x.name for x in session.get_inputs()] output_names = [x.name for x in session.get_outputs()] return session, onnx_model, input_names, output_names def inference(re_im, session, onnx_model, input_names, output_names): inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim], dtype=np.float32) for i, _input in enumerate(onnx_model.graph.input) } output_audio = [] for t in range(re_im.shape[0]): inputs[input_names[0]] = re_im[t] out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs) inputs[input_names[1]] = prev_mag inputs[input_names[2]] = predictor_state inputs[input_names[3]] = mlp_state output_audio.append(out) output_audio = torch.tensor(np.concatenate(output_audio, 0)) output_audio = output_audio.permute(1, 0, 2).contiguous() output_audio = torch.view_as_complex(output_audio) output_audio = torch.istft(output_audio, window, stride, window=hann) return output_audio.numpy() def visualize(hr, lr, recon): sr = CONFIG.DATA.sr window_size = 1024 window = np.hanning(window_size) stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window) stft_hr = 2 * np.abs(stft_hr) / np.sum(window) stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window) stft_lr = 2 * np.abs(stft_lr) / np.sum(window) stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window) stft_recon = 2 * np.abs(stft_recon) / np.sum(window) fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10)) ax1.title.set_text('Target signal') ax2.title.set_text('Lossy signal') ax3.title.set_text('Enhanced signal') canvas = FigureCanvas(fig) p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='linear', x_axis='time', sr=sr) p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='linear', x_axis='time', sr=sr) p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='linear', x_axis='time', sr=sr) return fig packet_size = CONFIG.DATA.EVAL.packet_size window = CONFIG.DATA.window_size stride = CONFIG.DATA.stride title = 'Packet Loss Concealment' st.set_page_config(page_title=title, page_icon=":sound:") st.title(title) st.subheader('Upload audio') uploaded_file = st.file_uploader("Upload your audio file (.wav) at 48 kHz sampling rate") is_file_uploaded = uploaded_file is not None if not is_file_uploaded: uploaded_file = 'sample.wav' target, sr = librosa.load(uploaded_file, sr=48000) target = target[:packet_size * (len(target) // packet_size)] st.text('Audio sample') st.audio(uploaded_file) st.subheader('Choose expected packet loss rate') slider = [st.slider("Expected loss rate for Markov Chain loss generator", 0, 100, step=1)] loss_percent = float(slider[0])/100 mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)]) lossy_input = target.copy().reshape(-1, packet_size) mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis] lossy_input *= mask lossy_input = lossy_input.reshape(-1) hann = torch.sqrt(torch.hann_window(window)) lossy_input_tensor = torch.tensor(lossy_input) re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze( 1).numpy().astype(np.float32) session, onnx_model, input_names, output_names = load_model() if st.button('Conceal lossy audio!'): with st.spinner('Please wait for completion'): output = inference(re_im, session, onnx_model, input_names, output_names) st.subheader('Visualization') fig = visualize(target, lossy_input, output) st.pyplot(fig) st.success('Done!') sf.write('target.wav', target, sr) sf.write('lossy.wav', lossy_input, sr) sf.write('enhanced.wav', output, sr) st.text('Original audio') st.audio('target.wav') st.text('Lossy audio') st.audio('lossy.wav') st.text('Enhanced audio') st.audio('enhanced.wav')