Spaces:
Runtime error
Runtime error
import io | |
import os | |
import json | |
import logging | |
import secrets | |
import gradio as gr | |
import numpy as np | |
import openai | |
import pandas as pd | |
from google.oauth2.service_account import Credentials | |
from googleapiclient.discovery import build | |
from googleapiclient.http import MediaIoBaseDownload, MediaFileUpload | |
from openai.embeddings_utils import distances_from_embeddings | |
from .gpt_processor import QuestionAnswerer | |
from .work_flow_controller import WorkFlowController | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
openai.api_key = OPENAI_API_KEY | |
class Chatbot: | |
def __init__(self): | |
self.history = [] | |
self.upload_state = "waiting" | |
self.uid = self.__generate_uid() | |
self.g_drive_service = self.__init_drive_service() | |
self.knowledge_base = None | |
self.context = None | |
self.context_page_num = None | |
self.context_file_name = None | |
def build_knowledge_base(self, files, upload_mode="once"): | |
work_flow_controller = WorkFlowController(files, self.uid) | |
self.csv_result_path = work_flow_controller.csv_result_path | |
self.json_result_path = work_flow_controller.json_result_path | |
if upload_mode == "Upload to Database": | |
self.__get_db_knowledge_base() | |
else: | |
self.__get_local_knowledge_base() | |
def __get_db_knowledge_base(self): | |
filename = "knowledge_base.csv" | |
db = self.__read_db(self.g_drive_service) | |
cur_content = pd.read_csv(self.csv_result_path) | |
for _ in range(10): | |
try: | |
self.__write_into_db(self.g_drive_service, db, cur_content) | |
break | |
except Exception as e: | |
logging.error(e) | |
logging.error("Failed to upload to database, retrying...") | |
continue | |
self.knowledge_base = db | |
self.upload_state = "done" | |
def __get_local_knowledge_base(self): | |
with open(self.csv_result_path, "r", encoding="UTF-8") as fp: | |
knowledge_base = pd.read_csv(fp) | |
knowledge_base["page_embedding"] = ( | |
knowledge_base["page_embedding"].apply(eval).apply(np.array) | |
) | |
self.knowledge_base = knowledge_base | |
self.upload_state = "done" | |
def __write_into_db(self, service, db: pd.DataFrame, cur_content: pd.DataFrame): | |
db = pd.concat([db, cur_content], ignore_index=True) | |
db.to_csv(f"{self.uid}_knowledge_base.csv", index=False) | |
media = MediaFileUpload(f"{self.uid}_knowledge_base.csv", resumable=True) | |
request = ( | |
service.files() | |
.update(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW", media_body=media) | |
.execute() | |
) | |
def __init_drive_service(self): | |
SCOPES = ["https://www.googleapis.com/auth/drive"] | |
SERVICE_ACCOUNT_INFO = os.getenv("CREDENTIALS") | |
service_account_info_dict = json.loads(SERVICE_ACCOUNT_INFO) | |
creds = Credentials.from_service_account_info( | |
service_account_info_dict, scopes=SCOPES | |
) | |
return build("drive", "v3", credentials=creds) | |
def __read_db(self, service): | |
request = service.files().get_media(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW") | |
fh = io.BytesIO() | |
downloader = MediaIoBaseDownload(fh, request) | |
done = False | |
while done is False: | |
status, done = downloader.next_chunk() | |
print(f"Download {int(status.progress() * 100)}%.") | |
fh.seek(0) | |
return pd.read_csv(fh) | |
def __read_file(self, service, filename) -> pd.DataFrame: | |
query = f"name='{filename}'" | |
results = service.files().list(q=query).execute() | |
files = results.get("files", []) | |
file_id = files[0]["id"] | |
request = service.files().get_media(fileId=file_id) | |
fh = io.BytesIO() | |
downloader = MediaIoBaseDownload(fh, request) | |
done = False | |
while done is False: | |
status, done = downloader.next_chunk() | |
print(f"Download {int(status.progress() * 100)}%.") | |
fh.seek(0) | |
return pd.read_csv(fh) | |
def __upload_file(self, service): | |
results = service.files().list(pageSize=10).execute() | |
items = results.get("files", []) | |
if not items: | |
print("No files found.") | |
else: | |
print("Files:") | |
for item in items: | |
print(f"{item['name']} ({item['id']})") | |
media = MediaFileUpload(self.csv_result_path, resumable=True) | |
filename_prefix = "ex_bot_database_" | |
filename = filename_prefix + self.uid + ".csv" | |
request = ( | |
service.files() | |
.create( | |
media_body=media, | |
body={ | |
"name": filename, | |
"parents": [ | |
"1Lp21EZlVlqL-c27VQBC6wTbUC1YpKMsG" | |
], | |
}, | |
) | |
.execute() | |
) | |
def clear_state(self): | |
self.context = None | |
self.context_page_num = None | |
self.context_file_name = None | |
self.knowledge_base = None | |
self.upload_state = "waiting" | |
self.history = [] | |
def send_system_notification(self): | |
if self.upload_state == "waiting": | |
conversation = [["已上傳文件", "文件處理中(摘要、翻譯等),結束後將自動回覆"]] | |
return conversation | |
elif self.upload_state == "done": | |
conversation = [["已上傳文件", "文件處理完成,請開始提問"]] | |
return conversation | |
def change_md(self): | |
content = self.__construct_summary() | |
return gr.Markdown.update(content, visible=True) | |
def __construct_summary(self): | |
with open(self.json_result_path, "r", encoding="UTF-8") as fp: | |
knowledge_base = json.load(fp) | |
context = "" | |
for key in knowledge_base.keys(): | |
file_name = knowledge_base[key]["file_name"] | |
total_page = knowledge_base[key]["total_pages"] | |
summary = knowledge_base[key]["summarized_content"] | |
file_context = f""" | |
### 文件摘要 | |
{file_name} (共 {total_page} 頁)<br><br> | |
{summary}<br><br> | |
""" | |
context += file_context | |
return context | |
def user(self, message): | |
self.history += [[message, None]] | |
return "", self.history | |
def bot(self): | |
user_message = self.history[-1][0] | |
print(f"user_message: {user_message}") | |
if self.knowledge_base is None: | |
response = [ | |
[user_message, "請先上傳文件"], | |
] | |
self.history = response | |
return self.history | |
else: | |
self.__get_index_file(user_message) | |
if self.context is None: | |
response = [ | |
[user_message, "無法找到相關文件,請重新提問"], | |
] | |
self.history = response | |
return self.history | |
else: | |
qa_processor = QuestionAnswerer() | |
bot_message = qa_processor.answer_question( | |
self.context, | |
self.context_page_num, | |
self.context_file_name, | |
self.history, | |
) | |
print(f"bot_message: {bot_message}") | |
response = [ | |
[user_message, bot_message], | |
] | |
self.history[-1] = response[0] | |
return self.history | |
def __get_index_file(self, user_message): | |
user_message_embedding = openai.Embedding.create( | |
input=user_message, engine="text-embedding-ada-002" | |
)["data"][0]["embedding"] | |
self.knowledge_base["distance"] = distances_from_embeddings( | |
user_message_embedding, | |
self.knowledge_base["page_embedding"].values, | |
distance_metric="cosine", | |
) | |
self.knowledge_base = self.knowledge_base.sort_values( | |
by="distance", ascending=True | |
) | |
if self.knowledge_base["distance"].values[0] > 0.2: | |
self.context = None | |
else: | |
self.context = self.knowledge_base["page_content"].values[0] | |
self.context_page_num = self.knowledge_base["page_num"].values[0] | |
self.context_file_name = self.knowledge_base["file_name"].values[0] | |
def __generate_uid(self): | |
return secrets.token_hex(8) |