souljoy commited on
Commit
43f767e
1 Parent(s): 16653e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -8,11 +8,11 @@ from cnocr import CnOcr
8
  import numpy as np
9
  import openai
10
  from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, Prompt
11
- from transformers import pipeline
12
  import opencc
13
  import scipy
14
  import torch
15
- import onnxruntime
16
 
17
  converter = opencc.OpenCC('t2s') # 创建一个OpenCC实例,指定繁体字转为简体字
18
  ocr = CnOcr() # 初始化ocr模型
@@ -21,8 +21,9 @@ all_max_len = 2000 # 输入的最大长度
21
  asr_model_id = "openai/whisper-tiny" # 更新为你的模型ID
22
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
23
  asr_pipe = pipeline("automatic-speech-recognition", model=asr_model_id, device=device)
24
- synthesiser = pipeline("text-to-speech", "suno/bark-small", device=device)
25
-
 
26
 
27
  def get_text_emb(open_ai_key, text): # 文本向量化
28
  openai.api_key = open_ai_key # 设置openai的key
@@ -145,14 +146,16 @@ def get_response_by_llama_index(open_ai_key, msg, bot, query_engine): # 获取
145
  return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
146
 
147
 
148
- import hashlib
149
-
150
-
151
- def get_audio_answer(answer): # 获取语音回答
152
- speech = synthesiser(answer, forward_params={"do_sample": True}) # 生成语音
153
- md5 = hashlib.md5(answer.encode('utf-8')).hexdigest() # 获取md5
154
- scipy.io.wavfile.write("{}.wav".format(md5), rate=speech["sampling_rate"], data=speech["audio"]) # 保存语音
155
- return "{}.wav".format(md5)
 
 
156
 
157
 
158
  def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings, query_engine, index_type): # 获取机器人回复
@@ -160,8 +163,7 @@ def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings, query_eng
160
  bot = get_response_by_self(open_ai_key, msg, bot, doc_text_list, doc_embeddings)
161
  else: # 如果是使用llama_index索引
162
  bot = get_response_by_llama_index(open_ai_key, msg, bot, query_engine)
163
- audio_answer_dir = get_audio_answer(bot[-1][1]) # 获取语音回答
164
- return bot, gr.Audio(audio_answer_dir)
165
 
166
 
167
  def up_file(files): # 上传文件
@@ -268,7 +270,7 @@ with gr.Blocks() as demo:
268
  audio_inputs.change(transcribe_speech, [open_ai_key, audio_inputs, asr_type], [msg_txt]) # 录音输入
269
  chat_bu.click(get_response,
270
  [open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state, query_engine, index_type],
271
- [chat_bot, audio_answer]) # 发送消息
272
 
273
  if __name__ == "__main__":
274
  demo.queue(concurrency_count=4).launch()
 
8
  import numpy as np
9
  import openai
10
  from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, Prompt
11
+ from transformers import pipeline, BarkModel, BarkProcessor
12
  import opencc
13
  import scipy
14
  import torch
15
+ import hashlib
16
 
17
  converter = opencc.OpenCC('t2s') # 创建一个OpenCC实例,指定繁体字转为简体字
18
  ocr = CnOcr() # 初始化ocr模型
 
21
  asr_model_id = "openai/whisper-tiny" # 更新为你的模型ID
22
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
23
  asr_pipe = pipeline("automatic-speech-recognition", model=asr_model_id, device=device)
24
+ bark_model = BarkModel.from_pretrained("suno/bark-small")
25
+ bark_processor = BarkProcessor.from_pretrained("suno/bark-small")
26
+ sampling_rate = bark_model.generation_config.sample_rate
27
 
28
  def get_text_emb(open_ai_key, text): # 文本向量化
29
  openai.api_key = open_ai_key # 设置openai的key
 
146
  return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
147
 
148
 
149
+ def get_audio_answer(bot): # 获取语音回答
150
+ answer = bot[-1][1]
151
+ inputs = bark_processor(
152
+ text=[answer],
153
+ return_tensors="pt",
154
+ )
155
+ speech_values = bark_model.generate(**inputs, do_sample=True)
156
+ au_dir = hashlib.md5(answer.encode('utf-8')).hexdigest() + '.wav' # 获取md5
157
+ scipy.io.wavfile.write(au_dir, rate=sampling_rate, data=speech_values.cpu().numpy().squeeze())
158
+ return gr.Audio().update(au_dir, autoplay=True)
159
 
160
 
161
  def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings, query_engine, index_type): # 获取机器人回复
 
163
  bot = get_response_by_self(open_ai_key, msg, bot, doc_text_list, doc_embeddings)
164
  else: # 如果是使用llama_index索引
165
  bot = get_response_by_llama_index(open_ai_key, msg, bot, query_engine)
166
+ return bot
 
167
 
168
 
169
  def up_file(files): # 上传文件
 
270
  audio_inputs.change(transcribe_speech, [open_ai_key, audio_inputs, asr_type], [msg_txt]) # 录音输入
271
  chat_bu.click(get_response,
272
  [open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state, query_engine, index_type],
273
+ [chat_bot])# .then(get_audio_answer, [chat_bot], [audio_answer]) # 发送消息
274
 
275
  if __name__ == "__main__":
276
  demo.queue(concurrency_count=4).launch()