|
import os |
|
import streamlit as st |
|
import time |
|
import json |
|
from datetime import datetime |
|
from openai import OpenAI |
|
|
|
|
|
st.set_page_config(page_title="Schlager ContractAi", layout="wide") |
|
|
|
|
|
AUTHORIZED_USERS = { |
|
"andrew@lortechnologies.com": "Pass.123", |
|
"manish@schlagergroup.com.au": "Pass.123", |
|
"benb@schlagergroup.com.au": "Pass.123", |
|
"nick@schlagergroup.com.au": "Pass.123", |
|
"thomasc@schlagergroup.com.au": "Pass.123" |
|
} |
|
|
|
if "authenticated" not in st.session_state: |
|
st.session_state["authenticated"] = False |
|
|
|
def login(): |
|
email = st.session_state.get("email", "") |
|
password = st.session_state.get("password", "") |
|
if email in AUTHORIZED_USERS and AUTHORIZED_USERS[email] == password: |
|
st.session_state["authenticated"] = True |
|
else: |
|
st.error("Invalid email or password. Please try again.") |
|
|
|
if not st.session_state["authenticated"]: |
|
st.title("Sign In") |
|
st.text_input("Email", key="email") |
|
st.text_input("Password", type="password", key="password") |
|
st.button("Login", on_click=login) |
|
st.stop() |
|
|
|
|
|
st.title("Schlager ContractAi") |
|
st.caption("Chat with your contract or manage meeting minutes") |
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
ASSISTANT_CONTRACT_ID = os.getenv("ASSISTANT_CONTRACT_ID") |
|
ASSISTANT_TECHNICAL_ID = os.getenv("ASSISTANT_TECHNICAL_ID") |
|
|
|
if not OPENAI_API_KEY or not ASSISTANT_CONTRACT_ID or not ASSISTANT_TECHNICAL_ID: |
|
st.error("Missing required environment variables. Please set OPENAI_API_KEY, ASSISTANT_CONTRACT_ID, and ASSISTANT_TECHNICAL_ID.") |
|
st.stop() |
|
|
|
|
|
tab1, tab2, tab3 = st.tabs(["Contract", "Technical", "Flagged Responses"]) |
|
|
|
FLAGGED_RESPONSES_DIR = "flagged_responses" |
|
os.makedirs(FLAGGED_RESPONSES_DIR, exist_ok=True) |
|
|
|
def save_flagged_response(user_query, ai_response): |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
filename = f"{FLAGGED_RESPONSES_DIR}/flagged_{timestamp}.json" |
|
flagged_data = { |
|
"query": user_query, |
|
"response": ai_response, |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
|
|
with open(filename, "w") as file: |
|
json.dump(flagged_data, file, indent=4) |
|
|
|
st.success(f"Response flagged and saved as {filename}.") |
|
|
|
|
|
def contract_chat_section(tab, assistant_id, session_key, input_key): |
|
with tab: |
|
st.subheader("Chat") |
|
client = OpenAI(api_key=OPENAI_API_KEY) |
|
|
|
if session_key not in st.session_state: |
|
st.session_state[session_key] = [] |
|
|
|
if st.button("Clear Chat", key=f"clear_chat_{session_key}"): |
|
st.session_state[session_key] = [] |
|
st.rerun() |
|
|
|
for idx, message in enumerate(st.session_state[session_key]): |
|
role, content = message["role"], message["content"] |
|
if role == "assistant": |
|
with st.container(): |
|
st.chat_message(role).write(content) |
|
if st.button("🚩 Flag", key=f"flag_{session_key}_{idx}"): |
|
user_query = st.session_state[session_key][idx-1]["content"] if idx > 0 else "Unknown" |
|
save_flagged_response(user_query, content) |
|
else: |
|
st.chat_message(role).write(content) |
|
|
|
if prompt := st.chat_input("Enter your message:", key=input_key): |
|
st.session_state[session_key].append({"role": "user", "content": prompt}) |
|
st.chat_message("user").write(prompt) |
|
|
|
try: |
|
thread = client.beta.threads.create() |
|
thread_id = thread.id |
|
client.beta.threads.messages.create( |
|
thread_id=thread_id, |
|
role="user", |
|
content=prompt |
|
) |
|
|
|
run = client.beta.threads.runs.create( |
|
thread_id=thread_id, |
|
assistant_id=assistant_id |
|
) |
|
|
|
while True: |
|
run_status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id) |
|
if run_status.status == "completed": |
|
break |
|
time.sleep(1) |
|
|
|
messages = client.beta.threads.messages.list(thread_id=thread_id) |
|
assistant_message = messages.data[0].content[0].text.value |
|
st.session_state[session_key].append({"role": "assistant", "content": assistant_message}) |
|
with st.container(): |
|
st.chat_message("assistant").write(assistant_message) |
|
if st.button("🚩 Flag", key=f"flag_response_{session_key}_{len(st.session_state[session_key])}"): |
|
save_flagged_response(prompt, assistant_message) |
|
except Exception as e: |
|
st.error(f"Error: {str(e)}") |
|
|
|
contract_chat_section(tab1, ASSISTANT_CONTRACT_ID, "contract_messages", "contract_input") |
|
contract_chat_section(tab2, ASSISTANT_TECHNICAL_ID, "technical_messages", "technical_input") |
|
|
|
|
|
with tab3: |
|
st.subheader("Flagged Responses") |
|
flagged_files = [f for f in os.listdir(FLAGGED_RESPONSES_DIR) if f.endswith(".json")] |
|
|
|
if flagged_files: |
|
selected_file = st.selectbox("Select a flagged response file to download:", flagged_files) |
|
|
|
if selected_file: |
|
with open(os.path.join(FLAGGED_RESPONSES_DIR, selected_file), "r") as file: |
|
flagged_responses = file.read() |
|
st.download_button("Download Selected Flagged Responses", data=flagged_responses, file_name=selected_file, mime="application/json") |
|
else: |
|
st.info("No flagged responses available.") |