File size: 3,724 Bytes
c20b6de
 
 
 
 
 
 
 
 
cfb23ad
442ed91
 
c20b6de
 
9430ecf
c20b6de
 
 
 
 
 
442ed91
 
 
 
14485b0
 
 
 
 
 
 
 
c20b6de
 
14485b0
c20b6de
 
 
 
 
 
 
 
442ed91
 
 
 
 
 
 
 
 
 
c20b6de
14485b0
 
 
c20b6de
14485b0
c20b6de
14485b0
 
 
c20b6de
14485b0
c20b6de
442ed91
 
14485b0
 
 
 
 
 
c20b6de
 
14485b0
 
 
c20b6de
 
 
 
cfb23ad
c20b6de
 
1040275
c20b6de
 
cfb23ad
c20b6de
 
 
14485b0
 
fc9ce6a
c20b6de
 
14485b0
c20b6de
 
14485b0
c20b6de
14485b0
 
c20b6de
 
14485b0
c20b6de
 
 
9430ecf
c20b6de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from io import BytesIO
from typing import Tuple
import wave
import gradio as gr
import numpy as np
from pydub.audio_segment import AudioSegment
import requests
from os.path import exists
from stt import Model
from datetime import datetime
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch

# download model
version = "v0.4"
storage_url = f"https://github.com/robinhad/voice-recognition-ua/releases/download/{version}"
model_name = "uk.tflite"
scorer_name = "kenlm.scorer"
model_link = f"{storage_url}/{model_name}"
scorer_link = f"{storage_url}/{scorer_name}"

model = Wav2Vec2ForCTC.from_pretrained("robinhad/wav2vec2-xls-r-300m-uk")#.to("cuda")
processor = Wav2Vec2Processor.from_pretrained("robinhad/wav2vec2-xls-r-300m-uk")
# TODO: download config.json, pytorch_model.bin, preprocessor_config.json, tokenizer_config.json, vocab.json, added_tokens.json, special_tokens.json

def download(url, file_name):
    if not exists(file_name):
        print(f"Downloading {file_name}")
        r = requests.get(url, allow_redirects=True)
        with open(file_name, 'wb') as file:
            file.write(r.content)
    else:
        print(f"Found {file_name}. Skipping download...")


def deepspeech(audio: np.array, use_scorer=False):
    ds = Model(model_name)
    if use_scorer:
        ds.enableExternalScorer("kenlm.scorer")

    result = ds.stt(audio)

    return result

def wav2vec2(audio: np.array):
    input_dict = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
    with torch.no_grad():
        output = model(input_dict.input_values.float())

    logits = output.logits

    pred_ids = torch.argmax(logits, dim=-1)[0]

    return processor.decode(pred_ids)

def inference(audio: Tuple[int, np.array]):
    print("=============================")
    print(f"Time: {datetime.utcnow()}.`")

    output_audio = _convert_audio(audio[1], audio[0])

    fin = wave.open(output_audio, 'rb')
    audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
    fin.close()

    transcripts = []

    transcripts.append(wav2vec2(audio))
    print(f"Wav2Vec2: `{transcripts[-1]}`")
    transcripts.append(deepspeech(audio, use_scorer=True))
    print(f"Deepspeech with LM: `{transcripts[-1]}`")
    transcripts.append(deepspeech(audio))
    print(f"Deepspeech: `{transcripts[-1]}`")
    return tuple(transcripts)
    

def _convert_audio(audio_data: np.array, sample_rate: int):
    audio_limit = sample_rate * 60 * 2 # limit audio to 2 minutes max
    if audio_data.shape[0] > audio_limit: 
        audio_data = audio_data[0:audio_limit]
    source_audio = BytesIO()
    source_audio.write(audio_data)
    source_audio.seek(0)
    output_audio = BytesIO()
    wav_file: AudioSegment = AudioSegment.from_raw(
        source_audio,
        channels=1,
        sample_width=audio_data.dtype.itemsize,
        frame_rate=sample_rate
    )
    wav_file.export(output_audio, "wav", codec="pcm_s16le", parameters=["-ar", "16k"])
    output_audio.seek(0)
    return output_audio

with open("README.md") as file:
    article = file.read()
    article = article[article.find("---\n", 4) + 5::]

iface = gr.Interface(
    fn=inference,
    inputs=[
        gr.inputs.Audio(type="numpy",
                        label="Аудіо", optional=False),
    ],
    outputs=[gr.outputs.Textbox(label="Wav2Vec2"), gr.outputs.Textbox(label="DeepSpeech with LM"), gr.outputs.Textbox(label="DeepSpeech")],
    title="🇺🇦 Ukrainian Speech-to-Text models",
    theme="huggingface",
    description="Україномовний🇺🇦 Speech-to-Text за допомогою Coqui STT",
    article=article,
)

download(model_link, model_name)
download(scorer_link, scorer_name)
iface.launch()