ChenyuRabbitLove's picture
add upload feature and optimize user experience
a2f42ca
raw
history blame
No virus
6.99 kB
import json
import time
import random
import os
import openai
import gradio as gr
import pandas as pd
import numpy as np
from openai.embeddings_utils import distances_from_embeddings
from utils.gpt_processor import QuestionAnswerer
from utils.work_flow_controller import WorkFlowController
qa_processor = QuestionAnswerer()
CSV_FILE_PATHS = ''
JSON_FILE_PATHS = ''
KNOWLEDGE_BASE = None
CONTEXT = None
CONTEXT_PAGE_NUM = None
CONTEXT_FILE_NAME = None
def build_knowledge_base(files):
global CSV_FILE_PATHS
global JSON_FILE_PATHS
global KNOWLEDGE_BASE
work_flow_controller = WorkFlowController(files)
CSV_FILE_PATHS = work_flow_controller.csv_result_path
JSON_FILE_PATHS = work_flow_controller.result_path
with open(CSV_FILE_PATHS, 'r', encoding='UTF-8') as fp:
knowledge_base = pd.read_csv(fp)
knowledge_base['page_embedding'] = knowledge_base['page_embedding'].apply(eval).apply(np.array)
KNOWLEDGE_BASE = knowledge_base
def construct_summary():
with open(JSON_FILE_PATHS, 'r', encoding='UTF-8') as fp:
knowledge_base = json.load(fp)
context = """"""
for key in knowledge_base.keys():
file_name = knowledge_base[key]['file_name']
total_page = knowledge_base[key]['total_pages']
summary = knowledge_base[key]['summarized_content']
file_context = f"""
### 文件摘要
{file_name} (共 {total_page} 頁)<br><br>
{summary}<br><br>
"""
context += file_context
return context
def change_md():
content = construct_summary()
return gr.Markdown.update(content, visible=True)
def user(message, history):
return "", history + [[message, None]]
def system_notification(action):
if action == 'upload':
return [['已上傳文件', '文件處理中(摘要、翻譯等),結束後將自動回覆']]
else:
return [['已上傳文件', '文件處理完成,請開始提問']]
def get_index_file(user_message):
global KNOWLEDGE_BASE
global CONTEXT
global CONTEXT_PAGE_NUM
global CONTEXT_FILE_NAME
user_message_embedding = openai.Embedding.create(input=user_message, engine='text-embedding-ada-002')['data'][0]['embedding']
KNOWLEDGE_BASE['distance'] = distances_from_embeddings(user_message_embedding, KNOWLEDGE_BASE['page_embedding'].values, distance_metric='cosine')
KNOWLEDGE_BASE = KNOWLEDGE_BASE.sort_values(by='distance', ascending=True).head(1)
if KNOWLEDGE_BASE['distance'].values[0] > 0.2:
CONTEXT = None
else:
CONTEXT = KNOWLEDGE_BASE['page_content'].values[0]
CONTEXT_PAGE_NUM = KNOWLEDGE_BASE['page_num'].values[0]
CONTEXT_FILE_NAME = KNOWLEDGE_BASE['file_name'].values[0]
def bot(history):
user_message = history[-1][0]
global CONTEXT
print(f'user_message: {user_message}')
if KNOWLEDGE_BASE is None:
response = [
[user_message, "請先上傳文件"],
]
history = response
return history
elif CONTEXT is None:
get_index_file(user_message)
print(f'CONTEXT: {CONTEXT}')
if CONTEXT is None:
response = [
[user_message, "無法找到相關文件,請重新提問"],
]
history = response
return history
else:
pass
if CONTEXT is not None:
bot_message = qa_processor.answer_question(CONTEXT, CONTEXT_PAGE_NUM, CONTEXT_FILE_NAME, history)
print(f'bot_message: {bot_message}')
response = [
[user_message, bot_message],
]
history[-1] = response[0]
return history
def clear_state():
global CONTEXT
global CONTEXT_PAGE_NUM
global CONTEXT_FILE_NAME
CONTEXT = None
CONTEXT_PAGE_NUM = None
CONTEXT_FILE_NAME = None
with gr.Blocks() as demo:
history = gr.State([])
upload_state = gr.State("upload")
finished = gr.State("finished")
user_question = gr.State("")
with gr.Row():
gr.HTML('Junyi Academy Chatbot')
#status_display = gr.Markdown("Success", elem_id="status_display")
with gr.Row(equal_height=True):
with gr.Column(scale=5):
with gr.Row():
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=12):
user_input = gr.Textbox(
show_label=False,
placeholder="Enter text",
container=False,
)
# with gr.Column(min_width=70, scale=1):
# submit_btn = gr.Button("Send")
with gr.Column(min_width=70, scale=1):
clear_btn = gr.Button("清除")
with gr.Column(min_width=70, scale=1):
submit_btn = gr.Button("傳送")
response = user_input.submit(user,
[user_input, chatbot],
[user_input, chatbot],
queue=False,
).then(bot, chatbot, chatbot)
response.then(lambda: gr.update(interactive=True), None, [user_input], queue=False)
clear_btn.click(lambda: None, None, chatbot, queue=False)
submit_btn.click(user,
[user_input, chatbot],
[user_input, chatbot],
chatbot,
queue=False).then(bot, chatbot, chatbot).then(lambda: gr.update(interactive=True), None, [user_input], queue=False)
clear_btn.click(clear_state, None, None, queue=False)
with gr.Row():
index_file = gr.File(file_count="multiple", file_types=["pdf"], label="Upload PDF file")
with gr.Row():
instruction = gr.Markdown("""
## 使用說明
1. 上傳一個或多個 PDF 檔案,系統將自動進行摘要、翻譯等處理後建立知識庫
2. 在上方輸入欄輸入問題,系統將自動回覆
3. 可以根據下方的摘要內容來提問
4. 每次對話會根據第一個問題的內容來檢索所有文件,並挑選最能回答問題的文件來回覆
5. 要切換檢索的文件,請點選「清除對話記錄」按鈕後再重新提問
""")
with gr.Row():
describe = gr.Markdown('', visible=True)
index_file.upload(system_notification, [upload_state], chatbot) \
.then(lambda: gr.update(interactive=True), None, None, queue=False) \
.then(build_knowledge_base, [index_file]) \
.then(system_notification, [finished], chatbot) \
.then(lambda: gr.update(interactive=True), None, None, queue=False) \
.then(change_md, None, describe)
if __name__ == "__main__":
demo.launch()