File size: 11,937 Bytes
89a09d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
# refer https://www.gradio.app/guides/real-time-speech-recognition
import os.path
from loguru import logger
from transformers import pipeline, WhisperProcessor
import gradio as gr
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from time import time
stream = True
import librosa
import soundfile
import psutil


def get_gpu_mem_info(gpu_id=0):
    """
    根据显卡 id 获取显存使用信息, 单位 MB
    :param gpu_id: 显卡 ID
    :return: total 所有的显存,used 当前使用的显存, free 可使用的显存
    """
    import pynvml
    pynvml.nvmlInit()
    if gpu_id < 0 or gpu_id >= pynvml.nvmlDeviceGetCount():
        print(r'gpu_id {} 对应的显卡不存在!'.format(gpu_id))
        return 0, 0, 0

    handler = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
    meminfo = pynvml.nvmlDeviceGetMemoryInfo(handler)
    total = round(meminfo.total / 1024 / 1024, 2)
    used = round(meminfo.used / 1024 / 1024, 2)
    free = round(meminfo.free / 1024 / 1024, 2)
    return total, used, free


def get_cpu_mem_info():
    """
    获取当前机器的内存信息, 单位 MB
    :return: mem_total 当前机器所有的内存 mem_free 当前机器可用的内存 mem_process_used 当前进程使用的内存
    """
    mem_total = round(psutil.virtual_memory().total / 1024 / 1024, 2)
    mem_free = round(psutil.virtual_memory().available / 1024 / 1024, 2)
    mem_process_used = round(psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024, 2)
    return mem_total, mem_free, mem_process_used

"""
base模型: 30s语音GPU推理需要 ~500ms
"""

model_path = "yuxiang1990/asr-surg"

transcriber = pipeline(model=model_path, task="automatic-speech-recognition")
# transcriber = pipeline(task="automatic-speech-recognition", model="openai/whisper-base")

whole_sentence = ""
listen_cnt = 0  # 监听连续非静音的计数
silence_flag = False  # True: 标志new_chunk信号为静音状态
silence_cnt = 0  # 静音计数
max_sentence = 10  # 交互界面显示最大交互句子的数量
sentence_cnt = 0  # 交互语句计数

processor = WhisperProcessor.from_pretrained(model_path, local_files_only=True)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="zh", task="transcribe")

generate_kwargs = {
                    "temperature": 1.0,
                   "top_p": 1.0,
                    "top_k": 50,
                   "max_length": 448,
                   "num_beams": 1,
                   "do_sample": False,
                   "forced_decoder_ids": forced_decoder_ids,
                   "repetition_penalty": 1.0,
                   "diversity_penalty": 0,
                   "length_penalty": 1.0,
                   "num_beam_groups": 1
                   }


def transcribe(stream, new_chunk, std_num, silence_num, max_speak_num, checkbox_record, record_path):
    """
    :param stream: variable, all the audio history
    :param new_chunk: -32767~32767 的 numpy arr
    :param std_num: 归一化后的信号的方差
    :param silence_num: 静默的时间,也就是有效语音后静音的时间, 如果为2,则是500ms * 2
    :param max_speak_num: 最大监听时间:500ms * 30
    :param checkbox_record: 是否录制
    :return:
    """
    global whole_sentence, listen_cnt, silence_flag, silence_cnt, sentence_cnt, max_sentence, generate_kwargs
    # 有audio检测到声音时,才会有回调该函数
    # 500 ms 执行一次
    # 如果判断为静音,而且并没有在监听中,则不作为输入
    # 如果处于监听中,开始静音,持续1s,则整体作为输入
    logger.debug("std_num:{}, silence_num:{}, record:{}, max_speak_num:{}".format(std_num, silence_num, checkbox_record, max_speak_num))

    sr, y = new_chunk
    y = y.astype(np.float32)
    y /= 32767

    # 方法一:使用时间序列求Mel频谱
    S = librosa.feature.melspectrogram(y=y, sr=sr)
    mel_shape = "{}".format(S.shape)

    if np.std(y) < std_num:
        silence_flag = 1
        silence_cnt += 1
        listen_cnt = 0

    else:
        silence_flag = 0
        listen_cnt += 1

    if silence_flag == 0:
        # 如果该new chunk信号为非静音信号,则进行信号合并操作
        if stream is not None:  # stream 为None,表明首次监听,或者上次推理结束steam设置None
            stream = np.concatenate([stream, y])  # 持续将监听内容,进行融合
        else:
            stream = y

    if stream is not None:
        if (silence_cnt >= silence_num) | (listen_cnt > max_speak_num):
            # 1. 当静音持续一段时间,进行推理
            # 2. 当监听连续非静音的计算超过阈值,进行推理

            now = time()
            logger.info("start transcriber")
            text = transcriber({"sampling_rate": sr, "raw": stream}, generate_kwargs=generate_kwargs)["text"]
            logger.info("transcribe:{} sr:{} shape:{} max:{} min:{}".format(text, sr, y.shape, y.max(), y.min()))
            logger.info("start transcriber done, cost:{}".format(time() - now))

            # 进行录制记录
            if checkbox_record:
                if os.path.exists(record_path):
                    wav_path = os.path.join(record_path, "{}-{}.wav".format(text, str(int(time()))))
                    soundfile.write(wav_path, data=stream, samplerate=sr)

            if sentence_cnt >= max_sentence:
                sentence_cnt = 0
                whole_sentence = text + "\n"
            else:
                sentence_cnt += 1
                whole_sentence += text + "\n"
            silence_cnt = 0
            listen_cnt = 0

            return None, whole_sentence, y.std(), y.max(), y.min(), S.std(), S.max(), S.min(), mel_shape


    return stream, whole_sentence, y.std(), y.max(), y.min(), S.std(), S.max(), S.min(), mel_shape


def update_generate_kwargs(temperature, top_p, top_k, max_length,
                                       num_beams, do_sample, repetition_penalty,
                                       diversity_penalty, length_penalty, num_beam_groups):
    global generate_kwargs
    generate_kwargs['temperature'] = float(temperature)
    generate_kwargs['top_p'] = float(top_p)
    generate_kwargs['top_k'] = int(float(top_k))
    generate_kwargs['max_length'] = int(float(max_length))
    generate_kwargs['num_beams'] = int(float(num_beams))
    if do_sample.lower() == 'false':
        generate_kwargs['do_sample'] = False
    else:
        generate_kwargs['do_sample'] = True

    generate_kwargs['repetition_penalty'] = float(repetition_penalty)
    generate_kwargs['diversity_penalty'] = float(diversity_penalty)
    generate_kwargs['length_penalty'] = float(length_penalty)
    generate_kwargs['num_beam_groups'] = int(float(num_beam_groups))

    logger.info("tmp:{} top_p:{} top_k:{} max_len:{}\n num_beams: {} do_sample:{} repetition_penalty:{} \n"
                "diversity_penalty:{} length_penalty:{}".format(temperature, top_p, top_k, max_length, num_beams, do_sample,
                                                 repetition_penalty, diversity_penalty, length_penalty))


def clear():
    global whole_sentence
    whole_sentence = ""


def get_gpu_info():
    total, used, free = get_gpu_mem_info(0)
    return "{:.2f}/{:.2f}MB".format(used, total)


def get_cpu_info():
    total, free, used = get_cpu_mem_info()
    return "{:.2f}/{:.2f}MB".format(used, total)


with gr.Blocks() as demo:
    gr.Markdown("# 语音识别助手--`内测`")
    gr.Markdown("#### 例子:\n"
                    "- 小技助手,隐藏全部器官、显示全部脉管、显示全部病灶、显示肝段、显示支气管\n"
                    "- 小技助手,显示肺动脉,肺静脉, 关闭肝门静脉,下腔静脉,肝静脉,打开全部病灶,显示胆囊\n"
                    "- 助手小技,仅显示肝脏S8段,尾状叶,显示全部肺结节,仅显示肝脏病灶\n")
    with gr.Group():
        with gr.Row():
            # with gr.Group():
            with gr.Column(scale=2):  # vertically
                audio = gr.Audio(sources=["microphone"], streaming=True, every=True)
                btn = gr.Button("清除识别内容")
                status = gr.State()

                silence_num = gr.Slider(label="静音时间阈值 (单位:500ms)", minimum=1, maximum=8, value=2, step=1, interactive=True)
                max_speak_num = gr.Slider(label="最大stream时间 (单位:500ms)", minimum=10, maximum=100, value=50, step=1, interactive=True)
                signal_std_num = gr.Slider(label="信号方差阈值 (s)", minimum=0.001, maximum=0.01, value=0.001, step=0.001, interactive=True)

                with gr.Accordion("录制 options", open=False):
                    checkbox_record = gr.Checkbox(label="录制语音流")
                    record_path = gr.Textbox(label="保存路径")
            with gr.Column(scale=2):
                gr.Markdown("#### 信号监测:`")
                with gr.Group():
                    with gr.Row():
                        info1 = gr.Textbox(label="实时信号方差:", min_width=100)
                        info2 = gr.Textbox(label="实时信号最大值:", min_width=100)
                        info3 = gr.Textbox(label="实时信号最小值:", min_width=100)

                        info1_med = gr.Textbox(label="梅尔频谱方差:", min_width=100)
                        info2_med = gr.Textbox(label="梅尔频谱最大值:", min_width=100)
                        info3_med = gr.Textbox(label="梅尔频谱最小值:", min_width=100)
                        info4_med = gr.Textbox(label="梅尔频谱shape:", min_width=100)

                        info_gpu = gr.Textbox(label="GPU:", min_width=100)
                        info_cpu = gr.Textbox(label="CPU:", min_width=100)

                gr.Markdown("#### 模型调参:")
                with gr.Group():
                    with gr.Row():
                        temperature = gr.Textbox(label='temperature', min_width=50, value=1.0)
                        top_p = gr.Textbox(label='top_p', min_width=100, value=1.0)
                        top_k = gr.Textbox(label='top_k', min_width=100, value=50)
                        max_length = gr.Textbox(label='max_length', min_width=100, value=448)
                        num_beams = gr.Textbox(label='num_beams', min_width=100, value=1)
                        do_sample = gr.Textbox(label='do_sample', min_width=100, value=False)
                        repetition_penalty = gr.Textbox(label='repetition_penalty', min_width=100, value=1.0)
                        diversity_penalty = gr.Textbox(label='diversity_penalty', min_width=100, value=0)
                        length_penalty = gr.Textbox(label='length_penalty', min_width=100, value=1.0)
                        num_beam_groups = gr.Textbox(label='num_beam_groups', min_width=100, value=1)

                        update_buttom = gr.Button('确认参数更新')
                        update_buttom.click(fn=update_generate_kwargs,
                                            inputs=[temperature, top_p, top_k, max_length, num_beams, do_sample,
                                                    repetition_penalty, diversity_penalty, length_penalty, num_beam_groups], outputs=None)


            with gr.Column(scale=5):
                text = gr.Textbox("识别结果将在这里出现..", lines=40, max_lines=40, autoscroll=True, label="语音识别结果")
                btn.click(clear) # 手动清除交互内容
                audio.stream(transcribe, inputs=[status, audio, signal_std_num, silence_num, max_speak_num, checkbox_record, record_path],
                             outputs=[status, text, info1, info2, info3, info1_med, info2_med, info3_med, info4_med])

    demo.load(get_gpu_info, None, info_gpu, every=0.5)
    demo.load(get_cpu_info, None, info_cpu, every=0.5)

demo.launch()