Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM | |
import torch | |
import torchaudio | |
import torchaudio.functional as F | |
st.set_page_config( | |
page_title='Swedish Speech-to-Text', | |
page_icon='๐๏ธ' | |
) | |
# Import model and processor | |
model_name = 'viktor-enzell/wav2vec2-large-voxrex-swedish-4gram' | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device) | |
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_name) | |
def run_inference(file): | |
waveform, sample_rate = torchaudio.load(file) | |
if sample_rate == 16_000: | |
waveform = waveform[0] | |
else: | |
waveform = F.resample(waveform, sample_rate, 16_000)[0] | |
inputs = processor( | |
waveform, | |
sampling_rate=16_000, | |
return_tensors='pt', | |
padding=True | |
).to(device) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
return processor.batch_decode(logits.cpu().numpy()).text[0].lower() | |
uploaded_file = st.file_uploader('Choose a file', type=['.wav']) | |
if uploaded_file is not None: | |
transcript = run_inference(uploaded_file) | |
st.write(transcript) | |