anhnv125 commited on
Commit
042c5b7
Β·
2 Parent(s): 28d7565 16e5af6

Merge branch 'main' of https://huggingface.co/spaces/anhnv125/FRN

Browse files
Files changed (3) hide show
  1. README.md +4 -2
  2. app.py +11 -7
  3. requirements.txt +1 -1
README.md CHANGED
@@ -3,8 +3,10 @@ title: FRN
3
  emoji: πŸ“‰
4
  colorFrom: gray
5
  colorTo: red
6
- sdk: static
7
- pinned: false
 
 
8
  ---
9
 
10
  # FRN - Full-band Recurrent Network Official Implementation
 
3
  emoji: πŸ“‰
4
  colorFrom: gray
5
  colorTo: red
6
+ sdk: streamlit
7
+ pinned: true
8
+ app_file: app.py
9
+ sdk_version: 1.10.0
10
  ---
11
 
12
  # FRN - Full-band Recurrent Network Official Implementation
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  import librosa
 
3
  import librosa.display
4
  from config import CONFIG
5
  import torch
@@ -9,7 +10,7 @@ import matplotlib.pyplot as plt
9
  import numpy as np
10
  from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
11
 
12
- @st.cache_resource
13
  def load_model():
14
  path = 'lightning_logs/version_0/checkpoints/frn.onnx'
15
  onnx_model = onnx.load(path)
@@ -87,9 +88,9 @@ target = target[:packet_size * (len(target) // packet_size)]
87
  st.subheader('Original audio')
88
  st.audio(uploaded_file)
89
 
90
- st.subheader('Choose loss packet percentage')
91
- loss_percent = st.radio('Loss percentage', ['10%', '20%', '30%', '40%'])
92
- loss_percent = float(loss_percent[:-1])/100
93
  mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)])
94
  lossy_input = target.copy().reshape(-1, packet_size)
95
  mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis]
@@ -109,9 +110,12 @@ if st.button('Conceal lossy audio!'):
109
  fig = visualize(target, lossy_input, output)
110
  st.pyplot(fig)
111
  st.success('Done!')
 
 
 
112
  st.text('Original audio')
113
- st.audio(target, sample_rate=sr)
114
  st.text('Lossy audio')
115
- st.audio(lossy_input, sample_rate=sr)
116
  st.text('Enhanced audio')
117
- st.audio(output, sample_rate=sr)
 
1
  import streamlit as st
2
  import librosa
3
+ import soundfile as sf
4
  import librosa.display
5
  from config import CONFIG
6
  import torch
 
10
  import numpy as np
11
  from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
12
 
13
+ @st.cache
14
  def load_model():
15
  path = 'lightning_logs/version_0/checkpoints/frn.onnx'
16
  onnx_model = onnx.load(path)
 
88
  st.subheader('Original audio')
89
  st.audio(uploaded_file)
90
 
91
+ st.subheader('Choose expected packet loss rate')
92
+ slider = [st.slider("Expected loss rate for Markov Chain loss generator", 0, 100, step=1)]
93
+ loss_percent = float(slider[0])/100
94
  mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)])
95
  lossy_input = target.copy().reshape(-1, packet_size)
96
  mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis]
 
110
  fig = visualize(target, lossy_input, output)
111
  st.pyplot(fig)
112
  st.success('Done!')
113
+ sf.write('target.wav', target, sr)
114
+ sf.write('lossy.wav', lossy_input, sr)
115
+ sf.write('enhanced.wav', output, sr)
116
  st.text('Original audio')
117
+ st.audio('target.wav')
118
  st.text('Lossy audio')
119
+ st.audio('lossy.wav')
120
  st.text('Enhanced audio')
121
+ st.audio('enhanced.wav')
requirements.txt CHANGED
@@ -12,6 +12,6 @@ soundfile==0.11.0
12
  torch==1.13.1
13
  torchmetrics==0.11.0
14
  tqdm==4.64.0
15
- stoi==0.3.3
16
  pesq==0.0.4
17
  onnx==1.13.0
 
12
  torch==1.13.1
13
  torchmetrics==0.11.0
14
  tqdm==4.64.0
15
+ pystoi==0.3.3
16
  pesq==0.0.4
17
  onnx==1.13.0