phonemize-audio / app.py
cdleong's picture
off by one on progress bar
ac96dc0
raw
history blame
4.18 kB
import streamlit as st
import langcodes
from allosaurus.app import read_recognizer
from pathlib import Path
import string
from itertools import permutations
from collections import defaultdict
import torchaudio
@st.cache
def get_supported_codes():
model = read_recognizer()
supported_codes = []
supported_codes.append("ipa") # default option
for combo in permutations(string.ascii_lowercase, r=3):
code = "".join(combo)
if model.is_available(code):
supported_codes.append(code)
return supported_codes
def get_path_to_wav_format(uploaded_file, suppress_outputs=False):
# st.write(dir(uploaded_file))
# st.write(type(uploaded_file))
# st.write(uploaded_file)
uploaded_bytes = uploaded_file.getvalue()
actual_file_path = Path(uploaded_file.name)
actual_file_path.write_bytes(uploaded_bytes)
if ".wav" in uploaded_file.name:
return Path(uploaded_file.name)
if ".mp3" in uploaded_file.name:
new_desired_path = actual_file_path.with_suffix(".wav")
encoding="PCM_S" # Prevent encoding errors. https://stackoverflow.com/questions/60352850/wave-error-unknown-format-3-arises-when-trying-to-convert-a-wav-file-into-text
bits_per_sample=16
waveform, sample_rate = torchaudio.load(actual_file_path)
if not suppress_outputs:
st.info(f"Allosaurus requires .wav files. Converting with torchaudio, encoding={encoding}, bits_per_sample={bits_per_sample}")
st.info(f"Uploaded file sample_rate: {sample_rate}")
torchaudio.save(new_desired_path, waveform, sample_rate,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
return new_desired_path
@st.cache
def get_langcode_description(input_code, url=False):
langcode = "ipa" # the default allosaurus recognizer
description = "the default universal setting, not specific to any language"
if not input_code or input_code==langcode:
return description
try:
lang = langcodes.get(input_code)
alpha3 = lang.to_alpha3()
langcode = alpha3
display_name = lang.display_name()
if url:
description = f"[{display_name}](https://iso639-3.sil.org/code/{alpha3})"
else:
description = display_name
except langcodes.LanguageTagError as e:
pass
return description
@st.cache
def get_langcode_with_description(input_code):
return f"{input_code}: {get_langcode_description(input_code)}"
if __name__ == "__main__":
# input_code = st.text_input("(optional) 2 or 3-letter ISO code for input language. 2-letter codes will be converted to 3-letter codes", max_chars=3)
supported_codes = get_supported_codes()
index_of_desired_default = supported_codes.index("ipa")
langcode = st.selectbox("ISO code for input language. Allosaurus doesn't need this, but it can improve accuracy",
options=supported_codes,
index=index_of_desired_default,
format_func=get_langcode_with_description
)
model = read_recognizer()
description = get_langcode_description(langcode, url=True)
st.write(f"Instructing Allosaurus to recognize using language {langcode}. That is, {description}")
uploaded_files = st.file_uploader("Choose a file", type=[
".wav",
".mp3",
],
accept_multiple_files=True,
)
results = {} # for better download/display
uploaded_files_count = len(uploaded_files)
suppress_output_threshold = 2
my_bar = st.progress(0)
for i, uploaded_file in enumerate(uploaded_files):
if uploaded_file is not None:
wav_file = get_path_to_wav_format(uploaded_file, uploaded_files_count>suppress_output_threshold)
result = model.recognize(wav_file, langcode)
results[uploaded_file.name] = result
my_bar.progress(i+1/uploaded_files_count)
st.write(results)