Spaces:
Runtime error
Runtime error
| import os | |
| from typing import List, Dict | |
| import time | |
| from openai import OpenAI | |
| from assistant_file_handler import FileHandler | |
| from openai.types.beta.thread import Thread | |
| from openai.types.beta.threads.message import Message | |
| from openai.types.beta.vector_stores.vector_store_file import VectorStoreFile | |
| import structlog | |
| from openai.pagination import SyncCursorPage | |
| class OAIAssistant: | |
| def __init__(self, assistant_id, vectorstore_id) -> None: | |
| self.file_handler = FileHandler() | |
| self.assistant_id = assistant_id | |
| self.vectorstore_id = vectorstore_id | |
| self.client = OpenAI() | |
| self.openai_assistant = self.client.beta.assistants.retrieve( | |
| assistant_id=self.assistant_id | |
| ) | |
| self.log = structlog.get_logger() | |
| def create(self): | |
| pass | |
| def add_file(self, file_path: str): | |
| file_id = self.file_handler.add(file_path=file_path).id | |
| self.client.beta.vector_stores.files.create( | |
| file_id=file_id, vector_store_id=self.vectorstore_id | |
| ) | |
| def remove_file(self, file_id: str): | |
| self.client.beta.vector_stores.files.delete( | |
| file_id=file_id, vector_store_id=self.vectorstore_id | |
| ) | |
| self.log.info( | |
| f"OAIAssistant: Deleted file with id {file_id} from vector database" | |
| ) | |
| self.file_handler.remove(file_id=file_id) | |
| self.log.info(f"OAIAssistant: Deleted file with id {file_id} from file storage") | |
| def chat(self, query: str, thread_id: str): | |
| try: | |
| if not thread_id: | |
| thread = self.create_thread().id | |
| thread_id = thread.id | |
| # else: | |
| # thread_id = self.client.beta.threads.retrieve(thread_id).id | |
| self.client.beta.threads.messages.create( | |
| thread_id=thread_id, | |
| role="user", | |
| content=query, | |
| ) | |
| self.log.info( | |
| "OAIAssistant: Message added to thread", | |
| thread_id=thread_id, | |
| query=query, | |
| ) | |
| new_message, message_file_ids = self.__run_assistant(thread_id=thread_id) | |
| file_paths = [] | |
| for msg_file_id in message_file_ids: | |
| png_file_path = f"./tmp/{msg_file_id}.png" | |
| self.__convert_file_to_png( | |
| file_id=msg_file_id, write_path=png_file_path | |
| ) | |
| file_paths.append(png_file_path) | |
| file_ids = self.__add_files(file_paths=file_paths) | |
| self.client.beta.threads.messages.create( | |
| thread_id=thread_id, | |
| role="assistant", | |
| content=new_message, | |
| attachments=[ | |
| {"file_id": file_id, "tools": [{"type": "image_file"}]} | |
| for _, file_id in file_ids.items() | |
| ] | |
| if file_ids | |
| else None, | |
| ) | |
| self.log.info( | |
| "OAIAssistant: Assistant response generated", response=new_message | |
| ) | |
| return new_message | |
| except Exception as e: | |
| self.log.error("OAIAssistant: Error generating response", error=str(e)) | |
| return "OAIAssistant: An error occurred while generating the response." | |
| def create_thread(self) -> Thread: | |
| thread: Thread = self.client.beta.threads.create( | |
| tool_resources={"file_search": {"vector_store_ids": [self.vectorstore_id]}} | |
| ) | |
| return thread | |
| def delete_thread(self, thread_id: str): | |
| self.client.beta.threads.delete(thread_id=thread_id) | |
| self.log.info(f"OAIAssistant: Deleted thread with id: {thread_id}") | |
| def __convert_file_to_png(self, file_id, write_path): | |
| try: | |
| data = self.client.files.content(file_id) | |
| data_bytes = data.read() | |
| with open(write_path, "wb") as file: | |
| file.write(data_bytes) | |
| self.log.info("OAIAssistant: File converted to PNG", file_path=write_path) | |
| except Exception as e: | |
| self.log.error("OAIAssistant: Error converting file to PNG", error=str(e)) | |
| raise | |
| def __add_files(self, file_paths: List[str]) -> Dict[str, str]: | |
| try: | |
| files = {} | |
| for file in file_paths: | |
| filename = os.path.basename(file) | |
| file = self.file_handler.add(file) | |
| files[filename] = file.id | |
| self.log.info("OAIAssistant: Files added", files=files) | |
| return files | |
| except Exception as e: | |
| self.log.error("OAIAssistant: Error adding files", error=str(e)) | |
| raise | |
| def __run_assistant(self, thread_id: str): | |
| try: | |
| run = self.client.beta.threads.runs.create( | |
| thread_id=thread_id, | |
| assistant_id=self.assistant_id, | |
| ) | |
| self.log.info("OAIAssistant: Assistant run started", run_id=run.id) | |
| while run.status != "completed": | |
| time.sleep(1) | |
| run = self.client.beta.threads.runs.retrieve( | |
| thread_id=thread_id, run_id=run.id | |
| ) | |
| if run.status == "failed": | |
| self.log.error( | |
| "OAIAssistant: Assistant run failed", | |
| run_id=run.id, | |
| ) | |
| self.log.info(run) | |
| return "OAIAssistant: Error in generating response", [] | |
| messages: SyncCursorPage[Message] = self.client.beta.threads.messages.list( | |
| thread_id=thread_id, run_id=run.id | |
| ) | |
| new_message, file_ids = self.__extract_messages(messages) | |
| return new_message, file_ids | |
| except Exception as e: | |
| self.log.error("OAIAssistant: Error running assistant", error=str(e)) | |
| raise | |
| def __extract_messages(self, messages: SyncCursorPage[Message]): | |
| try: | |
| new_message = "" | |
| file_ids = [] | |
| for message in messages.data: | |
| if message.content[0].type == "text": | |
| new_message += message.content[0].text.value | |
| elif message.content[0].type == "image_file": | |
| new_message += "Image File:\n" | |
| new_message += message.content[0].image_file.file_id | |
| new_message += "\n\n" | |
| file_ids.append(message.content[0].image_file.file_id) | |
| self.log.info("OAIAssistant: Messages extracted", message=new_message) | |
| return new_message, file_ids | |
| except Exception as e: | |
| self.log.error("OAIAssistant: Error extracting messages", error=str(e)) | |
| raise | |
| def get_files_list(self): | |
| files: SyncCursorPage[VectorStoreFile] = ( | |
| self.client.beta.vector_stores.files.list( | |
| vector_store_id=self.vectorstore_id | |
| ) | |
| ) | |
| return [file.id for file in files] | |