nmed2024 / app.py
skowshik's picture
Update ckpt and data
30d58e8
raw
history blame
4.74 kB
import streamlit as st
import json
import random
import pandas as pd
import pickle
# set page configuration to wide mode
st.set_page_config(layout="wide")
st.markdown("""
<style>
.bounding-box {
border: 2px solid #4CAF50; # Green border
border-radius: 5px; # Rounded corners
padding: 10px; # Padding inside the box
margin: 10px; # Space outside the box
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_model():
import adrd
try:
ckpt_path = './ckpt_swinunetr_stripped_MNI.pt'
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
except:
ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
return model
@st.cache_resource
def load_nacc_data():
from data.dataset_csv import CSVDataset
dat = CSVDataset(
dat_file = "./data/test.csv",
cnf_file = "./data/input_meta_info.csv"
)
return dat
model = load_model()
dat_tst = load_nacc_data()
def predict_proba(data_dict):
pred_dict = model.predict_proba([data_dict])[1][0]
return pred_dict
# load NACC testing data
from data.dataset_csv import CSVDataset
dat_tst = CSVDataset(
dat_file = "./data/test.csv",
cnf_file = "./data/input_meta_info.csv"
)
# initialize session state for the text input if it's not already set
if 'input_text' not in st.session_state:
st.session_state.input_text = ""
# section 1
st.markdown("#### About")
st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")
# section 2
st.markdown("#### Demo")
st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**Random case**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**Predict**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")
# layout
layout_l, layout_r = st.columns([1, 1])
# create a form for user input
with layout_l:
with st.form("json_input_form"):
json_input = st.text_area(
"Please enter JSON-formatted input features:",
value = st.session_state.input_text,
height = 300
)
# create three columns
left_col, middle_col, right_col = st.columns([3, 4, 1])
with left_col:
sample_button = st.form_submit_button("Random case")
with right_col:
submit_button = st.form_submit_button("Predict")
with open('./data/nacc_variable_mappings.pkl', 'rb') as file:
nacc_mapping = pickle.load(file)
def convert_dictionary(original_dict, mappings):
transformed_dict = {}
for key, value in original_dict.items():
if key in mappings:
new_key, transform_map = mappings[key]
# If the value needs to be transformed
if value in transform_map:
transformed_value = transform_map[value]
else:
transformed_value = value # Keep the original value if no transformation is needed
transformed_dict[new_key] = transformed_value
return transformed_dict
if sample_button:
idx = random.randint(0, len(dat_tst) - 1)
random_case = dat_tst[idx][0]
st.session_state.input_text = json.dumps(random_case, indent=2)
# reset input text after form processing to show updated text in the input box
if 'input_text' in st.session_state:
st.experimental_rerun()
elif submit_button:
try:
# Parse the JSON input into a Python dictionary
data_dict = json.loads(json_input)
data_dict = convert_dictionary(data_dict, nacc_mapping)
# print(data_dict)
pred_dict = predict_proba(data_dict)
with layout_r:
st.write("Predicted probabilities:")
st.code(json.dumps(pred_dict, indent=2))
except json.JSONDecodeError as e:
# Handle JSON parsing errors
st.error(f"An error occurred: {e}")
# section 3
st.markdown("#### Feature Table")
df_input_meta_info = pd.read_csv('./data/input_meta_info.csv')
st.table(df_input_meta_info)