Spaces:
Runtime error
Runtime error
File size: 7,017 Bytes
cc83df3 ebe102a cc83df3 ebe102a cc83df3 ebe102a cc83df3 ebe102a cc83df3 ebe102a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import os
from typing import List, Dict
import time
from openai import OpenAI
from assistant_file_handler import FileHandler
from openai.types.beta.thread import Thread
from openai.types.beta.threads.message import Message
from openai.types.beta.vector_stores.vector_store_file import VectorStoreFile
import structlog
from openai.pagination import SyncCursorPage
class OAIAssistant:
def __init__(self, assistant_id, vectorstore_id) -> None:
self.file_handler = FileHandler()
self.assistant_id = assistant_id
self.vectorstore_id = vectorstore_id
self.client = OpenAI()
self.openai_assistant = self.client.beta.assistants.retrieve(
assistant_id=self.assistant_id
)
self.log = structlog.get_logger()
def create(self):
pass
def add_file(self, file_path: str):
file_id = self.file_handler.add(file_path=file_path).id
self.client.beta.vector_stores.files.create(
file_id=file_id, vector_store_id=self.vectorstore_id
)
def remove_file(self, file_id: str):
self.client.beta.vector_stores.files.delete(
file_id=file_id, vector_store_id=self.vectorstore_id
)
self.log.info(
f"OAIAssistant: Deleted file with id {file_id} from vector database"
)
self.file_handler.remove(file_id=file_id)
self.log.info(f"OAIAssistant: Deleted file with id {file_id} from file storage")
def chat(self, query: str, thread_id: str):
try:
if not thread_id:
thread = self.create_thread().id
thread_id = thread.id
# else:
# thread_id = self.client.beta.threads.retrieve(thread_id).id
self.client.beta.threads.messages.create(
thread_id=thread_id,
role="user",
content=query,
)
self.log.info(
"OAIAssistant: Message added to thread",
thread_id=thread_id,
query=query,
)
new_message, message_file_ids = self.__run_assistant(thread_id=thread_id)
file_paths = []
for msg_file_id in message_file_ids:
png_file_path = f"./tmp/{msg_file_id}.png"
self.__convert_file_to_png(
file_id=msg_file_id, write_path=png_file_path
)
file_paths.append(png_file_path)
file_ids = self.__add_files(file_paths=file_paths)
self.client.beta.threads.messages.create(
thread_id=thread_id,
role="assistant",
content=new_message,
attachments=[
{"file_id": file_id, "tools": [{"type": "file_search"}]}
for _, file_id in file_ids.items()
]
if file_ids
else None,
)
self.log.info(
"OAIAssistant: Assistant response generated", response=new_message
)
return new_message
except Exception as e:
self.log.error("OAIAssistant: Error generating response", error=str(e))
return "OAIAssistant: An error occurred while generating the response."
def create_thread(self) -> Thread:
thread: Thread = self.client.beta.threads.create(
tool_resources={"file_search": {"vector_store_ids": [self.vectorstore_id]}}
)
return thread
def delete_thread(self, thread_id: str):
self.client.beta.threads.delete(thread_id=thread_id)
self.log.info(f"OAIAssistant: Deleted thread with id: {thread_id}")
def __convert_file_to_png(self, file_id, write_path):
try:
data = self.client.files.content(file_id)
data_bytes = data.read()
with open(write_path, "wb") as file:
file.write(data_bytes)
self.log.info("OAIAssistant: File converted to PNG", file_path=write_path)
except Exception as e:
self.log.error("OAIAssistant: Error converting file to PNG", error=str(e))
raise
def __add_files(self, file_paths: List[str]) -> Dict[str, str]:
try:
files = {}
for file in file_paths:
filename = os.path.basename(file)
file = self.file_handler.add(file)
files[filename] = file.id
self.log.info("OAIAssistant: Files added", files=files)
return files
except Exception as e:
self.log.error("OAIAssistant: Error adding files", error=str(e))
raise
def __run_assistant(self, thread_id: str):
try:
run = self.client.beta.threads.runs.create(
thread_id=thread_id,
assistant_id=self.assistant_id,
)
self.log.info("OAIAssistant: Assistant run started", run_id=run.id)
while run.status != "completed":
time.sleep(1)
run = self.client.beta.threads.runs.retrieve(
thread_id=thread_id, run_id=run.id
)
if run.status == "failed":
self.log.error(
"OAIAssistant: Assistant run failed",
run_id=run.id,
)
self.log.info(run)
return "OAIAssistant: Error in generating response", []
messages: SyncCursorPage[Message] = self.client.beta.threads.messages.list(
thread_id=thread_id, run_id=run.id
)
new_message, file_ids = self.__extract_messages(messages)
return new_message, file_ids
except Exception as e:
self.log.error("OAIAssistant: Error running assistant", error=str(e))
raise
def __extract_messages(self, messages: SyncCursorPage[Message]):
try:
new_message = ""
file_ids = []
for message in messages.data:
if message.content[0].type == "text":
new_message += message.content[0].text.value
elif message.content[0].type == "image_file":
new_message += "Image File:\n"
new_message += message.content[0].image_file.file_id
new_message += "\n\n"
file_ids.append(message.content[0].image_file.file_id)
self.log.info("OAIAssistant: Messages extracted", message=new_message)
return new_message, file_ids
except Exception as e:
self.log.error("OAIAssistant: Error extracting messages", error=str(e))
raise
def get_files_list(self):
files: SyncCursorPage[VectorStoreFile] = (
self.client.beta.vector_stores.files.list(
vector_store_id=self.vectorstore_id
)
)
return [file.id for file in files]
|