|
import pickle |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
import plotly.express as px |
|
|
|
|
|
df = pd.read_csv("X_train_Y_Train_merged_train.csv") |
|
|
|
|
|
|
|
|
|
class ModelPredictor: |
|
def __init__(self, model_path, model_filenames): |
|
self.model_path = model_path |
|
self.model_filenames = model_filenames |
|
self.models = self.load_models() |
|
|
|
|
|
self.prediction_map = { |
|
"YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"], |
|
"YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"], |
|
"YOWRHRS": ["Did not have trouble sleeping", "Had trouble sleeping"], |
|
"YO_MDEA5": ["Others did not notice restlessness/lethargy", "Others noticed restlessness/lethargy"], |
|
"YOWRCHR": ["Did not feel so sad that nothing could cheer up", "Felt so sad that nothing could cheer up"], |
|
"YOWRLSIN": ["Did not feel bored and lose interest in all enjoyable things", |
|
"Felt bored and lost interest in all enjoyable things"], |
|
"YODPPROB": ["Did not have other problems for 2+ weeks", "Had other problems for 2+ weeks"], |
|
"YOWRPROB": ["Did not have the worst time ever feeling", "Had the worst time ever feeling"], |
|
"YODPR2WK": ["Did not have periods where feelings lasted 2+ weeks", |
|
"Had periods where feelings lasted 2+ weeks"], |
|
"YOWRDEPR": ["Did not feel sad/depressed mostly everyday", "Felt sad/depressed mostly everyday"], |
|
"YODPDISC": ["Overall mood duration was not sad/depressed", |
|
"Overall mood duration was sad/depressed (discrepancy)"], |
|
"YOLOSEV": ["Did not lose interest in enjoyable things and activities", |
|
"Lost interest in enjoyable things and activities"], |
|
"YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"], |
|
"YODSMMDE": ["Never had depression symptoms lasting 2 weeks or longer", |
|
"Had depression symptoms lasting 2 weeks or longer"], |
|
"YO_MDEA3": ["Did not experience changes in appetite or weight", |
|
"Experienced changes in appetite or weight"], |
|
"YODPLSIN": ["Never lost interest and felt bored", "Lost interest and felt bored"], |
|
"YOWRELES": ["Did not eat less than usual", "Ate less than usual"], |
|
"YODSCEV": ["Had fewer severe symptoms of depression", "Had more severe symptoms of depression"], |
|
"YOPB2WK": ["Did not experience uneasy feelings lasting every day for 2+ weeks or longer", |
|
"Experienced uneasy feelings lasting every day for 2+ weeks or longer"], |
|
"YO_MDEA2": ["Did not have issues with physical and mental well-being every day for 2 weeks or longer", |
|
"Had issues with physical and mental well-being every day for 2 weeks or longer"] |
|
} |
|
|
|
def load_models(self): |
|
models = [] |
|
for filename in self.model_filenames: |
|
filepath = self.model_path + filename |
|
with open(filepath, 'rb') as file: |
|
model = pickle.load(file) |
|
models.append(model) |
|
return models |
|
|
|
def make_predictions(self, user_input): |
|
""" |
|
Returns a list of numpy arrays, each array is [0] or [1]. |
|
The i-th array corresponds to the i-th model in self.models. |
|
""" |
|
predictions = [] |
|
for model in self.models: |
|
pred = model.predict(user_input) |
|
pred = np.array(pred).flatten() |
|
predictions.append(pred) |
|
return predictions |
|
|
|
def get_majority_vote(self, predictions): |
|
""" |
|
Flatten all predictions from all models, combine them into a single array, |
|
then find the majority class (0 or 1) across all of them. |
|
""" |
|
combined_predictions = np.concatenate(predictions) |
|
majority_vote = np.bincount(combined_predictions).argmax() |
|
return majority_vote |
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_severity(self, majority_vote_count): |
|
if majority_vote_count >= 13: |
|
return "Mental health severity: Severe" |
|
elif majority_vote_count >= 9: |
|
return "Mental health severity: Moderate" |
|
elif majority_vote_count >= 5: |
|
return "Mental health severity: Low" |
|
else: |
|
return "Mental health severity: Very Low" |
|
|
|
|
|
|
|
|
|
model_filenames = [ |
|
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl", |
|
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl", |
|
"YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl", |
|
"YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl" |
|
] |
|
model_path = "models/" |
|
predictor = ModelPredictor(model_path, model_filenames) |
|
|
|
|
|
|
|
|
|
def validate_inputs(*args): |
|
for arg in args: |
|
if arg == '' or arg is None: |
|
return False |
|
return True |
|
|
|
|
|
|
|
|
|
def predict( |
|
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, |
|
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, |
|
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, |
|
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, |
|
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR |
|
): |
|
|
|
user_input_data = { |
|
'YNURSMDE': [int(YNURSMDE)], |
|
'YMDEYR': [int(YMDEYR)], |
|
'YSOCMDE': [int(YSOCMDE)], |
|
'YMDESUD5ANYO': [int(YMDESUD5ANYO)], |
|
'YMSUD5YANY': [int(YMSUD5YANY)], |
|
'YUSUITHK': [int(YUSUITHK)], |
|
'YMDETXRX': [int(YMDETXRX)], |
|
'YUSUITHKYR': [int(YUSUITHKYR)], |
|
'YMDERSUD5ANY': [int(YMDERSUD5ANY)], |
|
'YUSUIPLNYR': [int(YUSUIPLNYR)], |
|
'YCOUNMDE': [int(YCOUNMDE)], |
|
'YPSY1MDE': [int(YPSY1MDE)], |
|
'YHLTMDE': [int(YHLTMDE)], |
|
'YDOCMDE': [int(YDOCMDE)], |
|
'YPSY2MDE': [int(YPSY2MDE)], |
|
'YMDEHARX': [int(YMDEHARX)], |
|
'LVLDIFMEM2': [int(LVLDIFMEM2)], |
|
'MDEIMPY': [int(MDEIMPY)], |
|
'YMDEHPO': [int(YMDEHPO)], |
|
'YMIMS5YANY': [int(YMIMS5YANY)], |
|
'YMDEIMAD5YR': [int(YMDEIMAD5YR)], |
|
'YMIUD5YANY': [int(YMIUD5YANY)], |
|
'YMDEHPRX': [int(YMDEHPRX)], |
|
'YMIMI5YANY': [int(YMIMI5YANY)], |
|
'YUSUIPLN': [int(YUSUIPLN)], |
|
'YTXMDEYR': [int(YTXMDEYR)], |
|
'YMDEAUD5YR': [int(YMDEAUD5YR)], |
|
'YRXMDEYR': [int(YRXMDEYR)], |
|
'YMDELT': [int(YMDELT)] |
|
} |
|
user_input = pd.DataFrame(user_input_data) |
|
|
|
|
|
predictions = predictor.make_predictions(user_input) |
|
|
|
|
|
majority_vote = predictor.get_majority_vote(predictions) |
|
|
|
|
|
majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1]) |
|
|
|
|
|
severity = predictor.evaluate_severity(majority_vote_count) |
|
|
|
|
|
|
|
results = { |
|
"Concentration_and_Decision_Making": [], |
|
"Sleep_and_Energy_Levels": [], |
|
"Mood_and_Emotional_State": [], |
|
"Appetite_and_Weight_Changes": [], |
|
"Duration_and_Severity_of_Depression_Symptoms": [] |
|
} |
|
|
|
prediction_groups = { |
|
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"], |
|
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"], |
|
"Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC", |
|
"YOLOSEV", "YODPLSIN", "YODSCEV"], |
|
"Appetite_and_Weight_Changes": ["YO_MDEA3", "YOWRELES"], |
|
"Duration_and_Severity_of_Depression_Symptoms": ["YODPPROB", "YOWRPROB", |
|
"YODPR2WK", "YODSMMDE", |
|
"YOPB2WK"] |
|
} |
|
|
|
|
|
for i, pred in enumerate(predictions): |
|
model_name = model_filenames[i].split('.')[0] |
|
pred_value = pred[0] |
|
|
|
if model_name in predictor.prediction_map and pred_value in [0, 1]: |
|
result_text = f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}" |
|
else: |
|
|
|
result_text = f"Model {model_name}: Prediction = {pred_value} (unmapped)" |
|
|
|
|
|
found_group = False |
|
for group_name, group_models in prediction_groups.items(): |
|
if model_name in group_models: |
|
results[group_name].append(result_text) |
|
found_group = True |
|
break |
|
if not found_group: |
|
|
|
pass |
|
|
|
|
|
formatted_results = [] |
|
for group, preds in results.items(): |
|
if preds: |
|
formatted_results.append(f"Group {group.replace('_', ' ')}:") |
|
formatted_results.append("\n".join(preds)) |
|
formatted_results.append("\n") |
|
formatted_results = "\n".join(formatted_results).strip() |
|
if not formatted_results: |
|
formatted_results = "No predictions made. Please check your inputs." |
|
|
|
|
|
num_unknown = len([p for group_preds in results.values() for p in group_preds if "(unmapped)" in p]) |
|
if num_unknown > len(model_filenames) / 2: |
|
severity += " (Unknown prediction count is high. Please consult with a human.)" |
|
|
|
|
|
|
|
|
|
total_patients = len(df) |
|
total_patient_count_markdown = ( |
|
"### Total Patient Count\n" |
|
f"There are **{total_patients}** total patients in the dataset.\n" |
|
"All subsequent analyses refer to these patients." |
|
) |
|
|
|
|
|
input_counts = {} |
|
for col in user_input_data.keys(): |
|
val = user_input_data[col][0] |
|
same_val_count = len(df[df[col] == val]) |
|
input_counts[col] = same_val_count |
|
bar_input_data = pd.DataFrame({ |
|
"Feature": list(input_counts.keys()), |
|
"Count": list(input_counts.values()) |
|
}) |
|
fig_bar_input = px.bar( |
|
bar_input_data, |
|
x="Feature", |
|
y="Count", |
|
title="Number of Patients with the Same Value for Each Input Feature", |
|
labels={"Feature": "Input Feature", "Count": "Number of Patients"} |
|
) |
|
fig_bar_input.update_layout(xaxis={'categoryorder':'total descending'}) |
|
|
|
|
|
label_counts = {} |
|
for i, pred in enumerate(predictions): |
|
model_name = model_filenames[i].split('.')[0] |
|
pred_value = pred[0] |
|
if pred_value in [0, 1]: |
|
label_counts[model_name] = len(df[df[model_name] == pred_value]) |
|
if len(label_counts) > 0: |
|
bar_label_data = pd.DataFrame({ |
|
"Model": list(label_counts.keys()), |
|
"Count": list(label_counts.values()) |
|
}) |
|
fig_bar_labels = px.bar( |
|
bar_label_data, |
|
x="Model", |
|
y="Count", |
|
title="Number of Patients with the Same Predicted Label", |
|
labels={"Model": "Predicted Column", "Count": "Patient Count"} |
|
) |
|
else: |
|
|
|
fig_bar_labels = px.bar(title="No valid predicted labels to display") |
|
|
|
|
|
|
|
|
|
|
|
demonstration_features = list(user_input_data.keys())[:4] |
|
demonstration_labels = [fn.split('.')[0] for fn in model_filenames[:3]] |
|
|
|
|
|
|
|
|
|
dist_rows = [] |
|
for feat in demonstration_features: |
|
if feat not in df.columns: |
|
continue |
|
for lbl in demonstration_labels: |
|
if lbl not in df.columns: |
|
continue |
|
tmp_df = df.groupby([feat, lbl]).size().reset_index(name="count") |
|
tmp_df["feature"] = feat |
|
tmp_df["label"] = lbl |
|
dist_rows.append(tmp_df) |
|
if len(dist_rows) > 0: |
|
big_dist_df = pd.concat(dist_rows, ignore_index=True) |
|
|
|
|
|
|
|
|
|
fig_dist = px.bar( |
|
big_dist_df, |
|
x=big_dist_df.columns[0], |
|
y="count", |
|
color=big_dist_df.columns[1], |
|
facet_row="feature", |
|
facet_col="label", |
|
title="Distribution of Sample Input Features vs. Sample Predicted Labels (Demo)", |
|
labels={ |
|
big_dist_df.columns[0]: "Feature Value", |
|
big_dist_df.columns[1]: "Label Value" |
|
} |
|
) |
|
fig_dist.update_layout(height=800) |
|
else: |
|
fig_dist = px.bar(title="No distribution plot could be generated (check feature/label columns).") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reverse_input_mapping = {} |
|
|
|
|
|
|
|
|
|
|
|
reverse_label_mapping = {} |
|
for lbl, str_list in predictor.prediction_map.items(): |
|
|
|
reverse_label_mapping[lbl] = { |
|
0: str_list[0], |
|
1: str_list[1] |
|
} |
|
|
|
|
|
|
|
input_mapping = { |
|
'YNURSMDE': {"Yes": 1, "No": 0}, |
|
'YMDEYR': {"Yes": 1, "No": 2}, |
|
'YSOCMDE': {"Yes": 1, "No": 0}, |
|
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4}, |
|
'YMSUD5YANY': {"Yes": 1, "No": 0}, |
|
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YMDETXRX': {"Yes": 1, "No": 0}, |
|
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YMDERSUD5ANY': {"Yes": 1, "No": 0}, |
|
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YCOUNMDE': {"Yes": 1, "No": 0}, |
|
'YPSY1MDE': {"Yes": 1, "No": 0}, |
|
'YHLTMDE': {"Yes": 1, "No": 0}, |
|
'YDOCMDE': {"Yes": 1, "No": 0}, |
|
'YPSY2MDE': {"Yes": 1, "No": 0}, |
|
'YMDEHARX': {"Yes": 1, "No": 0}, |
|
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3}, |
|
'MDEIMPY': {"Yes": 1, "No": 2}, |
|
'YMDEHPO': {"Yes": 1, "No": 0}, |
|
'YMIMS5YANY': {"Yes": 1, "No": 0}, |
|
'YMDEIMAD5YR': {"Yes": 1, "No": 0}, |
|
'YMIUD5YANY': {"Yes": 1, "No": 0}, |
|
'YMDEHPRX': {"Yes": 1, "No": 0}, |
|
'YMIMI5YANY': {"Yes": 1, "No": 0}, |
|
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YTXMDEYR': {"Yes": 1, "No": 0}, |
|
'YMDEAUD5YR': {"Yes": 1, "No": 0}, |
|
'YRXMDEYR': {"Yes": 1, "No": 0}, |
|
'YMDELT': {"Yes": 1, "No": 2} |
|
} |
|
|
|
|
|
for col, fwd_map in input_mapping.items(): |
|
reverse_input_mapping[col] = {v: k for k, v in fwd_map.items()} |
|
|
|
|
|
|
|
features_to_compare = list(user_input.columns) |
|
subset_df = df[features_to_compare].copy() |
|
user_series = user_input.iloc[0] |
|
|
|
distances = [] |
|
for idx, row in subset_df.iterrows(): |
|
dist = sum(row[col] != user_series[col] for col in features_to_compare) |
|
distances.append(dist) |
|
|
|
df_with_dist = df.copy() |
|
df_with_dist["distance"] = distances |
|
|
|
|
|
K = 5 |
|
nearest_neighbors = df_with_dist.sort_values("distance", ascending=True).head(K) |
|
|
|
|
|
|
|
|
|
|
|
nn_rows = [] |
|
for idx, nr in nearest_neighbors.iterrows(): |
|
|
|
row_text = [] |
|
for col in features_to_compare: |
|
val_numeric = nr[col] |
|
if col in reverse_input_mapping: |
|
row_text.append(f"{col}={reverse_input_mapping[col].get(val_numeric, val_numeric)}") |
|
else: |
|
row_text.append(f"{col}={val_numeric}") |
|
|
|
if "YOWRCONC" in nearest_neighbors.columns: |
|
label_val = nr["YOWRCONC"] |
|
if "YOWRCONC" in reverse_label_mapping: |
|
label_str = reverse_label_mapping["YOWRCONC"].get(label_val, label_val) |
|
row_text.append(f"YOWRCONC={label_str}") |
|
else: |
|
row_text.append(f"YOWRCONC={label_val}") |
|
|
|
nn_rows.append(f"- **Neighbor ID {idx}** (distance={nr['distance']}): " + ", ".join(row_text)) |
|
|
|
similar_patient_markdown = ( |
|
"### Nearest Neighbors (Simple Hamming Distance)\n" |
|
f"We searched for the top **{K}** patients whose features most closely match your input.\n\n" |
|
"> **Note**: “Nearest neighbor” methods for high-dimensional or purely categorical data can be non-trivial. " |
|
"This demo simply uses a Hamming distance over all input features and picks K=5 neighbors. " |
|
"In a real application, you would refine which features are most relevant, how to encode them, " |
|
"and how many neighbors to select.\n\n" |
|
"Below is a brief overview of each neighbor's input-feature values and one example label (`YOWRCONC`).\n\n" |
|
+ "\n".join(nn_rows) |
|
) |
|
|
|
|
|
if all(col in df.columns for col in ["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]): |
|
co_occ_data = df.groupby(["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]).size().reset_index(name="count") |
|
fig_co_occ = px.bar( |
|
co_occ_data, |
|
x="YMDEYR", |
|
y="count", |
|
color="YOWRCONC", |
|
facet_col="YMDERSUD5ANY", |
|
title="Co-Occurrence Plot: YMDEYR and YMDERSUD5ANY vs YOWRCONC" |
|
) |
|
else: |
|
fig_co_occ = px.bar(title="Co-occurrence plot not available (check columns).") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ( |
|
formatted_results, |
|
severity, |
|
total_patient_count_markdown, |
|
fig_dist, |
|
similar_patient_markdown, |
|
fig_co_occ, |
|
fig_bar_input, |
|
fig_bar_labels |
|
) |
|
|
|
|
|
|
|
|
|
input_mapping = { |
|
'YNURSMDE': {"Yes": 1, "No": 0}, |
|
'YMDEYR': {"Yes": 1, "No": 2}, |
|
'YSOCMDE': {"Yes": 1, "No": 0}, |
|
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4}, |
|
'YMSUD5YANY': {"Yes": 1, "No": 0}, |
|
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YMDETXRX': {"Yes": 1, "No": 0}, |
|
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YMDERSUD5ANY': {"Yes": 1, "No": 0}, |
|
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YCOUNMDE': {"Yes": 1, "No": 0}, |
|
'YPSY1MDE': {"Yes": 1, "No": 0}, |
|
'YHLTMDE': {"Yes": 1, "No": 0}, |
|
'YDOCMDE': {"Yes": 1, "No": 0}, |
|
'YPSY2MDE': {"Yes": 1, "No": 0}, |
|
'YMDEHARX': {"Yes": 1, "No": 0}, |
|
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3}, |
|
'MDEIMPY': {"Yes": 1, "No": 2}, |
|
'YMDEHPO': {"Yes": 1, "No": 0}, |
|
'YMIMS5YANY': {"Yes": 1, "No": 0}, |
|
'YMDEIMAD5YR': {"Yes": 1, "No": 0}, |
|
'YMIUD5YANY': {"Yes": 1, "No": 0}, |
|
'YMDEHPRX': {"Yes": 1, "No": 0}, |
|
'YMIMI5YANY': {"Yes": 1, "No": 0}, |
|
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YTXMDEYR': {"Yes": 1, "No": 0}, |
|
'YMDEAUD5YR': {"Yes": 1, "No": 0}, |
|
'YRXMDEYR': {"Yes": 1, "No": 0}, |
|
'YMDELT': {"Yes": 1, "No": 2} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
inputs = [ |
|
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEARS MAJOR DEPRESSIVE EPISODE"), |
|
gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE OR SUBSTANCE USE DISORDER - ANY"), |
|
gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE WITH SEV. IMP + ALCOHOL USE DISORDER"), |
|
gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE W/ SEV. IMP + SUBSTANCE USE DISORDER"), |
|
gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: HAD MAJOR DEPRESSIVE EPISODE IN LIFETIME"), |
|
gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: SAW HEALTH PROF + MEDS FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: SAW HEALTH PROF OR MEDS FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: RECEIVED TREATMENT/COUNSELING FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: SAW HEALTH PROF ONLY FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + ALCOHOL USE DISORDER"), |
|
gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE W/ ILL DRUG USE DISORDER"), |
|
gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL DRUG USE DISORDER"), |
|
gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs. SUD vs. BOTH vs. NEITHER"), |
|
|
|
|
|
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: SAW/TALK TO NURSE/OT ABOUT MDE"), |
|
gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: SAW/TALK TO SOCIAL WORKER ABOUT MDE"), |
|
gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: SAW/TALK TO COUNSELOR ABOUT MDE"), |
|
gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: SAW/TALK TO PSYCHOLOGIST ABOUT MDE"), |
|
gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: SAW/TALK TO PSYCHIATRIST ABOUT MDE"), |
|
gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: SAW/TALK TO HEALTH PROFESSIONAL ABOUT MDE"), |
|
gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: SAW/TALK TO GP/FAMILY MD ABOUT MDE"), |
|
gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: SAW/TALK DOCTOR/HEALTH PROF FOR MDE"), |
|
|
|
|
|
gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: SERIOUSLY THOUGHT ABOUT KILLING SELF"), |
|
gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: MADE PLANS TO KILL SELF"), |
|
gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: THINK ABOUT KILLING SELF (12 MONTHS)"), |
|
gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: MADE PLANS TO KILL SELF (12 MONTHS)"), |
|
|
|
|
|
gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: MDE W/ SEVERE ROLE IMPAIRMENT"), |
|
gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: LEVEL OF DIFFICULTY REMEMBERING/CONCENTRATING"), |
|
gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + SUBSTANCE USE DISORDER - ANY"), |
|
gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: USED MEDS FOR MDE IN PAST YEAR"), |
|
] |
|
|
|
|
|
outputs = [ |
|
gr.Textbox(label="Prediction Results", lines=30), |
|
gr.Textbox(label="Mental Health Severity", lines=4), |
|
gr.Markdown(label="Total Patient Count"), |
|
gr.Plot(label="Distribution Plot (Sample of Features & Labels)"), |
|
gr.Markdown(label="Nearest Neighbors Summary"), |
|
gr.Plot(label="Co-Occurrence Plot"), |
|
gr.Plot(label="Number of Patients per Input Feature"), |
|
gr.Plot(label="Number of Patients with Predicted Labels") |
|
] |
|
|
|
|
|
|
|
|
|
def predict_with_text( |
|
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, |
|
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, |
|
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, |
|
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, |
|
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR |
|
): |
|
|
|
if not validate_inputs( |
|
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, |
|
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, |
|
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, |
|
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, |
|
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR |
|
): |
|
return ( |
|
"Please select all required fields.", |
|
"Validation Error", |
|
"No data", |
|
None, |
|
"No data", |
|
None, |
|
None, |
|
None |
|
) |
|
|
|
|
|
user_inputs = { |
|
'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE], |
|
'YMDEYR': input_mapping['YMDEYR'][YMDEYR], |
|
'YSOCMDE': input_mapping['YSOCMDE'][YSOCMDE], |
|
'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO], |
|
'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY], |
|
'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK], |
|
'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX], |
|
'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR], |
|
'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY], |
|
'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR], |
|
'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE], |
|
'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE], |
|
'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE], |
|
'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE], |
|
'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE], |
|
'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX], |
|
'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2], |
|
'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY], |
|
'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO], |
|
'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY], |
|
'YMDEIMAD5YR': input_mapping['YMDEIMAD5YR'][YMDEIMAD5YR], |
|
'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY], |
|
'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX], |
|
'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY], |
|
'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN], |
|
'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR], |
|
'YMDEAUD5YR': input_mapping['YMDEAUD5YR'][YMDEAUD5YR], |
|
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR], |
|
'YMDELT': input_mapping['YMDELT'][YMDELT] |
|
} |
|
|
|
|
|
return predict(**user_inputs) |
|
|
|
|
|
custom_css = """ |
|
.gradio-container * { |
|
color: #1B1212 !important; |
|
} |
|
.gradio-container .form .form-group label { |
|
color: #1B1212 !important; |
|
} |
|
.gradio-container .output-textbox, |
|
.gradio-container .output-textbox textarea { |
|
color: #1B1212 !important; |
|
} |
|
.gradio-container .label, |
|
.gradio-container .input-label { |
|
color: #1B1212 !important; |
|
} |
|
""" |
|
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_with_text, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="Adolescents with Substance Use Mental Health Screening (NSDUH Data)", |
|
css=custom_css |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|