Spaces:
Sleeping
Sleeping
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://huggingface.co/spaces/fffiloni/langchain-chat-with-pdf-openai | |
""" | |
import argparse | |
import json | |
import logging | |
import time | |
from typing import List, Tuple | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format="%(asctime)s %(levelname)s %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
import gradio as gr | |
import openai | |
from openai import OpenAI | |
from threading import Thread | |
import _queue | |
from queue import Queue | |
import project_settings as settings | |
from project_settings import project_path | |
logger = logging.getLogger(__name__) | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--openai_api_key", | |
default=settings.environment.get("openai_api_key", default=None, dtype=str), | |
type=str | |
) | |
args = parser.parse_args() | |
return args | |
def greet(question: str, history: List[Tuple[str, str]]): | |
answer = "Hello " + question + "!" | |
result = history + [(question, answer)] | |
return result | |
def click_create_assistant(openai_api_key: str, | |
name: str, | |
instructions: str, | |
description: str, | |
tools: str, | |
files: List[str], | |
file_ids: str, | |
model: str, | |
): | |
logger.info("click create assistant, name: {}".format(name)) | |
client = OpenAI( | |
api_key=openai_api_key, | |
) | |
# tools | |
tools = str(tools).strip() | |
if tools is not None and len(tools) != 0: | |
tools = tools.split("\n") | |
tools = [json.loads(tool) for tool in tools if len(tool.strip()) != 0] | |
else: | |
tools = list() | |
# files | |
if files is not None and len(files) != 0: | |
files = [ | |
client.files.create( | |
file=open(file, "rb"), | |
purpose='assistants' | |
) for file in files | |
] | |
else: | |
files = list() | |
# file_ids | |
file_ids = str(file_ids).strip() | |
if file_ids is not None and len(file_ids) != 0: | |
file_ids = file_ids.split("\n") | |
file_ids = [file_id.strip() for file_id in file_ids if len(file_id.strip()) != 0] | |
else: | |
file_ids = list() | |
# assistant | |
assistant = client.beta.assistants.create( | |
name=name, | |
instructions=instructions, | |
description=description, | |
tools=tools, | |
file_ids=file_ids + [file.id for file in files], | |
model=model, | |
) | |
assistant_id = assistant.id | |
# thread | |
thread = client.beta.threads.create() | |
thread_id = thread.id | |
return assistant_id, thread_id | |
def click_list_assistant(openai_api_key: str) -> str: | |
client = OpenAI( | |
api_key=openai_api_key, | |
) | |
assistant_list = client.beta.assistants.list() | |
assistant_list.model_dump(mode="json") | |
result = "" | |
for a in assistant_list["data"]: | |
assis = "id: \n{}\nname: \n{}\ndescription: \n{}\n\n".format(a["id"], a["name"], a["description"]) | |
result += assis | |
return result | |
def click_delete_assistant(openai_api_key: str, | |
assistant_id: str) -> str: | |
logger.info("click delete assistant, assistant_id: {}".format(assistant_id)) | |
client = OpenAI( | |
api_key=openai_api_key, | |
) | |
try: | |
assistant_deleted = client.beta.assistants.delete(assistant_id=assistant_id) | |
result = "success" if assistant_deleted.deleted else "failed" | |
except openai.NotFoundError as e: | |
result = e.message | |
return result | |
def click_list_file(openai_api_key: str): | |
client = OpenAI( | |
api_key=openai_api_key, | |
) | |
file_list = client.files.list() | |
file_list = file_list.model_dump(mode="json") | |
result = "" | |
for f in file_list["data"]: | |
file = "id: \n{}\nfilename: \n{}\nbytes: \n{}\nstatus: \n{}\n\n".format( | |
f["id"], f["filename"], f["bytes"], f["status"] | |
) | |
result += file | |
return result | |
def click_delete_file(openai_api_key: str, | |
file_id: str) -> str: | |
logger.info("click delete file, file_id: {}".format(file_id)) | |
client = OpenAI( | |
api_key=openai_api_key, | |
) | |
try: | |
assistant_deleted = client.files.delete(file_id=file_id) | |
result = "success" if assistant_deleted.deleted else "failed" | |
except openai.NotFoundError as e: | |
result = e.message | |
return result | |
def click_upload_files(openai_api_key: str, | |
files: List[str], | |
): | |
logger.info("click upload files, files: {}".format(files)) | |
client = OpenAI( | |
api_key=openai_api_key, | |
) | |
result = list() | |
if files is not None and len(files) != 0: | |
files = [ | |
client.files.create( | |
file=open(file, "rb"), | |
purpose='assistants' | |
) for file in files | |
] | |
file_ids = [file.id for file in files] | |
result.extend(file_ids) | |
return result | |
def get_message_list(client: OpenAI, thread_id: str): | |
""" | |
SyncCursorPage[ThreadMessage]( | |
data=[ | |
ThreadMessage( | |
id='msg_kb0f2fyDC6OwMyXxKbUpcuBS', | |
assistant_id='asst_DzVVZkE0dIGe0gsOdsdn3A0w', | |
content=[ | |
MessageContentText( | |
text=Text( | |
annotations=[ | |
TextAnnotationFileCitation( | |
end_index=44, | |
file_citation=TextAnnotationFileCitationFileCitation( | |
file_id='file-IwzwXQkixMu7fvgGoC1alIWu', | |
quote='念刘备、关羽、张飞,虽然异姓,既结为兄弟,则同心协力,救困扶危;上报国家,下安黎庶。不求同年同月同日生,只愿同年同月同日死。皇天后土,实鉴此心,背义忘恩,天人共戮!”誓毕,拜玄德为兄,关羽次之,张飞为弟' | |
), | |
start_index=34, | |
text='【7†source】', | |
type='file_citation' | |
) | |
], | |
value='刘备和张飞虽然是异姓,但他们结为了兄弟,其中刘备被拜为兄,而张飞为弟【7†source】。' | |
), | |
type='text' | |
) | |
], | |
created_at=1699493845, | |
file_ids=[], | |
metadata={}, | |
object='thread.message', | |
role='assistant', | |
run_id='run_zJYZX0KFEvEh2VG5x5zSLq9s', | |
thread_id='thread_3JWRdjvZDJTBgZ0tlrrKXnrt' | |
), | |
ThreadMessage( | |
id='msg_tc5Tit7q19S5TSgvmBauME3H', | |
assistant_id=None, | |
content=[ | |
MessageContentText( | |
text=Text( | |
annotations=[], | |
value='刘备和张飞是什么关系。' | |
), | |
type='text' | |
) | |
], | |
created_at=1699493838, | |
file_ids=[], | |
metadata={}, | |
object='thread.message', | |
role='user', | |
run_id=None, | |
thread_id='thread_3JWRdjvZDJTBgZ0tlrrKXnrt' | |
) | |
], | |
object='list', | |
first_id='msg_kb0f2fyDC6OwMyXxKbUpcuBS', | |
last_id='msg_tc5Tit7q19S5TSgvmBauME3H', | |
has_more=False | |
) | |
""" | |
messages = client.beta.threads.messages.list( | |
thread_id=thread_id | |
) | |
# print(messages) | |
result = list() | |
for message in messages.data: | |
content = list() | |
for msg in message.content: | |
annotations = list() | |
for annotation in msg.text.annotations: | |
a = { | |
"start_index": annotation.start_index, | |
"end_index": annotation.end_index, | |
"text": annotation.text, | |
"type": annotation.type, | |
} | |
if annotation.type == "file_citation": | |
a["file_citation"] = { | |
"file_id": annotation.file_citation.file_id, | |
"quote": annotation.file_citation.quote, | |
} | |
annotations.append(a) | |
content.append({ | |
"text": { | |
"annotations": annotations, | |
"value": msg.text.value, | |
}, | |
"type": msg.type, | |
}) | |
result.append({ | |
"id": message.id, | |
"assistant_id": message.assistant_id, | |
"content": content, | |
"created_at": message.created_at, | |
"file_ids": message.file_ids, | |
"metadata": message.metadata, | |
"object": message.object, | |
"role": message.role, | |
"run_id": message.run_id, | |
"thread_id": message.thread_id, | |
}) | |
result = list(sorted(result, key=lambda x: x["created_at"])) | |
return result | |
def convert_message_list_to_response(message_list: List[dict]) -> str: | |
response = "" | |
for message in message_list: | |
role = message["role"] | |
content = message["content"] | |
for c in content: | |
if c["type"] != "text": | |
continue | |
text: dict = c["text"] | |
msg = "{}: \n{}\n".format(role, text["value"]) | |
response += msg | |
response += "-" * 80 | |
response += "\n" | |
return response | |
def convert_message_list_to_conversation(message_list: List[dict]) -> List[Tuple[str, str]]: | |
conversation = list() | |
for message in message_list: | |
role = message["role"] | |
content = message["content"] | |
for c in content: | |
c_type = c["type"] | |
if c_type != "text": | |
continue | |
text: dict = c["text"] | |
if c_type == "text": | |
text_value = text["value"] | |
text_annotations = text["annotations"] | |
msg = text_value | |
for text_annotation in text_annotations: | |
a_type = text_annotation["type"] | |
if a_type == "file_citation": | |
msg += "\n\n" | |
msg += "\nquote: \n{}\nfile_id: \n{}".format( | |
text_annotation["file_citation"]["quote"], | |
text_annotation["file_citation"]["file_id"], | |
) | |
else: | |
raise NotImplementedError | |
if role == "assistant": | |
msg = [None, msg] | |
else: | |
msg = [msg, None] | |
conversation.append(msg) | |
return conversation | |
def streaming_refresh(openai_api_key: str, | |
thread_id: str, | |
queue: Queue, | |
): | |
delta_time = 0.3 | |
last_response = None | |
no_updates_count = 0 | |
max_no_updates_count = 5 | |
while True: | |
time.sleep(delta_time) | |
this_response = refresh(openai_api_key, thread_id) | |
if this_response == last_response: | |
no_updates_count += 1 | |
if no_updates_count >= max_no_updates_count: | |
break | |
last_response = this_response | |
queue.put(this_response, block=True, timeout=2) | |
return last_response | |
def refresh(openai_api_key: str, | |
thread_id: str, | |
): | |
client = OpenAI( | |
api_key=openai_api_key, | |
) | |
message_list = get_message_list(client, thread_id=thread_id) | |
print(message_list) | |
logger.info("message_list: {}".format(message_list)) | |
conversation = convert_message_list_to_conversation(message_list) | |
return conversation | |
def add_and_run(openai_api_key: str, | |
assistant_id: str, | |
thread_id: str, | |
name: str, | |
instructions: str, | |
description: str, | |
tools: str, | |
files: List[str], | |
file_ids: str, | |
model: str, | |
query: str, | |
): | |
client = OpenAI( | |
api_key=openai_api_key, | |
) | |
if assistant_id is None or len(assistant_id.strip()) == 0: | |
assistant_id = click_create_assistant( | |
openai_api_key, | |
name, instructions, description, tools, files, file_ids, model | |
) | |
if thread_id is None or len(thread_id.strip()) == 0: | |
thread = client.beta.threads.create() | |
thread_id = thread.id | |
message = client.beta.threads.messages.create( | |
thread_id=thread_id, | |
role="user", | |
content=query | |
) | |
run = client.beta.threads.runs.create( | |
thread_id=thread_id, | |
assistant_id=assistant_id, | |
) | |
run = client.beta.threads.runs.retrieve( | |
thread_id=thread_id, | |
run_id=run.id | |
) | |
response_queue = Queue(maxsize=10) | |
refresh_kwargs = dict( | |
openai_api_key=openai_api_key, | |
thread_id=thread_id, | |
queue=response_queue, | |
) | |
thread = Thread(target=streaming_refresh, kwargs=refresh_kwargs) | |
thread.start() | |
delta_time = 0.1 | |
last_response = None | |
no_updates_count = 0 | |
max_no_updates_count = 10 | |
while True: | |
time.sleep(delta_time) | |
try: | |
this_response = response_queue.get(block=True, timeout=2) | |
except _queue.Empty: | |
break | |
if this_response == last_response: | |
no_updates_count += 1 | |
if no_updates_count >= max_no_updates_count: | |
break | |
last_response = this_response | |
result = [ | |
assistant_id, thread_id, | |
last_response, | |
[] | |
] | |
yield result | |
def main(): | |
args = get_args() | |
gr_description = """ | |
OpenAI Assistant | |
""" | |
# ui | |
with gr.Blocks() as blocks: | |
gr.Markdown(value=gr_description) | |
with gr.Row(): | |
# settings | |
with gr.Column(scale=3): | |
with gr.Tabs(): | |
with gr.TabItem("create assistant"): | |
openai_api_key = gr.Text( | |
value=args.openai_api_key, | |
label="openai_api_key", | |
placeholder="Fill with your `openai_api_key`" | |
) | |
name = gr.Textbox(label="name") | |
instructions = gr.Textbox(label="instructions") | |
description = gr.Textbox(label="description") | |
model = gr.Dropdown(["gpt-4-1106-preview"], value="gpt-4-1106-preview", label="model") | |
# functions | |
tools = gr.TextArea(label="functions") | |
# upload files | |
retrieval_files = gr.Files(label="retrieval_files") | |
retrieval_file_ids = gr.TextArea(label="retrieval_file_ids") | |
# create assistant | |
create_assistant_button = gr.Button("create assistant") | |
with gr.TabItem("list assistant"): | |
list_assistant_button = gr.Button("list assistant") | |
assistant_list = gr.TextArea(label="assistant_list") | |
delete_assistant_id = gr.Textbox(label="delete_assistant_id") | |
delete_assistant_button = gr.Button("delete assistant") | |
with gr.TabItem("list file"): | |
list_file_button = gr.Button("list file") | |
file_list = gr.TextArea(label="file_list") | |
delete_file_id = gr.Textbox(label="delete_file_id") | |
delete_file_button = gr.Button("delete file") | |
upload_files = gr.Files(label="upload_files") | |
upload_files_button = gr.Button("upload file") | |
# chat | |
with gr.Column(scale=5): | |
chat_bot = gr.Chatbot(label="conversation", height=600) | |
query = gr.Textbox(lines=1, label="query") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
add_and_run_button = gr.Button("Add and run") | |
with gr.Column(scale=1): | |
refresh_button = gr.Button("Refresh") | |
# states | |
with gr.Column(scale=2): | |
assistant_id = gr.Textbox(value=None, label="assistant_id") | |
thread_id = gr.Textbox(value=None, label="thread_id") | |
# examples | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
[ | |
"Math Tutor", | |
"You are a personal math tutor. Write and run code to answer math questions.", | |
"Official math test cases", | |
None, | |
None, | |
"gpt-4-1106-preview", | |
"123 * 524 等于多少?" | |
], | |
[ | |
"小说专家", | |
"根据小说内容回答问题。", | |
"三国演义文档问答测试", | |
"{\"type\": \"retrieval\"}", | |
[ | |
(project_path / "data/三国演义.txt").as_posix() | |
], | |
"gpt-4-1106-preview", | |
"刘备和张飞是什么关系。" | |
], | |
], | |
inputs=[ | |
name, instructions, description, tools, retrieval_files, model, | |
query, | |
], | |
examples_per_page=5 | |
) | |
# create assistant | |
create_assistant_button.click( | |
click_create_assistant, | |
inputs=[ | |
openai_api_key, | |
name, instructions, description, tools, retrieval_files, retrieval_file_ids, model, | |
], | |
outputs=[ | |
assistant_id, thread_id | |
] | |
) | |
# list assistant | |
list_assistant_button.click( | |
click_list_assistant, | |
inputs=[ | |
openai_api_key | |
], | |
outputs=[ | |
assistant_list | |
] | |
) | |
# delete assistant button | |
delete_assistant_button.click( | |
click_delete_assistant, | |
inputs=[ | |
openai_api_key, | |
delete_assistant_id | |
], | |
outputs=[ | |
delete_assistant_id | |
] | |
) | |
# list file | |
list_file_button.click( | |
click_list_file, | |
inputs=[ | |
openai_api_key | |
], | |
outputs=[ | |
file_list | |
], | |
) | |
# delete file | |
delete_file_button.click( | |
click_delete_file, | |
inputs=[ | |
openai_api_key, | |
delete_file_id | |
], | |
outputs=[ | |
delete_file_id | |
] | |
) | |
# upload files | |
upload_files_button.click( | |
click_upload_files, | |
inputs=[ | |
openai_api_key, | |
upload_files | |
], | |
outputs=[ | |
] | |
) | |
# add and run | |
add_and_run_button.click( | |
add_and_run, | |
inputs=[ | |
openai_api_key, | |
assistant_id, thread_id, | |
name, instructions, description, tools, retrieval_files, retrieval_file_ids, model, | |
query, | |
], | |
outputs=[ | |
assistant_id, thread_id, | |
chat_bot | |
], | |
) | |
# refresh | |
refresh_button.click( | |
refresh, | |
inputs=[ | |
openai_api_key, | |
thread_id, | |
], | |
outputs=[ | |
chat_bot | |
] | |
) | |
blocks.queue().launch() | |
return | |
if __name__ == '__main__': | |
main() | |