Spaces:
Sleeping
Sleeping
# refer https://www.gradio.app/guides/real-time-speech-recognition | |
import os.path | |
from loguru import logger | |
from transformers import pipeline, WhisperProcessor | |
import gradio as gr | |
import numpy as np | |
import matplotlib | |
matplotlib.use('TkAgg') | |
import matplotlib.pyplot as plt | |
from time import time | |
stream = True | |
import librosa | |
import soundfile | |
import psutil | |
def get_gpu_mem_info(gpu_id=0): | |
""" | |
根据显卡 id 获取显存使用信息, 单位 MB | |
:param gpu_id: 显卡 ID | |
:return: total 所有的显存,used 当前使用的显存, free 可使用的显存 | |
""" | |
import pynvml | |
pynvml.nvmlInit() | |
if gpu_id < 0 or gpu_id >= pynvml.nvmlDeviceGetCount(): | |
print(r'gpu_id {} 对应的显卡不存在!'.format(gpu_id)) | |
return 0, 0, 0 | |
handler = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) | |
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handler) | |
total = round(meminfo.total / 1024 / 1024, 2) | |
used = round(meminfo.used / 1024 / 1024, 2) | |
free = round(meminfo.free / 1024 / 1024, 2) | |
return total, used, free | |
def get_cpu_mem_info(): | |
""" | |
获取当前机器的内存信息, 单位 MB | |
:return: mem_total 当前机器所有的内存 mem_free 当前机器可用的内存 mem_process_used 当前进程使用的内存 | |
""" | |
mem_total = round(psutil.virtual_memory().total / 1024 / 1024, 2) | |
mem_free = round(psutil.virtual_memory().available / 1024 / 1024, 2) | |
mem_process_used = round(psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024, 2) | |
return mem_total, mem_free, mem_process_used | |
""" | |
base模型: 30s语音GPU推理需要 ~500ms | |
""" | |
model_path = "yuxiang1990/asr-surg" | |
transcriber = pipeline(model=model_path, task="automatic-speech-recognition") | |
# transcriber = pipeline(task="automatic-speech-recognition", model="openai/whisper-base") | |
whole_sentence = "" | |
listen_cnt = 0 # 监听连续非静音的计数 | |
silence_flag = False # True: 标志new_chunk信号为静音状态 | |
silence_cnt = 0 # 静音计数 | |
max_sentence = 10 # 交互界面显示最大交互句子的数量 | |
sentence_cnt = 0 # 交互语句计数 | |
processor = WhisperProcessor.from_pretrained(model_path, local_files_only=True) | |
forced_decoder_ids = processor.get_decoder_prompt_ids(language="zh", task="transcribe") | |
generate_kwargs = { | |
"temperature": 1.0, | |
"top_p": 1.0, | |
"top_k": 50, | |
"max_length": 448, | |
"num_beams": 1, | |
"do_sample": False, | |
"forced_decoder_ids": forced_decoder_ids, | |
"repetition_penalty": 1.0, | |
"diversity_penalty": 0, | |
"length_penalty": 1.0, | |
"num_beam_groups": 1 | |
} | |
def transcribe(stream, new_chunk, std_num, silence_num, max_speak_num, checkbox_record, record_path): | |
""" | |
:param stream: variable, all the audio history | |
:param new_chunk: -32767~32767 的 numpy arr | |
:param std_num: 归一化后的信号的方差 | |
:param silence_num: 静默的时间,也就是有效语音后静音的时间, 如果为2,则是500ms * 2 | |
:param max_speak_num: 最大监听时间:500ms * 30 | |
:param checkbox_record: 是否录制 | |
:return: | |
""" | |
global whole_sentence, listen_cnt, silence_flag, silence_cnt, sentence_cnt, max_sentence, generate_kwargs | |
# 有audio检测到声音时,才会有回调该函数 | |
# 500 ms 执行一次 | |
# 如果判断为静音,而且并没有在监听中,则不作为输入 | |
# 如果处于监听中,开始静音,持续1s,则整体作为输入 | |
logger.debug("std_num:{}, silence_num:{}, record:{}, max_speak_num:{}".format(std_num, silence_num, checkbox_record, max_speak_num)) | |
sr, y = new_chunk | |
y = y.astype(np.float32) | |
y /= 32767 | |
# 方法一:使用时间序列求Mel频谱 | |
S = librosa.feature.melspectrogram(y=y, sr=sr) | |
mel_shape = "{}".format(S.shape) | |
if np.std(y) < std_num: | |
silence_flag = 1 | |
silence_cnt += 1 | |
listen_cnt = 0 | |
else: | |
silence_flag = 0 | |
listen_cnt += 1 | |
if silence_flag == 0: | |
# 如果该new chunk信号为非静音信号,则进行信号合并操作 | |
if stream is not None: # stream 为None,表明首次监听,或者上次推理结束steam设置None | |
stream = np.concatenate([stream, y]) # 持续将监听内容,进行融合 | |
else: | |
stream = y | |
if stream is not None: | |
if (silence_cnt >= silence_num) | (listen_cnt > max_speak_num): | |
# 1. 当静音持续一段时间,进行推理 | |
# 2. 当监听连续非静音的计算超过阈值,进行推理 | |
now = time() | |
logger.info("start transcriber") | |
text = transcriber({"sampling_rate": sr, "raw": stream}, generate_kwargs=generate_kwargs)["text"] | |
logger.info("transcribe:{} sr:{} shape:{} max:{} min:{}".format(text, sr, y.shape, y.max(), y.min())) | |
logger.info("start transcriber done, cost:{}".format(time() - now)) | |
# 进行录制记录 | |
if checkbox_record: | |
if os.path.exists(record_path): | |
wav_path = os.path.join(record_path, "{}-{}.wav".format(text, str(int(time())))) | |
soundfile.write(wav_path, data=stream, samplerate=sr) | |
if sentence_cnt >= max_sentence: | |
sentence_cnt = 0 | |
whole_sentence = text + "\n" | |
else: | |
sentence_cnt += 1 | |
whole_sentence += text + "\n" | |
silence_cnt = 0 | |
listen_cnt = 0 | |
return None, whole_sentence, y.std(), y.max(), y.min(), S.std(), S.max(), S.min(), mel_shape | |
return stream, whole_sentence, y.std(), y.max(), y.min(), S.std(), S.max(), S.min(), mel_shape | |
def update_generate_kwargs(temperature, top_p, top_k, max_length, | |
num_beams, do_sample, repetition_penalty, | |
diversity_penalty, length_penalty, num_beam_groups): | |
global generate_kwargs | |
generate_kwargs['temperature'] = float(temperature) | |
generate_kwargs['top_p'] = float(top_p) | |
generate_kwargs['top_k'] = int(float(top_k)) | |
generate_kwargs['max_length'] = int(float(max_length)) | |
generate_kwargs['num_beams'] = int(float(num_beams)) | |
if do_sample.lower() == 'false': | |
generate_kwargs['do_sample'] = False | |
else: | |
generate_kwargs['do_sample'] = True | |
generate_kwargs['repetition_penalty'] = float(repetition_penalty) | |
generate_kwargs['diversity_penalty'] = float(diversity_penalty) | |
generate_kwargs['length_penalty'] = float(length_penalty) | |
generate_kwargs['num_beam_groups'] = int(float(num_beam_groups)) | |
logger.info("tmp:{} top_p:{} top_k:{} max_len:{}\n num_beams: {} do_sample:{} repetition_penalty:{} \n" | |
"diversity_penalty:{} length_penalty:{}".format(temperature, top_p, top_k, max_length, num_beams, do_sample, | |
repetition_penalty, diversity_penalty, length_penalty)) | |
def clear(): | |
global whole_sentence | |
whole_sentence = "" | |
def get_gpu_info(): | |
total, used, free = get_gpu_mem_info(0) | |
return "{:.2f}/{:.2f}MB".format(used, total) | |
def get_cpu_info(): | |
total, free, used = get_cpu_mem_info() | |
return "{:.2f}/{:.2f}MB".format(used, total) | |
with gr.Blocks() as demo: | |
gr.Markdown("# 语音识别助手--`内测`") | |
gr.Markdown("#### 例子:\n" | |
"- 小技助手,隐藏全部器官、显示全部脉管、显示全部病灶、显示肝段、显示支气管\n" | |
"- 小技助手,显示肺动脉,肺静脉, 关闭肝门静脉,下腔静脉,肝静脉,打开全部病灶,显示胆囊\n" | |
"- 助手小技,仅显示肝脏S8段,尾状叶,显示全部肺结节,仅显示肝脏病灶\n") | |
with gr.Group(): | |
with gr.Row(): | |
# with gr.Group(): | |
with gr.Column(scale=2): # vertically | |
audio = gr.Audio(sources=["microphone"], streaming=True, every=True) | |
btn = gr.Button("清除识别内容") | |
status = gr.State() | |
silence_num = gr.Slider(label="静音时间阈值 (单位:500ms)", minimum=1, maximum=8, value=2, step=1, interactive=True) | |
max_speak_num = gr.Slider(label="最大stream时间 (单位:500ms)", minimum=10, maximum=100, value=50, step=1, interactive=True) | |
signal_std_num = gr.Slider(label="信号方差阈值 (s)", minimum=0.001, maximum=0.01, value=0.001, step=0.001, interactive=True) | |
with gr.Accordion("录制 options", open=False): | |
checkbox_record = gr.Checkbox(label="录制语音流") | |
record_path = gr.Textbox(label="保存路径") | |
with gr.Column(scale=2): | |
gr.Markdown("#### 信号监测:`") | |
with gr.Group(): | |
with gr.Row(): | |
info1 = gr.Textbox(label="实时信号方差:", min_width=100) | |
info2 = gr.Textbox(label="实时信号最大值:", min_width=100) | |
info3 = gr.Textbox(label="实时信号最小值:", min_width=100) | |
info1_med = gr.Textbox(label="梅尔频谱方差:", min_width=100) | |
info2_med = gr.Textbox(label="梅尔频谱最大值:", min_width=100) | |
info3_med = gr.Textbox(label="梅尔频谱最小值:", min_width=100) | |
info4_med = gr.Textbox(label="梅尔频谱shape:", min_width=100) | |
info_gpu = gr.Textbox(label="GPU:", min_width=100) | |
info_cpu = gr.Textbox(label="CPU:", min_width=100) | |
gr.Markdown("#### 模型调参:") | |
with gr.Group(): | |
with gr.Row(): | |
temperature = gr.Textbox(label='temperature', min_width=50, value=1.0) | |
top_p = gr.Textbox(label='top_p', min_width=100, value=1.0) | |
top_k = gr.Textbox(label='top_k', min_width=100, value=50) | |
max_length = gr.Textbox(label='max_length', min_width=100, value=448) | |
num_beams = gr.Textbox(label='num_beams', min_width=100, value=1) | |
do_sample = gr.Textbox(label='do_sample', min_width=100, value=False) | |
repetition_penalty = gr.Textbox(label='repetition_penalty', min_width=100, value=1.0) | |
diversity_penalty = gr.Textbox(label='diversity_penalty', min_width=100, value=0) | |
length_penalty = gr.Textbox(label='length_penalty', min_width=100, value=1.0) | |
num_beam_groups = gr.Textbox(label='num_beam_groups', min_width=100, value=1) | |
update_buttom = gr.Button('确认参数更新') | |
update_buttom.click(fn=update_generate_kwargs, | |
inputs=[temperature, top_p, top_k, max_length, num_beams, do_sample, | |
repetition_penalty, diversity_penalty, length_penalty, num_beam_groups], outputs=None) | |
with gr.Column(scale=5): | |
text = gr.Textbox("识别结果将在这里出现..", lines=40, max_lines=40, autoscroll=True, label="语音识别结果") | |
btn.click(clear) # 手动清除交互内容 | |
audio.stream(transcribe, inputs=[status, audio, signal_std_num, silence_num, max_speak_num, checkbox_record, record_path], | |
outputs=[status, text, info1, info2, info3, info1_med, info2_med, info3_med, info4_med]) | |
demo.load(get_gpu_info, None, info_gpu, every=0.5) | |
demo.load(get_cpu_info, None, info_cpu, every=0.5) | |
demo.launch() |