|
import sys |
|
import os |
|
|
|
import pandas as pd |
|
import langchain |
|
os.environ['STREAMLIT_SERVER_ENABLE_STATIC_SERVING'] = 'false' |
|
|
|
from simple_rag import app |
|
|
|
import streamlit as st |
|
import json |
|
from io import StringIO |
|
import tiktoken |
|
import time |
|
from langchain_community.document_loaders import PyMuPDFLoader |
|
import traceback |
|
import sqlite3 |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
import uuid |
|
|
|
|
|
config={"configurable": {"thread_id": "sample"}} |
|
GPT_LIMIT = 128000 |
|
GEMINI_LIMIT = 1000000 |
|
config={"configurable": {"thread_id": "sample"}} |
|
|
|
def count_tokens_gpt(text): |
|
enc = tiktoken.encoding_for_model("gpt-4") |
|
return len(enc.encode(text)) |
|
|
|
def count_tokens_gemini(text): |
|
return len(text.split()) |
|
|
|
|
|
def calculate_context_window_usage(json_data=None): |
|
|
|
full_conversation = "" |
|
for sender, message in st.session_state.chat_history: |
|
full_conversation += f"{sender}: {message}\n\n" |
|
|
|
|
|
if json_data: |
|
full_conversation += json.dumps(json_data) |
|
|
|
gpt_tokens = count_tokens_gpt(full_conversation) |
|
gemini_tokens = count_tokens_gemini(full_conversation) |
|
|
|
return gpt_tokens, gemini_tokens |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="π RAG Chat Assistant", layout="wide") |
|
|
|
|
|
|
|
SESSION_DB_DIR = "Data/sessions" |
|
|
|
def initialize_session_database(session_id): |
|
"""Initializes a new database for a chat session.""" |
|
db_path = os.path.join(SESSION_DB_DIR, f"{session_id}.db") |
|
conn = sqlite3.connect(db_path) |
|
cursor = conn.cursor() |
|
cursor.execute(""" |
|
CREATE TABLE IF NOT EXISTS chat_history ( |
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
sender TEXT, |
|
message TEXT, |
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP |
|
) |
|
""") |
|
conn.commit() |
|
conn.close() |
|
return db_path |
|
|
|
def save_message(db_path, sender, message): |
|
"""Saves a message to the specified session database.""" |
|
conn = sqlite3.connect(db_path) |
|
cursor = conn.cursor() |
|
cursor.execute("INSERT INTO chat_history (sender, message) VALUES (?, ?)", (sender, message)) |
|
conn.commit() |
|
conn.close() |
|
|
|
def clear_chat_history(db_path): |
|
"""Clears the chat history in the specified session database.""" |
|
conn = sqlite3.connect(db_path) |
|
cursor = conn.cursor() |
|
cursor.execute("DELETE FROM chat_history") |
|
conn.commit() |
|
conn.close() |
|
|
|
|
|
if not os.path.exists(SESSION_DB_DIR): |
|
os.makedirs(SESSION_DB_DIR) |
|
|
|
|
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = [ |
|
("assistant", "π Hello! I'm your RAG assistant. Please upload your JSON files and ask me a question about your portfolio.") |
|
] |
|
if "processing" not in st.session_state: |
|
st.session_state.processing = False |
|
if "total_gpt_tokens" not in st.session_state: |
|
st.session_state.total_gpt_tokens = 0 |
|
if "total_gemini_tokens" not in st.session_state: |
|
st.session_state.total_gemini_tokens = 0 |
|
if "window_gpt_tokens" not in st.session_state: |
|
st.session_state.window_gpt_tokens = 0 |
|
if "window_gemini_tokens" not in st.session_state: |
|
st.session_state.window_gemini_tokens = 0 |
|
|
|
|
|
if "session_id" not in st.session_state: |
|
st.session_state.session_id = str(uuid.uuid4()) |
|
st.session_state.session_db_path = initialize_session_database(st.session_state.session_id) |
|
|
|
|
|
def load_chat_history(db_path): |
|
conn = sqlite3.connect(db_path) |
|
cursor = conn.cursor() |
|
cursor.execute("SELECT sender, message FROM chat_history ORDER BY timestamp") |
|
history = cursor.fetchall() |
|
conn.close() |
|
return history |
|
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
PROJECT_ROOT = os.path.dirname(BASE_DIR) |
|
print(PROJECT_ROOT, BASE_DIR) |
|
|
|
col_chat, col_progress = st.columns([3, 1]) |
|
|
|
|
|
with col_chat: |
|
st.title("π¬ RAG Assistant") |
|
|
|
with st.expander("π Upload Required JSON Files", expanded=True): |
|
|
|
|
|
|
|
user_data_path = os.getenv('USER_DATA_PATH') |
|
allocations_path = os.getenv('ALLOCATIONS_PATH') |
|
|
|
try: |
|
with open(user_data_path, 'r') as f: |
|
user_data = json.load(f) |
|
except FileNotFoundError: |
|
st.error(f"Error: user_data.json not found at {user_data_path}") |
|
user_data = None |
|
except json.JSONDecodeError: |
|
st.error(f"Error: Could not decode user_data.json. Please ensure it is valid JSON.") |
|
user_data = None |
|
|
|
try: |
|
with open(allocations_path, 'r') as f: |
|
allocations = json.load(f) |
|
except FileNotFoundError: |
|
st.error(f"Error: allocations.json not found at {allocations_path}") |
|
allocations = None |
|
except json.JSONDecodeError: |
|
st.error(f"Error: Could not decode allocations.json. Please ensure it is valid JSON.") |
|
allocations = None |
|
|
|
if user_data: |
|
sematic = user_data.get("sematic", {}) |
|
demographic = sematic.get("demographic", {}) |
|
financial = sematic.get("financial", {}) |
|
episodic = user_data.get("episodic", {}).get("prefrences", []) |
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
with col1: |
|
st.markdown("### π§Ύ **Demographic Info**") |
|
for key, value in demographic.items(): |
|
st.markdown(f"- **{key.replace('_', ' ').title()}**: {value}") |
|
|
|
with col2: |
|
st.markdown("### π **Financial Status**") |
|
for key, value in financial.items(): |
|
st.markdown(f"- **{key.replace('_', ' ').title()}**: {value}") |
|
|
|
with col3: |
|
st.markdown("### βοΈ **Preferences & Goals**") |
|
st.markdown("**User Preferences:**") |
|
for pref in user_data.get("episodic", {}).get("prefrences", []): |
|
st.markdown(f"- {pref.capitalize()}") |
|
st.markdown("**Goals:**") |
|
for goal in user_data.get("episodic", {}).get("goals", []): |
|
for k, v in goal.items(): |
|
st.markdown(f"- **{k.replace('_', ' ').title()}**: {v}") |
|
|
|
|
|
|
|
|
|
if "allocations" not in st.session_state: |
|
st.session_state.allocations = allocations |
|
|
|
if st.session_state.allocations: |
|
try: |
|
|
|
st.markdown("### πΌ Investment Allocations") |
|
|
|
|
|
records = [] |
|
for asset_class, entries in st.session_state.allocations.items(): |
|
for item in entries: |
|
records.append({ |
|
"Asset Class": asset_class.replace("_", " ").title(), |
|
"Type": item.get("type", ""), |
|
"Label": item.get("label", ""), |
|
"Amount (βΉ)": item.get("amount", 0) |
|
}) |
|
|
|
df = pd.DataFrame(records) |
|
st.dataframe(df) |
|
|
|
except Exception as e: |
|
st.error(f"Failed to parse allocations.json: {e}") |
|
|
|
|
|
|
|
|
|
if st.button("Clear Chat"): |
|
st.session_state.chat_history = [ |
|
("assistant", "π Hello! I'm your RAG assistant. Please upload your JSON files and ask me a question about your portfolio.") |
|
] |
|
st.session_state.total_gpt_tokens = 0 |
|
st.session_state.total_gemini_tokens = 0 |
|
st.session_state.window_gpt_tokens = 0 |
|
st.session_state.window_gemini_tokens = 0 |
|
|
|
|
|
clear_chat_history(st.session_state.session_db_path) |
|
|
|
|
|
st.rerun() |
|
|
|
st.markdown("---") |
|
|
|
|
|
chat_container = st.container() |
|
with chat_container: |
|
for sender, message in st.session_state.chat_history: |
|
if sender == "user": |
|
st.chat_message("user").write(message) |
|
else: |
|
st.chat_message("assistant").write(message) |
|
|
|
|
|
if st.session_state.processing: |
|
thinking_placeholder = st.empty() |
|
with st.chat_message("assistant"): |
|
for i in range(3): |
|
for dots in [".", "..", "..."]: |
|
thinking_placeholder.markdown(f"Thinking{dots}") |
|
time.sleep(0.3) |
|
|
|
|
|
user_input = st.chat_input("Type your question...") |
|
|
|
if user_input and not st.session_state.processing: |
|
|
|
st.session_state.processing = True |
|
|
|
|
|
st.session_state.chat_history.append(("user", user_input)) |
|
save_message(st.session_state.session_db_path, "user", user_input) |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
if st.session_state.processing: |
|
if not user_data or not allocations: |
|
st.session_state.chat_history.append(("assistant", "β οΈ Please upload both JSON files before asking questions.")) |
|
st.session_state.processing = False |
|
st.rerun() |
|
else: |
|
try: |
|
|
|
|
|
|
|
|
|
|
|
combined_json_data = {"user_data": user_data, "allocations": allocations} |
|
|
|
|
|
last_user_message = next((msg for sender, msg in reversed(st.session_state.chat_history) if sender == "user"), "") |
|
|
|
|
|
user_msg_gpt_tokens = count_tokens_gpt(last_user_message) |
|
user_msg_gemini_tokens = count_tokens_gemini(last_user_message) |
|
|
|
|
|
st.session_state.total_gpt_tokens += user_msg_gpt_tokens |
|
st.session_state.total_gemini_tokens += user_msg_gemini_tokens |
|
|
|
|
|
window_gpt, window_gemini = calculate_context_window_usage(combined_json_data) |
|
st.session_state.window_gpt_tokens = window_gpt |
|
st.session_state.window_gemini_tokens = window_gemini |
|
|
|
|
|
if window_gpt > GPT_LIMIT or window_gemini > GEMINI_LIMIT: |
|
st.session_state.chat_history.append(("assistant", "β οΈ Your conversation has exceeded token limits. Please clear the chat to continue.")) |
|
st.session_state.processing = False |
|
st.rerun() |
|
else: |
|
|
|
inputs = { |
|
"query": last_user_message, |
|
"user_data": user_data, |
|
"allocations": allocations, |
|
|
|
"chat_history": st.session_state.chat_history |
|
} |
|
print(st.session_state.chat_history) |
|
|
|
|
|
|
|
output = app.invoke(inputs, config = config) |
|
response = output.get('output') |
|
print(response) |
|
|
|
|
|
|
|
if "allocations" in output: |
|
st.session_state.allocations = output["allocations"] |
|
|
|
|
|
response_gpt_tokens = count_tokens_gpt(response) |
|
response_gemini_tokens = count_tokens_gemini(response) |
|
|
|
|
|
st.session_state.total_gpt_tokens += response_gpt_tokens |
|
st.session_state.total_gemini_tokens += response_gemini_tokens |
|
|
|
|
|
st.session_state.chat_history.append(("assistant", response)) |
|
|
|
|
|
window_gpt, window_gemini = calculate_context_window_usage(combined_json_data) |
|
st.session_state.window_gpt_tokens = window_gpt |
|
st.session_state.window_gemini_tokens = window_gemini |
|
|
|
except Exception as e: |
|
tb = traceback.extract_stack() |
|
filename, line_number, function_name, text = tb[-2] |
|
error_message = f"β Error: {str(e)} in {filename} at line {line_number}, function: {function_name}" |
|
st.session_state.chat_history.append(("assistant", error_message)) |
|
|
|
|
|
st.session_state.processing = False |
|
st.rerun() |