Spaces:
Running
Running
import requests | |
import json | |
import gradio as gr | |
from concurrent.futures import ThreadPoolExecutor | |
from sentence_transformers import util | |
url = 'https://souljoy-my-api.hf.space/qa_maker' | |
headers = { | |
'Content-Type': 'application/json', | |
} | |
thread_pool_executor = ThreadPoolExecutor(max_workers=16) | |
history_max_len = 500 | |
all_max_len = 2000 | |
def get_emb(text): | |
emb_url = 'https://souljoy-my-api.hf.space/embeddings' | |
data = {"content": text} | |
result = requests.post(url=emb_url, | |
data=json.dumps(data), | |
headers=headers | |
) | |
return result.json()['data'][0]['embedding'] | |
def doc_emb(doc: str): | |
texts = doc.split('\n') | |
futures = [] | |
for text in texts: | |
futures.append(thread_pool_executor.submit(get_emb, text)) | |
emb_list = [] | |
for f in futures: | |
emb_list.append(f.result()) | |
print('\n'.join(texts)) | |
return texts, emb_list, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update( | |
visible=True) | |
def get_response(msg, bot, doc_text_list, doc_embeddings): | |
future = thread_pool_executor.submit(get_emb, msg) | |
now_len = len(msg) | |
req_json = {'question': msg} | |
his_bg = -1 | |
for i in range(len(bot) - 1, -1, -1): | |
if now_len + len(bot[i][0]) + len(bot[i][1]) > history_max_len: | |
break | |
now_len += len(bot[i][0]) + len(bot[i][1]) | |
his_bg = i | |
req_json['history'] = [] if his_bg == -1 else bot[his_bg:] | |
query_embedding = future.result() | |
cos_scores = util.cos_sim(query_embedding, doc_embeddings)[0] | |
score_index = [[score, index] for score, index in zip(cos_scores, [i for i in range(len(cos_scores))])] | |
score_index.sort(key=lambda x: x[0], reverse=True) | |
print('score_index:\n', score_index) | |
index_list, sub_doc_list = [], [] | |
for s_i in score_index: | |
doc = doc_text_list[s_i[1]] | |
if now_len + len(doc) > all_max_len: | |
break | |
index_list.append(s_i[1]) | |
now_len += len(doc) | |
index_list.sort() | |
for i in index_list: | |
sub_doc_list.append(doc_text_list[i]) | |
req_json['doc'] = '' if len(sub_doc_list) == 0 else '\n'.join(sub_doc_list) | |
data = {"content": json.dumps(req_json)} | |
print('data:\n', req_json) | |
result = requests.post(url='https://souljoy-my-api.hf.space/chatpdf', | |
data=json.dumps(data), | |
headers=headers | |
) | |
res = result.json()['content'] | |
bot.append([msg, res]) | |
return bot[max(0, len(bot) - 3):], gr.Markdown.update(visible=False) | |
def up_file(files): | |
for idx, file in enumerate(files): | |
print(file.name) | |
return gr.Button.update(visible=True) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
file = gr.File(file_types=['.pdf'], label='上传PDF') | |
txt = gr.Textbox(label='PDF解析结果', visible=False) | |
doc_bu = gr.Button(value='提交', visible=False) | |
md = gr.Markdown("""#### 文档提交成功 🙋 """, visible=False) | |
doc_text_state = gr.State([]) | |
doc_emb_state = gr.State([]) | |
with gr.Column(): | |
chat_bot = gr.Chatbot() | |
msg_txt = gr.Textbox(label='消息框', placeholder='输入消息,点击发送', visible=False) | |
chat_bu = gr.Button(value='发送', visible=False) | |
doc_bu.click(doc_emb, [txt], [doc_text_state, doc_emb_state, msg_txt, chat_bu, md]) | |
chat_bu.click(get_response, [msg_txt, chat_bot, doc_text_state, doc_emb_state], [chat_bot, md]) | |
file.change(up_file, [file], [doc_bu]) | |
if __name__ == "__main__": | |
demo.queue().launch() | |
# demo.queue().launch(share=False, server_name='172.22.2.54', server_port=9191) | |