|
import os |
|
import streamlit as st |
|
import requests |
|
import pandas as pd |
|
from datetime import datetime |
|
from typing import Dict, List, Optional, Tuple |
|
from dataclasses import dataclass |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class UserConfig: |
|
id: int |
|
name: str |
|
email: str |
|
age: int |
|
gender: str |
|
user_type: str |
|
created_at: str |
|
updated_at: str |
|
|
|
@dataclass |
|
class ChatMessage: |
|
role: str |
|
content: str |
|
timestamp: Optional[datetime] = None |
|
|
|
class Config: |
|
API_URL = os.getenv("API_URL", "https://clinical-agents-333016757590.us-central1.run.app/api/v1") |
|
PAGE_TITLE = "AI Triage System" |
|
PAGE_ICON = "π₯" |
|
|
|
|
|
SIDEBAR_WIDTH = 300 |
|
CHAT_HEIGHT = 500 |
|
|
|
|
|
PRIMARY_COLOR = "#1f77b4" |
|
SUCCESS_COLOR = "#2ca02c" |
|
WARNING_COLOR = "#ff7f0e" |
|
ERROR_COLOR = "#d62728" |
|
|
|
|
|
|
|
|
|
|
|
class SessionStateManager: |
|
@staticmethod |
|
def init_state(): |
|
"""Initialize session state with default values""" |
|
defaults = { |
|
"user_type": "patient", |
|
"auth_done": False, |
|
"user_id": None, |
|
"user_data": None, |
|
"messages": [], |
|
"notes": [], |
|
"chat_active": False, |
|
"finished": False, |
|
"current_assessment_id": None, |
|
"show_help": False |
|
} |
|
|
|
for key, val in defaults.items(): |
|
if key not in st.session_state: |
|
st.session_state[key] = val |
|
|
|
@staticmethod |
|
def get_user_config() -> Optional[UserConfig]: |
|
"""Get current user configuration""" |
|
if st.session_state.auth_done and st.session_state.user_data: |
|
data = st.session_state.user_data |
|
return UserConfig( |
|
id=data["id"], |
|
name=data["name"], |
|
email=data["email"], |
|
age=data["age"], |
|
gender=data["gender"], |
|
user_type=data["user_type"], |
|
created_at=data["created_at"], |
|
updated_at=data["updated_at"] |
|
) |
|
return None |
|
|
|
@staticmethod |
|
def reset_chat(): |
|
"""Reset chat state""" |
|
st.session_state.messages = [] |
|
st.session_state.chat_active = False |
|
st.session_state.finished = False |
|
st.session_state.current_assessment_id = None |
|
|
|
|
|
|
|
|
|
|
|
class APIService: |
|
@staticmethod |
|
def login_user(name: str, email: str, age: int, gender: str, user_type: str) -> Tuple[bool, Dict]: |
|
"""Login user with the new API structure""" |
|
try: |
|
payload = { |
|
"name": name, |
|
"email": email, |
|
"age": age, |
|
"gender": gender, |
|
"user_type": user_type |
|
} |
|
|
|
resp = requests.post( |
|
f"{Config.API_URL}/users/login", |
|
json=payload, |
|
timeout=10 |
|
) |
|
resp.raise_for_status() |
|
return True, resp.json() |
|
except Exception as e: |
|
return False, {"error": str(e)} |
|
|
|
@staticmethod |
|
def get_user_by_id(user_id: int) -> Tuple[bool, Dict]: |
|
"""Get user information by ID""" |
|
try: |
|
resp = requests.get(f"{Config.API_URL}/users/{user_id}", timeout=10) |
|
resp.raise_for_status() |
|
return True, resp.json() |
|
except Exception as e: |
|
return False, {"error": str(e)} |
|
|
|
@staticmethod |
|
def fetch_assessments() -> List[Dict]: |
|
"""Fetch all assessments from API""" |
|
try: |
|
resp = requests.get(f"{Config.API_URL}/assessments", timeout=10) |
|
resp.raise_for_status() |
|
return resp.json() |
|
except Exception as e: |
|
st.error(f"Failed to fetch assessments: {str(e)}") |
|
return [] |
|
|
|
@staticmethod |
|
def send_chat_message(message: str, history: List[Dict], patient_id: int) -> Tuple[bool, Dict]: |
|
"""Send chat message to triage API""" |
|
try: |
|
payload = { |
|
"message": message, |
|
"history": history, |
|
"patient_id": patient_id |
|
} |
|
|
|
resp = requests.post( |
|
f"{Config.API_URL}/triage/chat", |
|
json=payload, |
|
timeout=30 |
|
) |
|
resp.raise_for_status() |
|
return True, resp.json() |
|
except Exception as e: |
|
return False, {"error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
class UIComponents: |
|
@staticmethod |
|
def render_header(): |
|
"""Render main header with branding""" |
|
st.markdown(""" |
|
<div style="text-align: center; padding: 1rem 0; background: linear-gradient(90deg, #1f77b4, #2ca02c); |
|
border-radius: 10px; margin-bottom: 2rem;"> |
|
<h1 style="color: white; margin: 0; font-size: 2.5rem;">π₯ AI Triage System</h1> |
|
<p style="color: #f0f0f0; margin: 0.5rem 0 0 0; font-size: 1.1rem;"> |
|
Intelligent healthcare assessment at your fingertips |
|
</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
@staticmethod |
|
def render_sidebar_auth() -> bool: |
|
"""Render sidebar authentication section""" |
|
with st.sidebar: |
|
st.markdown("### π€ Authentication") |
|
|
|
|
|
user_type = st.selectbox( |
|
"Select your role:", |
|
["patient", "staff"], |
|
key="user_type_select", |
|
help="Choose whether you're a patient seeking assessment or medical staff" |
|
) |
|
st.session_state.user_type = user_type |
|
|
|
|
|
with st.form("signin_form"): |
|
st.markdown("**Sign In / Register**") |
|
name = st.text_input("Full Name", placeholder="Enter your full name") |
|
email = st.text_input("Email", placeholder="Enter your email address") |
|
|
|
|
|
if user_type == "patient": |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
age = st.number_input("Age", min_value=1, max_value=120, value=25) |
|
with col2: |
|
gender = st.selectbox("Gender", ["male", "female", "other"]) |
|
else: |
|
age = 30 |
|
gender = "not_specified" |
|
|
|
submitted = st.form_submit_button("π Sign In", use_container_width=True) |
|
|
|
if submitted: |
|
if name.strip() and email.strip(): |
|
with st.spinner("π Signing in..."): |
|
success, response = APIService.login_user( |
|
name.strip(), |
|
email.strip(), |
|
age, |
|
gender, |
|
user_type |
|
) |
|
|
|
if success: |
|
st.session_state.user_data = response |
|
st.session_state.user_id = response["id"] |
|
st.session_state.auth_done = True |
|
st.success("β
Successfully signed in!") |
|
st.rerun() |
|
else: |
|
st.error(f"β Login failed: {response.get('error', 'Unknown error')}") |
|
else: |
|
st.error("β Please enter both name and email") |
|
|
|
|
|
if st.session_state.auth_done and st.session_state.user_data: |
|
st.markdown("---") |
|
st.markdown("**Current User:**") |
|
user_data = st.session_state.user_data |
|
st.info(f""" |
|
π€ **{user_data['name']}** |
|
π§ {user_data['email']} |
|
π·οΈ {user_data['user_type'].title()} |
|
π ID: {user_data['id']} |
|
""") |
|
|
|
if st.button("πͺ Sign Out", use_container_width=True): |
|
for key in list(st.session_state.keys()): |
|
del st.session_state[key] |
|
st.rerun() |
|
|
|
return st.session_state.auth_done |
|
|
|
@staticmethod |
|
def render_chat_interface(user_config: UserConfig): |
|
"""Render patient chat interface""" |
|
st.markdown("### π¬ Chat Assessment") |
|
|
|
|
|
chat_container = st.container(height=Config.CHAT_HEIGHT) |
|
|
|
with chat_container: |
|
|
|
for i, msg in enumerate(st.session_state.messages): |
|
with st.chat_message(msg["role"]): |
|
st.markdown(msg["content"]) |
|
|
|
|
|
col1, col2, col3 = st.columns([2, 1, 1]) |
|
|
|
with col1: |
|
if not st.session_state.chat_active: |
|
if st.button("π Start New Assessment", use_container_width=True, type="primary"): |
|
UIComponents._start_new_assessment(user_config) |
|
|
|
with col2: |
|
if st.session_state.chat_active and st.button("π Reset Chat", use_container_width=True): |
|
SessionStateManager.reset_chat() |
|
st.rerun() |
|
|
|
with col3: |
|
if st.button("β Help", use_container_width=True): |
|
st.session_state.show_help = not st.session_state.get("show_help", False) |
|
|
|
|
|
if st.session_state.get("show_help", False): |
|
st.markdown(""" |
|
--- |
|
**βΉοΈ How to use the AI Triage System:** |
|
|
|
1. **Start Assessment**: Click 'Start New Assessment' to begin |
|
2. **Describe Symptoms**: Be detailed about your symptoms, when they started, and their severity |
|
3. **Answer Questions**: The AI will ask follow-up questions to better understand your condition |
|
4. **Get Results**: Receive your triage level and recommended next steps |
|
|
|
**Tips for better results:** |
|
- Be honest and specific about your symptoms |
|
- Include timeline information (when symptoms started) |
|
- Mention any relevant medical history |
|
- Don't hesitate to ask for clarification |
|
""") |
|
|
|
|
|
if st.session_state.chat_active and not st.session_state.finished: |
|
if user_input := st.chat_input("π Describe your symptoms or ask a question..."): |
|
UIComponents._handle_user_message(user_input, user_config) |
|
|
|
@staticmethod |
|
def _start_new_assessment(user_config: UserConfig): |
|
"""Start a new triage assessment""" |
|
st.session_state.chat_active = True |
|
st.session_state.messages = [] |
|
st.session_state.finished = False |
|
|
|
with st.spinner("π Starting your assessment..."): |
|
success, response = APIService.send_chat_message("", [], user_config.id) |
|
|
|
if success: |
|
st.session_state.messages.append({ |
|
"role": "assistant", |
|
"content": response["response"] |
|
}) |
|
st.rerun() |
|
else: |
|
st.error(f"β Failed to start assessment: {response.get('error', 'Unknown error')}") |
|
st.session_state.chat_active = False |
|
|
|
@staticmethod |
|
def _handle_user_message(user_input: str, user_config: UserConfig): |
|
"""Handle user message in chat""" |
|
|
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
|
|
with st.spinner("π€ AI is analyzing your response..."): |
|
|
|
history = [{"role": msg["role"], "content": msg["content"]} |
|
for msg in st.session_state.messages] |
|
|
|
success, response = APIService.send_chat_message(user_input, history, user_config.id) |
|
|
|
if success: |
|
st.session_state.messages.append({ |
|
"role": "assistant", |
|
"content": response["response"] |
|
}) |
|
|
|
if response.get("finished", False): |
|
st.session_state.chat_active = False |
|
st.session_state.finished = True |
|
st.success("β
Assessment completed successfully!") |
|
st.balloons() |
|
|
|
st.rerun() |
|
else: |
|
st.error(f"β Error: {response.get('error', 'Unknown error')}") |
|
|
|
class StaffDashboard: |
|
@staticmethod |
|
def render_dashboard(user_config: UserConfig): |
|
"""Render staff dashboard""" |
|
st.markdown("### π Staff Dashboard") |
|
|
|
|
|
with st.spinner("π₯ Loading assessment data..."): |
|
assessments = APIService.fetch_assessments() |
|
|
|
if not assessments: |
|
st.info("π No assessments available yet.") |
|
return |
|
|
|
df = pd.DataFrame(assessments) |
|
df["created_at"] = pd.to_datetime(df["created_at"]) |
|
|
|
|
|
StaffDashboard._render_metrics(df) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
StaffDashboard._render_esi_distribution(df) |
|
with col2: |
|
StaffDashboard._render_timeline_chart(df) |
|
|
|
|
|
StaffDashboard._render_assessments_table(df) |
|
|
|
@staticmethod |
|
def _render_metrics(df: pd.DataFrame): |
|
"""Render key metrics""" |
|
col1, col2, col3, col4 = st.columns(4) |
|
|
|
with col1: |
|
st.metric( |
|
"π Total Assessments", |
|
len(df), |
|
delta=None |
|
) |
|
|
|
with col2: |
|
avg_esi = df['esi_level'].mean() |
|
st.metric( |
|
"β‘ Avg ESI Level", |
|
f"{avg_esi:.1f}", |
|
delta=None |
|
) |
|
|
|
with col3: |
|
emergency_cases = len(df[df['esi_level'] <= 2]) |
|
st.metric( |
|
"π¨ Emergency Cases", |
|
emergency_cases, |
|
delta=f"{(emergency_cases/len(df)*100):.1f}% of total" |
|
) |
|
|
|
with col4: |
|
latest = df['created_at'].max() |
|
hours_ago = (datetime.now() - latest.replace(tzinfo=None)).total_seconds() / 3600 |
|
st.metric( |
|
"π Last Assessment", |
|
f"{hours_ago:.1f}h ago", |
|
delta=None |
|
) |
|
|
|
@staticmethod |
|
def _render_esi_distribution(df: pd.DataFrame): |
|
"""Render ESI level distribution chart""" |
|
st.markdown("**π― ESI Level Distribution**") |
|
|
|
esi_counts = df['esi_level'].value_counts().sort_index() |
|
|
|
colors = ['#d62728', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a'] |
|
|
|
fig = px.bar( |
|
x=esi_counts.index, |
|
y=esi_counts.values, |
|
color=esi_counts.index, |
|
color_continuous_scale='RdYlGn_r', |
|
title="Distribution by ESI Level" |
|
) |
|
|
|
fig.update_layout( |
|
xaxis_title="ESI Level", |
|
yaxis_title="Count", |
|
showlegend=False, |
|
height=300 |
|
) |
|
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
@staticmethod |
|
def _render_timeline_chart(df: pd.DataFrame): |
|
"""Render assessments timeline""" |
|
st.markdown("**π Assessment Timeline**") |
|
|
|
|
|
df_daily = df.groupby(df['created_at'].dt.date).size().reset_index() |
|
df_daily.columns = ['date', 'count'] |
|
|
|
fig = px.line( |
|
df_daily, |
|
x='date', |
|
y='count', |
|
title="Daily Assessment Volume", |
|
markers=True |
|
) |
|
|
|
fig.update_layout( |
|
xaxis_title="Date", |
|
yaxis_title="Number of Assessments", |
|
height=300 |
|
) |
|
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
@staticmethod |
|
def _render_assessments_table(df: pd.DataFrame): |
|
"""Render assessments data table""" |
|
st.markdown("**π Recent Assessments**") |
|
|
|
|
|
display_df = df.copy() |
|
display_df['created_at'] = pd.to_datetime(display_df['created_at']).dt.strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
|
|
if not display_df.empty: |
|
|
|
user_info_cache = {} |
|
user_names = [] |
|
user_emails = [] |
|
|
|
for user_id in display_df['user_id']: |
|
if user_id not in user_info_cache: |
|
success, user_data = APIService.get_user_by_id(user_id) |
|
if success: |
|
user_info_cache[user_id] = user_data |
|
else: |
|
user_info_cache[user_id] = {"name": "Unknown", "email": "Unknown"} |
|
|
|
user_info = user_info_cache[user_id] |
|
user_names.append(user_info.get("name", "Unknown")) |
|
user_emails.append(user_info.get("email", "Unknown")) |
|
|
|
display_df['patient_name'] = user_names |
|
display_df['patient_email'] = user_emails |
|
|
|
|
|
columns = ["id", "patient_name", "patient_email", "esi_level", "diagnosis", "notes", "created_at"] |
|
|
|
|
|
display_df = display_df.sort_values('created_at', ascending=False) |
|
|
|
st.dataframe( |
|
display_df[columns], |
|
use_container_width=True, |
|
hide_index=True, |
|
column_config={ |
|
"id": "Assessment ID", |
|
"patient_name": "Patient Name", |
|
"patient_email": "Patient Email", |
|
"esi_level": st.column_config.NumberColumn( |
|
"ESI Level", |
|
help="Emergency Severity Index (1=Most urgent, 5=Least urgent)", |
|
min_value=1, |
|
max_value=5, |
|
format="%d" |
|
), |
|
"diagnosis": "Diagnosis", |
|
"notes": "Notes", |
|
"created_at": "Created At" |
|
} |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
"""Main application entry point""" |
|
|
|
st.set_page_config( |
|
page_title=Config.PAGE_TITLE, |
|
layout="wide", |
|
page_icon=Config.PAGE_ICON, |
|
initial_sidebar_state="expanded" |
|
) |
|
|
|
|
|
SessionStateManager.init_state() |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.main > div { |
|
padding-top: 2rem; |
|
} |
|
.stChatMessage { |
|
padding: 1rem; |
|
border-radius: 10px; |
|
margin-bottom: 1rem; |
|
} |
|
.stButton > button { |
|
border-radius: 20px; |
|
border: none; |
|
font-weight: 600; |
|
} |
|
.metric-container { |
|
background: #f8f9fa; |
|
padding: 1rem; |
|
border-radius: 10px; |
|
margin-bottom: 1rem; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
UIComponents.render_header() |
|
|
|
|
|
if not UIComponents.render_sidebar_auth(): |
|
st.markdown(""" |
|
<div style="text-align: center; padding: 3rem; color: #666;"> |
|
<h3>π Welcome to AI Triage System</h3> |
|
<p>Please sign in using the sidebar to get started with your health assessment.</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
st.stop() |
|
|
|
|
|
user_config = SessionStateManager.get_user_config() |
|
|
|
|
|
if user_config.user_type == "patient": |
|
UIComponents.render_chat_interface(user_config) |
|
else: |
|
StaffDashboard.render_dashboard(user_config) |
|
|
|
if __name__ == "__main__": |
|
main() |