Spaces:
Sleeping
Sleeping
import os | |
from typing import List, Dict | |
import time | |
from openai import OpenAI | |
from assistant_file_handler import FileHandler | |
from openai.types.beta.thread import Thread | |
from openai.types.beta.threads.message import Message | |
from openai.types.beta.vector_stores.vector_store_file import VectorStoreFile | |
import structlog | |
from openai.pagination import SyncCursorPage | |
class OAIAssistant: | |
def __init__(self, assistant_id, vectorstore_id) -> None: | |
self.file_handler = FileHandler() | |
self.assistant_id = assistant_id | |
self.vectorstore_id = vectorstore_id | |
self.client = OpenAI() | |
self.openai_assistant = self.client.beta.assistants.retrieve( | |
assistant_id=self.assistant_id | |
) | |
self.log = structlog.get_logger() | |
def create(self): | |
pass | |
def add_file(self, file_path: str): | |
file_id = self.file_handler.add(file_path=file_path).id | |
self.client.beta.vector_stores.files.create( | |
file_id=file_id, vector_store_id=self.vectorstore_id | |
) | |
def remove_file(self, file_id: str): | |
self.client.beta.vector_stores.files.delete( | |
file_id=file_id, vector_store_id=self.vectorstore_id | |
) | |
self.log.info( | |
f"OAIAssistant: Deleted file with id {file_id} from vector database" | |
) | |
self.file_handler.remove(file_id=file_id) | |
self.log.info(f"OAIAssistant: Deleted file with id {file_id} from file storage") | |
def chat(self, query: str, thread_id: str): | |
try: | |
if not thread_id: | |
thread = self.create_thread().id | |
thread_id = thread.id | |
# else: | |
# thread_id = self.client.beta.threads.retrieve(thread_id).id | |
self.client.beta.threads.messages.create( | |
thread_id=thread_id, | |
role="user", | |
content=query, | |
) | |
self.log.info( | |
"OAIAssistant: Message added to thread", | |
thread_id=thread_id, | |
query=query, | |
) | |
new_message, message_file_ids = self.__run_assistant(thread_id=thread_id) | |
file_paths = [] | |
for msg_file_id in message_file_ids: | |
png_file_path = f"./tmp/{msg_file_id}.png" | |
self.__convert_file_to_png( | |
file_id=msg_file_id, write_path=png_file_path | |
) | |
file_paths.append(png_file_path) | |
file_ids = self.__add_files(file_paths=file_paths) | |
self.client.beta.threads.messages.create( | |
thread_id=thread_id, | |
role="assistant", | |
content=new_message, | |
attachments=[ | |
{"file_id": file_id, "tools": [{"type": "image_file"}]} | |
for _, file_id in file_ids.items() | |
] | |
if file_ids | |
else None, | |
) | |
self.log.info( | |
"OAIAssistant: Assistant response generated", response=new_message | |
) | |
return new_message | |
except Exception as e: | |
self.log.error("OAIAssistant: Error generating response", error=str(e)) | |
return "OAIAssistant: An error occurred while generating the response." | |
def create_thread(self) -> Thread: | |
thread: Thread = self.client.beta.threads.create( | |
tool_resources={"file_search": {"vector_store_ids": [self.vectorstore_id]}} | |
) | |
return thread | |
def delete_thread(self, thread_id: str): | |
self.client.beta.threads.delete(thread_id=thread_id) | |
self.log.info(f"OAIAssistant: Deleted thread with id: {thread_id}") | |
def __convert_file_to_png(self, file_id, write_path): | |
try: | |
data = self.client.files.content(file_id) | |
data_bytes = data.read() | |
with open(write_path, "wb") as file: | |
file.write(data_bytes) | |
self.log.info("OAIAssistant: File converted to PNG", file_path=write_path) | |
except Exception as e: | |
self.log.error("OAIAssistant: Error converting file to PNG", error=str(e)) | |
raise | |
def __add_files(self, file_paths: List[str]) -> Dict[str, str]: | |
try: | |
files = {} | |
for file in file_paths: | |
filename = os.path.basename(file) | |
file = self.file_handler.add(file) | |
files[filename] = file.id | |
self.log.info("OAIAssistant: Files added", files=files) | |
return files | |
except Exception as e: | |
self.log.error("OAIAssistant: Error adding files", error=str(e)) | |
raise | |
def __run_assistant(self, thread_id: str): | |
try: | |
run = self.client.beta.threads.runs.create( | |
thread_id=thread_id, | |
assistant_id=self.assistant_id, | |
) | |
self.log.info("OAIAssistant: Assistant run started", run_id=run.id) | |
while run.status != "completed": | |
time.sleep(1) | |
run = self.client.beta.threads.runs.retrieve( | |
thread_id=thread_id, run_id=run.id | |
) | |
if run.status == "failed": | |
self.log.error( | |
"OAIAssistant: Assistant run failed", | |
run_id=run.id, | |
) | |
self.log.info(run) | |
return "OAIAssistant: Error in generating response", [] | |
messages: SyncCursorPage[Message] = self.client.beta.threads.messages.list( | |
thread_id=thread_id, run_id=run.id | |
) | |
new_message, file_ids = self.__extract_messages(messages) | |
return new_message, file_ids | |
except Exception as e: | |
self.log.error("OAIAssistant: Error running assistant", error=str(e)) | |
raise | |
def __extract_messages(self, messages: SyncCursorPage[Message]): | |
try: | |
new_message = "" | |
file_ids = [] | |
for message in messages.data: | |
if message.content[0].type == "text": | |
new_message += message.content[0].text.value | |
elif message.content[0].type == "image_file": | |
new_message += "Image File:\n" | |
new_message += message.content[0].image_file.file_id | |
new_message += "\n\n" | |
file_ids.append(message.content[0].image_file.file_id) | |
self.log.info("OAIAssistant: Messages extracted", message=new_message) | |
return new_message, file_ids | |
except Exception as e: | |
self.log.error("OAIAssistant: Error extracting messages", error=str(e)) | |
raise | |
def get_files_list(self): | |
files: SyncCursorPage[VectorStoreFile] = ( | |
self.client.beta.vector_stores.files.list( | |
vector_store_id=self.vectorstore_id | |
) | |
) | |
return [file.id for file in files] | |