|
import numpy as np |
|
|
|
import streamlit as st |
|
import librosa |
|
import soundfile as sf |
|
import librosa.display |
|
from config import CONFIG |
|
import torch |
|
from dataset import MaskGenerator |
|
import onnxruntime, onnx |
|
import matplotlib.pyplot as plt |
|
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas |
|
from pystoi import stoi |
|
from pesq import pesq |
|
import pandas as pd |
|
import torchaudio |
|
|
|
|
|
from torchmetrics.audio import ShortTimeObjectiveIntelligibility as STOI |
|
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality as PESQ |
|
|
|
|
|
from PLCMOS.plc_mos import PLCMOSEstimator |
|
from speechmos import dnsmos |
|
from speechmos import plcmos |
|
|
|
|
|
@st.cache |
|
def load_model(): |
|
path = 'lightning_logs/version_0/checkpoints/frn.onnx' |
|
onnx_model = onnx.load(path) |
|
options = onnxruntime.SessionOptions() |
|
options.intra_op_num_threads = 2 |
|
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
session = onnxruntime.InferenceSession(path, options) |
|
input_names = [x.name for x in session.get_inputs()] |
|
output_names = [x.name for x in session.get_outputs()] |
|
return session, onnx_model, input_names, output_names |
|
|
|
def inference(re_im, session, onnx_model, input_names, output_names): |
|
inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim], |
|
dtype=np.float32) |
|
for i, _input in enumerate(onnx_model.graph.input) |
|
} |
|
|
|
output_audio = [] |
|
for t in range(re_im.shape[0]): |
|
inputs[input_names[0]] = re_im[t] |
|
out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs) |
|
inputs[input_names[1]] = prev_mag |
|
inputs[input_names[2]] = predictor_state |
|
inputs[input_names[3]] = mlp_state |
|
output_audio.append(out) |
|
|
|
output_audio = torch.tensor(np.concatenate(output_audio, 0)) |
|
output_audio = output_audio.permute(1, 0, 2).contiguous() |
|
output_audio = torch.view_as_complex(output_audio) |
|
output_audio = torch.istft(output_audio, window, stride, window=hann) |
|
return output_audio.numpy() |
|
|
|
def visualize(hr, lr, recon, sr): |
|
sr = sr |
|
window_size = 1024 |
|
window = np.hanning(window_size) |
|
|
|
stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window) |
|
stft_hr = 2 * np.abs(stft_hr) / np.sum(window) |
|
|
|
stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window) |
|
stft_lr = 2 * np.abs(stft_lr) / np.sum(window) |
|
|
|
stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window) |
|
stft_recon = 2 * np.abs(stft_recon) / np.sum(window) |
|
|
|
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12)) |
|
ax1.title.set_text('Оригинальный сигнал') |
|
ax2.title.set_text('Сигнал с потерями') |
|
ax3.title.set_text('Улучшенный сигнал') |
|
|
|
canvas = FigureCanvas(fig) |
|
p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='log', x_axis='time', sr=sr) |
|
p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='log', x_axis='time', sr=sr) |
|
p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='log', x_axis='time', sr=sr) |
|
|
|
ax1.set_xlabel('Время, с') |
|
ax1.set_ylabel('Частота, Гц') |
|
ax2.set_xlabel('Время, с') |
|
ax2.set_ylabel('Частота, Гц') |
|
ax3.set_xlabel('Время, с') |
|
ax3.set_ylabel('Частота, Гц') |
|
return fig |
|
|
|
packet_size = CONFIG.DATA.EVAL.packet_size |
|
window = CONFIG.DATA.window_size |
|
stride = CONFIG.DATA.stride |
|
|
|
title = 'Сокрытие потерь пакетов' |
|
st.set_page_config(page_title=title, page_icon=":sound:") |
|
st.title(title) |
|
|
|
st.subheader('1. Загрузка аудио') |
|
uploaded_file = st.file_uploader("Загрузите аудио формата (.wav) 48 КГц") |
|
|
|
is_file_uploaded = uploaded_file is not None |
|
if not is_file_uploaded: |
|
uploaded_file = 'sample.wav' |
|
|
|
target, sr = librosa.load(uploaded_file) |
|
target = target[:packet_size * (len(target) // packet_size)] |
|
|
|
st.text('Ваше аудио') |
|
st.audio(uploaded_file) |
|
|
|
st.subheader('2. Выберите желаемый процент потерь') |
|
slider = [st.slider("Ожидаемый процент потерь для генератора потерь цепи Маркова", 0, 100, step=1)] |
|
loss_percent = float(slider[0])/100 |
|
mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)]) |
|
lossy_input = target.copy().reshape(-1, packet_size) |
|
mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis] |
|
lossy_input *= mask |
|
lossy_input = lossy_input.reshape(-1) |
|
hann = torch.sqrt(torch.hann_window(window)) |
|
lossy_input_tensor = torch.tensor(lossy_input) |
|
re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze( |
|
1).numpy().astype(np.float32) |
|
session, onnx_model, input_names, output_names = load_model() |
|
|
|
if st.button('Сгенерировать потери'): |
|
with st.spinner('Ожидайте...'): |
|
output = inference(re_im, session, onnx_model, input_names, output_names) |
|
|
|
st.subheader('3. Визуализация') |
|
fig = visualize(target, lossy_input, output, sr) |
|
st.pyplot(fig) |
|
st.success('Сделано!') |
|
sf.write('target.wav', target, sr) |
|
sf.write('lossy.wav', lossy_input, sr) |
|
sf.write('enhanced.wav', output, sr) |
|
st.text('Оригинальное аудио') |
|
st.audio('target.wav') |
|
st.text('Аудио с потерями') |
|
st.audio('lossy.wav') |
|
st.text('Улучшенное аудио') |
|
st.audio('enhanced.wav') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_clean, samplerate = sf.read('target.wav') |
|
data_lossy, samplerate = sf.read('lossy.wav') |
|
data_enhanced, samplerate = sf.read('enhanced.wav') |
|
min_len = min(data_clean.shape[0], data_lossy.shape[0], data_enhanced.shape[0]) |
|
data_clean = data_clean[:min_len] |
|
data_lossy = data_lossy[:min_len] |
|
data_enhanced = data_enhanced[:min_len] |
|
|
|
|
|
stoi_orig = round(stoi(data_clean, data_clean, samplerate, extended=False),5) |
|
stoi_lossy = round(stoi(data_clean, data_lossy , samplerate, extended=False),5) |
|
stoi_enhanced = round(stoi(data_clean, data_enhanced, samplerate, extended=False),5) |
|
|
|
stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if samplerate != 8000: |
|
data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=8000) |
|
data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=8000) |
|
data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=8000) |
|
|
|
|
|
|
|
pesq_orig = pesq(fs = 8000, ref = data_clean, deg = data_clean, mode='nb') |
|
pesq_lossy = pesq(fs = 8000, ref = data_clean, deg = data_lossy, mode='nb') |
|
pesq_enhanced = pesq(fs = 8000, ref = data_clean, deg = data_enhanced, mode='nb') |
|
|
|
psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced] |
|
|
|
|
|
|
|
data_clean, fs = sf.read('target.wav') |
|
data_lossy, fs = sf.read('lossy.wav') |
|
data_enhanced, fs = sf.read('enhanced.wav') |
|
|
|
if fs!= 16000: |
|
data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000) |
|
data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000) |
|
data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000) |
|
|
|
PLC_example=PLCMOSEstimator() |
|
PLC_org = PLC_example.run(audio_degraded=data_clean, audio_clean=data_clean)[0] |
|
PLC_lossy = PLC_example.run(audio_degraded=data_lossy, audio_clean=data_clean)[0] |
|
PLC_enhanced = PLC_example.run(audio_degraded=data_enhanced, audio_clean=data_clean)[0] |
|
|
|
PLC_massv1 = [PLC_org, PLC_lossy, PLC_enhanced] |
|
|
|
|
|
|
|
df_1 = pd.DataFrame(columns=['Audio', 'PESQ', 'STOI', 'PLCMOSv1']) |
|
|
|
df_1['Audio'] = ['Clean', 'Lossy', 'Enhanced'] |
|
|
|
df_1['PESQ'] = psq_mas |
|
|
|
df_1['STOI'] = stoi_mass |
|
|
|
|
|
df_1['PLCMOSv1'] = PLC_massv1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
PLC_massv2 = [plcmos.run("target.wav", sr=16000)['plcmos'], plcmos.run("lossy.wav", sr=16000)['plcmos'], plcmos.run("enhanced.wav", sr=16000)['plcmos']] |
|
|
|
DNS = [dnsmos.run("target.wav", sr=16000)['ovrl_mos'], dnsmos.run("lossy.wav", sr=16000)['ovrl_mos'], dnsmos.run("enhanced.wav", sr=16000)['ovrl_mos']] |
|
|
|
df_1['PLCMOSv2'] = PLC_massv2 |
|
df_1['DNSMOS'] = DNS |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.dataframe(df_1) |
|
|
|
|
|
|