File size: 14,176 Bytes
6bccf2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import sys
import os

import pandas as pd
import langchain
os.environ['STREAMLIT_SERVER_ENABLE_STATIC_SERVING'] = 'false'

from simple_rag import app

import streamlit as st
import json
from io import StringIO
import tiktoken
import time
from langchain_community.document_loaders import PyMuPDFLoader
import traceback
import sqlite3  # Import SQLite
from dotenv import load_dotenv
load_dotenv()

import uuid  # Import the UUID library

# Token limits
config={"configurable": {"thread_id": "sample"}}
GPT_LIMIT = 128000
GEMINI_LIMIT = 1000000
config={"configurable": {"thread_id": "sample"}}
# Token counters
def count_tokens_gpt(text):
    enc = tiktoken.encoding_for_model("gpt-4")
    return len(enc.encode(text))

def count_tokens_gemini(text):
    return len(text.split())  # Approximation

# Calculate tokens for the entire context window
def calculate_context_window_usage(json_data=None):
    # Reconstruct the full conversation context
    full_conversation = ""
    for sender, message in st.session_state.chat_history:
        full_conversation += f"{sender}: {message}\n\n"
    
    # Add JSON context if provided
    if json_data:
        full_conversation += json.dumps(json_data)
        
    gpt_tokens = count_tokens_gpt(full_conversation)
    gemini_tokens = count_tokens_gemini(full_conversation)
    
    return gpt_tokens, gemini_tokens




# Page configuration
st.set_page_config(page_title="πŸ“Š RAG Chat Assistant", layout="wide")

# --- Database setup ---
# DATABASE_PATH = "Data/chat_history.db"  # Original database path
SESSION_DB_DIR = "Data/sessions"  # Directory to store individual session DBs

def initialize_session_database(session_id):
    """Initializes a new database for a chat session."""
    db_path = os.path.join(SESSION_DB_DIR, f"{session_id}.db")
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("""
        CREATE TABLE IF NOT EXISTS chat_history (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            sender TEXT,
            message TEXT,
            timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
        )
    """)
    conn.commit()
    conn.close()
    return db_path

def save_message(db_path, sender, message):
    """Saves a message to the specified session database."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("INSERT INTO chat_history (sender, message) VALUES (?, ?)", (sender, message))
    conn.commit()
    conn.close()

def clear_chat_history(db_path):
    """Clears the chat history in the specified session database."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("DELETE FROM chat_history")
    conn.commit()
    conn.close()

# Initialize session DB directory
if not os.path.exists(SESSION_DB_DIR):
    os.makedirs(SESSION_DB_DIR)

# --- Session state setup ---
if "chat_history" not in st.session_state:
    st.session_state.chat_history = [
        ("assistant", "πŸ‘‹ Hello! I'm your RAG assistant. Please upload your JSON files and ask me a question about your portfolio.")
    ]
if "processing" not in st.session_state:
    st.session_state.processing = False
if "total_gpt_tokens" not in st.session_state:
    st.session_state.total_gpt_tokens = 0  # Total accumulated
if "total_gemini_tokens" not in st.session_state:
    st.session_state.total_gemini_tokens = 0  # Total accumulated
if "window_gpt_tokens" not in st.session_state:
    st.session_state.window_gpt_tokens = 0  # Current context window
if "window_gemini_tokens" not in st.session_state:
    st.session_state.window_gemini_tokens = 0  # Current context window

# Generate a unique session ID if one doesn't exist
if "session_id" not in st.session_state:
    st.session_state.session_id = str(uuid.uuid4())
    st.session_state.session_db_path = initialize_session_database(st.session_state.session_id)  # Initialize session DB

# --- Load chat history from the session database ---
def load_chat_history(db_path):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("SELECT sender, message FROM chat_history ORDER BY timestamp")
    history = cursor.fetchall()
    conn.close()
    return history


BASE_DIR = os.path.dirname(os.path.abspath(__file__))

# Go one level up to reach RAG_rubik/
PROJECT_ROOT = os.path.dirname(BASE_DIR)
print(PROJECT_ROOT, BASE_DIR)
# --- Layout: Chat UI Left | Progress Bars Right ---
col_chat, col_progress = st.columns([3, 1])

# --- LEFT COLUMN: Chat UI ---
with col_chat:
    st.title("πŸ’¬ RAG Assistant")

    with st.expander("πŸ“‚ Upload Required JSON Files", expanded=True):
        # user_data_file = st.file_uploader("Upload user_data.json", type="json", key="user_data")
        # allocations_file = st.file_uploader("Upload allocations.json", type="json", key="allocations")
        
        user_data_path = os.getenv('USER_DATA_PATH')
        allocations_path = os.getenv('ALLOCATIONS_PATH')

        try:
            with open(user_data_path, 'r') as f:
                user_data = json.load(f)
        except FileNotFoundError:
            st.error(f"Error: user_data.json not found at {user_data_path}")
            user_data = None
        except json.JSONDecodeError:
            st.error(f"Error: Could not decode user_data.json. Please ensure it is valid JSON.")
            user_data = None

        try:
            with open(allocations_path, 'r') as f:
                allocations = json.load(f)
        except FileNotFoundError:
            st.error(f"Error: allocations.json not found at {allocations_path}")
            allocations = None
        except json.JSONDecodeError:
            st.error(f"Error: Could not decode allocations.json. Please ensure it is valid JSON.")
            allocations = None

        if user_data:
            sematic = user_data.get("sematic", {})
            demographic = sematic.get("demographic", {})
            financial = sematic.get("financial", {})
            episodic = user_data.get("episodic", {}).get("prefrences", [])

            col1, col2, col3 = st.columns(3)

            with col1:
                            st.markdown("### 🧾 **Demographic Info**")
                            for key, value in demographic.items():
                                st.markdown(f"- **{key.replace('_', ' ').title()}**: {value}")

            with col2:
                            st.markdown("### πŸ“Š **Financial Status**")
                            for key, value in financial.items():
                                st.markdown(f"- **{key.replace('_', ' ').title()}**: {value}")

            with col3:
                            st.markdown("### βš™οΈ **Preferences & Goals**")
                            st.markdown("**User Preferences:**")
                            for pref in user_data.get("episodic", {}).get("prefrences", []):
                                st.markdown(f"- {pref.capitalize()}")
                            st.markdown("**Goals:**")
                            for goal in user_data.get("episodic", {}).get("goals", []):
                                for k, v in goal.items():
                                    st.markdown(f"- **{k.replace('_', ' ').title()}**: {v}")


       

        if "allocations" not in st.session_state:
            st.session_state.allocations = allocations

        if st.session_state.allocations:
            try:
                # allocations = json.load(StringIO(allocations_file.getvalue().decode("utf-8")))
                st.markdown("### πŸ’Ό Investment Allocations")

                # Flatten data for display
                records = []
                for asset_class, entries in st.session_state.allocations.items():
                    for item in entries:
                        records.append({
                            "Asset Class": asset_class.replace("_", " ").title(),
                            "Type": item.get("type", ""),
                            "Label": item.get("label", ""),
                            "Amount (β‚Ή)": item.get("amount", 0)
                        })

                df = pd.DataFrame(records)
                st.dataframe(df)

            except Exception as e:
                st.error(f"Failed to parse allocations.json: {e}")


        
        # Clear chat button
        if st.button("Clear Chat"):
            st.session_state.chat_history = [
                ("assistant", "πŸ‘‹ Hello! I'm your RAG assistant. Please upload your JSON files and ask me a question about your portfolio.")
            ]
            st.session_state.total_gpt_tokens = 0
            st.session_state.total_gemini_tokens = 0
            st.session_state.window_gpt_tokens = 0
            st.session_state.window_gemini_tokens = 0
            
            # Clear the chat history in the session database
            clear_chat_history(st.session_state.session_db_path)
       

            st.rerun()

    st.markdown("---")
    
    # Display chat history
    chat_container = st.container()
    with chat_container:
        for sender, message in st.session_state.chat_history:
            if sender == "user":
                st.chat_message("user").write(message)
            else:
                st.chat_message("assistant").write(message)
        
        # Show thinking animation if processing
        if st.session_state.processing:
            thinking_placeholder = st.empty()
            with st.chat_message("assistant"):
                for i in range(3):
                    for dots in [".", "..", "..."]:
                        thinking_placeholder.markdown(f"Thinking{dots}")
                        time.sleep(0.3)

    # Input box at the bottom
    user_input = st.chat_input("Type your question...")

    if user_input and not st.session_state.processing:
        # Set processing flag
        st.session_state.processing = True
        
        # Add user message to history immediately
        st.session_state.chat_history.append(("user", user_input))
        save_message(st.session_state.session_db_path, "user", user_input)  # Save user message to session DB
        
        # Force a rerun to show the message and thinking indicator
        st.rerun()

# This part runs after the rerun if we're processing
if st.session_state.processing:
    if not user_data or not allocations:
        st.session_state.chat_history.append(("assistant", "⚠️ Please upload both JSON files before asking questions."))
        st.session_state.processing = False
        st.rerun()
    else:
        try:
            # Load JSONs
            # user_data = json.load(StringIO(user_data_file.getvalue().decode("utf-8")))
            # allocations = json.load(StringIO(allocations_file.getvalue().decode("utf-8")))
            
            # Combined JSON data (for token calculation)
            combined_json_data = {"user_data": user_data, "allocations": allocations}

            # Get the last user message
            last_user_message = next((msg for sender, msg in reversed(st.session_state.chat_history) if sender == "user"), "")
            
            # Count tokens for this user message
            user_msg_gpt_tokens = count_tokens_gpt(last_user_message)
            user_msg_gemini_tokens = count_tokens_gemini(last_user_message)
            
            # Add to accumulated totals
            st.session_state.total_gpt_tokens += user_msg_gpt_tokens
            st.session_state.total_gemini_tokens += user_msg_gemini_tokens
            
            # Calculate context window usage (conversation + JSON data)
            window_gpt, window_gemini = calculate_context_window_usage(combined_json_data)
            st.session_state.window_gpt_tokens = window_gpt
            st.session_state.window_gemini_tokens = window_gemini

            # Check token limits for context window
            if window_gpt > GPT_LIMIT or window_gemini > GEMINI_LIMIT:
                st.session_state.chat_history.append(("assistant", "⚠️ Your conversation has exceeded token limits. Please clear the chat to continue."))
                st.session_state.processing = False
                st.rerun()
            else:
                # --- Call LangGraph ---
                inputs = {
                    "query": last_user_message,
                    "user_data": user_data,
                    "allocations": allocations,
                    #"data":"",
                    "chat_history": st.session_state.chat_history
                }
                print(st.session_state.chat_history)

                
                
                output = app.invoke(inputs, config = config)
                response = output.get('output')
                print(response)
                

                # Check if the response contains allocation updates
                if "allocations" in output:
                    st.session_state.allocations = output["allocations"]

                # Count tokens for the response
                response_gpt_tokens = count_tokens_gpt(response)
                response_gemini_tokens = count_tokens_gemini(response)
                
                # Add to accumulated totals
                st.session_state.total_gpt_tokens += response_gpt_tokens
                st.session_state.total_gemini_tokens += response_gemini_tokens

                # Add to chat history
                st.session_state.chat_history.append(("assistant", response))
                
                # Update context window calculations after adding response
                window_gpt, window_gemini = calculate_context_window_usage(combined_json_data)
                st.session_state.window_gpt_tokens = window_gpt
                st.session_state.window_gemini_tokens = window_gemini
                
        except Exception as e:
            tb = traceback.extract_stack()
            filename, line_number, function_name, text = tb[-2]
            error_message = f"❌ Error: {str(e)} in {filename} at line {line_number}, function: {function_name}"
            st.session_state.chat_history.append(("assistant", error_message))
        
        # Reset processing flag
        st.session_state.processing = False
        st.rerun()