File size: 3,582 Bytes
7b59ebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4902da8
 
 
7b59ebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4902da8
 
7b59ebe
 
 
 
 
 
 
 
 
 
4902da8
7b59ebe
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import json

import gradio as gr
from omegaconf import OmegaConf
from huggingface_hub import snapshot_download
from vosk import Model, KaldiRecognizer

def load_vosk(model_id: str):
    model_dir = snapshot_download(model_id)
    return Model(model_path=model_dir)

OmegaConf.register_new_resolver("load_vosk", load_vosk)

models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml"))

def automatic_speech_recognition(model_id: str, dialect_id: str, audio_data: str):
    model = models_config[model_id]["model"][dialect_id]
    sample_rate, audio_array = audio_data
    if audio_array.ndim == 2:
        audio_array = audio_array[:, 0]

    audio_bytes = audio_array.tobytes()

    rec = KaldiRecognizer(model, sample_rate)

    rec.SetWords(True)

    results = []
    
    for start in range(0, len(audio_bytes), 4000):
        end = min(start + 4000, len(audio_bytes))
        data = audio_bytes[start:end]
        if rec.AcceptWaveform(data):
            raw_result = json.loads(rec.Result())
            results.append(raw_result)
                
    final_result = json.loads(rec.FinalResult())
    results.append(final_result)

    filtered_lines = []

    for result in results:
        result["text"] = result["text"].replace(" ", "")
        if len(result["text"]) > 0:
            filtered_lines.append(result["text"])

  
    return ",".join(filtered_lines) + "。"


def when_model_selected(model_id: str):
    model_config = models_config[model_id]

    dialect_drop_down_choices = [
        (k, v) for k, v in model_config["dialect_mapping"].items()
    ]

    return gr.update(
        choices=dialect_drop_down_choices,
        value=dialect_drop_down_choices[0][1],
    )


demo = gr.Blocks(
    title="臺灣客語語音辨識系統",
    css="@import url(https://tauhu.tw/tauhu-oo.css);",
    theme=gr.themes.Default(
        font=(
            "tauhu-oo",
            gr.themes.GoogleFont("Source Sans Pro"),
            "ui-sans-serif",
            "system-ui",
            "sans-serif",
        )
    ),
)

with demo:
    default_model_id = list(models_config.keys())[0]
    model_drop_down = gr.Dropdown(
        models_config.keys(),
        value=default_model_id,
        label="模型",
    )

    dialect_drop_down = gr.Dropdown(
        choices=[
            (k, v)
            for k, v in models_config[default_model_id]["dialect_mapping"].items()
        ],
        value=list(models_config[default_model_id]["dialect_mapping"].values())[0],
        label="腔調",
    )

    model_drop_down.input(
        when_model_selected,
        inputs=[model_drop_down],
        outputs=[dialect_drop_down],
    )

    gr.Markdown(
        """
        # 臺灣客語語音辨識系統
        ### Taiwanese Hakka Automatic-Speech-Recognition System
        ### 研發
        - **[李鴻欣 Hung-Shin Lee](mailto:hungshinlee@gmail.com)(聯和科創 United Link Co., Ltd.)**
        - **[陳力瑋 Li-Wei Chen](mailto:wayne900619@gmail.com)(聯和科創 United Link Co., Ltd.)**
        """
    )
    gr.Interface(
        automatic_speech_recognition,
        inputs=[
            model_drop_down,
            dialect_drop_down,
            gr.Audio(
                label="上傳或錄音",
                type="numpy",
                format="wav",
                waveform_options=gr.WaveformOptions(
                    sample_rate=16000,
                ),
            ),
        ],
        outputs=[
            gr.Text(interactive=False, label="客語漢字"),
        ],
        allow_flagging="auto",
    )

demo.launch()