chat / src /session_state.py
Dhruv-Ty's picture
Update src/session_state.py
c1d31db verified
"""
This module handles session state initialization and management.
"""
import streamlit as st
import uuid
from database import get_db_client
def initialize_session_state():
"""
Initialize all session state variables needed for the application.
"""
# Database client
if 'db_client' not in st.session_state:
try:
st.session_state.db_client = get_db_client()
except Exception as e:
st.error(f"Failed to initialize database: {str(e)}")
st.session_state.db_client = None
# Unique consultation ID for tracking sessions
if 'consultation_id' not in st.session_state:
st.session_state.consultation_id = str(uuid.uuid4())[:8]
# Create new conversation in database
if st.session_state.db_client:
try:
st.session_state.db_client.create_conversation(st.session_state.consultation_id)
except Exception as e:
st.error(f"Failed to create conversation in database: {str(e)}")
# Main conversation history (now acts as a cache)
if 'history' not in st.session_state:
st.session_state.history = []
# Try to load from database if consultation_id exists
if hasattr(st.session_state, 'consultation_id') and st.session_state.db_client:
try:
db_history = st.session_state.db_client.get_conversation_history(st.session_state.consultation_id)
st.session_state.history = db_history
except Exception as e:
# Silently continue with empty history if database fetch fails
pass
# RAG feature toggle
if 'use_rag' not in st.session_state:
st.session_state.use_rag = True
# Processing state for showing typing indicator
if 'processing' not in st.session_state:
st.session_state.processing = False
# Report generation state
if 'show_report_form' not in st.session_state:
st.session_state.show_report_form = False
if 'report_step' not in st.session_state:
st.session_state.report_step = 0
# Patient information for reports
if 'patient_info' not in st.session_state:
st.session_state.patient_info = {"name": "", "age": "", "gender": ""}
# PDF report data
if 'pdf_data' not in st.session_state:
st.session_state.pdf_data = None
# Email form visibility
if 'show_email_form' not in st.session_state:
st.session_state.show_email_form = False
def add_message_to_history(message):
"""
Add a message to the conversation history and persist it to database.
Args:
message (dict): The message to add to history
"""
# Add to local history
st.session_state.history.append(message)
# Persist to database if available
if hasattr(st.session_state, 'db_client') and st.session_state.db_client:
try:
st.session_state.db_client.save_message(
st.session_state.consultation_id,
message
)
except Exception as e:
st.error(f"Failed to save message to database: {str(e)}")
def get_full_history():
"""
Get the complete conversation history from the database.
Returns:
list: Full conversation history
"""
if not hasattr(st.session_state, 'consultation_id'):
print("No consultation_id in session state")
return []
if not hasattr(st.session_state, 'db_client') or not st.session_state.db_client:
print("No database client available")
return st.session_state.history if hasattr(st.session_state, 'history') else []
try:
db_history = st.session_state.db_client.get_conversation_history(
st.session_state.consultation_id
)
# Update the local cache with the database results
st.session_state.history = db_history
# For debugging
print(f"Retrieved {len(db_history)} messages from database for consultation {st.session_state.consultation_id}")
return db_history
except Exception as e:
print(f"Failed to retrieve history from database: {str(e)}")
# Fallback to session state history if database retrieval fails
return st.session_state.history if hasattr(st.session_state, 'history') else []
def end_conversation():
"""
End the current conversation and clean up resources.
"""
if hasattr(st.session_state, 'db_client') and st.session_state.db_client:
try:
st.session_state.db_client.delete_conversation(
st.session_state.consultation_id
)
except Exception as e:
st.error(f"Failed to delete conversation from database: {str(e)}")
# Reset session state
st.session_state.history = []
st.session_state.consultation_id = str(uuid.uuid4())[:8]
# Create new conversation in database
if hasattr(st.session_state, 'db_client') and st.session_state.db_client:
try:
st.session_state.db_client.create_conversation(st.session_state.consultation_id)
except Exception as e:
st.error(f"Failed to create new conversation in database: {str(e)}")