Docker_v / app.py
XDHDD's picture
Update app.py
82a7041 verified
raw history blame
No virus
13.7 kB
import numpy as np
import streamlit as st
import librosa
import soundfile as sf
import librosa.display
from config import CONFIG
import torch
from dataset import MaskGenerator
import onnxruntime, onnx
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from pystoi import stoi
from pesq import pesq
import pandas as pd
import torchaudio
from torchmetrics.audio import ShortTimeObjectiveIntelligibility as STOI
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality as PESQ
from PLCMOS.plc_mos import PLCMOSEstimator
from speechmos import dnsmos
from speechmos import plcmos
import speech_recognition as speech_r
from jiwer import wer
import time
@st.cache
def load_model(model):
path = 'lightning_logs/version_0/checkpoints/' + str(model)
onnx_model = onnx.load(path)
options = onnxruntime.SessionOptions()
options.intra_op_num_threads = 2
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
session = onnxruntime.InferenceSession(path, options)
input_names = [x.name for x in session.get_inputs()]
output_names = [x.name for x in session.get_outputs()]
return session, onnx_model, input_names, output_names
def inference(re_im, session, onnx_model, input_names, output_names):
inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim],
dtype=np.float32)
for i, _input in enumerate(onnx_model.graph.input)
}
output_audio = []
for t in range(re_im.shape[0]):
inputs[input_names[0]] = re_im[t]
out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs)
inputs[input_names[1]] = prev_mag
inputs[input_names[2]] = predictor_state
inputs[input_names[3]] = mlp_state
output_audio.append(out)
output_audio = torch.tensor(np.concatenate(output_audio, 0))
output_audio = output_audio.permute(1, 0, 2).contiguous()
output_audio = torch.view_as_complex(output_audio)
output_audio = torch.istft(output_audio, window, stride, window=hann)
return output_audio.numpy()
def visualize(hr, lr, recon, sr):
sr = sr
window_size = 1024
window = np.hanning(window_size)
stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window)
stft_hr = 2 * np.abs(stft_hr) / np.sum(window)
stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window)
stft_lr = 2 * np.abs(stft_lr) / np.sum(window)
stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
stft_recon = 2 * np.abs(stft_recon) / np.sum(window)
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12))
ax1.title.set_text('Оригинальный сигнал')
ax2.title.set_text('Сигнал с потерями')
ax3.title.set_text('Улучшенный сигнал')
canvas = FigureCanvas(fig)
p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='log', x_axis='time', sr=sr)
p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='log', x_axis='time', sr=sr)
p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='log', x_axis='time', sr=sr)
ax1.set_xlabel('Время, с')
ax1.set_ylabel('Частота, Гц')
ax2.set_xlabel('Время, с')
ax2.set_ylabel('Частота, Гц')
ax3.set_xlabel('Время, с')
ax3.set_ylabel('Частота, Гц')
return fig
packet_size = CONFIG.DATA.EVAL.packet_size
window = CONFIG.DATA.window_size
stride = CONFIG.DATA.stride
title = 'Сокрытие потерь пакетов'
st.set_page_config(page_title=title, page_icon=":sound:")
st.title(title)
st.subheader('1. Загрузка аудио')
uploaded_file = st.file_uploader("Загрузите аудио формата (.wav) 48 КГц")
is_file_uploaded = uploaded_file is not None
if not is_file_uploaded:
uploaded_file = 'sample.wav'
target, sr = librosa.load(uploaded_file, sr=48000)
target = target[:packet_size * (len(target) // packet_size)]
st.text('Ваше аудио')
st.audio(uploaded_file)
model_ver = st.selectbox(
'Оригинал или Pruned ?',
('frn.onnx', 'frn_modified.onnx', 'frn_out_Q.onnx', 'frn_out_QF.onnx', 'frn_out_QInt16.onnx', 'frn_out_QInt8.onnx', 'frn_out_QUInt8.onnx', 'frn_out_QUInt16.onnx', 'frn_fp16 (1).onnx'))
st.write('Вы выбрали:', model_ver)
lang = st.selectbox(
'Выберите язык',
('ru-RU', 'en-EN'))
st.write('Вы выбрали:', lang)
st.subheader('2. Выберите желаемый процент потерь')
slider = [st.slider("Ожидаемый процент потерь для генератора потерь цепи Маркова", 0, 100, step=1)]
loss_percent = float(slider[0])/100
mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)])
lossy_input = target.copy().reshape(-1, packet_size)
mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis]
lossy_input *= mask
lossy_input = lossy_input.reshape(-1)
hann = torch.sqrt(torch.hann_window(window))
lossy_input_tensor = torch.tensor(lossy_input)
re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(1).numpy().astype(np.float32)
session, onnx_model, input_names, output_names = load_model(model_ver)
if st.button('Сгенерировать потери'):
with st.spinner('Ожидайте...'):
start_time = time.time()
output = inference(re_im, session, onnx_model, input_names, output_names)
st.text(str(time.time() - start_time))
st.subheader('3. Визуализация')
fig = visualize(target, lossy_input, output, sr)
st.pyplot(fig)
st.success('Сделано!')
sf.write('target.wav', target, sr)
sf.write('lossy.wav', lossy_input, sr)
sf.write('enhanced.wav', output, sr)
st.text('Оригинальное аудио')
st.audio('target.wav')
st.text('Аудио с потерями')
st.audio('lossy.wav')
st.text('Улучшенное аудио')
st.audio('enhanced.wav')
#data_clean, samplerate = torchaudio.load('target.wav')
#data_lossy, samplerate = torchaudio.load('lossy.wav')
#data_enhanced, samplerate = torchaudio.load('enhanced.wav')
#min_len = min(data_clean.shape[1], data_lossy.shape[1], data_enhanced.shape[1])
#data_clean = data_clean[:, :min_len]
#data_lossy = data_lossy[:, :min_len]
#data_enhanced = data_enhanced[:, :min_len]
#stoi = STOI(samplerate)
#stoi_orig = round(float(stoi(data_clean, data_clean)),3)
#stoi_lossy = round(float(stoi(data_clean, data_lossy)),5)
#stoi_enhanced = round(float(stoi(data_clean, data_enhanced)),5)
#stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced]
#pesq = PESQ(8000, 'nb')
#data_clean = data_clean.cpu().numpy()
#data_lossy = data_lossy.cpu().numpy()
#data_enhanced = data_enhanced.cpu().numpy()
#if samplerate != 8000:
#data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=8000)
#data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=8000)
#data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=8000)
#pesq_orig = float(pesq(torch.tensor(data_clean), torch.tensor(data_clean)))
#pesq_lossy = float(pesq(torch.tensor(data_lossy), torch.tensor(data_clean)))
#pesq_enhanced = float(pesq(torch.tensor(data_enhanced), torch.tensor(data_clean)))
#psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced]
#_____________________________________________
data_clean, samplerate = sf.read('target.wav')
data_lossy, samplerate = sf.read('lossy.wav')
data_enhanced, samplerate = sf.read('enhanced.wav')
min_len = min(data_clean.shape[0], data_lossy.shape[0], data_enhanced.shape[0])
data_clean = data_clean[:min_len]
data_lossy = data_lossy[:min_len]
data_enhanced = data_enhanced[:min_len]
stoi_orig = round(stoi(data_clean, data_clean, samplerate, extended=False),5)
stoi_lossy = round(stoi(data_clean, data_lossy , samplerate, extended=False),5)
stoi_enhanced = round(stoi(data_clean, data_enhanced, samplerate, extended=False),5)
stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced]
#def get_power(x, nfft):
# S = librosa.stft(x, n_fft=nfft)
# S = np.log(np.abs(S) ** 2 + 1e-8)
# return S
#def LSD(x_hr, x_pr):
# S1 = get_power(x_hr, nfft=2048)
# S2 = get_power(x_pr, nfft=2048)
# lsd = np.mean(np.sqrt(np.mean((S1 - S2) ** 2, axis=-1)), axis=0)
# return lsd
#lsd_orig = LSD(data_clean,data_clean)
#lsd_lossy = LSD(data_lossy,data_clean)
#lsd_enhanced = LSD(data_enhanced,data_clean)
#lsd_mass=[lsd_orig, lsd_lossy, lsd_enhanced]
if samplerate != 16000:
data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000)
data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000)
data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000)
pesq_orig = pesq(fs = 16000, ref = data_clean, deg = data_clean, mode='wb')
pesq_lossy = pesq(fs = 16000, ref = data_clean, deg = data_lossy, mode='wb')
pesq_enhanced = pesq(fs = 16000, ref = data_clean, deg = data_enhanced, mode='wb')
psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced]
data_clean, fs = sf.read('target.wav')
data_lossy, fs = sf.read('lossy.wav')
data_enhanced, fs = sf.read('enhanced.wav')
if fs!= 16000:
data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000)
data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000)
data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000)
PLC_example=PLCMOSEstimator()
PLC_org = PLC_example.run(audio_degraded=data_clean, audio_clean=data_clean)[0]
PLC_lossy = PLC_example.run(audio_degraded=data_lossy, audio_clean=data_clean)[0]
PLC_enhanced = PLC_example.run(audio_degraded=data_enhanced, audio_clean=data_clean)[0]
PLC_massv1 = [PLC_org, PLC_lossy, PLC_enhanced]
df_1 = pd.DataFrame(columns=['Audio', 'PESQ', 'STOI', 'PLCMOSv1'])
df_1['Audio'] = ['Clean', 'Lossy', 'Enhanced']
df_1['PESQ'] = psq_mas
df_1['STOI'] = stoi_mass
#df['LSD'] = lsd_mass
df_1['PLCMOSv1'] = PLC_massv1
#new_columns = pd.MultiIndex.from_tuples([('', 'Audio'), ('Эталонные метрики', 'PESQ'), ('Эталонные метрики', 'STOI'), ('Эталонные метрики', 'PLCMOSv1')])
# Присваиваем новый мультииндекс столбцам
#df_1.columns = new_columns
PLC_massv2 = [plcmos.run("target.wav", sr=16000)['plcmos'], plcmos.run("lossy.wav", sr=16000)['plcmos'], plcmos.run("enhanced.wav", sr=16000)['plcmos']]
#DNS = [dnsmos.run("target.wav", sr=16000)['ovrl_mos'], dnsmos.run("lossy.wav", sr=16000)['ovrl_mos'], dnsmos.run("enhanced.wav", sr=16000)['ovrl_mos']]
df_1['PLCMOSv2'] = PLC_massv2
#df_1['DNSMOS'] = DNS
#df_2 = pd.DataFrame(columns=['DNSMOS', 'PLCMOSv2'])
#df_2['DNSMOS'] = DNS
#df_2['PLCMOSv2'] = PLC_massv2
#new_columns = pd.MultiIndex.from_tuples([('Неэталонные метрики', 'DNSMOS'), ('Неэталонные метрики', 'PLCMOSv2')])
# Присваиваем новый мультииндекс столбцам
#df_2.columns = new_columns
#df_merged = df_1.merge(df_2, left_index=True, right_index=True)
r = speech_r.Recognizer()
harvard = speech_r.AudioFile('target.wav')
with harvard as source:
audio = r.record(source)
orig = r.recognize_google(audio, language = str(lang))
harvard = speech_r.AudioFile('lossy.wav')
#with harvard as source:
# audio = r.record(source)
#lossy = r.recognize_google(audio, language = "ru-RU")
try:
with harvard as source:
audio = r.record(source)
lossy = r.recognize_google(audio, language = str(lang))
#print("Распознанный текст:", text)
except speech_r.UnknownValueError:
#st.text("Система не смогла распознать аудио")
lossy = ''
#except speech_r.RequestError as e:
#st.text("Ошибка при запросе к сервису распознавания речи; {0}".format(e))
harvard = speech_r.AudioFile('enhanced.wav')
#with harvard as source:
# audio = r.record(source)
#enhanced = r.recognize_google(audio, language = "ru-RU")
try:
with harvard as source:
audio = r.record(source)
enhanced = r.recognize_google(audio, language = str(lang))
#print("Распознанный текст:", text)
except speech_r.UnknownValueError:
#st.text("Система не смогла распознать улучшенное аудио")
enhanced = ''
#except speech_r.RequestError as e:
#st.text("Ошибка при запросе к сервису распознавания речи; {0}".format(e))
error1 = wer(orig, orig)
error2 = wer(orig, lossy)
error3 = wer(orig, enhanced)
WER_mass=[error1*100, error2*100, error3*100]
df_1['WER'] = WER_mass
st.dataframe(df_1)