File size: 4,741 Bytes
6397272
fdab7a6
ec804d3
8ea5177
30d58e8
6397272
2730dad
 
 
 
 
 
 
 
 
 
 
 
 
 
3e4df7c
de23f75
 
2730dad
30d58e8
2730dad
 
30d58e8
2730dad
de23f75
 
991df54
 
 
 
30d58e8
 
991df54
 
 
3e4df7c
991df54
3e4df7c
de23f75
 
 
fdab7a6
ec804d3
 
 
30d58e8
 
ec804d3
 
 
 
 
 
2730dad
b6d83f4
2730dad
 
 
 
1d9d836
ec804d3
2730dad
 
ec804d3
2730dad
 
 
 
 
 
8ea5177
2730dad
ec804d3
2730dad
 
ec804d3
2730dad
1d9d836
2730dad
 
 
30d58e8
 
 
 
 
 
 
 
 
 
2730dad
30d58e8
 
 
 
 
 
 
 
 
 
d217981
ec804d3
2730dad
 
fdab7a6
ec804d3
 
 
 
 
fdab7a6
 
 
30d58e8
 
de23f75
2730dad
 
8ea5177
fdab7a6
 
 
2730dad
 
8ea5177
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import streamlit as st
import json
import random
import pandas as pd
import pickle

# set page configuration to wide mode
st.set_page_config(layout="wide")

st.markdown("""
<style>
.bounding-box {
    border: 2px solid #4CAF50;  # Green border
    border-radius: 5px;         # Rounded corners
    padding: 10px;              # Padding inside the box
    margin: 10px;               # Space outside the box
}
</style>
""", unsafe_allow_html=True)

@st.cache_resource
def load_model():
    import adrd
    try:   
        ckpt_path = './ckpt_swinunetr_stripped_MNI.pt'
        model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
    except:
        ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
        model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
    return model

@st.cache_resource
def load_nacc_data():
    from data.dataset_csv import CSVDataset
    dat = CSVDataset(
        dat_file = "./data/test.csv", 
        cnf_file = "./data/input_meta_info.csv"
    )
    return dat

model = load_model()
dat_tst = load_nacc_data()

def predict_proba(data_dict):
    pred_dict = model.predict_proba([data_dict])[1][0]
    return pred_dict

# load NACC testing data
from data.dataset_csv import CSVDataset
dat_tst = CSVDataset(
    dat_file = "./data/test.csv", 
    cnf_file = "./data/input_meta_info.csv"
)

# initialize session state for the text input if it's not already set
if 'input_text' not in st.session_state:
    st.session_state.input_text = ""

# section 1
st.markdown("#### About")
st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")

# section 2
st.markdown("#### Demo")
st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**Random case**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**Predict**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")

# layout
layout_l, layout_r = st.columns([1, 1])

# create a form for user input
with layout_l:
    with st.form("json_input_form"):
        json_input = st.text_area(
            "Please enter JSON-formatted input features:", 
            value = st.session_state.input_text,
            height = 300
        )

        # create three columns
        left_col, middle_col, right_col = st.columns([3, 4, 1])

        with left_col:
            sample_button = st.form_submit_button("Random case")

        with right_col:
            submit_button = st.form_submit_button("Predict")
    
with open('./data/nacc_variable_mappings.pkl', 'rb') as file:
    nacc_mapping = pickle.load(file)
    
def convert_dictionary(original_dict, mappings):
    transformed_dict = {}
    
    for key, value in original_dict.items():
        if key in mappings:
            new_key, transform_map = mappings[key]
            
            # If the value needs to be transformed
            if value in transform_map:
                transformed_value = transform_map[value]
            else:
                transformed_value = value  # Keep the original value if no transformation is needed
            
            transformed_dict[new_key] = transformed_value
    
    return transformed_dict
 
if sample_button:
    idx = random.randint(0, len(dat_tst) - 1)
    random_case = dat_tst[idx][0]    
    st.session_state.input_text = json.dumps(random_case, indent=2)

    # reset input text after form processing to show updated text in the input box
    if 'input_text' in st.session_state:
        st.experimental_rerun()
    
elif submit_button:
    try:
        # Parse the JSON input into a Python dictionary
        data_dict = json.loads(json_input)
        data_dict = convert_dictionary(data_dict, nacc_mapping)
        # print(data_dict)
        pred_dict = predict_proba(data_dict)
        with layout_r:
            st.write("Predicted probabilities:")
            st.code(json.dumps(pred_dict, indent=2))
    except json.JSONDecodeError as e:
        # Handle JSON parsing errors
        st.error(f"An error occurred: {e}")

# section 3
st.markdown("#### Feature Table")
df_input_meta_info = pd.read_csv('./data/input_meta_info.csv')
st.table(df_input_meta_info)