File size: 4,741 Bytes
6397272 fdab7a6 ec804d3 8ea5177 30d58e8 6397272 2730dad 3e4df7c de23f75 2730dad 30d58e8 2730dad 30d58e8 2730dad de23f75 991df54 30d58e8 991df54 3e4df7c 991df54 3e4df7c de23f75 fdab7a6 ec804d3 30d58e8 ec804d3 2730dad b6d83f4 2730dad 1d9d836 ec804d3 2730dad ec804d3 2730dad 8ea5177 2730dad ec804d3 2730dad ec804d3 2730dad 1d9d836 2730dad 30d58e8 2730dad 30d58e8 d217981 ec804d3 2730dad fdab7a6 ec804d3 fdab7a6 30d58e8 de23f75 2730dad 8ea5177 fdab7a6 2730dad 8ea5177 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import streamlit as st
import json
import random
import pandas as pd
import pickle
# set page configuration to wide mode
st.set_page_config(layout="wide")
st.markdown("""
<style>
.bounding-box {
border: 2px solid #4CAF50; # Green border
border-radius: 5px; # Rounded corners
padding: 10px; # Padding inside the box
margin: 10px; # Space outside the box
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_model():
import adrd
try:
ckpt_path = './ckpt_swinunetr_stripped_MNI.pt'
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
except:
ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
return model
@st.cache_resource
def load_nacc_data():
from data.dataset_csv import CSVDataset
dat = CSVDataset(
dat_file = "./data/test.csv",
cnf_file = "./data/input_meta_info.csv"
)
return dat
model = load_model()
dat_tst = load_nacc_data()
def predict_proba(data_dict):
pred_dict = model.predict_proba([data_dict])[1][0]
return pred_dict
# load NACC testing data
from data.dataset_csv import CSVDataset
dat_tst = CSVDataset(
dat_file = "./data/test.csv",
cnf_file = "./data/input_meta_info.csv"
)
# initialize session state for the text input if it's not already set
if 'input_text' not in st.session_state:
st.session_state.input_text = ""
# section 1
st.markdown("#### About")
st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")
# section 2
st.markdown("#### Demo")
st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**Random case**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**Predict**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")
# layout
layout_l, layout_r = st.columns([1, 1])
# create a form for user input
with layout_l:
with st.form("json_input_form"):
json_input = st.text_area(
"Please enter JSON-formatted input features:",
value = st.session_state.input_text,
height = 300
)
# create three columns
left_col, middle_col, right_col = st.columns([3, 4, 1])
with left_col:
sample_button = st.form_submit_button("Random case")
with right_col:
submit_button = st.form_submit_button("Predict")
with open('./data/nacc_variable_mappings.pkl', 'rb') as file:
nacc_mapping = pickle.load(file)
def convert_dictionary(original_dict, mappings):
transformed_dict = {}
for key, value in original_dict.items():
if key in mappings:
new_key, transform_map = mappings[key]
# If the value needs to be transformed
if value in transform_map:
transformed_value = transform_map[value]
else:
transformed_value = value # Keep the original value if no transformation is needed
transformed_dict[new_key] = transformed_value
return transformed_dict
if sample_button:
idx = random.randint(0, len(dat_tst) - 1)
random_case = dat_tst[idx][0]
st.session_state.input_text = json.dumps(random_case, indent=2)
# reset input text after form processing to show updated text in the input box
if 'input_text' in st.session_state:
st.experimental_rerun()
elif submit_button:
try:
# Parse the JSON input into a Python dictionary
data_dict = json.loads(json_input)
data_dict = convert_dictionary(data_dict, nacc_mapping)
# print(data_dict)
pred_dict = predict_proba(data_dict)
with layout_r:
st.write("Predicted probabilities:")
st.code(json.dumps(pred_dict, indent=2))
except json.JSONDecodeError as e:
# Handle JSON parsing errors
st.error(f"An error occurred: {e}")
# section 3
st.markdown("#### Feature Table")
df_input_meta_info = pd.read_csv('./data/input_meta_info.csv')
st.table(df_input_meta_info)
|