ChatPDF / app.py
souljoy's picture
Update app.py
ce756cd
raw
history blame
No virus
3.8 kB
import requests
import json
import gradio as gr
from concurrent.futures import ThreadPoolExecutor
from sentence_transformers import util
url = 'https://souljoy-my-api.hf.space/qa_maker'
headers = {
'Content-Type': 'application/json',
}
thread_pool_executor = ThreadPoolExecutor(max_workers=16)
history_max_len = 500
all_max_len = 2000
def get_emb(text):
emb_url = 'https://souljoy-my-api.hf.space/embeddings'
data = {"content": text}
result = requests.post(url=emb_url,
data=json.dumps(data),
headers=headers
)
return result.json()['data'][0]['embedding']
def doc_emb(doc: str):
texts = doc.split('\n')
futures = []
for text in texts:
futures.append(thread_pool_executor.submit(get_emb, text))
emb_list = []
for f in futures:
emb_list.append(f.result())
print('\n'.join(texts))
return texts, emb_list, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(
visible=True)
def get_response(msg, bot, doc_text_list, doc_embeddings):
future = thread_pool_executor.submit(get_emb, msg)
now_len = len(msg)
req_json = {'question': msg}
his_bg = -1
for i in range(len(bot) - 1, -1, -1):
if now_len + len(bot[i][0]) + len(bot[i][1]) > history_max_len:
break
now_len += len(bot[i][0]) + len(bot[i][1])
his_bg = i
req_json['history'] = [] if his_bg == -1 else bot[his_bg:]
query_embedding = future.result()
cos_scores = util.cos_sim(query_embedding, doc_embeddings)[0]
score_index = [[score, index] for score, index in zip(cos_scores, [i for i in range(len(cos_scores))])]
score_index.sort(key=lambda x: x[0], reverse=True)
print('score_index:\n', score_index)
index_list, sub_doc_list = [], []
for s_i in score_index:
doc = doc_text_list[s_i[1]]
if now_len + len(doc) > all_max_len:
break
index_list.append(s_i[1])
now_len += len(doc)
index_list.sort()
for i in index_list:
sub_doc_list.append(doc_text_list[i])
req_json['doc'] = '' if len(sub_doc_list) == 0 else '\n'.join(sub_doc_list)
data = {"content": json.dumps(req_json)}
print('data:\n', req_json)
result = requests.post(url='https://souljoy-my-api.hf.space/chatpdf',
data=json.dumps(data),
headers=headers
)
res = result.json()['content']
bot.append([msg, res])
return bot[max(0, len(bot) - 3):], gr.Markdown.update(visible=False)
def up_file(files):
for idx, file in enumerate(files):
print(file.name)
return gr.Button.update(visible=True)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
file = gr.File(file_types=['.pdf'], label='上传PDF')
txt = gr.Textbox(label='PDF解析结果', visible=False)
doc_bu = gr.Button(value='提交', visible=False)
md = gr.Markdown("""#### 文档提交成功 🙋 """, visible=False)
doc_text_state = gr.State([])
doc_emb_state = gr.State([])
with gr.Column():
chat_bot = gr.Chatbot()
msg_txt = gr.Textbox(label='消息框', placeholder='输入消息,点击发送', visible=False)
chat_bu = gr.Button(value='发送', visible=False)
doc_bu.click(doc_emb, [txt], [doc_text_state, doc_emb_state, msg_txt, chat_bu, md])
chat_bu.click(get_response, [msg_txt, chat_bot, doc_text_state, doc_emb_state], [chat_bot, md])
file.change(up_file, [file], [doc_bu])
if __name__ == "__main__":
demo.queue().launch()
# demo.queue().launch(share=False, server_name='172.22.2.54', server_port=9191)