import os
import streamlit as st
import time
import json
from datetime import datetime
from openai import OpenAI

# Streamlit Page Config
st.set_page_config(page_title="Schlager ContractAi", layout="wide")

# Authentication
AUTHORIZED_USERS = {
    "andrew@lortechnologies.com": "Pass.123",
    "manish@schlagergroup.com.au": "Pass.123",
    "benb@schlagergroup.com.au": "Pass.123",
    "nick@schlagergroup.com.au": "Pass.123",
    "thomasc@schlagergroup.com.au": "Pass.123"
}

if "authenticated" not in st.session_state:
    st.session_state["authenticated"] = False

def login():
    email = st.session_state.get("email", "")
    password = st.session_state.get("password", "")
    if email in AUTHORIZED_USERS and AUTHORIZED_USERS[email] == password:
        st.session_state["authenticated"] = True
    else:
        st.error("Invalid email or password. Please try again.")

if not st.session_state["authenticated"]:
    st.title("Sign In")
    st.text_input("Email", key="email")
    st.text_input("Password", type="password", key="password")
    st.button("Login", on_click=login)
    st.stop()

# Main App
st.title("Schlager ContractAi")
st.caption("Chat with your contract or manage meeting minutes")

# Load API Key and Assistant IDs from Environment Variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
ASSISTANT_CONTRACT_ID = os.getenv("ASSISTANT_CONTRACT_ID")
ASSISTANT_TECHNICAL_ID = os.getenv("ASSISTANT_TECHNICAL_ID")

if not OPENAI_API_KEY or not ASSISTANT_CONTRACT_ID or not ASSISTANT_TECHNICAL_ID:
    st.error("Missing required environment variables. Please set OPENAI_API_KEY, ASSISTANT_CONTRACT_ID, and ASSISTANT_TECHNICAL_ID.")
    st.stop()

# Tabs
tab1, tab2, tab3 = st.tabs(["Contract", "Technical", "Flagged Responses"])

FLAGGED_RESPONSES_DIR = "flagged_responses"
os.makedirs(FLAGGED_RESPONSES_DIR, exist_ok=True)

def save_flagged_response(user_query, ai_response):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"{FLAGGED_RESPONSES_DIR}/flagged_{timestamp}.json"
    flagged_data = {
        "query": user_query,
        "response": ai_response,
        "timestamp": datetime.now().isoformat()
    }
    
    with open(filename, "w") as file:
        json.dump(flagged_data, file, indent=4)
    
    st.success(f"Response flagged and saved as {filename}.")

# Contract Chat Section
def contract_chat_section(tab, assistant_id, session_key, input_key):
    with tab:
        st.subheader("Chat")
        client = OpenAI(api_key=OPENAI_API_KEY)

        if session_key not in st.session_state:
            st.session_state[session_key] = []

        if st.button("Clear Chat", key=f"clear_chat_{session_key}"):
            st.session_state[session_key] = []
            st.rerun()

        for idx, message in enumerate(st.session_state[session_key]):
            role, content = message["role"], message["content"]
            if role == "assistant":
                with st.container():
                    st.chat_message(role).write(content)
                    if st.button("🚩 Flag", key=f"flag_{session_key}_{idx}"):
                        user_query = st.session_state[session_key][idx-1]["content"] if idx > 0 else "Unknown"
                        save_flagged_response(user_query, content)
            else:
                st.chat_message(role).write(content)

        if prompt := st.chat_input("Enter your message:", key=input_key):
            st.session_state[session_key].append({"role": "user", "content": prompt})
            st.chat_message("user").write(prompt)

            try:
                thread = client.beta.threads.create()
                thread_id = thread.id
                client.beta.threads.messages.create(
                    thread_id=thread_id,
                    role="user",
                    content=prompt
                )

                run = client.beta.threads.runs.create(
                    thread_id=thread_id,
                    assistant_id=assistant_id
                )

                while True:
                    run_status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id)
                    if run_status.status == "completed":
                        break
                    time.sleep(1)

                messages = client.beta.threads.messages.list(thread_id=thread_id)
                assistant_message = messages.data[0].content[0].text.value
                st.session_state[session_key].append({"role": "assistant", "content": assistant_message})
                with st.container():
                    st.chat_message("assistant").write(assistant_message)
                    if st.button("🚩 Flag", key=f"flag_response_{session_key}_{len(st.session_state[session_key])}"):
                        save_flagged_response(prompt, assistant_message)
            except Exception as e:
                st.error(f"Error: {str(e)}")

contract_chat_section(tab1, ASSISTANT_CONTRACT_ID, "contract_messages", "contract_input")
contract_chat_section(tab2, ASSISTANT_TECHNICAL_ID, "technical_messages", "technical_input")

# Flagged Responses Tab
with tab3:
    st.subheader("Flagged Responses")
    flagged_files = [f for f in os.listdir(FLAGGED_RESPONSES_DIR) if f.endswith(".json")]
    
    if flagged_files:
        selected_file = st.selectbox("Select a flagged response file to download:", flagged_files)
        
        if selected_file:
            with open(os.path.join(FLAGGED_RESPONSES_DIR, selected_file), "r") as file:
                flagged_responses = file.read()
            st.download_button("Download Selected Flagged Responses", data=flagged_responses, file_name=selected_file, mime="application/json")
    else:
        st.info("No flagged responses available.")