Spaces:
Sleeping
Sleeping
Commit
•
89a09d6
1
Parent(s):
5c9a2c0
commit
Browse files- app.py +259 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# refer https://www.gradio.app/guides/real-time-speech-recognition
|
2 |
+
import os.path
|
3 |
+
from loguru import logger
|
4 |
+
from transformers import pipeline, WhisperProcessor
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib
|
8 |
+
matplotlib.use('TkAgg')
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from time import time
|
11 |
+
stream = True
|
12 |
+
import librosa
|
13 |
+
import soundfile
|
14 |
+
import psutil
|
15 |
+
|
16 |
+
|
17 |
+
def get_gpu_mem_info(gpu_id=0):
|
18 |
+
"""
|
19 |
+
根据显卡 id 获取显存使用信息, 单位 MB
|
20 |
+
:param gpu_id: 显卡 ID
|
21 |
+
:return: total 所有的显存,used 当前使用的显存, free 可使用的显存
|
22 |
+
"""
|
23 |
+
import pynvml
|
24 |
+
pynvml.nvmlInit()
|
25 |
+
if gpu_id < 0 or gpu_id >= pynvml.nvmlDeviceGetCount():
|
26 |
+
print(r'gpu_id {} 对应的显卡不存在!'.format(gpu_id))
|
27 |
+
return 0, 0, 0
|
28 |
+
|
29 |
+
handler = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
|
30 |
+
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handler)
|
31 |
+
total = round(meminfo.total / 1024 / 1024, 2)
|
32 |
+
used = round(meminfo.used / 1024 / 1024, 2)
|
33 |
+
free = round(meminfo.free / 1024 / 1024, 2)
|
34 |
+
return total, used, free
|
35 |
+
|
36 |
+
|
37 |
+
def get_cpu_mem_info():
|
38 |
+
"""
|
39 |
+
获取当前机器的内存信息, 单位 MB
|
40 |
+
:return: mem_total 当前机器所有的内存 mem_free 当前机器可用的内存 mem_process_used 当前进程使用的内存
|
41 |
+
"""
|
42 |
+
mem_total = round(psutil.virtual_memory().total / 1024 / 1024, 2)
|
43 |
+
mem_free = round(psutil.virtual_memory().available / 1024 / 1024, 2)
|
44 |
+
mem_process_used = round(psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024, 2)
|
45 |
+
return mem_total, mem_free, mem_process_used
|
46 |
+
|
47 |
+
"""
|
48 |
+
base模型: 30s语音GPU推理需要 ~500ms
|
49 |
+
"""
|
50 |
+
|
51 |
+
model_path = "yuxiang1990/asr-surg"
|
52 |
+
|
53 |
+
transcriber = pipeline(model=model_path, task="automatic-speech-recognition")
|
54 |
+
# transcriber = pipeline(task="automatic-speech-recognition", model="openai/whisper-base")
|
55 |
+
|
56 |
+
whole_sentence = ""
|
57 |
+
listen_cnt = 0 # 监听连续非静音的计数
|
58 |
+
silence_flag = False # True: 标志new_chunk信号为静音状态
|
59 |
+
silence_cnt = 0 # 静音计数
|
60 |
+
max_sentence = 10 # 交互界面显示最大交互句子的数量
|
61 |
+
sentence_cnt = 0 # 交互语句计数
|
62 |
+
|
63 |
+
processor = WhisperProcessor.from_pretrained(model_path, local_files_only=True)
|
64 |
+
forced_decoder_ids = processor.get_decoder_prompt_ids(language="zh", task="transcribe")
|
65 |
+
|
66 |
+
generate_kwargs = {
|
67 |
+
"temperature": 1.0,
|
68 |
+
"top_p": 1.0,
|
69 |
+
"top_k": 50,
|
70 |
+
"max_length": 448,
|
71 |
+
"num_beams": 1,
|
72 |
+
"do_sample": False,
|
73 |
+
"forced_decoder_ids": forced_decoder_ids,
|
74 |
+
"repetition_penalty": 1.0,
|
75 |
+
"diversity_penalty": 0,
|
76 |
+
"length_penalty": 1.0,
|
77 |
+
"num_beam_groups": 1
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
def transcribe(stream, new_chunk, std_num, silence_num, max_speak_num, checkbox_record, record_path):
|
82 |
+
"""
|
83 |
+
:param stream: variable, all the audio history
|
84 |
+
:param new_chunk: -32767~32767 的 numpy arr
|
85 |
+
:param std_num: 归一化后的信号的方差
|
86 |
+
:param silence_num: 静默的时间,也就是有效语音后静音的时间, 如果为2,则是500ms * 2
|
87 |
+
:param max_speak_num: 最大监听时间:500ms * 30
|
88 |
+
:param checkbox_record: 是否录制
|
89 |
+
:return:
|
90 |
+
"""
|
91 |
+
global whole_sentence, listen_cnt, silence_flag, silence_cnt, sentence_cnt, max_sentence, generate_kwargs
|
92 |
+
# 有audio检测到声音时,才会有回调该函数
|
93 |
+
# 500 ms 执行一次
|
94 |
+
# 如果判断为静音,而且并没有在监听中,则不作为输入
|
95 |
+
# 如果处于监听中,开始静音,持续1s,则整体作为输入
|
96 |
+
logger.debug("std_num:{}, silence_num:{}, record:{}, max_speak_num:{}".format(std_num, silence_num, checkbox_record, max_speak_num))
|
97 |
+
|
98 |
+
sr, y = new_chunk
|
99 |
+
y = y.astype(np.float32)
|
100 |
+
y /= 32767
|
101 |
+
|
102 |
+
# 方法一:使用时间序列求Mel频谱
|
103 |
+
S = librosa.feature.melspectrogram(y=y, sr=sr)
|
104 |
+
mel_shape = "{}".format(S.shape)
|
105 |
+
|
106 |
+
if np.std(y) < std_num:
|
107 |
+
silence_flag = 1
|
108 |
+
silence_cnt += 1
|
109 |
+
listen_cnt = 0
|
110 |
+
|
111 |
+
else:
|
112 |
+
silence_flag = 0
|
113 |
+
listen_cnt += 1
|
114 |
+
|
115 |
+
if silence_flag == 0:
|
116 |
+
# 如果该new chunk信号为非静音信号,则进行信号合并操作
|
117 |
+
if stream is not None: # stream 为None,表明首次监听,或者上次推理结束steam设置None
|
118 |
+
stream = np.concatenate([stream, y]) # 持续将监听内容,进行融合
|
119 |
+
else:
|
120 |
+
stream = y
|
121 |
+
|
122 |
+
if stream is not None:
|
123 |
+
if (silence_cnt >= silence_num) | (listen_cnt > max_speak_num):
|
124 |
+
# 1. 当静音持续一段时间,进行推理
|
125 |
+
# 2. 当监听连续非静音的计算超过阈值,进行推理
|
126 |
+
|
127 |
+
now = time()
|
128 |
+
logger.info("start transcriber")
|
129 |
+
text = transcriber({"sampling_rate": sr, "raw": stream}, generate_kwargs=generate_kwargs)["text"]
|
130 |
+
logger.info("transcribe:{} sr:{} shape:{} max:{} min:{}".format(text, sr, y.shape, y.max(), y.min()))
|
131 |
+
logger.info("start transcriber done, cost:{}".format(time() - now))
|
132 |
+
|
133 |
+
# 进行录制记录
|
134 |
+
if checkbox_record:
|
135 |
+
if os.path.exists(record_path):
|
136 |
+
wav_path = os.path.join(record_path, "{}-{}.wav".format(text, str(int(time()))))
|
137 |
+
soundfile.write(wav_path, data=stream, samplerate=sr)
|
138 |
+
|
139 |
+
if sentence_cnt >= max_sentence:
|
140 |
+
sentence_cnt = 0
|
141 |
+
whole_sentence = text + "\n"
|
142 |
+
else:
|
143 |
+
sentence_cnt += 1
|
144 |
+
whole_sentence += text + "\n"
|
145 |
+
silence_cnt = 0
|
146 |
+
listen_cnt = 0
|
147 |
+
|
148 |
+
return None, whole_sentence, y.std(), y.max(), y.min(), S.std(), S.max(), S.min(), mel_shape
|
149 |
+
|
150 |
+
|
151 |
+
return stream, whole_sentence, y.std(), y.max(), y.min(), S.std(), S.max(), S.min(), mel_shape
|
152 |
+
|
153 |
+
|
154 |
+
def update_generate_kwargs(temperature, top_p, top_k, max_length,
|
155 |
+
num_beams, do_sample, repetition_penalty,
|
156 |
+
diversity_penalty, length_penalty, num_beam_groups):
|
157 |
+
global generate_kwargs
|
158 |
+
generate_kwargs['temperature'] = float(temperature)
|
159 |
+
generate_kwargs['top_p'] = float(top_p)
|
160 |
+
generate_kwargs['top_k'] = int(float(top_k))
|
161 |
+
generate_kwargs['max_length'] = int(float(max_length))
|
162 |
+
generate_kwargs['num_beams'] = int(float(num_beams))
|
163 |
+
if do_sample.lower() == 'false':
|
164 |
+
generate_kwargs['do_sample'] = False
|
165 |
+
else:
|
166 |
+
generate_kwargs['do_sample'] = True
|
167 |
+
|
168 |
+
generate_kwargs['repetition_penalty'] = float(repetition_penalty)
|
169 |
+
generate_kwargs['diversity_penalty'] = float(diversity_penalty)
|
170 |
+
generate_kwargs['length_penalty'] = float(length_penalty)
|
171 |
+
generate_kwargs['num_beam_groups'] = int(float(num_beam_groups))
|
172 |
+
|
173 |
+
logger.info("tmp:{} top_p:{} top_k:{} max_len:{}\n num_beams: {} do_sample:{} repetition_penalty:{} \n"
|
174 |
+
"diversity_penalty:{} length_penalty:{}".format(temperature, top_p, top_k, max_length, num_beams, do_sample,
|
175 |
+
repetition_penalty, diversity_penalty, length_penalty))
|
176 |
+
|
177 |
+
|
178 |
+
def clear():
|
179 |
+
global whole_sentence
|
180 |
+
whole_sentence = ""
|
181 |
+
|
182 |
+
|
183 |
+
def get_gpu_info():
|
184 |
+
total, used, free = get_gpu_mem_info(0)
|
185 |
+
return "{:.2f}/{:.2f}MB".format(used, total)
|
186 |
+
|
187 |
+
|
188 |
+
def get_cpu_info():
|
189 |
+
total, free, used = get_cpu_mem_info()
|
190 |
+
return "{:.2f}/{:.2f}MB".format(used, total)
|
191 |
+
|
192 |
+
|
193 |
+
with gr.Blocks() as demo:
|
194 |
+
gr.Markdown("# 语音识别助手--`内测`")
|
195 |
+
gr.Markdown("#### 例子:\n"
|
196 |
+
"- 小技助手,隐藏全部器官、显示全部脉管、显示全部病灶、显示肝段、显示支气管\n"
|
197 |
+
"- 小技助手,显示肺动脉,肺静脉, 关闭肝门静脉,下腔静脉,肝静脉,打开全部病灶,显示胆囊\n"
|
198 |
+
"- 助手小技,仅显示肝脏S8段,尾状叶,显示全部肺结节,仅显示肝脏病灶\n")
|
199 |
+
with gr.Group():
|
200 |
+
with gr.Row():
|
201 |
+
# with gr.Group():
|
202 |
+
with gr.Column(scale=2): # vertically
|
203 |
+
audio = gr.Audio(sources=["microphone"], streaming=True, every=True)
|
204 |
+
btn = gr.Button("清除识别内容")
|
205 |
+
status = gr.State()
|
206 |
+
|
207 |
+
silence_num = gr.Slider(label="静音时间阈值 (单位:500ms)", minimum=1, maximum=8, value=2, step=1, interactive=True)
|
208 |
+
max_speak_num = gr.Slider(label="最大stream时间 (单位:500ms)", minimum=10, maximum=100, value=50, step=1, interactive=True)
|
209 |
+
signal_std_num = gr.Slider(label="信号方差阈值 (s)", minimum=0.001, maximum=0.01, value=0.001, step=0.001, interactive=True)
|
210 |
+
|
211 |
+
with gr.Accordion("录制 options", open=False):
|
212 |
+
checkbox_record = gr.Checkbox(label="录制语音流")
|
213 |
+
record_path = gr.Textbox(label="保存路径")
|
214 |
+
with gr.Column(scale=2):
|
215 |
+
gr.Markdown("#### 信号监测:`")
|
216 |
+
with gr.Group():
|
217 |
+
with gr.Row():
|
218 |
+
info1 = gr.Textbox(label="实时信号方差:", min_width=100)
|
219 |
+
info2 = gr.Textbox(label="实时信号最大值:", min_width=100)
|
220 |
+
info3 = gr.Textbox(label="实时信号最小值:", min_width=100)
|
221 |
+
|
222 |
+
info1_med = gr.Textbox(label="梅尔频谱方差:", min_width=100)
|
223 |
+
info2_med = gr.Textbox(label="梅尔频谱最大值:", min_width=100)
|
224 |
+
info3_med = gr.Textbox(label="梅尔频谱最小值:", min_width=100)
|
225 |
+
info4_med = gr.Textbox(label="梅尔频谱shape:", min_width=100)
|
226 |
+
|
227 |
+
info_gpu = gr.Textbox(label="GPU:", min_width=100)
|
228 |
+
info_cpu = gr.Textbox(label="CPU:", min_width=100)
|
229 |
+
|
230 |
+
gr.Markdown("#### 模型调参:")
|
231 |
+
with gr.Group():
|
232 |
+
with gr.Row():
|
233 |
+
temperature = gr.Textbox(label='temperature', min_width=50, value=1.0)
|
234 |
+
top_p = gr.Textbox(label='top_p', min_width=100, value=1.0)
|
235 |
+
top_k = gr.Textbox(label='top_k', min_width=100, value=50)
|
236 |
+
max_length = gr.Textbox(label='max_length', min_width=100, value=448)
|
237 |
+
num_beams = gr.Textbox(label='num_beams', min_width=100, value=1)
|
238 |
+
do_sample = gr.Textbox(label='do_sample', min_width=100, value=False)
|
239 |
+
repetition_penalty = gr.Textbox(label='repetition_penalty', min_width=100, value=1.0)
|
240 |
+
diversity_penalty = gr.Textbox(label='diversity_penalty', min_width=100, value=0)
|
241 |
+
length_penalty = gr.Textbox(label='length_penalty', min_width=100, value=1.0)
|
242 |
+
num_beam_groups = gr.Textbox(label='num_beam_groups', min_width=100, value=1)
|
243 |
+
|
244 |
+
update_buttom = gr.Button('确认参数更新')
|
245 |
+
update_buttom.click(fn=update_generate_kwargs,
|
246 |
+
inputs=[temperature, top_p, top_k, max_length, num_beams, do_sample,
|
247 |
+
repetition_penalty, diversity_penalty, length_penalty, num_beam_groups], outputs=None)
|
248 |
+
|
249 |
+
|
250 |
+
with gr.Column(scale=5):
|
251 |
+
text = gr.Textbox("识别结果将在这里出现..", lines=40, max_lines=40, autoscroll=True, label="语音识别结果")
|
252 |
+
btn.click(clear) # 手动清除交互内容
|
253 |
+
audio.stream(transcribe, inputs=[status, audio, signal_std_num, silence_num, max_speak_num, checkbox_record, record_path],
|
254 |
+
outputs=[status, text, info1, info2, info3, info1_med, info2_med, info3_med, info4_med])
|
255 |
+
|
256 |
+
demo.load(get_gpu_info, None, info_gpu, every=0.5)
|
257 |
+
demo.load(get_cpu_info, None, info_cpu, every=0.5)
|
258 |
+
|
259 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
librosa
|
2 |
+
soundfile
|
3 |
+
psutil
|
4 |
+
loguru
|
5 |
+
pynvml
|