File size: 2,924 Bytes
38229d4
 
 
 
 
 
5b95586
091b848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dce433
091b848
 
 
 
 
 
4dce433
 
 
 
 
091b848
 
 
 
 
 
 
 
 
 
4dce433
091b848
4dce433
091b848
 
4dce433
 
091b848
 
 
 
 
 
 
 
 
 
 
4dce433
 
091b848
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import streamlit as st
from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
import torch
import torchaudio
import torchaudio.functional as F


class ASR:
    def __init__(self):
        self.model_name = "viktor-enzell/wav2vec2-large-voxrex-swedish-4gram"
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model = None
        self.processor = None

    def load_model(self):
        self.model = Wav2Vec2ForCTC.from_pretrained(
            self.model_name).to(self.device)
        self.processor = Wav2Vec2ProcessorWithLM.from_pretrained(
            self.model_name)

    def run_inference(self, file):
        waveform, sample_rate = torchaudio.load(file)

        if sample_rate == 16_000:
            waveform = waveform[0]
        else:
            waveform = F.resample(waveform, sample_rate, 16_000)[0]

        inputs = self.processor(
            waveform,
            sampling_rate=16_000,
            return_tensors="pt",
            padding=True
        ).to(self.device)

        with torch.no_grad():
            logits = self.model(**inputs).logits

        return self.processor.batch_decode(logits.cpu().numpy()).text[0].lower()


@st.cache(allow_output_mutation=True, show_spinner=False)
def load_model():
    asr = ASR()
    asr.load_model()
    return asr


@st.cache(allow_output_mutation=True, hash_funcs={ASR: lambda _: None}, show_spinner=False)
def run_inference(asr, file):
    return asr.run_inference(file)


if __name__ == "__main__":
    st.set_page_config(
        page_title="Swedish Speech-to-Text",
        page_icon="🎙️"
    )
    st.image(
        "https://emojipedia-us.s3.dualstack.us-west-1.amazonaws.com/thumbs/320/apple/325/studio-microphone_1f399-fe0f.png",
        width=100,
    )
    st.markdown("""
    # Swedish Speech-to-text

    Generate and download high-quality Swedish transcripts for your audio files. The speech-to-text model is KBLab's wav2vec 2.0 large VoxRex Swedish (C) with a 4-gram language model, which you can access [here](https://huggingface.co/viktor-enzell/wav2vec2-large-voxrex-swedish-4gram).
    """)

    with st.spinner(text="Loading model..."):
        asr = load_model()

    uploaded_file = st.file_uploader("Choose a file", type=[".wav"])
    if uploaded_file is not None:
        if uploaded_file.type != "audio/wav":
            pass
            # TODO: convert to wav
            # bytes = uploaded_file.getvalue()
            # audio_input = ffmpeg.input(bytes).audio
            # audio_output = ffmpeg.output(audio_input, "tmp.wav", format="wav")
            # ffmpeg.run(audio_output)

        with st.spinner(text="Transcribing..."):
            transcript = run_inference(asr, uploaded_file)

        st.download_button("Download transcript", transcript, "transcript.txt")

        with st.expander("Transcript", expanded=True):
            st.write(transcript)