Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -8,11 +8,11 @@ from cnocr import CnOcr
|
|
8 |
import numpy as np
|
9 |
import openai
|
10 |
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, Prompt
|
11 |
-
from transformers import pipeline
|
12 |
import opencc
|
13 |
import scipy
|
14 |
import torch
|
15 |
-
import
|
16 |
|
17 |
converter = opencc.OpenCC('t2s') # 创建一个OpenCC实例,指定繁体字转为简体字
|
18 |
ocr = CnOcr() # 初始化ocr模型
|
@@ -21,8 +21,9 @@ all_max_len = 2000 # 输入的最大长度
|
|
21 |
asr_model_id = "openai/whisper-tiny" # 更新为你的模型ID
|
22 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
23 |
asr_pipe = pipeline("automatic-speech-recognition", model=asr_model_id, device=device)
|
24 |
-
|
25 |
-
|
|
|
26 |
|
27 |
def get_text_emb(open_ai_key, text): # 文本向量化
|
28 |
openai.api_key = open_ai_key # 设置openai的key
|
@@ -145,14 +146,16 @@ def get_response_by_llama_index(open_ai_key, msg, bot, query_engine): # 获取
|
|
145 |
return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
|
146 |
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
|
|
156 |
|
157 |
|
158 |
def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings, query_engine, index_type): # 获取机器人回复
|
@@ -160,8 +163,7 @@ def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings, query_eng
|
|
160 |
bot = get_response_by_self(open_ai_key, msg, bot, doc_text_list, doc_embeddings)
|
161 |
else: # 如果是使用llama_index索引
|
162 |
bot = get_response_by_llama_index(open_ai_key, msg, bot, query_engine)
|
163 |
-
|
164 |
-
return bot, gr.Audio(audio_answer_dir)
|
165 |
|
166 |
|
167 |
def up_file(files): # 上传文件
|
@@ -268,7 +270,7 @@ with gr.Blocks() as demo:
|
|
268 |
audio_inputs.change(transcribe_speech, [open_ai_key, audio_inputs, asr_type], [msg_txt]) # 录音输入
|
269 |
chat_bu.click(get_response,
|
270 |
[open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state, query_engine, index_type],
|
271 |
-
[chat_bot, audio_answer]) # 发送消息
|
272 |
|
273 |
if __name__ == "__main__":
|
274 |
demo.queue(concurrency_count=4).launch()
|
|
|
8 |
import numpy as np
|
9 |
import openai
|
10 |
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, Prompt
|
11 |
+
from transformers import pipeline, BarkModel, BarkProcessor
|
12 |
import opencc
|
13 |
import scipy
|
14 |
import torch
|
15 |
+
import hashlib
|
16 |
|
17 |
converter = opencc.OpenCC('t2s') # 创建一个OpenCC实例,指定繁体字转为简体字
|
18 |
ocr = CnOcr() # 初始化ocr模型
|
|
|
21 |
asr_model_id = "openai/whisper-tiny" # 更新为你的模型ID
|
22 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
23 |
asr_pipe = pipeline("automatic-speech-recognition", model=asr_model_id, device=device)
|
24 |
+
bark_model = BarkModel.from_pretrained("suno/bark-small")
|
25 |
+
bark_processor = BarkProcessor.from_pretrained("suno/bark-small")
|
26 |
+
sampling_rate = bark_model.generation_config.sample_rate
|
27 |
|
28 |
def get_text_emb(open_ai_key, text): # 文本向量化
|
29 |
openai.api_key = open_ai_key # 设置openai的key
|
|
|
146 |
return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
|
147 |
|
148 |
|
149 |
+
def get_audio_answer(bot): # 获取语音回答
|
150 |
+
answer = bot[-1][1]
|
151 |
+
inputs = bark_processor(
|
152 |
+
text=[answer],
|
153 |
+
return_tensors="pt",
|
154 |
+
)
|
155 |
+
speech_values = bark_model.generate(**inputs, do_sample=True)
|
156 |
+
au_dir = hashlib.md5(answer.encode('utf-8')).hexdigest() + '.wav' # 获取md5
|
157 |
+
scipy.io.wavfile.write(au_dir, rate=sampling_rate, data=speech_values.cpu().numpy().squeeze())
|
158 |
+
return gr.Audio().update(au_dir, autoplay=True)
|
159 |
|
160 |
|
161 |
def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings, query_engine, index_type): # 获取机器人回复
|
|
|
163 |
bot = get_response_by_self(open_ai_key, msg, bot, doc_text_list, doc_embeddings)
|
164 |
else: # 如果是使用llama_index索引
|
165 |
bot = get_response_by_llama_index(open_ai_key, msg, bot, query_engine)
|
166 |
+
return bot
|
|
|
167 |
|
168 |
|
169 |
def up_file(files): # 上传文件
|
|
|
270 |
audio_inputs.change(transcribe_speech, [open_ai_key, audio_inputs, asr_type], [msg_txt]) # 录音输入
|
271 |
chat_bu.click(get_response,
|
272 |
[open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state, query_engine, index_type],
|
273 |
+
[chat_bot])# .then(get_audio_answer, [chat_bot], [audio_answer]) # 发送消息
|
274 |
|
275 |
if __name__ == "__main__":
|
276 |
demo.queue(concurrency_count=4).launch()
|