viktor-enzell's picture
Caching inference function.
4dce433
raw
history blame
No virus
2.92 kB
import streamlit as st
from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
import torch
import torchaudio
import torchaudio.functional as F
class ASR:
def __init__(self):
self.model_name = "viktor-enzell/wav2vec2-large-voxrex-swedish-4gram"
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.processor = None
def load_model(self):
self.model = Wav2Vec2ForCTC.from_pretrained(
self.model_name).to(self.device)
self.processor = Wav2Vec2ProcessorWithLM.from_pretrained(
self.model_name)
def run_inference(self, file):
waveform, sample_rate = torchaudio.load(file)
if sample_rate == 16_000:
waveform = waveform[0]
else:
waveform = F.resample(waveform, sample_rate, 16_000)[0]
inputs = self.processor(
waveform,
sampling_rate=16_000,
return_tensors="pt",
padding=True
).to(self.device)
with torch.no_grad():
logits = self.model(**inputs).logits
return self.processor.batch_decode(logits.cpu().numpy()).text[0].lower()
@st.cache(allow_output_mutation=True, show_spinner=False)
def load_model():
asr = ASR()
asr.load_model()
return asr
@st.cache(allow_output_mutation=True, hash_funcs={ASR: lambda _: None}, show_spinner=False)
def run_inference(asr, file):
return asr.run_inference(file)
if __name__ == "__main__":
st.set_page_config(
page_title="Swedish Speech-to-Text",
page_icon="๐ŸŽ™๏ธ"
)
st.image(
"https://emojipedia-us.s3.dualstack.us-west-1.amazonaws.com/thumbs/320/apple/325/studio-microphone_1f399-fe0f.png",
width=100,
)
st.markdown("""
# Swedish Speech-to-text
Generate and download high-quality Swedish transcripts for your audio files. The speech-to-text model is KBLab's wav2vec 2.0 large VoxRex Swedish (C) with a 4-gram language model, which you can access [here](https://huggingface.co/viktor-enzell/wav2vec2-large-voxrex-swedish-4gram).
""")
with st.spinner(text="Loading model..."):
asr = load_model()
uploaded_file = st.file_uploader("Choose a file", type=[".wav"])
if uploaded_file is not None:
if uploaded_file.type != "audio/wav":
pass
# TODO: convert to wav
# bytes = uploaded_file.getvalue()
# audio_input = ffmpeg.input(bytes).audio
# audio_output = ffmpeg.output(audio_input, "tmp.wav", format="wav")
# ffmpeg.run(audio_output)
with st.spinner(text="Transcribing..."):
transcript = run_inference(asr, uploaded_file)
st.download_button("Download transcript", transcript, "transcript.txt")
with st.expander("Transcript", expanded=True):
st.write(transcript)