|
import streamlit as st |
|
import openai |
|
import time |
|
import config.pagesetup as ps |
|
from openai import OpenAI |
|
import uuid |
|
|
|
|
|
|
|
st.set_page_config("FlowGenius", initial_sidebar_state="collapsed", layout="wide") |
|
|
|
|
|
|
|
ps.set_title("FlowGenius", "Legal Assistant") |
|
ps.set_page_overview("Overview", "**Legal Assistant** provides a way to quickly ask about the law") |
|
|
|
|
|
|
|
|
|
assistant = st.secrets.openai.assistant_key |
|
model = "gpt-4-1106-preview" |
|
client = OpenAI(api_key=st.secrets.openai.api_key) |
|
|
|
|
|
if "session_id" not in st.session_state: |
|
st.session_state.session_id = str(uuid.uuid4()) |
|
|
|
if "run" not in st.session_state: |
|
st.session_state.run = {"status": None} |
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
st.chat_message("assistant").markdown("I am your assistant. How may I help you?") |
|
if "retry_error" not in st.session_state: |
|
st.session_state.retry_error = 0 |
|
|
|
|
|
if "assistant" not in st.session_state: |
|
openai.api_key = st.secrets.openai.api_key |
|
|
|
|
|
st.session_state.assistant = client.beta.assistants.retrieve(st.secrets.openai.assistant_key) |
|
|
|
|
|
st.session_state.thread = client.beta.threads.create( |
|
metadata={ |
|
'session_id': st.session_state.session_id, |
|
} |
|
) |
|
|
|
|
|
elif hasattr(st.session_state.run, 'status') and st.session_state.run.status == "completed": |
|
|
|
st.session_state.messages = client.beta.threads.messages.list( |
|
thread_id=st.session_state.thread.id |
|
) |
|
|
|
for thread_message in st.session_state.messages.data: |
|
for message_content in thread_message.content: |
|
|
|
message_content = message_content.text |
|
annotations = message_content.annotations |
|
citations = [] |
|
|
|
|
|
for index, annotation in enumerate(annotations): |
|
|
|
message_content.value = message_content.value.replace(annotation.text, f' [{index}]') |
|
|
|
|
|
if (file_citation := getattr(annotation, 'file_citation', None)): |
|
cited_file = client.files.retrieve(file_citation.file_id) |
|
citations.append(f'[{index}] {file_citation.quote} from {cited_file.filename}') |
|
elif (file_path := getattr(annotation, 'file_path', None)): |
|
cited_file = client.files.retrieve(file_path.file_id) |
|
citations.append(f'[{index}] Click <here> to download {cited_file.filename}') |
|
|
|
|
|
|
|
message_content.value += '\n' + '\n'.join(citations) |
|
|
|
|
|
for message in reversed(st.session_state.messages.data): |
|
if message.role in ["user", "assistant"]: |
|
with st.chat_message(message.role): |
|
for content_part in message.content: |
|
message_text = content_part.text.value |
|
st.markdown(message_text) |
|
|
|
if prompt := st.chat_input("How can I help you?"): |
|
with st.chat_message('user'): |
|
st.write(prompt) |
|
|
|
|
|
st.session_state.messages = client.beta.threads.messages.create( |
|
thread_id=st.session_state.thread.id, |
|
role="user", |
|
content=prompt |
|
) |
|
|
|
|
|
st.session_state.run = client.beta.threads.runs.create( |
|
thread_id=st.session_state.thread.id, |
|
assistant_id=st.session_state.assistant.id, |
|
) |
|
if st.session_state.retry_error < 3: |
|
time.sleep(1) |
|
st.rerun() |
|
|
|
|
|
if hasattr(st.session_state.run, 'status'): |
|
|
|
if st.session_state.run.status == "running": |
|
with st.chat_message('assistant'): |
|
st.write("Thinking ......") |
|
if st.session_state.retry_error < 3: |
|
time.sleep(1) |
|
st.rerun() |
|
|
|
|
|
elif st.session_state.run.status == "failed": |
|
st.session_state.retry_error += 1 |
|
with st.chat_message('assistant'): |
|
if st.session_state.retry_error < 3: |
|
st.write("Run failed, retrying ......") |
|
time.sleep(3) |
|
st.rerun() |
|
else: |
|
st.error("FAILED: The OpenAI API is currently processing too many requests. Please try again later ......") |
|
|
|
|
|
elif st.session_state.run.status != "completed": |
|
|
|
st.session_state.run = client.beta.threads.runs.retrieve( |
|
thread_id=st.session_state.thread.id, |
|
run_id=st.session_state.run.id, |
|
) |
|
if st.session_state.retry_error < 3: |
|
time.sleep(3) |
|
st.rerun() |
|
|
|
|
|
|