File size: 5,026 Bytes
687e655
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69f6cc9
 
 
687e655
 
 
 
 
 
 
 
 
 
 
69f6cc9
687e655
 
 
69f6cc9
687e655
 
 
 
 
 
 
 
 
 
 
 
69f6cc9
687e655
 
 
 
 
 
 
 
 
 
 
 
 
69f6cc9
 
687e655
 
69f6cc9
687e655
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import streamlit as st
import librosa
import soundfile as sf
import librosa.display
from config import CONFIG
import torch
from dataset import MaskGenerator
import onnxruntime, onnx
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

@st.cache
def load_model():
    path = 'lightning_logs/version_0/checkpoints/frn.onnx'
    onnx_model = onnx.load(path)
    options = onnxruntime.SessionOptions()
    options.intra_op_num_threads = 2
    options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
    session = onnxruntime.InferenceSession(path, options)
    input_names = [x.name for x in session.get_inputs()]
    output_names = [x.name for x in session.get_outputs()]
    return session, onnx_model, input_names, output_names

def inference(re_im, session, onnx_model, input_names, output_names):
    inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim],
                                       dtype=np.float32)
              for i, _input in enumerate(onnx_model.graph.input)
              }

    output_audio = []
    for t in range(re_im.shape[0]):
        inputs[input_names[0]] = re_im[t]
        out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs)
        inputs[input_names[1]] = prev_mag
        inputs[input_names[2]] = predictor_state
        inputs[input_names[3]] = mlp_state
        output_audio.append(out)

    output_audio = torch.tensor(np.concatenate(output_audio, 0))
    output_audio = output_audio.permute(1, 0, 2).contiguous()
    output_audio = torch.view_as_complex(output_audio)
    output_audio = torch.istft(output_audio, window, stride, window=hann)
    return output_audio.numpy()

def visualize(hr, lr, recon):
    sr = CONFIG.DATA.sr
    window_size = 1024
    window = np.hanning(window_size)

    stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window)
    stft_hr = 2 * np.abs(stft_hr) / np.sum(window)

    stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window)
    stft_lr = 2 * np.abs(stft_lr) / np.sum(window)

    stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
    stft_recon = 2 * np.abs(stft_recon) / np.sum(window)

    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10))
    ax1.title.set_text('Оригинальный сигнал')
    ax2.title.set_text('Сигнал с потерями')
    ax3.title.set_text('Улучшенный сигнал')

    canvas = FigureCanvas(fig)
    p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='linear', x_axis='time', sr=sr)
    p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='linear', x_axis='time', sr=sr)
    p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='linear', x_axis='time', sr=sr)
    return fig

packet_size = CONFIG.DATA.EVAL.packet_size
window = CONFIG.DATA.window_size
stride = CONFIG.DATA.stride

title = 'Сокрытие потерь пакетов'
st.set_page_config(page_title=title, page_icon=":sound:")
st.title(title)

st.subheader('Загрузить аудио')
uploaded_file = st.file_uploader("Upload your audio file (.wav) at 48 kHz sampling rate")

is_file_uploaded = uploaded_file is not None
if not is_file_uploaded:
    uploaded_file = 'sample.wav'

target, sr = librosa.load(uploaded_file, sr=48000)
target = target[:packet_size * (len(target) // packet_size)]

st.text('Audio sample')
st.audio(uploaded_file)

st.subheader('Выберите желаемый процент потерь')
slider = [st.slider("Expected loss rate for Markov Chain loss generator", 0, 100, step=1)]
loss_percent = float(slider[0])/100
mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)])
lossy_input = target.copy().reshape(-1, packet_size)
mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis]
lossy_input *= mask
lossy_input = lossy_input.reshape(-1)
hann = torch.sqrt(torch.hann_window(window))
lossy_input_tensor = torch.tensor(lossy_input)
re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(
    1).numpy().astype(np.float32)
session, onnx_model, input_names, output_names = load_model()

if st.button('Улучшить'):
    with st.spinner('Ожидайте...'):
        output = inference(re_im, session, onnx_model, input_names, output_names)

        st.subheader('Визуализация')
        fig = visualize(target, lossy_input, output)
        st.pyplot(fig)
    st.success('Done!')
    sf.write('target.wav', target, sr)
    sf.write('lossy.wav', lossy_input, sr)
    sf.write('enhanced.wav', output, sr)
    st.text('Original audio')
    st.audio('target.wav')
    st.text('Lossy audio')
    st.audio('lossy.wav')
    st.text('Enhanced audio')
    st.audio('enhanced.wav')