teofizzy's picture
feat: Enhance agent robustness by separating base and fallback LLMs, update Ollama model to 7B, and improve Docker startup reliability.
de956af
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
import streamlit as st
import os
import tempfile
import pandas as pd
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OllamaEmbeddings
import pymupdf4llm
from langchain_core.documents import Document
from langchain_text_splitters import MarkdownTextSplitter
from sqlalchemy import create_engine, text
# --- PATH SETUP ---
current_dir = os.getcwd() # Should be /home/user/app in Docker
load_dir = os.path.join(current_dir, "src", "load")
sys.path.append(load_dir)
# Import your agent creator and configurations
try:
from mshauri_demo import create_mshauri_agent, DEFAULT_EMBED_MODEL, DEFAULT_OLLAMA_URL
except ImportError as e:
st.error(f"Critical Error: Could not import mshauri_demo. Paths checked: {sys.path}. Details: {e}")
st.stop()
# --- GLOBALS FOR DB PATHS ---
# SQLAlchemy requires a URI starting with sqlite:///
sql_path = f"sqlite:///{os.path.join(current_dir, 'mshauri_fedha_v6.db')}"
vector_path = os.path.join(current_dir, "mshauri_fedha_chroma_db")
# --- SESSION MANAGEMENT & CLEANUP ---
def init_session_state():
if "messages" not in st.session_state:
st.session_state.messages = []
if "temp_tables" not in st.session_state:
st.session_state.temp_tables = []
if "temp_doc_ids" not in st.session_state:
st.session_state.temp_doc_ids = []
if "uploaded_files" not in st.session_state:
st.session_state.uploaded_files = set()
def cleanup_ephemeral_data():
"""Drops temporary SQL tables and ChromaDB chunks if consent was not given."""
# 1. Clean up SQL Tables
if st.session_state.temp_tables:
engine = create_engine(sql_path)
with engine.connect() as conn:
for table in st.session_state.temp_tables:
try:
if not table.replace("_", "").isalnum():
print(f"Skipping invalid table name: {table}")
continue
conn.execute(text(f"DROP TABLE IF EXISTS \"{table}\""))
conn.commit()
except Exception as e:
print(f"SQL Cleanup error: {e}")
st.session_state.temp_tables = []
# 2. Clean up Vector DB Documents
if st.session_state.temp_doc_ids:
try:
embeddings = OllamaEmbeddings(model=DEFAULT_EMBED_MODEL, base_url=DEFAULT_OLLAMA_URL)
vectorstore = Chroma(persist_directory=vector_path, embedding_function=embeddings)
vectorstore.delete(ids=st.session_state.temp_doc_ids)
except Exception as e:
print(f"Vector Cleanup error: {e}")
st.session_state.temp_doc_ids = []
# --- FAST EXTRACTION PIPELINE ---
def process_uploaded_file(uploaded_file, consent):
"""Handles fast extraction based on file type and applies consent rules."""
file_name = uploaded_file.name
# Skip if already processed in this session
if file_name in st.session_state.uploaded_files:
return
st.session_state.uploaded_files.add(file_name)
try:
# ==========================================
# PATH 1: TABULAR DATA -> SQL DATABASE
# ==========================================
if file_name.endswith(('.csv', '.xlsx', '.xls')):
# Read the file based on extension
if file_name.endswith('.csv'):
df = pd.read_csv(uploaded_file)
else:
df = pd.read_excel(uploaded_file) # Requires 'openpyxl'
# Sanitize table name (e.g., "Q3 Budget.xlsx" -> "user_upload_q3_budget")
safe_name = file_name.rsplit('.', 1)[0].replace(" ", "_").lower()
safe_table_name = f"user_upload_{safe_name}"
engine = create_engine(sql_path)
df.to_sql(safe_table_name, con=engine, if_exists='replace', index=False)
if not consent:
st.session_state.temp_tables.append(safe_table_name)
st.sidebar.success(f"Spreadsheet loaded ephemerally! Agent can query `{safe_table_name}`.")
else:
st.sidebar.success(f"Spreadsheet saved persistently as `{safe_table_name}`.")
# ==========================================
# PATH 2: UNSTRUCTURED DATA -> VECTOR DB (PDF)
# ==========================================
elif file_name.endswith('.pdf'):
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
tmp.write(uploaded_file.getvalue())
tmp_path = tmp.name
# 1. Fast Markdown Extraction (Preserves Tables!)
md_text = pymupdf4llm.to_markdown(tmp_path)
# 2. Chunk the Markdown intelligently
splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = splitter.split_text(md_text)
# 3. Convert to LangChain Document objects
docs = [Document(page_content=chunk, metadata={"source": file_name}) for chunk in chunks]
# 4. Embed and Store
embeddings = OllamaEmbeddings(model=DEFAULT_EMBED_MODEL, base_url=DEFAULT_OLLAMA_URL)
vectorstore = Chroma(persist_directory=vector_path, embedding_function=embeddings)
doc_ids = vectorstore.add_documents(docs)
if not consent:
st.session_state.temp_doc_ids.extend(doc_ids)
st.sidebar.success("PDF loaded securely for this session only.")
else:
st.sidebar.success("PDF saved to active database.")
os.unlink(tmp_path)
# ==========================================
# PATH 3: DOCUMENT FILES -> VECTOR DB (TXT/DOCX)
# ==========================================
elif file_name.endswith(('.docx', '.txt', '.md')):
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file_name)[1]) as tmp:
tmp.write(uploaded_file.getvalue())
tmp_path = tmp.name
# Load document based on type
if file_name.endswith('.docx'):
loader = Docx2txtLoader(tmp_path)
else:
loader = TextLoader(tmp_path)
raw_docs = loader.load()
raw_text = "\n".join([doc.page_content for doc in raw_docs])
# Chunk the text
splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = splitter.split_text(raw_text)
# Convert to LangChain Document objects
docs = [Document(page_content=chunk, metadata={"source": file_name}) for chunk in chunks]
# Embed and Store
embeddings = OllamaEmbeddings(model=DEFAULT_EMBED_MODEL, base_url=DEFAULT_OLLAMA_URL)
vectorstore = Chroma(persist_directory=vector_path, embedding_function=embeddings)
doc_ids = vectorstore.add_documents(docs)
if not consent:
st.session_state.temp_doc_ids.extend(doc_ids)
st.sidebar.success(f"{os.path.splitext(file_name)[1].upper()} file loaded securely for this session only.")
else:
st.sidebar.success(f"{os.path.splitext(file_name)[1].upper()} file saved to active database.")
os.unlink(tmp_path)
else:
st.sidebar.error("Unsupported file type.")
except Exception as e:
st.sidebar.error(f"Error processing {file_name}: {e}")
# --- STREAMLIT UI CONFIGURATION ---
st.set_page_config(page_title="Mshauri Fedha", page_icon="🦁")
init_session_state()
st.title("🦁 Mshauri Fedha")
st.markdown("### AI Financial Advisor for Kenya")
# --- SIDEBAR: UPLOAD & CONSENT ---
st.sidebar.header("📁 Data Upload & Security")
st.sidebar.markdown("Upload your own financial reports or datasets (Max 10 files).")
st.sidebar.info("💡 **Tip:** For best results, upload raw data (like financial ledgers) as **CSV/Excel**. Upload narrative reports as **PDFs**.")
consent = st.sidebar.checkbox(
"I consent to securely storing this document in the active database (not persistent).",
value=False,
help="If unchecked, your data is treated as ephemeral. It will be deleted instantly when you clear the chat."
)
uploaded_files = st.sidebar.file_uploader(
"Upload PDF, CSV, XLSX, DOCX, TXT",
type=['pdf', 'csv', 'xlsx', 'xls', 'docx', 'txt', 'md'],
accept_multiple_files=True
)
if uploaded_files:
# Enforce the 10 file limit
if len(uploaded_files) > 10:
st.sidebar.error("⚠️ Please upload a maximum of 10 files at a time.")
else:
# Track initial file count to prevent infinite reruns
initial_file_count = len(st.session_state.uploaded_files)
with st.sidebar:
with st.spinner(f"Processing {len(uploaded_files)} file(s)..."):
# Loop through all uploaded files
for uploaded_file in uploaded_files:
process_uploaded_file(uploaded_file, consent)
# If new files were actually processed, delete the cached agent and refresh
if len(st.session_state.uploaded_files) > initial_file_count:
if "agent" in st.session_state:
del st.session_state["agent"]
st.rerun()
st.sidebar.markdown("---")
if st.sidebar.button("🗑️ Clear Chat & Ephemeral Data"):
cleanup_ephemeral_data()
st.session_state.messages = []
st.session_state.uploaded_files = set()
# Also delete the agent so it drops the ephemeral tables from its memory
if "agent" in st.session_state:
del st.session_state["agent"]
st.rerun()
# --- AGENT INITIALIZATION ---
if "agent" not in st.session_state:
with st.spinner("Initializing Mshauri Brain (Loading Models & Data)..."):
# Check if baseline data exists (Debugging for Space deployment)
real_db_path = os.path.join(current_dir, "mshauri_fedha_v6.db")
if not os.path.exists(real_db_path):
st.warning(f"Database not found at {real_db_path}. Using empty state.")
try:
st.session_state.agent = create_mshauri_agent(
sql_db_path=sql_path,
vector_db_path=vector_path
)
except Exception as e:
st.error(f"Failed to initialize agent: {e}")
# --- CHAT UI ---
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Ask about inflation, your uploaded data, or economic trends..."):
# Display the user's raw prompt in the UI
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# CONTEXT INJECTION: Secretly tell the agent about the uploaded files
augmented_prompt = prompt
if st.session_state.uploaded_files:
files_str = ", ".join(st.session_state.uploaded_files)
augmented_prompt = f"Context: The user has just uploaded the following files: {files_str}. Please prioritize searching these documents and querying their corresponding SQL tables (prefixed with 'user_upload_') to answer the following question.\n\nQuestion: {prompt}"
with st.chat_message("assistant"):
with st.spinner("Analyzing..."):
try:
if st.session_state.agent:
# Send the AUGMENTED prompt to the agent, not just the raw prompt
response = st.session_state.agent.invoke({"input": augmented_prompt})
output_text = response.get("output", "Error generating response.")
st.markdown(output_text)
st.session_state.messages.append({"role": "assistant", "content": output_text})
else:
st.error("Agent failed to initialize. Please refresh the page.")
except Exception as e:
st.error(f"An error occurred: {e}")