conette / app.py
Labbeti's picture
Mod: Rework UI, remove tmp files and clear cache after 10min.
4ff8b3b
raw
history blame
No virus
9.7 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import Any, Optional, Union
import streamlit as st
import torchaudio
from st_audiorec import st_audiorec
from streamlit.runtime.uploaded_file_manager import UploadedFile
from torch import Tensor
from conette import CoNeTTEModel, conette
from conette.utils.collections import dict_list_to_list_dict
ALLOW_REP_MODES = ("stopwords", "all", "none")
DEFAULT_TASK = "audiocaps"
MAX_BEAM_SIZE = 20
MAX_PRED_SIZE = 30
MAX_BATCH_SIZE = 16
RECORD_AUDIO_FNAME = "microphone_conette_record.wav"
DEFAULT_THRESHOLD = 0.3
THRESHOLD_PRECISION = 100
MIN_AUDIO_DURATION_SEC = 0.3
MAX_AUDIO_DURATION_SEC = 60
HASH_PREFIX = "hash_"
TMP_FILE_PREFIX = "audio_tmp_file_"
SECOND_BEFORE_CLEAR_CACHE = 10 * 60
@st.cache_resource
def load_conette(*args, **kwargs) -> CoNeTTEModel:
return conette(*args, **kwargs)
def format_candidate(candidate: str) -> str:
if len(candidate) == 0:
return ""
else:
return f"{candidate[0].title()}{candidate[1:]}."
def format_tags(tags: Optional[list[str]]) -> str:
if tags is None or len(tags) == 0:
return "None."
else:
return ", ".join(tags)
def get_result_hash(audio_fname: str, generate_kwds: dict[str, Any]) -> str:
return f"{HASH_PREFIX}{audio_fname}-{generate_kwds}"
def get_results(
model: CoNeTTEModel,
audio_files: dict[str, bytes],
generate_kwds: dict[str, Any],
) -> dict[str, Union[dict[str, Any], str]]:
# Get audio to be processed
audio_to_predict: dict[str, tuple[str, bytes]] = {}
for audio_fname, audio in audio_files.items():
result_hash = get_result_hash(audio_fname, generate_kwds)
if result_hash not in st.session_state or audio_fname == RECORD_AUDIO_FNAME:
audio_to_predict[result_hash] = (audio_fname, audio)
# Save audio to be processed
tmp_files: dict[str, _TemporaryFileWrapper] = {}
for result_hash, (audio_fname, audio) in audio_to_predict.items():
tmp_file = NamedTemporaryFile(delete=False, prefix=TMP_FILE_PREFIX)
tmp_file.write(audio)
tmp_file.close()
metadata = torchaudio.info(tmp_file.name) # type: ignore
duration = metadata.num_frames / metadata.sample_rate
if MIN_AUDIO_DURATION_SEC > duration:
error_msg = f"""
##### Result for "{audio_fname}"
Audio file is too short. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}])
"""
st.session_state[result_hash] = error_msg
elif duration > MAX_AUDIO_DURATION_SEC:
error_msg = f"""
##### Result for "{audio_fname}"
Audio file is too long. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}])
"""
st.session_state[result_hash] = error_msg
else:
tmp_files[result_hash] = tmp_file
# Generate predictions and store them in session state
for start in range(0, len(tmp_files), MAX_BATCH_SIZE):
end = min(start + MAX_BATCH_SIZE, len(tmp_files))
result_hashes_j = list(tmp_files.keys())[start:end]
tmp_files_j = list(tmp_files.values())[start:end]
tmp_paths_j = [tmp_file.name for tmp_file in tmp_files_j]
outputs_j = model(
tmp_paths_j,
**generate_kwds,
)
outputs_lst = dict_list_to_list_dict(outputs_j) # type: ignore
for result_hash, output_i in zip(result_hashes_j, outputs_lst):
st.session_state[result_hash] = output_i
# Get outputs
outputs = {}
for audio_fname in audio_files.keys():
result_hash = get_result_hash(audio_fname, generate_kwds)
output_i = st.session_state[result_hash]
outputs[audio_fname] = output_i
for tmp_file in tmp_files.values():
os.remove(tmp_file.name)
return outputs
def show_results(outputs: dict[str, Union[dict[str, Any], str]]) -> None:
keys = list(outputs.keys())[::-1]
outputs = {key: outputs[key] for key in keys}
st.divider()
for audio_fname, output in outputs.items():
if isinstance(output, str):
st.error(output)
st.divider()
continue
cand: str = output["cands"]
lprobs: Tensor = output["lprobs"]
tags_lst = output.get("tags")
mult_cands: list[str] = output["mult_cands"]
mult_lprobs: Tensor = output["mult_lprobs"]
cand = format_candidate(cand)
prob = lprobs.exp().tolist()
tags = format_tags(tags_lst)
mult_cands = [format_candidate(cand_i) for cand_i in mult_cands]
mult_probs = mult_lprobs.exp()
indexes = mult_probs.argsort(descending=True)[1:]
mult_probs = mult_probs[indexes].tolist()
mult_cands = [mult_cands[idx] for idx in indexes]
if audio_fname == RECORD_AUDIO_FNAME:
header = "##### Result for microphone input:"
else:
header = f'##### Result for "{audio_fname}"'
lines = [
header,
f'<center><p class="space"><p class="big-font">"{cand}"</p></p></center>',
]
st.markdown("""
<style>
.big-font {
font-size:22px !important;
background-color: rgba(0, 255, 0, 0.1);
padding: 10px;
}
</style>
""", unsafe_allow_html=True)
content = "<br>".join(lines)
st.markdown(content, unsafe_allow_html=True)
lines = [
f"- **Probability**: {prob*100:.1f}%",
]
if len(mult_cands) > 0:
msg = f"- **Other descriptions:**"
lines.append(msg)
for cand_i, prob_i in zip(mult_cands, mult_probs):
msg = f' - "{cand_i}" ({prob_i*100:.1f}%)'
lines.append(msg)
msg = f"- **Tags:** {tags}"
lines.append(msg)
content = "\n".join(lines)
st.markdown(content, unsafe_allow_html=False)
st.divider()
def main() -> None:
model = load_conette(model_kwds=dict(device="cpu"))
st.header("Describe audio content with CoNeTTE")
st.markdown(
"This interface allows you to generate a short description of the sound events of any recording using an Audio Captioning system. You can try it from your microphone or upload a file below."
)
st.markdown(
"Use '**Start Recording**' and '**Stop**' to record an audio from your microphone."
)
record_data = st_audiorec()
with st.expander("Or upload audio files here:"):
audio_files: Optional[list[UploadedFile]] = st.file_uploader(
f"Audio files are automatically resampled to 32 kHz.\nTheir duration must be in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}] seconds.",
type=["wav", "flac", "mp3", "ogg", "avi"],
accept_multiple_files=True,
help="Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum.",
)
with st.expander("Model options"):
if DEFAULT_TASK in model.tasks:
default_task_idx = list(model.tasks).index(DEFAULT_TASK)
else:
default_task_idx = 0
task = st.selectbox("Task embedding input", model.tasks, default_task_idx)
allow_rep_mode = st.selectbox("Allow repetition of words", ALLOW_REP_MODES, 0)
beam_size: int = st.select_slider( # type: ignore
"Beam size",
list(range(1, MAX_BEAM_SIZE + 1)),
model.config.beam_size,
)
min_pred_size, max_pred_size = st.slider(
"Minimal and maximal number of words",
1,
MAX_PRED_SIZE,
(model.config.min_pred_size, model.config.max_pred_size),
)
threshold = st.select_slider(
"Tags threshold",
[(i / THRESHOLD_PRECISION) for i in range(THRESHOLD_PRECISION + 1)],
DEFAULT_THRESHOLD,
)
if allow_rep_mode == "all":
forbid_rep_mode = "none"
elif allow_rep_mode == "none":
forbid_rep_mode = "all"
elif allow_rep_mode == "stopwords":
forbid_rep_mode = "content_words"
else:
raise ValueError(
f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})"
)
del allow_rep_mode
generate_kwds: dict[str, Any] = dict(
task=task,
beam_size=beam_size,
min_pred_size=min_pred_size,
max_pred_size=max_pred_size,
forbid_rep_mode=forbid_rep_mode,
threshold=threshold,
)
audios: dict[str, bytes] = {}
if audio_files is not None:
audios |= {audio.name: audio.getvalue() for audio in audio_files}
if record_data is not None:
audios |= {RECORD_AUDIO_FNAME: record_data}
if len(audios) > 0:
with st.spinner("Generating descriptions..."):
outputs = get_results(model, audios, generate_kwds)
st.header("Results:")
show_results(outputs)
current = time.perf_counter()
last_generation = st.session_state.get("last_generation", current)
if current > last_generation + SECOND_BEFORE_CLEAR_CACHE:
print(f"Removing result cache...")
for key in st.session_state.keys():
if isinstance(key, str) and key.startswith(HASH_PREFIX):
del st.session_state[key]
st.session_state["last_generation"] = current
if __name__ == "__main__":
main()