mostafa-atef21's picture
adds deployment md
8d48081
import streamlit as st
import pandas as pd
import os
import sys
from typing import List, Dict
# Add parent directory to path for utils imports
current_file = __file__
parent_dir = os.path.dirname(current_file)
project_root = os.path.dirname(parent_dir)
sys.path.insert(0, project_root)
from utils.embeddings import get_chroma
from utils.rules import apply_rules
st.set_page_config(page_title="Epilepsy AI Agent", page_icon="🧠", layout="wide")
st.title("🧠 Epilepsy Management AI Agent")
st.caption("Personalized, evidence-informed recommendations for epilepsy management devices and techniques.")
# Load catalog - use project root for data paths
CATALOG_PATH = os.path.join(project_root, "data", "devices.csv")
df = pd.read_csv(CATALOG_PATH)
# Build/attach vector index on first run
chroma_dir = os.path.join(project_root, ".chroma")
client, col = get_chroma(collection_name="devices", persist_dir=chroma_dir)
def ensure_index(df: pd.DataFrame):
existing = set(col.get(ids=None)["ids"]) if col.count() else set()
new_rows = []
for _, row in df.iterrows():
rid = str(row["device_id"])
if rid not in existing:
new_rows.append(row)
if new_rows:
ids = [str(r["device_id"]) for r in new_rows]
docs = [f"{r['name']} | {r['category']} | {r['indication']} | {r['invasiveness']} | {r['approvals']} | {r['summary']}" for r in new_rows]
metadatas = [r.to_dict() for r in new_rows]
col.add(ids=ids, documents=docs, metadatas=metadatas)
ensure_index(df)
with st.sidebar:
st.header("Filters & Safety Rules")
avoid_invasive = st.toggle("Avoid invasive devices", value=True)
req_apps = st.text_input("Require approvals (comma separated)", value="")
top_k = st.slider("Top K", 1, 10, 3)
st.markdown("---")
st.subheader("Catalog Snapshot")
st.write(df[["device_id","name","category","invasiveness","approvals"]])
st.subheader("Patient Input")
patient_text = st.text_area(
"Describe the patient (seizure type, frequency, age group, prior treatments, lifestyle).", height=140, placeholder="Example: 16-year-old with focal aware seizures 2–3 times/week, mostly nocturnal; wants non-invasive detection and caregiver alerts."
)
if st.button("Generate Recommendations", type="primary"):
if not patient_text.strip():
st.warning("Please enter a short patient description first.")
else:
q = patient_text.strip()
# Retrieve candidates
res = col.query(query_texts=[q], n_results=10)
candidates: List[Dict] = res.get("metadatas", [[]])[0]
# Apply safety/business rules
filtered = apply_rules(candidates, require_approvals=req_apps, avoid_invasive=avoid_invasive)
# Rank again by distance (already sorted by similarity) and limit
filtered_ids = [c.get("device_id") for c in filtered][:top_k]
if not filtered_ids:
st.info("No items matched your filters. Try clearing approvals, or switch off 'Avoid invasive' when looking for implants.")
else:
st.success(f"Showing {len(filtered_ids)} recommendation(s)." )
show = df[df["device_id"].isin(filtered_ids)].copy()
# Merge in a simple "why" snippet
why_map = {}
for i, md in enumerate(res.get("metadatas", [[]])[0]):
why_map[str(md.get("device_id"))] = md.get("summary", "")
show["why"] = show["device_id"].astype(str).map(why_map).fillna("")
# Display
for _, r in show.iterrows():
with st.container(border=True):
st.markdown(f"### {r['name']}")
c1, c2, c3, c4 = st.columns([2,1,1,2])
c1.write(f"**Category:** {r['category']}")
c2.write(f"**Invasiveness:** {r['invasiveness']}")
c3.write(f"**Approvals:** {r['approvals']}")
c4.write(f"**Indication:** {r['indication']}")
st.write(r["why"])
st.link_button("Open Source Page", r["source_url"], use_container_width=False)
st.markdown("---")
st.info("This tool provides informational recommendations only and is not a substitute for professional medical advice. Always consult a licensed clinician.")