anhnv125 commited on
Commit
bdb2571
1 Parent(s): 8da3748

add streamlit app

Browse files
Files changed (7) hide show
  1. app.py +117 -0
  2. dataset.py +1 -1
  3. inference_onnx.py +4 -4
  4. main.py +3 -4
  5. models/frn.py +2 -2
  6. sample.wav +0 -0
  7. utils/utils.py +5 -5
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import librosa
3
+ import librosa.display
4
+ from config import CONFIG
5
+ import torch
6
+ from dataset import MaskGenerator
7
+ import onnxruntime, onnx
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
11
+
12
+ @st.cache_resource
13
+ def load_model():
14
+ path = 'lightning_logs/version_0/checkpoints/frn.onnx'
15
+ onnx_model = onnx.load(path)
16
+ options = onnxruntime.SessionOptions()
17
+ options.intra_op_num_threads = 2
18
+ options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
19
+ session = onnxruntime.InferenceSession(path, options)
20
+ input_names = [x.name for x in session.get_inputs()]
21
+ output_names = [x.name for x in session.get_outputs()]
22
+ return session, onnx_model, input_names, output_names
23
+
24
+ def inference(re_im, session, onnx_model, input_names, output_names):
25
+ inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim],
26
+ dtype=np.float32)
27
+ for i, _input in enumerate(onnx_model.graph.input)
28
+ }
29
+
30
+ output_audio = []
31
+ for t in range(re_im.shape[0]):
32
+ inputs[input_names[0]] = re_im[t]
33
+ out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs)
34
+ inputs[input_names[1]] = prev_mag
35
+ inputs[input_names[2]] = predictor_state
36
+ inputs[input_names[3]] = mlp_state
37
+ output_audio.append(out)
38
+
39
+ output_audio = torch.tensor(np.concatenate(output_audio, 0))
40
+ output_audio = output_audio.permute(1, 0, 2).contiguous()
41
+ output_audio = torch.view_as_complex(output_audio)
42
+ output_audio = torch.istft(output_audio, window, stride, window=hann)
43
+ return output_audio.numpy()
44
+
45
+ def visualize(hr, lr, recon):
46
+ sr = CONFIG.DATA.sr
47
+ window_size = 1024
48
+ window = np.hanning(window_size)
49
+
50
+ stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window)
51
+ stft_hr = 2 * np.abs(stft_hr) / np.sum(window)
52
+
53
+ stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window)
54
+ stft_lr = 2 * np.abs(stft_lr) / np.sum(window)
55
+
56
+ stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
57
+ stft_recon = 2 * np.abs(stft_recon) / np.sum(window)
58
+
59
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10))
60
+ ax1.title.set_text('Target signal')
61
+ ax2.title.set_text('Lossy signal')
62
+ ax3.title.set_text('Enhanced signal')
63
+
64
+ canvas = FigureCanvas(fig)
65
+ p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='linear', x_axis='time', sr=sr)
66
+ p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='linear', x_axis='time', sr=sr)
67
+ p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='linear', x_axis='time', sr=sr)
68
+ return fig
69
+
70
+ packet_size = CONFIG.DATA.EVAL.packet_size
71
+ window = CONFIG.DATA.window_size
72
+ stride = CONFIG.DATA.stride
73
+
74
+ title = 'Packet Loss Concealment'
75
+ st.set_page_config(page_title=title, page_icon=":sound:")
76
+ st.title(title)
77
+
78
+ uploaded_file = st.file_uploader("Upload your audio file (.wav)")
79
+
80
+ is_file_uploaded = uploaded_file is not None
81
+ if not is_file_uploaded:
82
+ uploaded_file = 'sample.wav'
83
+
84
+ target, sr = librosa.load(uploaded_file, sr=48000)
85
+ target = target[:packet_size * (len(target) // packet_size)]
86
+
87
+ st.subheader('Original audio')
88
+ st.audio(uploaded_file)
89
+
90
+ st.subheader('Choose loss packet percentage')
91
+ loss_percent = st.radio('Loss percentage', ['10%', '20%', '30%', '40%'])
92
+ loss_percent = float(loss_percent[:-1])/100
93
+ mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)])
94
+ lossy_input = target.copy().reshape(-1, packet_size)
95
+ mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis]
96
+ lossy_input *= mask
97
+ lossy_input = lossy_input.reshape(-1)
98
+ hann = torch.sqrt(torch.hann_window(window))
99
+ lossy_input_tensor = torch.tensor(lossy_input)
100
+ re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(
101
+ 1).numpy().astype(np.float32)
102
+ session, onnx_model, input_names, output_names = load_model()
103
+
104
+ if st.button('Conceal lossy audio!'):
105
+ with st.spinner('Please wait for completion'):
106
+ output = inference(re_im, session, onnx_model, input_names, output_names)
107
+
108
+ st.subheader('Visualization')
109
+ fig = visualize(target, lossy_input, output)
110
+ st.pyplot(fig)
111
+ st.success('Done!')
112
+ st.text('Original audio')
113
+ st.audio(target, sample_rate=sr)
114
+ st.text('Lossy audio')
115
+ st.audio(lossy_input, sample_rate=sr)
116
+ st.text('Enhanced audio')
117
+ st.audio(output, sample_rate=sr)
dataset.py CHANGED
@@ -67,7 +67,7 @@ class MaskGenerator:
67
  else:
68
  assert len(probs) == 1
69
  prob = self.probs[0]
70
- self.mcs.append(MarkovChain([[probs[0], 1 - prob[0]], [1 - prob[1], prob[1]]], ['1', '0']))
71
 
72
  def gen_mask(self, length, seed=0):
73
  if self.is_train:
 
67
  else:
68
  assert len(probs) == 1
69
  prob = self.probs[0]
70
+ self.mcs.append(MarkovChain([[prob[0], 1 - prob[0]], [1 - prob[1], prob[1]]], ['1', '0']))
71
 
72
  def gen_mask(self, length, seed=0):
73
  if self.is_train:
inference_onnx.py CHANGED
@@ -38,8 +38,8 @@ if __name__ == '__main__':
38
  for file in tqdm.tqdm(audio_files, total=len(audio_files)):
39
  sig, _ = librosa.load(file, sr=48000)
40
  sig = torch.tensor(sig)
41
- re_im = torch.stft(sig, window, stride, window=hann, return_complex=False).permute(2, 0, 1).unsqueeze(
42
- 0).numpy().astype(np.float32)
43
 
44
  inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim],
45
  dtype=np.float32)
@@ -47,8 +47,8 @@ if __name__ == '__main__':
47
  }
48
 
49
  output_audio = []
50
- for t in range(re_im.shape[-1]):
51
- ri_t = re_im[:, :, :, t:t + 1]
52
  out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs)
53
  inputs[input_names[1]] = prev_mag
54
  inputs[input_names[2]] = predictor_state
 
38
  for file in tqdm.tqdm(audio_files, total=len(audio_files)):
39
  sig, _ = librosa.load(file, sr=48000)
40
  sig = torch.tensor(sig)
41
+ re_im = torch.stft(sig, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(
42
+ 1).numpy().astype(np.float32)
43
 
44
  inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim],
45
  dtype=np.float32)
 
47
  }
48
 
49
  output_audio = []
50
+ for t in range(re_im.shape[0]):
51
+ inputs[input_names[0]] = re_im[t]
52
  out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs)
53
  inputs[input_names[1]] = prev_mag
54
  inputs[input_names[2]] = predictor_state
main.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import pytorch_lightning as pl
5
  import soundfile as sf
6
  import torch
7
- from pytorch_lightning.callbacks import ModelCheckpoint
8
  from pytorch_lightning.utilities.model_summary import summarize
9
  from torch.utils.data import DataLoader
10
 
@@ -65,9 +65,8 @@ def train():
65
  gradient_clip_val=CONFIG.TRAIN.clipping_val,
66
  gpus=len(gpus),
67
  max_epochs=CONFIG.TRAIN.epochs,
68
- accelerator="ddp" if len(gpus) > 1 else None,
69
- stochastic_weight_avg=True,
70
- callbacks=[checkpoint_callback]
71
  )
72
 
73
  print(model.hparams)
 
4
  import pytorch_lightning as pl
5
  import soundfile as sf
6
  import torch
7
+ from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
8
  from pytorch_lightning.utilities.model_summary import summarize
9
  from torch.utils.data import DataLoader
10
 
 
65
  gradient_clip_val=CONFIG.TRAIN.clipping_val,
66
  gpus=len(gpus),
67
  max_epochs=CONFIG.TRAIN.epochs,
68
+ accelerator="gpu" if len(gpus) > 1 else None,
69
+ callbacks=[checkpoint_callback, StochasticWeightAveraging(swa_lrs=1e-2)]
 
70
  )
71
 
72
  print(model.hparams)
models/frn.py CHANGED
@@ -92,11 +92,11 @@ class PLCModel(pl.LightningModule):
92
 
93
  def train_dataloader(self):
94
  return DataLoader(self.train_dataset, shuffle=False, batch_size=self.hparams.batch_size,
95
- num_workers=CONFIG.TRAIN.workers)
96
 
97
  def val_dataloader(self):
98
  return DataLoader(self.val_dataset, shuffle=False, batch_size=self.hparams.batch_size,
99
- num_workers=CONFIG.TRAIN.workers)
100
 
101
  def training_step(self, batch, batch_idx):
102
  x_in, y = batch
 
92
 
93
  def train_dataloader(self):
94
  return DataLoader(self.train_dataset, shuffle=False, batch_size=self.hparams.batch_size,
95
+ num_workers=CONFIG.TRAIN.workers, persistent_workers=True)
96
 
97
  def val_dataloader(self):
98
  return DataLoader(self.val_dataset, shuffle=False, batch_size=self.hparams.batch_size,
99
+ num_workers=CONFIG.TRAIN.workers, persistent_workers=True)
100
 
101
  def training_step(self, batch, batch_idx):
102
  x_in, y = batch
sample.wav ADDED
Binary file (797 kB). View file
 
utils/utils.py CHANGED
@@ -24,23 +24,23 @@ def mkdir_p(mypath):
24
  raise
25
 
26
 
27
- def visualize(hr, lr, recon, path):
28
  sr = CONFIG.DATA.sr
29
  window_size = 1024
30
  window = np.hanning(window_size)
31
 
32
- stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window)
33
  stft_hr = 2 * np.abs(stft_hr) / np.sum(window)
34
 
35
- stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window)
36
  stft_lr = 2 * np.abs(stft_lr) / np.sum(window)
37
 
38
  stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
39
  stft_recon = 2 * np.abs(stft_recon) / np.sum(window)
40
 
41
  fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10))
42
- ax1.title.set_text('HR signal')
43
- ax2.title.set_text('LR signal')
44
  ax3.title.set_text('Reconstructed signal')
45
 
46
  canvas = FigureCanvas(fig)
 
24
  raise
25
 
26
 
27
+ def visualize(target, input, recon, path):
28
  sr = CONFIG.DATA.sr
29
  window_size = 1024
30
  window = np.hanning(window_size)
31
 
32
+ stft_hr = librosa.core.spectrum.stft(target, n_fft=window_size, hop_length=512, window=window)
33
  stft_hr = 2 * np.abs(stft_hr) / np.sum(window)
34
 
35
+ stft_lr = librosa.core.spectrum.stft(input, n_fft=window_size, hop_length=512, window=window)
36
  stft_lr = 2 * np.abs(stft_lr) / np.sum(window)
37
 
38
  stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
39
  stft_recon = 2 * np.abs(stft_recon) / np.sum(window)
40
 
41
  fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10))
42
+ ax1.title.set_text('Target signal')
43
+ ax2.title.set_text('Lossy signal')
44
  ax3.title.set_text('Reconstructed signal')
45
 
46
  canvas = FigureCanvas(fig)