SymptomTracker / app.py
qaisar701shan's picture
Update app.py
d0686ec verified
import os
import re
from datetime import datetime
import streamlit as st
import pandas as pd
import numpy as np
import altair as alt
from rules_engine import compute_bmi, evaluate_risks, rules_risk_assessment
from data_store import init_db, upsert_entry, get_user_history, to_csv_bytes
# ----------------------
# App Config & Theming
# ----------------------
st.set_page_config(
page_title="NeuraVia – Symptom Tracker & Risk Dashboard",
page_icon="🧑‍⚕️",
layout="wide",
)
PRIMARY_CONDITIONS = [
"Type 2 Diabetes Risk",
"Hypertension Risk",
"Depression/Mood Concern",
"Migraine Risk",
"Sleep Apnea Risk",
"Anemia Risk",
]
# Map human label → DB column (underscored, preserve case to match data_store.py)
def to_db_key(label: str) -> str:
# e.g. "Type 2 Diabetes Risk" -> "risk_Type_2_Diabetes_Risk"
return "risk_" + re.sub(r"[^0-9A-Za-z]+", "_", label).strip("_")
@st.cache_resource(show_spinner=False)
def get_hf_pipeline():
try:
return load_zero_shot_classifier()
except Exception:
return None
# ----------------------
# Sidebar – Navigation
# ----------------------
with st.sidebar:
st.title("🧑‍⚕️NeuraVia – Symptom Tracker")
st.caption("Detect • Connect • Personalize")
page = st.radio("Go to", ["Start", "Input", "Dashboard", "History"], index=0)
st.divider()
demo_mode = st.toggle("Use Demo Data", value=False, help="Prefill a demo input")
st.divider()
st.info(
"Educational use only. Not a medical device. Consult a clinician for advice.",
icon="⚠️",
)
def explain_rules(rules_explain: dict) -> str:
"""
Converts the explanation dict from rules_risk_assessment into
a Streamlit-friendly markdown string with bold labels.
"""
if not rules_explain:
return "No rule-based explanation available."
lines = []
for key, value in rules_explain.items():
lines.append(f"- **{key}**: {value}")
return "\n".join(lines)
# ----------------------
# Global init
# ----------------------
init_db()
st.markdown(
"""
<style>
.card {background:#fff;border-radius:16px;padding:16px;box-shadow:0 8px 24px rgba(0,0,0,0.07)}
.metric{font-size:22px;font-weight:800}
.section-title{font-size:26px;font-weight:800;margin-bottom:8px}
</style>
""",
unsafe_allow_html=True,
)
st.markdown("## 🧑‍⚕️ Symptom Tracker & Risk Dashboard")
st.caption("Capture data → interpretable risk flags → trend charts. Built for NeuraViaHacks.")
# ----------------------
# Start Page
# ----------------------
if page == "Start":
c1, c2 = st.columns([1.2, 1])
with c1:
st.markdown("### What this does")
st.write(
"- Collect demographics, vitals, and symptoms.\n"
"- Score conditions via explainable rules + optional HF zero-shot.\n"
"- Store locally in SQLite and visualize trends.\n"
"- Export CSV for sharing."
)
st.markdown("### How to demo")
st.write(
"1) Open **Input** and submit (or toggle Demo).\n"
"2) Open **Dashboard** for cards & charts.\n"
"3) Use **History** to review and export CSV."
)
with c2:
st.markdown("<div class='card'>", unsafe_allow_html=True)
st.markdown("**Beginner-friendly • Interpretable • Real-world relevance**")
st.caption("Detect · Connect · Personalize")
st.markdown("</div>", unsafe_allow_html=True)
# ----------------------
# Input Page
# ----------------------
if page == "Input":
st.markdown("<div class='section-title'>Patient Data Input</div>", unsafe_allow_html=True)
if demo_mode:
default = dict(
user_id="demo_user",
age=42,
sex="Male",
height_cm=175,
weight_kg=92,
sbp=142,
dbp=92,
hr=88,
spo2=96,
glucose=118,
symptoms=["Frequent urination", "Headache", "Daytime sleepiness", "Loud snoring"],
free_text="Often thirsty and tired; loud snoring; morning headaches.",
)
else:
default = dict(
user_id="",
age=30,
sex="Female",
height_cm=165,
weight_kg=65,
sbp=120,
dbp=80,
hr=72,
spo2=98,
glucose=95,
symptoms=[],
free_text="",
)
with st.form("input_form"):
st.markdown("**Profile**")
c1, c2, c3 = st.columns(3)
user_id = c1.text_input("User ID (unique)", value=default["user_id"], help="Used to group your timeline")
age = c2.number_input("Age", 1, 120, int(default["age"]))
sex = c3.selectbox(
"Sex",
["Female", "Male", "Other", "Prefer not to say"],
index=1 if default["sex"] == "Male" else 0,
)
st.markdown("**Vitals**")
c1, c2, c3, c4 = st.columns(4)
height_cm = c1.number_input("Height (cm)", 80, 230, int(default["height_cm"]))
weight_kg = c2.number_input("Weight (kg)", 20, 250, int(default["weight_kg"]))
sbp = c3.number_input("Systolic BP (SBP)", 70, 240, int(default["sbp"]))
dbp = c4.number_input("Diastolic BP (DBP)", 40, 140, int(default["dbp"]))
c1, c2, c3 = st.columns(3)
hr = c1.number_input("Heart Rate (bpm)", 30, 200, int(default["hr"]))
spo2 = c2.number_input("SpO₂ (%)", 70, 100, int(default["spo2"]))
glucose = c3.number_input("Glucose (mg/dL)", 50, 400, int(default["glucose"]))
st.markdown("**Symptoms**")
SYMPTOMS_LIST = [
"Frequent urination", "Excessive thirst", "Unintended weight loss", "Blurred vision",
"Headache", "Nausea", "Sensitivity to light", "Sleep problems", "Daytime sleepiness",
"Loud snoring", "Pauses in breathing during sleep", "Shortness of breath",
"Fatigue", "Dizziness", "Pale skin", "Cold hands/feet",
"Low mood", "Loss of interest", "Anxiety", "Irritability",
]
symptoms = st.multiselect("Select common symptoms (optional)", SYMPTOMS_LIST, default=default["symptoms"])
free_text = st.text_area("Describe your symptoms (optional)", value=default["free_text"], height=110)
submitted = st.form_submit_button("Compute Risk", use_container_width=True)
if submitted:
if not user_id:
st.error("Please enter a User ID.")
st.stop()
bmi = compute_bmi(height_cm, weight_kg)
rules_scores, rules_explain = rules_risk_assessment(
age=age,
sex=sex,
sbp=sbp,
dbp=dbp,
hr=hr,
spo2=spo2,
glucose=glucose,
bmi=bmi,
symptoms=symptoms,
)
# Optional HF zero-shot scoring
hf_scores = {}
pipe = get_hf_pipeline()
if pipe is not None:
text = "; ".join(symptoms) + ("; " + free_text if free_text else "")
hf_scores = zero_shot_score(pipe, text=text, labels=PRIMARY_CONDITIONS)
# Combine scores (prefer explainable rules slightly)
combined = {}
for cond in PRIMARY_CONDITIONS:
r = float(rules_scores.get(cond, 0.0))
h = hf_scores.get(cond, np.nan)
combined[cond] = r if (isinstance(h, float) and np.isnan(h)) else (0.65 * r + 0.35 * float(h))
# Persist entry (note: DB uses underscored columns; upsert_entry expects dict keys with spaces)
entry = {
"user_id": user_id,
"ts": datetime.utcnow().isoformat(),
"age": age,
"sex": sex,
"height_cm": height_cm,
"weight_kg": weight_kg,
"bmi": round(bmi, 2),
"sbp": sbp,
"dbp": dbp,
"hr": hr,
"spo2": spo2,
"glucose": glucose,
"symptoms": ", ".join(symptoms),
"free_text": free_text,
**{ to_db_key(cond): round(score, 3) for cond, score in combined.items() },
}
upsert_entry(entry)
st.success("Risk computed and saved. Open the **Dashboard** tab.")
with st.expander("See rule explanations"):
st.markdown(explain_rules(rules_explain))
# ----------------------
# Dashboard Page
# ----------------------
if page == "Dashboard":
st.markdown("<div class='section-title'>Risk Dashboard</div>", unsafe_allow_html=True)
uid = st.text_input("Enter User ID to view", value="demo_user" if demo_mode else "")
if not uid:
st.info("Enter your User ID to see assessments.")
st.stop()
df = get_user_history(uid)
if df.empty:
st.warning("No records found. Submit an entry in **Input**.")
st.stop()
df["ts"] = pd.to_datetime(df["ts"], errors="coerce")
df = df.sort_values("ts")
latest = df.iloc[-1]
# Top risk cards (use DB column names)
st.markdown("#### Latest Risk Flags")
cols = st.columns(3)
def badge(v: float) -> str:
if v >= 0.75:
return "🔴 High"
if v >= 0.5:
return "🟠 Moderate"
if v >= 0.25:
return "🟡 Mild"
return "🟢 Low"
metrics = []
for label in PRIMARY_CONDITIONS:
colname = to_db_key(label)
if colname in df.columns:
try:
val = float(latest[colname])
except Exception:
val = 0.0
pretty = label.replace(" Risk", "")
metrics.append((pretty, val))
for i, (name, val) in enumerate(metrics):
with cols[i % 3]:
st.markdown("<div class='card'>", unsafe_allow_html=True)
st.markdown(f"**{name}**")
st.markdown(f"<div class='metric'>{val:.2f}</div>", unsafe_allow_html=True)
st.caption(badge(val))
st.markdown("</div>", unsafe_allow_html=True)
# Trend charts
st.markdown("#### Trends Over Time")
vitals = ["bmi", "sbp", "dbp", "hr", "spo2", "glucose"]
display_to_col = {**{v: v for v in vitals}, **{f"risk: {lbl}": to_db_key(lbl) for lbl in PRIMARY_CONDITIONS}}
pick = st.multiselect(
"Choose metrics to plot",
options=vitals + [f"risk: {lbl}" for lbl in PRIMARY_CONDITIONS],
default=["bmi", "sbp", "dbp"],
)
for disp in pick:
colname = display_to_col.get(disp, disp)
if colname not in df.columns:
continue
chart = (
alt.Chart(df).mark_line(point=True).encode(
x=alt.X("ts:T", title="Time"),
y=alt.Y(f"{colname}:Q", title=disp),
tooltip=["ts:T", colname],
).properties(height=220)
)
st.altair_chart(chart, use_container_width=True)
# ----------------------
# History Page
# ----------------------
if page == "History":
st.markdown("<div class='section-title'>History & Export</div>", unsafe_allow_html=True)
uid = st.text_input("User ID", value="demo_user" if demo_mode else "")
if not uid:
st.info("Enter a User ID to load history.")
st.stop()
df = get_user_history(uid)
if df.empty:
st.warning("No records yet. Submit data in **Input**.")
else:
st.dataframe(df, use_container_width=True)
csv_bytes = to_csv_bytes(df)
st.download_button(
"Download CSV",
data=csv_bytes,
file_name=f"{uid}_history.csv",
mime="text/csv",
use_container_width=True,
)
st.markdown("---")
st.caption("NeuraViaHacks submission by Qaisar A • Built with Streamlit • Not medical advice.")