| |
| import sqlite3 |
| import os |
| import streamlit as st |
| import chromadb |
| from typing import Dict, Optional, Any |
| from pathlib import Path |
| from dotenv import load_dotenv |
| from llama_index.core import VectorStoreIndex, StorageContext, Settings |
| from llama_index.vector_stores.chroma import ChromaVectorStore |
| from llama_index.llms.groq import Groq |
| from llama_index.embeddings.cohere import CohereEmbedding |
|
|
| |
| load_dotenv() |
|
|
| |
| os.environ["ANONYMIZED_TELEMETRY"] = "False" |
|
|
| |
| try: |
| from arize.otel import register |
| from openinference.instrumentation.llama_index import LlamaIndexInstrumentor |
| |
| if os.getenv("ARIZE_SPACE_ID") and os.getenv("ARIZE_API_KEY"): |
| tracer_provider = register( |
| space_id=os.getenv("ARIZE_SPACE_ID"), |
| api_key=os.getenv("ARIZE_API_KEY"), |
| project_name="rbacrag" |
| ) |
| LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider) |
| else: |
| print("Arize credentials not found, skipping instrumentation") |
| except Exception as e: |
| print(f"Warning: Arize instrumentation failed: {e}") |
|
|
| |
| from database import db, initialize_users |
|
|
| |
| try: |
| success_count, error_count = initialize_users() |
| if error_count > 0: |
| print(f"Database initialization completed with {error_count} errors (likely users already exist)") |
| else: |
| print(f"Database initialization successful: {success_count} users ready") |
| except Exception as e: |
| print(f"Error during user initialization: {e}") |
|
|
| |
| ROLE_ACCESS = { |
| "hr": ["hr", "general"], |
| "engineering": ["engineering", "general"], |
| "finance": ["finance", "general"], |
| "marketing": ["marketing", "general"] |
| } |
|
|
| def initialize_session_state(): |
| """Initialize or reset the session state""" |
| if "authenticated" not in st.session_state: |
| st.session_state.authenticated = False |
| if "username" not in st.session_state: |
| st.session_state.username = None |
| if "role" not in st.session_state: |
| st.session_state.role = None |
| if "messages" not in st.session_state: |
| st.session_state.messages = [] |
| if "vector_index" not in st.session_state: |
| st.session_state.vector_index = None |
| if "query_engine" not in st.session_state: |
| st.session_state.query_engine = None |
|
|
| |
| st.set_page_config( |
| page_title="Departmental RAG System", |
| page_icon="π", |
| layout="centered", |
| initial_sidebar_state="collapsed" |
| ) |
|
|
| |
| initialize_session_state() |
|
|
| def login(username: str, password: str) -> bool: |
| """ |
| Authenticate user and set session state |
| |
| Args: |
| username: The username to authenticate |
| password: The password to verify |
| |
| Returns: |
| bool: True if authentication was successful, False otherwise |
| """ |
| try: |
| user = db.verify_user(username, password) |
| if user: |
| st.session_state.authenticated = True |
| st.session_state.username = user["username"] |
| st.session_state.role = user["role"] |
| st.session_state.messages = [ |
| {"role": "assistant", "content": f"Welcome, {user['username']}! How can I assist you today?"} |
| ] |
| st.rerun() |
| return True |
| return False |
| except Exception as e: |
| st.error(f"An error occurred during login: {str(e)}") |
| return False |
|
|
| def logout(): |
| """Log out the current user and clear session state""" |
| username = st.session_state.get('username', 'Unknown') |
| st.session_state.clear() |
| initialize_session_state() |
| st.success(f"Successfully logged out {username}") |
| st.rerun() |
|
|
| @st.cache_resource |
| def load_vector_index(role: str): |
| """Load the ChromaDB index for the user's role with enhanced error handling""" |
| try: |
| |
| cohere_api_key = os.getenv("COHERE_API_KEY") |
| if not cohere_api_key: |
| st.error("β COHERE_API_KEY not found in environment variables") |
| st.info("Please set your Cohere API key in the .env file") |
| st.stop() |
| |
| embed_model = CohereEmbedding( |
| cohere_api_key=cohere_api_key, |
| model_name="embed-english-v3.0", |
| input_type="search_document" |
| ) |
| Settings.embed_model = embed_model |
| |
| |
| persist_dir = f"./chroma_db/{role}" |
| |
| |
| Path(persist_dir).mkdir(parents=True, exist_ok=True) |
| |
| |
| try: |
| chroma_client = chromadb.PersistentClient( |
| path=persist_dir, |
| settings=chromadb.Settings( |
| anonymized_telemetry=False, |
| allow_reset=True |
| ) |
| ) |
| except Exception as e: |
| st.warning(f"Failed to connect to persistent ChromaDB: {e}") |
| st.info("Attempting to create new collection...") |
| |
| |
| try: |
| chroma_client = chromadb.PersistentClient(path=persist_dir) |
| chroma_client.reset() |
| chroma_client = chromadb.PersistentClient( |
| path=persist_dir, |
| settings=chromadb.Settings( |
| anonymized_telemetry=False, |
| allow_reset=True |
| ) |
| ) |
| except: |
| |
| st.warning("β οΈ Using in-memory ChromaDB (data will not persist)") |
| chroma_client = chromadb.Client( |
| settings=chromadb.Settings(anonymized_telemetry=False) |
| ) |
| |
| |
| collection_name = "documents" |
| try: |
| chroma_collection = chroma_client.get_collection(collection_name) |
| st.success(f"β
Connected to existing collection for {role} role") |
| except Exception: |
| st.warning(f"β οΈ Collection '{collection_name}' not found for role '{role}'. Creating empty collection.") |
| try: |
| chroma_collection = chroma_client.create_collection( |
| name=collection_name, |
| metadata={"hnsw:space": "cosine"} |
| ) |
| st.info("π Created new empty collection. You may need to add documents first.") |
| except Exception as create_error: |
| st.error(f"β Failed to create collection: {create_error}") |
| st.stop() |
| |
| |
| vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
| |
| |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) |
| |
| |
| if chroma_collection.count() == 0: |
| st.warning(f"π No documents found in {role} collection.") |
| st.info("The system will work, but responses will be limited without documents.") |
| |
| index = VectorStoreIndex([], storage_context=storage_context, embed_model=embed_model) |
| else: |
| st.info(f"π Found {chroma_collection.count()} documents in {role} collection") |
| |
| index = VectorStoreIndex.from_vector_store( |
| vector_store=vector_store, |
| storage_context=storage_context, |
| embed_model=embed_model |
| ) |
| |
| return index |
| |
| except Exception as e: |
| st.error(f"β Error loading vector index: {str(e)}") |
| st.info("**Possible solutions:**") |
| st.info("1. Check that ChromaDB collections exist for this role") |
| st.info("2. Verify database files are properly mounted in Docker") |
| st.info("3. Check permissions on the database directory") |
| st.info("4. Ensure COHERE_API_KEY is set correctly") |
| st.stop() |
|
|
| def chat_interface(): |
| """Main chat interface""" |
| |
| st.markdown(f"<h2 style='color: #1407fa;'>π¬ {st.session_state.role.capitalize()} Department Chat</h2>", unsafe_allow_html=True) |
| |
| |
| for message in st.session_state.messages: |
| with st.chat_message(message["role"]): |
| st.markdown(message["content"]) |
|
|
| |
| index = load_vector_index(st.session_state.role) |
| |
| |
| try: |
| groq_api_key = os.getenv("GROQ_API_KEY") |
| if not groq_api_key: |
| st.error("β GROQ_API_KEY not found in environment variables") |
| st.info("Please set your Groq API key in the .env file") |
| st.stop() |
| |
| llm = Groq( |
| model="llama3-8b-8192", |
| api_key=groq_api_key, |
| temperature=0.5, |
| system_prompt=f"You are a helpful assistant specialized in {st.session_state.role} department documents. Answer the user queries with the help of the provided context with high accuracy and precision." |
| ) |
| |
| |
| query_engine = index.as_query_engine( |
| llm=llm, |
| similarity_top_k=3, |
| response_mode="compact" |
| ) |
| except Exception as e: |
| st.error(f"β Error initializing LLM: {str(e)}") |
| st.warning("β οΈ Falling back to default LLM settings. Some features may be limited.") |
| query_engine = index.as_query_engine( |
| similarity_top_k=3, |
| response_mode="compact" |
| ) |
| |
| |
| if prompt := st.chat_input(f"Ask about {st.session_state.role} documents..."): |
| |
| st.session_state.messages.append({"role": "user", "content": prompt}) |
| |
| |
| with st.chat_message("user"): |
| st.markdown(prompt) |
| |
| |
| with st.chat_message("assistant"): |
| message_placeholder = st.empty() |
| full_response = "" |
| |
| try: |
| |
| response = query_engine.query(prompt) |
| full_response = str(response) |
| message_placeholder.markdown(full_response) |
| except Exception as e: |
| error_msg = f"β Error generating response: {str(e)}" |
| message_placeholder.error(error_msg) |
| full_response = error_msg |
| |
| |
| st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
| def show_login_form(): |
| """Display the beautiful login form""" |
| st.markdown( |
| """ |
| <style> |
| .main { |
| background-color: #1a1a2e; |
| color: white; |
| } |
| .stTextInput > div > div > input { |
| background-color: #2a2a3e; |
| color: white; |
| border: 1px solid #4a4a6a; |
| border-radius: 8px; |
| } |
| .stTextInput > div > div > input::placeholder { |
| color: #a0a0b0 !important; |
| opacity: 1 !important; |
| } |
| .stButton > button { |
| background-color: #e94560; |
| color: white; |
| border: none; |
| border-radius: 8px; |
| padding: 10px 20px; |
| font-size: 16px; |
| width: 100%; |
| } |
| .stButton > button:hover { |
| background-color: #d83450; |
| } |
| h1, h2, h3, h4, h5, h6 { |
| color: white; |
| } |
| .st-emotion-cache-1r6slb0 { |
| border: 1px solid #4a4a6a; |
| border-radius: 12px; |
| padding: 2rem; |
| background-color: #232339; |
| } |
| </style> |
| """, |
| unsafe_allow_html=True |
| ) |
| st.markdown('<div style="text-align: center; margin-top: -80px; margin-bottom: 30px;"><h1 style="font-size: 3rem;">π</h1></div>', unsafe_allow_html=True) |
| st.markdown('<h1 style="text-align: center; margin-bottom: 20px;">Department Portal</h1>', unsafe_allow_html=True) |
| st.markdown('<p style="text-align: center; color: #a0a0b0; margin-bottom: 30px;">Sign in to access your department\'s knowledge base</p>', unsafe_allow_html=True) |
|
|
| with st.container(): |
| with st.form("login_form", border=True): |
| username = st.text_input("Username", placeholder="Enter your username") |
| password = st.text_input("Password", type="password", placeholder="Enter your password") |
| login_button = st.form_submit_button("Sign In") |
| |
| if login_button: |
| if not username or not password: |
| st.error("Please enter both username and password") |
| elif login(username, password): |
| st.success(f"Welcome, {username}! Redirecting...") |
| else: |
| st.error("Invalid username or password") |
|
|
| with st.expander("Need demo credentials?"): |
| st.markdown(""" |
| - **Engineering:** `Tony` / `password123` |
| - **Marketing:** `Bruce` / `securepass` |
| - **Finance:** `Sam` / `financepass` |
| - **HR:** `Natasha` / `hrpass123` |
| """) |
|
|
| st.markdown('<p style="text-align: center; margin-top: 2rem; color: #a0a0b0;">2025 Department RAG System</p>', unsafe_allow_html=True) |
|
|
| def main(): |
| """ |
| Main application entry point |
| Handles routing between login and main application |
| """ |
| |
| if st.session_state.authenticated: |
| with st.sidebar: |
| st.markdown(f"### Welcome, {st.session_state.username}") |
| st.markdown(f"**Role:** {st.session_state.role.capitalize()}") |
| |
| if st.button("Logout", key="logout_btn"): |
| logout() |
| return |
| |
| st.markdown("---") |
| st.markdown("### About") |
| st.markdown(""" |
| This is a secure departmental RAG system that provides |
| role-based access to information across different departments. |
| """) |
| |
| |
| try: |
| users = db.list_users() |
| st.markdown("---") |
| st.markdown("### System Status") |
| st.markdown(f"β
Database: {len(users)} users") |
| st.markdown("β
Authentication: Active") |
| except: |
| st.markdown("β οΈ Database: Connection issues") |
| |
| |
| if not st.session_state.authenticated: |
| show_login_form() |
| else: |
| chat_interface() |
|
|
| if __name__ == "__main__": |
| main() |