XDHDD commited on
Commit
a700e01
1 Parent(s): 00e8424

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -29,8 +29,8 @@ from jiwer import wer
29
 
30
 
31
  @st.cache
32
- def load_model():
33
- path = 'lightning_logs/version_0/checkpoints/frn_modified.onnx'
34
  onnx_model = onnx.load(path)
35
  options = onnxruntime.SessionOptions()
36
  options.intra_op_num_threads = 2
@@ -114,6 +114,13 @@ target = target[:packet_size * (len(target) // packet_size)]
114
  st.text('Ваше аудио')
115
  st.audio(uploaded_file)
116
 
 
 
 
 
 
 
 
117
  st.subheader('2. Выберите желаемый процент потерь')
118
  slider = [st.slider("Ожидаемый процент потерь для генератора потерь цепи Маркова", 0, 100, step=1)]
119
  loss_percent = float(slider[0])/100
@@ -126,7 +133,9 @@ hann = torch.sqrt(torch.hann_window(window))
126
  lossy_input_tensor = torch.tensor(lossy_input)
127
  re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(
128
  1).numpy().astype(np.float32)
129
- session, onnx_model, input_names, output_names = load_model()
 
 
130
 
131
  if st.button('Сгенерировать потери'):
132
  with st.spinner('Ожидайте...'):
 
29
 
30
 
31
  @st.cache
32
+ def load_model(model):
33
+ path = 'lightning_logs/version_0/checkpoints/' + str(model)
34
  onnx_model = onnx.load(path)
35
  options = onnxruntime.SessionOptions()
36
  options.intra_op_num_threads = 2
 
114
  st.text('Ваше аудио')
115
  st.audio(uploaded_file)
116
 
117
+ option = st.selectbox(
118
+ '1 or 2 onnx?',
119
+ ('frn.onnx', 'frn_modified.onnx'))
120
+
121
+ st.write('You selected:', option)
122
+
123
+
124
  st.subheader('2. Выберите желаемый процент потерь')
125
  slider = [st.slider("Ожидаемый процент потерь для генератора потерь цепи Маркова", 0, 100, step=1)]
126
  loss_percent = float(slider[0])/100
 
133
  lossy_input_tensor = torch.tensor(lossy_input)
134
  re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(
135
  1).numpy().astype(np.float32)
136
+
137
+
138
+ session, onnx_model, input_names, output_names = load_model(option)
139
 
140
  if st.button('Сгенерировать потери'):
141
  with st.spinner('Ожидайте...'):