ChatPDF / app.py
souljoy's picture
Update app.py
43f767e
import requests
import json
import gradio as gr
import pdfplumber
import pandas as pd
import time
from cnocr import CnOcr
import numpy as np
import openai
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, Prompt
from transformers import pipeline, BarkModel, BarkProcessor
import opencc
import scipy
import torch
import hashlib
converter = opencc.OpenCC('t2s') # 创建一个OpenCC实例,指定繁体字转为简体字
ocr = CnOcr() # 初始化ocr模型
history_max_len = 500 # 机器人记忆的最大长度
all_max_len = 2000 # 输入的最大长度
asr_model_id = "openai/whisper-tiny" # 更新为你的模型ID
device = "cuda:0" if torch.cuda.is_available() else "cpu"
asr_pipe = pipeline("automatic-speech-recognition", model=asr_model_id, device=device)
bark_model = BarkModel.from_pretrained("suno/bark-small")
bark_processor = BarkProcessor.from_pretrained("suno/bark-small")
sampling_rate = bark_model.generation_config.sample_rate
def get_text_emb(open_ai_key, text): # 文本向量化
openai.api_key = open_ai_key # 设置openai的key
response = openai.Embedding.create(
input=text,
model="text-embedding-ada-002"
) # 调用openai的api
return response['data'][0]['embedding'] # 返回向量
def doc_index_self(open_ai_key, doc): # 文档向量化
texts = doc.split('\n') # 按行切分
emb_list = [] # 用于存储向量
for text in texts: # 遍历每一行
emb_list.append(get_text_emb(open_ai_key, text)) # 获取向量
return texts, emb_list, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(
value="""操作说明 step 3:建立索引(by self)成功! 🙋 可以开始对话啦~"""), gr.Chatbot.update(visible=True), 1, gr.Audio.update(
visible=True), gr.Radio.update(visible=True)
def doc_index_llama(open_ai_key, txt): # 建立索引
# 根据时间戳新建目录,保存txt文件
path = str(time.time())
import os
os.mkdir(path)
with open(path + '/doc.txt', mode='w', encoding='utf-8') as f:
f.write(txt)
openai.api_key = open_ai_key # 设置OpenAI API Key
documents = SimpleDirectoryReader(path).load_data() # 读取文档
index = GPTVectorStoreIndex.from_documents(documents) # 建立索引
template = (
"你是一个有用的助手,可以使用文章内容准确地回答问题。使用提供的文章来生成你的答案,但避免逐字复制文章。尽可能使用自己的话。准确、有用、简洁、清晰。文章内容如下: \n"
"---------------------\n"
"{context_str}"
"\n---------------------\n"
"{query_str}\n"
"请你回复用户。\n"
) # 定义模板
qa_template = Prompt(template) # 将模板转换成Prompt对象
query_engine = index.as_query_engine(text_qa_template=qa_template) # 建立查询引擎
return query_engine, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(
value="""操作说明 step 3:建立索引(by llama_index)成功! 🙋 可以开始对话啦~"""), gr.Chatbot.update(
visible=True), 0, gr.Audio.update(visible=True), gr.Radio.update(visible=True)
def get_response_by_self(open_ai_key, msg, bot, doc_text_list, doc_embeddings): # 获取机器人回复
now_len = len(msg) # 当前输入的长度
his_bg = -1 # 历史记录的起始位置
for i in range(len(bot) - 1, -1, -1): # 从后往前遍历历史记录
if now_len + len(bot[i][0]) + len(bot[i][1]) > history_max_len: # 如果超过了历史记录的最大长度,就不再加入
break
now_len += len(bot[i][0]) + len(bot[i][1]) # 更新当前长度
his_bg = i # 更新历史记录的起始位置
history = [] if his_bg == -1 else bot[his_bg:] # 获取历史记录
query_embedding = get_text_emb(open_ai_key, msg) # 获取输入的向量
cos_scores = [] # 用于存储相似度
def cos_sim(a, b): # 计算余弦相似度
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) # 返回相似度
for doc_embedding in doc_embeddings: # 遍历文档向量
cos_scores.append(cos_sim(query_embedding, doc_embedding)) # 计算相似度
score_index = [] # 用于存储相似度和索引对应
for i in range(len(cos_scores)): # 遍历相似度
score_index.append((cos_scores[i], i)) # 加入相似度和索引对应
score_index.sort(key=lambda x: x[0], reverse=True) # 按相似度排序
print('score_index:\n', score_index)
index_set, sub_doc_list = set(), [] # 用于存储最终的索引和文档
for s_i in score_index: # 遍历相似度和索引对应
doc = doc_text_list[s_i[1]] # 获取文档
if now_len + len(doc) > all_max_len: # 如果超过了最大长度,就不再加入
break
index_set.add(s_i[1]) # 加入索引
now_len += len(doc) # 更新当前长度
# 可能段落截断错误,所以把上下段也加入进来
if s_i[1] > 0 and s_i[1] - 1 not in index_set: # 如果上一段没有加入
doc = doc_text_list[s_i[1] - 1] # 获取上一段
if now_len + len(doc) > all_max_len: # 如果超过了最大长度,就不再加入
break
index_set.add(s_i[1] - 1) # 加入索引
now_len += len(doc) # 更新当前长度
if s_i[1] + 1 < len(doc_text_list) and s_i[1] + 1 not in index_set: # 如果下一段没有加入
doc = doc_text_list[s_i[1] + 1] # 获取下一段
if now_len + len(doc) > all_max_len: # 如果超过了最大长度,就不再加入
break
index_set.add(s_i[1] + 1) # 加入索引
now_len += len(doc) # 更新当前长度
index_list = list(index_set) # 转换成list
index_list.sort() # 排序
for i in index_list: # 遍历索引
sub_doc_list.append(doc_text_list[i]) # 加入文档
document = '' if len(sub_doc_list) == 0 else '\n'.join(sub_doc_list) # 拼接文档
messages = [{
"role": "system",
"content": "你是一个有用的助手,可以使用文章内容准确地回答问题。使用提供的文章来生成你的答案,但避免逐字复制文章。尽可能使用自己的话。准确、有用、简洁、清晰。"
}, {"role": "system", "content": "文章内容:\n" + document}] # 角色人物定义
for his in history: # 遍历历史记录
messages.append({"role": "user", "content": his[0]}) # 加入用户的历史记录
messages.append({"role": "assistant", "content": his[1]}) # 加入机器人的历史记录
messages.append({"role": "user", "content": msg}) # 加入用户的当前输入
openai.api_key = open_ai_key
chat_completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages) # 获取机器人的回复
res = chat_completion.choices[0].message.content # 获取机器人的回复
bot.append([msg, res]) # 加入历史记录
return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
def get_response_by_llama_index(open_ai_key, msg, bot, query_engine): # 获取机器人回复
openai.api_key = open_ai_key
query_str = "历史对话如下:\n"
for his in bot: # 遍历历史记录
query_str += "用户:" + his[0] + "\n" # 加入用户的历史记录
query_str += "助手:" + his[1] + "\n" # 加入机器人的历史记录
query_str += "用户:" + msg + "\n" # 加入用户的当前输入
res = query_engine.query(query_str) # 获取回答
print(res) # 显示回答
bot.append([msg, str(res)]) # 加入历史记录
return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
def get_audio_answer(bot): # 获取语音回答
answer = bot[-1][1]
inputs = bark_processor(
text=[answer],
return_tensors="pt",
)
speech_values = bark_model.generate(**inputs, do_sample=True)
au_dir = hashlib.md5(answer.encode('utf-8')).hexdigest() + '.wav' # 获取md5
scipy.io.wavfile.write(au_dir, rate=sampling_rate, data=speech_values.cpu().numpy().squeeze())
return gr.Audio().update(au_dir, autoplay=True)
def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings, query_engine, index_type): # 获取机器人回复
if index_type == 1: # 如果是使用自己的索引
bot = get_response_by_self(open_ai_key, msg, bot, doc_text_list, doc_embeddings)
else: # 如果是使用llama_index索引
bot = get_response_by_llama_index(open_ai_key, msg, bot, query_engine)
return bot
def up_file(files): # 上传文件
doc_text_list = [] # 用于存储文档
for idx, file in enumerate(files): # 遍历文件
print(file.name)
with pdfplumber.open(file.name) as pdf: # 打开pdf
for i in range(len(pdf.pages)): # 遍历pdf的每一页
# 读取PDF文档第i+1页
page = pdf.pages[i]
res_list = page.extract_text().split('\n')[:-1] # 提取文本
for j in range(len(page.images)): # 遍历图片
# 获取图片的二进制流
img = page.images[j]
file_name = '{}-{}-{}.png'.format(str(time.time()), str(i), str(j)) # 生成文件名
with open(file_name, mode='wb') as f: # 保存图片
f.write(img['stream'].get_data())
try:
res = ocr.ocr(file_name) # 识别图片
except Exception as e:
res = [] # 识别失败
if len(res) > 0: # 如果识别成功
res_list.append(' '.join([re['text'] for re in res])) # 加入识别结果
tables = page.extract_tables() # 提取表格
for table in tables: # 遍历表格
# 第一列当成表头:
df = pd.DataFrame(table[1:], columns=table[0])
try:
records = json.loads(df.to_json(orient="records", force_ascii=False)) # 转换成json
for rec in records: # 遍历json
res_list.append(json.dumps(rec, ensure_ascii=False)) # 加入json
except Exception as e:
res_list.append(str(df)) # 如果转换识别,直接把表格转为str
doc_text_list += res_list # 加入文档
doc_text_list = [str(text).strip() for text in doc_text_list if len(str(text).strip()) > 0] # 去除空格
print(doc_text_list)
return gr.Textbox.update(value='\n'.join(doc_text_list), visible=True), gr.Button.update(
visible=True), gr.Button.update(
visible=True), gr.Markdown.update(
value="操作说明 step 2:确认PDF解析结果(可修正),点击“建立索引”,随后进行对话")
def transcribe_speech_by_self(filepath):
output = asr_pipe(
filepath,
max_new_tokens=256,
generate_kwargs={
"task": "transcribe",
"language": "chinese",
},
chunk_length_s=30,
batch_size=8,
) # 识别语音
simplified_text = converter.convert(output["text"]) # 转换为简体字
return simplified_text
def transcribe_speech_by_openai(openai_key, filepath):
openai.api_key = openai_key # 设置OpenAI API Key
audio_file = open(filepath, "rb")
transcript = openai.Audio.transcribe("whisper-1", audio_file)
print(transcript)
return transcript['text']
def transcribe_speech(openai_key, filepath, a_type):
if a_type == 'self':
return transcribe_speech_by_self(filepath)
else:
return transcribe_speech_by_openai(openai_key, filepath)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
open_ai_key = gr.Textbox(label='OpenAI API Key', placeholder='输入你的OpenAI API Key') # 你的OpenAI API Key
file = gr.File(file_types=['.pdf'], label='点击上传PDF,进行解析(支持多文档、表格、OCR)',
file_count='multiple') # 支持多文档、表格、OCR
txt = gr.Textbox(label='PDF解析结果', visible=False) # PDF解析结果
with gr.Row():
index_llama_bu = gr.Button(value='建立索引(by llama_index)', visible=False) # 建立索引(by llama_index)
index_self_bu = gr.Button(value='建立索引(by self)', visible=False) # 建立索引(by self)
doc_text_state = gr.State([]) # 存储PDF解析结果
doc_emb_state = gr.State([]) # 存储PDF解析结果的embedding
query_engine = gr.State([]) # 存储查询引擎
index_type = gr.State([]) # 存储索引类型
with gr.Column():
md = gr.Markdown("""操作说明 step 1:点击左侧区域,上传PDF,进行解析""") # 操作说明
chat_bot = gr.Chatbot(visible=False) # 聊天机器人
audio_answer = gr.Audio() # 语音回答
with gr.Row():
asr_type = gr.Radio(value='self', choices=['self', 'openai'], label='语音识别方式', visible=False) # 语音识别方式
audio_inputs = gr.Audio(source="microphone", type="filepath", label="点击录音输入", visible=False) # 录音输入
msg_txt = gr.Textbox(label='消息框', placeholder='输入消息', visible=False) # 消息框
chat_bu = gr.Button(value='发送', visible=False) # 发送按钮
file.change(up_file, [file], [txt, index_self_bu, index_llama_bu, md]) # 上传文件
index_self_bu.click(doc_index_self, [open_ai_key, txt],
[doc_text_state, doc_emb_state, msg_txt, chat_bu, md, chat_bot, index_type,
audio_inputs, asr_type]) # 提交解析结果
index_llama_bu.click(doc_index_llama, [open_ai_key, txt],
[query_engine, msg_txt, chat_bu, md, chat_bot, index_type, audio_inputs, asr_type]) # 提交解析结果
audio_inputs.change(transcribe_speech, [open_ai_key, audio_inputs, asr_type], [msg_txt]) # 录音输入
chat_bu.click(get_response,
[open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state, query_engine, index_type],
[chat_bot])# .then(get_audio_answer, [chat_bot], [audio_answer]) # 发送消息
if __name__ == "__main__":
demo.queue(concurrency_count=4).launch()