import streamlit as st from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM import torch import torchaudio import torchaudio.functional as F class ASR: def __init__(self): self.model_name = "viktor-enzell/wav2vec2-large-voxrex-swedish-4gram" self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.model = None self.processor = None def load_model(self): self.model = Wav2Vec2ForCTC.from_pretrained( self.model_name).to(self.device) self.processor = Wav2Vec2ProcessorWithLM.from_pretrained( self.model_name) def run_inference(self, file): waveform, sample_rate = torchaudio.load(file) if sample_rate == 16_000: waveform = waveform[0] else: waveform = F.resample(waveform, sample_rate, 16_000)[0] inputs = self.processor( waveform, sampling_rate=16_000, return_tensors="pt", padding=True ).to(self.device) with torch.no_grad(): logits = self.model(**inputs).logits return self.processor.batch_decode(logits.cpu().numpy()).text[0].lower() @st.cache(allow_output_mutation=True, show_spinner=False) def load_model(): asr = ASR() asr.load_model() return asr @st.cache(allow_output_mutation=True, hash_funcs={ASR: lambda _: None}, show_spinner=False) def run_inference(asr, file): return asr.run_inference(file) if __name__ == "__main__": st.set_page_config( page_title="Swedish Speech-to-Text", page_icon="🎙️" ) st.image( "https://emojipedia-us.s3.dualstack.us-west-1.amazonaws.com/thumbs/320/apple/325/studio-microphone_1f399-fe0f.png", width=100, ) st.markdown(""" # Swedish Speech-to-text Generate and download high-quality Swedish transcripts for your audio files. The speech-to-text model is KBLab's wav2vec 2.0 large VoxRex Swedish (C) with a 4-gram language model, which you can access [here](https://huggingface.co/viktor-enzell/wav2vec2-large-voxrex-swedish-4gram). """) with st.spinner(text="Loading model..."): asr = load_model() uploaded_file = st.file_uploader("Choose a file", type=[".wav"]) if uploaded_file is not None: if uploaded_file.type != "audio/wav": pass # TODO: convert to wav # bytes = uploaded_file.getvalue() # audio_input = ffmpeg.input(bytes).audio # audio_output = ffmpeg.output(audio_input, "tmp.wav", format="wav") # ffmpeg.run(audio_output) with st.spinner(text="Transcribing..."): transcript = run_inference(asr, uploaded_file) st.download_button("Download transcript", transcript, "transcript.txt") with st.expander("Transcript", expanded=True): st.write(transcript)