|
import streamlit as st |
|
import json |
|
import random |
|
import pandas as pd |
|
import pickle |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
st.markdown(""" |
|
<style> |
|
.bounding-box { |
|
border: 2px solid #4CAF50; # Green border |
|
border-radius: 5px; # Rounded corners |
|
padding: 10px; # Padding inside the box |
|
margin: 10px; # Space outside the box |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
import adrd |
|
try: |
|
ckpt_path = './ckpt_swinunetr_stripped_MNI.pt' |
|
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu') |
|
except: |
|
ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt' |
|
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu') |
|
return model |
|
|
|
@st.cache_resource |
|
def load_nacc_data(): |
|
from data.dataset_csv import CSVDataset |
|
dat = CSVDataset( |
|
dat_file = "./data/test.csv", |
|
cnf_file = "./data/input_meta_info.csv" |
|
) |
|
return dat |
|
|
|
model = load_model() |
|
dat_tst = load_nacc_data() |
|
|
|
def predict_proba(data_dict): |
|
pred_dict = model.predict_proba([data_dict])[1][0] |
|
return pred_dict |
|
|
|
|
|
from data.dataset_csv import CSVDataset |
|
dat_tst = CSVDataset( |
|
dat_file = "./data/test.csv", |
|
cnf_file = "./data/input_meta_info.csv" |
|
) |
|
|
|
|
|
if 'input_text' not in st.session_state: |
|
st.session_state.input_text = "" |
|
|
|
|
|
st.markdown("#### About") |
|
st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.") |
|
|
|
|
|
st.markdown("#### Demo") |
|
st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**Random case**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**Predict**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.") |
|
|
|
|
|
layout_l, layout_r = st.columns([1, 1]) |
|
|
|
|
|
with layout_l: |
|
with st.form("json_input_form"): |
|
json_input = st.text_area( |
|
"Please enter JSON-formatted input features:", |
|
value = st.session_state.input_text, |
|
height = 300 |
|
) |
|
|
|
|
|
left_col, middle_col, right_col = st.columns([3, 4, 1]) |
|
|
|
with left_col: |
|
sample_button = st.form_submit_button("Random case") |
|
|
|
with right_col: |
|
submit_button = st.form_submit_button("Predict") |
|
|
|
with open('./data/nacc_variable_mappings.pkl', 'rb') as file: |
|
nacc_mapping = pickle.load(file) |
|
|
|
def convert_dictionary(original_dict, mappings): |
|
transformed_dict = {} |
|
|
|
for key, value in original_dict.items(): |
|
if key in mappings: |
|
new_key, transform_map = mappings[key] |
|
|
|
|
|
if value in transform_map: |
|
transformed_value = transform_map[value] |
|
else: |
|
transformed_value = value |
|
|
|
transformed_dict[new_key] = transformed_value |
|
|
|
return transformed_dict |
|
|
|
if sample_button: |
|
idx = random.randint(0, len(dat_tst) - 1) |
|
random_case = dat_tst[idx][0] |
|
st.session_state.input_text = json.dumps(random_case, indent=2) |
|
|
|
|
|
if 'input_text' in st.session_state: |
|
st.experimental_rerun() |
|
|
|
elif submit_button: |
|
try: |
|
|
|
data_dict = json.loads(json_input) |
|
data_dict = convert_dictionary(data_dict, nacc_mapping) |
|
|
|
pred_dict = predict_proba(data_dict) |
|
with layout_r: |
|
st.write("Predicted probabilities:") |
|
st.code(json.dumps(pred_dict, indent=2)) |
|
except json.JSONDecodeError as e: |
|
|
|
st.error(f"An error occurred: {e}") |
|
|
|
|
|
st.markdown("#### Feature Table") |
|
df_input_meta_info = pd.read_csv('./data/input_meta_info.csv') |
|
st.table(df_input_meta_info) |
|
|