Spaces:
Sleeping
Sleeping
File size: 11,937 Bytes
89a09d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
# refer https://www.gradio.app/guides/real-time-speech-recognition
import os.path
from loguru import logger
from transformers import pipeline, WhisperProcessor
import gradio as gr
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from time import time
stream = True
import librosa
import soundfile
import psutil
def get_gpu_mem_info(gpu_id=0):
"""
根据显卡 id 获取显存使用信息, 单位 MB
:param gpu_id: 显卡 ID
:return: total 所有的显存,used 当前使用的显存, free 可使用的显存
"""
import pynvml
pynvml.nvmlInit()
if gpu_id < 0 or gpu_id >= pynvml.nvmlDeviceGetCount():
print(r'gpu_id {} 对应的显卡不存在!'.format(gpu_id))
return 0, 0, 0
handler = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handler)
total = round(meminfo.total / 1024 / 1024, 2)
used = round(meminfo.used / 1024 / 1024, 2)
free = round(meminfo.free / 1024 / 1024, 2)
return total, used, free
def get_cpu_mem_info():
"""
获取当前机器的内存信息, 单位 MB
:return: mem_total 当前机器所有的内存 mem_free 当前机器可用的内存 mem_process_used 当前进程使用的内存
"""
mem_total = round(psutil.virtual_memory().total / 1024 / 1024, 2)
mem_free = round(psutil.virtual_memory().available / 1024 / 1024, 2)
mem_process_used = round(psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024, 2)
return mem_total, mem_free, mem_process_used
"""
base模型: 30s语音GPU推理需要 ~500ms
"""
model_path = "yuxiang1990/asr-surg"
transcriber = pipeline(model=model_path, task="automatic-speech-recognition")
# transcriber = pipeline(task="automatic-speech-recognition", model="openai/whisper-base")
whole_sentence = ""
listen_cnt = 0 # 监听连续非静音的计数
silence_flag = False # True: 标志new_chunk信号为静音状态
silence_cnt = 0 # 静音计数
max_sentence = 10 # 交互界面显示最大交互句子的数量
sentence_cnt = 0 # 交互语句计数
processor = WhisperProcessor.from_pretrained(model_path, local_files_only=True)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="zh", task="transcribe")
generate_kwargs = {
"temperature": 1.0,
"top_p": 1.0,
"top_k": 50,
"max_length": 448,
"num_beams": 1,
"do_sample": False,
"forced_decoder_ids": forced_decoder_ids,
"repetition_penalty": 1.0,
"diversity_penalty": 0,
"length_penalty": 1.0,
"num_beam_groups": 1
}
def transcribe(stream, new_chunk, std_num, silence_num, max_speak_num, checkbox_record, record_path):
"""
:param stream: variable, all the audio history
:param new_chunk: -32767~32767 的 numpy arr
:param std_num: 归一化后的信号的方差
:param silence_num: 静默的时间,也就是有效语音后静音的时间, 如果为2,则是500ms * 2
:param max_speak_num: 最大监听时间:500ms * 30
:param checkbox_record: 是否录制
:return:
"""
global whole_sentence, listen_cnt, silence_flag, silence_cnt, sentence_cnt, max_sentence, generate_kwargs
# 有audio检测到声音时,才会有回调该函数
# 500 ms 执行一次
# 如果判断为静音,而且并没有在监听中,则不作为输入
# 如果处于监听中,开始静音,持续1s,则整体作为输入
logger.debug("std_num:{}, silence_num:{}, record:{}, max_speak_num:{}".format(std_num, silence_num, checkbox_record, max_speak_num))
sr, y = new_chunk
y = y.astype(np.float32)
y /= 32767
# 方法一:使用时间序列求Mel频谱
S = librosa.feature.melspectrogram(y=y, sr=sr)
mel_shape = "{}".format(S.shape)
if np.std(y) < std_num:
silence_flag = 1
silence_cnt += 1
listen_cnt = 0
else:
silence_flag = 0
listen_cnt += 1
if silence_flag == 0:
# 如果该new chunk信号为非静音信号,则进行信号合并操作
if stream is not None: # stream 为None,表明首次监听,或者上次推理结束steam设置None
stream = np.concatenate([stream, y]) # 持续将监听内容,进行融合
else:
stream = y
if stream is not None:
if (silence_cnt >= silence_num) | (listen_cnt > max_speak_num):
# 1. 当静音持续一段时间,进行推理
# 2. 当监听连续非静音的计算超过阈值,进行推理
now = time()
logger.info("start transcriber")
text = transcriber({"sampling_rate": sr, "raw": stream}, generate_kwargs=generate_kwargs)["text"]
logger.info("transcribe:{} sr:{} shape:{} max:{} min:{}".format(text, sr, y.shape, y.max(), y.min()))
logger.info("start transcriber done, cost:{}".format(time() - now))
# 进行录制记录
if checkbox_record:
if os.path.exists(record_path):
wav_path = os.path.join(record_path, "{}-{}.wav".format(text, str(int(time()))))
soundfile.write(wav_path, data=stream, samplerate=sr)
if sentence_cnt >= max_sentence:
sentence_cnt = 0
whole_sentence = text + "\n"
else:
sentence_cnt += 1
whole_sentence += text + "\n"
silence_cnt = 0
listen_cnt = 0
return None, whole_sentence, y.std(), y.max(), y.min(), S.std(), S.max(), S.min(), mel_shape
return stream, whole_sentence, y.std(), y.max(), y.min(), S.std(), S.max(), S.min(), mel_shape
def update_generate_kwargs(temperature, top_p, top_k, max_length,
num_beams, do_sample, repetition_penalty,
diversity_penalty, length_penalty, num_beam_groups):
global generate_kwargs
generate_kwargs['temperature'] = float(temperature)
generate_kwargs['top_p'] = float(top_p)
generate_kwargs['top_k'] = int(float(top_k))
generate_kwargs['max_length'] = int(float(max_length))
generate_kwargs['num_beams'] = int(float(num_beams))
if do_sample.lower() == 'false':
generate_kwargs['do_sample'] = False
else:
generate_kwargs['do_sample'] = True
generate_kwargs['repetition_penalty'] = float(repetition_penalty)
generate_kwargs['diversity_penalty'] = float(diversity_penalty)
generate_kwargs['length_penalty'] = float(length_penalty)
generate_kwargs['num_beam_groups'] = int(float(num_beam_groups))
logger.info("tmp:{} top_p:{} top_k:{} max_len:{}\n num_beams: {} do_sample:{} repetition_penalty:{} \n"
"diversity_penalty:{} length_penalty:{}".format(temperature, top_p, top_k, max_length, num_beams, do_sample,
repetition_penalty, diversity_penalty, length_penalty))
def clear():
global whole_sentence
whole_sentence = ""
def get_gpu_info():
total, used, free = get_gpu_mem_info(0)
return "{:.2f}/{:.2f}MB".format(used, total)
def get_cpu_info():
total, free, used = get_cpu_mem_info()
return "{:.2f}/{:.2f}MB".format(used, total)
with gr.Blocks() as demo:
gr.Markdown("# 语音识别助手--`内测`")
gr.Markdown("#### 例子:\n"
"- 小技助手,隐藏全部器官、显示全部脉管、显示全部病灶、显示肝段、显示支气管\n"
"- 小技助手,显示肺动脉,肺静脉, 关闭肝门静脉,下腔静脉,肝静脉,打开全部病灶,显示胆囊\n"
"- 助手小技,仅显示肝脏S8段,尾状叶,显示全部肺结节,仅显示肝脏病灶\n")
with gr.Group():
with gr.Row():
# with gr.Group():
with gr.Column(scale=2): # vertically
audio = gr.Audio(sources=["microphone"], streaming=True, every=True)
btn = gr.Button("清除识别内容")
status = gr.State()
silence_num = gr.Slider(label="静音时间阈值 (单位:500ms)", minimum=1, maximum=8, value=2, step=1, interactive=True)
max_speak_num = gr.Slider(label="最大stream时间 (单位:500ms)", minimum=10, maximum=100, value=50, step=1, interactive=True)
signal_std_num = gr.Slider(label="信号方差阈值 (s)", minimum=0.001, maximum=0.01, value=0.001, step=0.001, interactive=True)
with gr.Accordion("录制 options", open=False):
checkbox_record = gr.Checkbox(label="录制语音流")
record_path = gr.Textbox(label="保存路径")
with gr.Column(scale=2):
gr.Markdown("#### 信号监测:`")
with gr.Group():
with gr.Row():
info1 = gr.Textbox(label="实时信号方差:", min_width=100)
info2 = gr.Textbox(label="实时信号最大值:", min_width=100)
info3 = gr.Textbox(label="实时信号最小值:", min_width=100)
info1_med = gr.Textbox(label="梅尔频谱方差:", min_width=100)
info2_med = gr.Textbox(label="梅尔频谱最大值:", min_width=100)
info3_med = gr.Textbox(label="梅尔频谱最小值:", min_width=100)
info4_med = gr.Textbox(label="梅尔频谱shape:", min_width=100)
info_gpu = gr.Textbox(label="GPU:", min_width=100)
info_cpu = gr.Textbox(label="CPU:", min_width=100)
gr.Markdown("#### 模型调参:")
with gr.Group():
with gr.Row():
temperature = gr.Textbox(label='temperature', min_width=50, value=1.0)
top_p = gr.Textbox(label='top_p', min_width=100, value=1.0)
top_k = gr.Textbox(label='top_k', min_width=100, value=50)
max_length = gr.Textbox(label='max_length', min_width=100, value=448)
num_beams = gr.Textbox(label='num_beams', min_width=100, value=1)
do_sample = gr.Textbox(label='do_sample', min_width=100, value=False)
repetition_penalty = gr.Textbox(label='repetition_penalty', min_width=100, value=1.0)
diversity_penalty = gr.Textbox(label='diversity_penalty', min_width=100, value=0)
length_penalty = gr.Textbox(label='length_penalty', min_width=100, value=1.0)
num_beam_groups = gr.Textbox(label='num_beam_groups', min_width=100, value=1)
update_buttom = gr.Button('确认参数更新')
update_buttom.click(fn=update_generate_kwargs,
inputs=[temperature, top_p, top_k, max_length, num_beams, do_sample,
repetition_penalty, diversity_penalty, length_penalty, num_beam_groups], outputs=None)
with gr.Column(scale=5):
text = gr.Textbox("识别结果将在这里出现..", lines=40, max_lines=40, autoscroll=True, label="语音识别结果")
btn.click(clear) # 手动清除交互内容
audio.stream(transcribe, inputs=[status, audio, signal_std_num, silence_num, max_speak_num, checkbox_record, record_path],
outputs=[status, text, info1, info2, info3, info1_med, info2_med, info3_med, info4_med])
demo.load(get_gpu_info, None, info_gpu, every=0.5)
demo.load(get_cpu_info, None, info_cpu, every=0.5)
demo.launch() |