Labbeti commited on
Commit
4ff8b3b
1 Parent(s): b7a5794

Mod: Rework UI, remove tmp files and clear cache after 10min.

Browse files
Files changed (1) hide show
  1. app.py +65 -19
app.py CHANGED
@@ -1,6 +1,8 @@
1
  #!/usr/bin/env python
2
  # -*- coding: utf-8 -*-
3
 
 
 
4
  from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
5
  from typing import Any, Optional, Union
6
 
@@ -16,14 +18,18 @@ from conette.utils.collections import dict_list_to_list_dict
16
 
17
 
18
  ALLOW_REP_MODES = ("stopwords", "all", "none")
 
19
  MAX_BEAM_SIZE = 20
20
  MAX_PRED_SIZE = 30
21
- MAX_BATCH_SIZE = 32
22
  RECORD_AUDIO_FNAME = "microphone_conette_record.wav"
23
  DEFAULT_THRESHOLD = 0.3
24
  THRESHOLD_PRECISION = 100
25
  MIN_AUDIO_DURATION_SEC = 0.3
26
  MAX_AUDIO_DURATION_SEC = 60
 
 
 
27
 
28
 
29
  @st.cache_resource
@@ -46,7 +52,7 @@ def format_tags(tags: Optional[list[str]]) -> str:
46
 
47
 
48
  def get_result_hash(audio_fname: str, generate_kwds: dict[str, Any]) -> str:
49
- return f"{audio_fname}-{generate_kwds}"
50
 
51
 
52
  def get_results(
@@ -64,7 +70,7 @@ def get_results(
64
  # Save audio to be processed
65
  tmp_files: dict[str, _TemporaryFileWrapper] = {}
66
  for result_hash, (audio_fname, audio) in audio_to_predict.items():
67
- tmp_file = NamedTemporaryFile(delete=False)
68
  tmp_file.write(audio)
69
  tmp_file.close()
70
 
@@ -109,6 +115,9 @@ def get_results(
109
  output_i = st.session_state[result_hash]
110
  outputs[audio_fname] = output_i
111
 
 
 
 
112
  return outputs
113
 
114
 
@@ -145,20 +154,39 @@ def show_results(outputs: dict[str, Union[dict[str, Any], str]]) -> None:
145
  else:
146
  header = f'##### Result for "{audio_fname}"'
147
 
148
- content = [
149
  header,
150
- f'- **Description:** "{cand}" ({prob*100:.1f}%)',
151
- f"- **Tags:** {tags}",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  ]
153
  if len(mult_cands) > 0:
154
  msg = f"- **Other descriptions:**"
155
- content.append(msg)
156
 
157
  for cand_i, prob_i in zip(mult_cands, mult_probs):
158
  msg = f' - "{cand_i}" ({prob_i*100:.1f}%)'
159
- content.append(msg)
 
 
 
160
 
161
- st.success("\n".join(content))
 
162
  st.divider()
163
 
164
 
@@ -167,19 +195,28 @@ def main() -> None:
167
 
168
  st.header("Describe audio content with CoNeTTE")
169
  st.markdown(
170
- "This interface allows you to generate a short description of the sound events of any recording. You can try it from your microphone or upload a file below."
171
  )
172
-
173
- record_data = st_audiorec()
174
- audio_files: Optional[list[UploadedFile]] = st.file_uploader(
175
- "**Or upload audio files here:**",
176
- type=["wav", "flac", "mp3", "ogg", "avi"],
177
- accept_multiple_files=True,
178
- help="Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum.",
179
  )
 
 
 
 
 
 
 
 
 
180
 
181
- with st.expander("Model hyperparameters"):
182
- task = st.selectbox("Task embedding input", model.tasks, 0)
 
 
 
 
 
183
  allow_rep_mode = st.selectbox("Allow repetition of words", ALLOW_REP_MODES, 0)
184
  beam_size: int = st.select_slider( # type: ignore
185
  "Beam size",
@@ -231,6 +268,15 @@ def main() -> None:
231
  st.header("Results:")
232
  show_results(outputs)
233
 
 
 
 
 
 
 
 
 
 
234
 
235
  if __name__ == "__main__":
236
  main()
 
1
  #!/usr/bin/env python
2
  # -*- coding: utf-8 -*-
3
 
4
+ import os
5
+ import time
6
  from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
7
  from typing import Any, Optional, Union
8
 
 
18
 
19
 
20
  ALLOW_REP_MODES = ("stopwords", "all", "none")
21
+ DEFAULT_TASK = "audiocaps"
22
  MAX_BEAM_SIZE = 20
23
  MAX_PRED_SIZE = 30
24
+ MAX_BATCH_SIZE = 16
25
  RECORD_AUDIO_FNAME = "microphone_conette_record.wav"
26
  DEFAULT_THRESHOLD = 0.3
27
  THRESHOLD_PRECISION = 100
28
  MIN_AUDIO_DURATION_SEC = 0.3
29
  MAX_AUDIO_DURATION_SEC = 60
30
+ HASH_PREFIX = "hash_"
31
+ TMP_FILE_PREFIX = "audio_tmp_file_"
32
+ SECOND_BEFORE_CLEAR_CACHE = 10 * 60
33
 
34
 
35
  @st.cache_resource
 
52
 
53
 
54
  def get_result_hash(audio_fname: str, generate_kwds: dict[str, Any]) -> str:
55
+ return f"{HASH_PREFIX}{audio_fname}-{generate_kwds}"
56
 
57
 
58
  def get_results(
 
70
  # Save audio to be processed
71
  tmp_files: dict[str, _TemporaryFileWrapper] = {}
72
  for result_hash, (audio_fname, audio) in audio_to_predict.items():
73
+ tmp_file = NamedTemporaryFile(delete=False, prefix=TMP_FILE_PREFIX)
74
  tmp_file.write(audio)
75
  tmp_file.close()
76
 
 
115
  output_i = st.session_state[result_hash]
116
  outputs[audio_fname] = output_i
117
 
118
+ for tmp_file in tmp_files.values():
119
+ os.remove(tmp_file.name)
120
+
121
  return outputs
122
 
123
 
 
154
  else:
155
  header = f'##### Result for "{audio_fname}"'
156
 
157
+ lines = [
158
  header,
159
+ f'<center><p class="space"><p class="big-font">"{cand}"</p></p></center>',
160
+ ]
161
+
162
+ st.markdown("""
163
+ <style>
164
+ .big-font {
165
+ font-size:22px !important;
166
+ background-color: rgba(0, 255, 0, 0.1);
167
+ padding: 10px;
168
+ }
169
+ </style>
170
+ """, unsafe_allow_html=True)
171
+ content = "<br>".join(lines)
172
+ st.markdown(content, unsafe_allow_html=True)
173
+
174
+ lines = [
175
+ f"- **Probability**: {prob*100:.1f}%",
176
  ]
177
  if len(mult_cands) > 0:
178
  msg = f"- **Other descriptions:**"
179
+ lines.append(msg)
180
 
181
  for cand_i, prob_i in zip(mult_cands, mult_probs):
182
  msg = f' - "{cand_i}" ({prob_i*100:.1f}%)'
183
+ lines.append(msg)
184
+
185
+ msg = f"- **Tags:** {tags}"
186
+ lines.append(msg)
187
 
188
+ content = "\n".join(lines)
189
+ st.markdown(content, unsafe_allow_html=False)
190
  st.divider()
191
 
192
 
 
195
 
196
  st.header("Describe audio content with CoNeTTE")
197
  st.markdown(
198
+ "This interface allows you to generate a short description of the sound events of any recording using an Audio Captioning system. You can try it from your microphone or upload a file below."
199
  )
200
+ st.markdown(
201
+ "Use '**Start Recording**' and '**Stop**' to record an audio from your microphone."
 
 
 
 
 
202
  )
203
+ record_data = st_audiorec()
204
+
205
+ with st.expander("Or upload audio files here:"):
206
+ audio_files: Optional[list[UploadedFile]] = st.file_uploader(
207
+ f"Audio files are automatically resampled to 32 kHz.\nTheir duration must be in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}] seconds.",
208
+ type=["wav", "flac", "mp3", "ogg", "avi"],
209
+ accept_multiple_files=True,
210
+ help="Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum.",
211
+ )
212
 
213
+ with st.expander("Model options"):
214
+ if DEFAULT_TASK in model.tasks:
215
+ default_task_idx = list(model.tasks).index(DEFAULT_TASK)
216
+ else:
217
+ default_task_idx = 0
218
+
219
+ task = st.selectbox("Task embedding input", model.tasks, default_task_idx)
220
  allow_rep_mode = st.selectbox("Allow repetition of words", ALLOW_REP_MODES, 0)
221
  beam_size: int = st.select_slider( # type: ignore
222
  "Beam size",
 
268
  st.header("Results:")
269
  show_results(outputs)
270
 
271
+ current = time.perf_counter()
272
+ last_generation = st.session_state.get("last_generation", current)
273
+ if current > last_generation + SECOND_BEFORE_CLEAR_CACHE:
274
+ print(f"Removing result cache...")
275
+ for key in st.session_state.keys():
276
+ if isinstance(key, str) and key.startswith(HASH_PREFIX):
277
+ del st.session_state[key]
278
+ st.session_state["last_generation"] = current
279
+
280
 
281
  if __name__ == "__main__":
282
  main()