nahue-passano
update: main streamlit app
9bdb941
import streamlit as st
import whisper_transcriber as whisper
import pandas as pd
from utils.files import (
create_temp_directory,
save_temp_file,
compress_utterances_folder,
)
from utils.text import get_sentence_data, get_word_data, generate_transcriptions_splits, check_ut_min_duration
from utils.audio import generate_audio_splits
STAMP_TYPES = {"Sentence-level": "sentence", "Word-level": "word"}
LANGUAGES = {"English": "en", "Spanish": "es"}
MODEL_SIZES = {"Medium": "medium", "Large": "large"}
@st.cache_resource(show_spinner=False)
def load_model(model_size: str):
"""Loads the Whisper model with size model_size
Parameters
----------
model_size : str
Available size of the whisper model
Returns
-------
_type_
Whisper model
"""
return whisper.load_model(
MODEL_SIZES[model_size], device="cpu", download_root="models"
)
def main_app():
st.title("🗣️💬 LibriSpeech Corpus Generator")
st.divider()
# Audio load
audio_file = st.file_uploader(
"Load audio files to process", type=["wav", "mp3"], accept_multiple_files=True
)
st.divider()
stamp_type, lang, size = st.columns(3)
with stamp_type:
timestamp_type = st.selectbox(
"Division level", options=list(STAMP_TYPES.keys())
)
with lang:
language = st.selectbox("Language", options=list(LANGUAGES.keys()))
with size:
model_size = st.selectbox("Model size", options=list(MODEL_SIZES.keys()))
st.divider()
if st.button("Process audios", use_container_width=True):
with st.spinner("Loading model..."):
model = load_model(model_size)
timestamps_df = pd.DataFrame()
temp_dir = create_temp_directory()
utterances_folder = temp_dir / "utterances_segments"
utterances_folder.mkdir(exist_ok=True)
for audio_i in audio_file:
with st.spinner(f"Processing audio: {audio_i.name}"):
tmp_audio = save_temp_file(audio_i)
# Whisper inference
tmp_audio_file = whisper.load_audio(tmp_audio)
timestamp_result = whisper.transcribe(
model, tmp_audio_file, language=LANGUAGES[language]
)
# Stamp level
if timestamp_type == "Sentence-level":
audio_i_df = get_sentence_data(audio_i.name, timestamp_result)
# Checks utterance duration
audio_i_df = check_ut_min_duration(audio_i_df)
if timestamp_type == "Word-level":
audio_i_df = get_word_data(audio_i.name, timestamp_result)
# Timestamps in dataframe
timestamps_df = pd.concat(
[timestamps_df, audio_i_df], ignore_index=True
)
generate_audio_splits(tmp_audio, audio_i_df, utterances_folder)
generate_transcriptions_splits(tmp_audio, audio_i_df, utterances_folder)
st.divider()
st.markdown(
"<h3 style='text-align: center;'>Timestamps</h3>",
unsafe_allow_html=True,
)
st.dataframe(timestamps_df)
st.divider()
col1, col2 = st.columns(2)
with col1:
st.download_button(
"Download timestamps in .csv",
timestamps_df.to_csv(index=False),
file_name="timestamps.csv",
mime="text/csv",
use_container_width=True,
)
with col2:
st.download_button(
"Download LibriSpeech-like dataset",
data=compress_utterances_folder(utterances_folder),
file_name="librispeech-like-dataset.zip",
mime="application/zip",
use_container_width=True,
)
if __name__ == "__main__":
main_app()