Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from glob import glob | |
| from functools import partial | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import pandas as pd | |
| import re | |
| from model import VariationalGNN | |
| examples_path = "examples" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| correct_preds, wrong_preds = {}, {} | |
| condition_lst = pd.read_csv("data/feature.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str) | |
| D_LABITEMS = pd.read_csv("data/D_LABITEMS.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str) | |
| def load_model(): | |
| path = r"models/final_model.pt" | |
| kwargs, state = torch.load(path, weights_only=False, map_location=device) | |
| model = VariationalGNN(**kwargs).to(device) | |
| model.load_state_dict(state) | |
| return model | |
| model = load_model() | |
| def _check_patient_csv_format(df: pd.DataFrame): | |
| if not (list(df.columns)[0:2] == ["condition", "value"]): | |
| raise gr.Error(f"Column set [{list(df.columns)}]: not expected.", duration=None) | |
| if condition_lst["condition"].to_list() != df["condition"].to_list(): | |
| raise gr.Error(f"Condition set: not expected.", duration=None) | |
| vals = np.sort(df["value"].unique()) | |
| if not (vals.ndim == 1 and len(vals) == 2 and all(vals == np.array([0.0, 1.0]))): | |
| raise gr.Error(f"Column 'value': contain invalid values.", duration=None) | |
| def _extract_patient_data_from_name(csv_file_name: str): | |
| patient_file_pat = r"^Patient_(\d+)_\(Label-(alive|dead)\)_\(Predicted-(dead|alive)\).csv$" | |
| csv_name = Path(csv_file_name).name | |
| matches = re.search(patient_file_pat, csv_name) | |
| if matches is None: | |
| return None | |
| else: | |
| return (matches.group(1), matches.group(2), matches.group(3)) | |
| def _find_example_csv_files() -> None: | |
| all_csv_files = glob(f'{examples_path}/*.csv', recursive=True) | |
| if len(all_csv_files) == 0: | |
| print("*** No csv files found.") | |
| else: | |
| for one_csv_file in all_csv_files: | |
| matches = _extract_patient_data_from_name(one_csv_file) | |
| if matches: | |
| pat_id, pat_label, pat_predicted = matches | |
| if pat_id in correct_preds or pat_id in wrong_preds: | |
| print(f"*** File [{one_csv_file}]: already processed! How come?") | |
| else: | |
| if pat_label == pat_predicted: | |
| correct_preds[pat_id] = {"label": pat_label, | |
| "predicted": pat_predicted, | |
| "file_name": one_csv_file} | |
| else: | |
| wrong_preds[pat_id] = {"label": pat_label, | |
| "predicted": pat_predicted, | |
| "file_name": one_csv_file} | |
| else: | |
| print(f"*** File [{one_csv_file}]: wrong name.") | |
| _find_example_csv_files() | |
| def _predict(file_path: str): | |
| df = pd.read_csv(f"{file_path}", | |
| header="infer", | |
| sep=",", | |
| encoding="utf-8", | |
| dtype={'condition': 'str', 'value': 'float32'}, | |
| keep_default_na=False) | |
| _check_patient_csv_format(df) | |
| patient_data = torch.from_numpy(df["value"].to_numpy()).unsqueeze(dim=0).to(device) | |
| model.eval() | |
| with torch.inference_mode(): | |
| probability, _ = model(patient_data) | |
| probability = torch.sigmoid(probability.detach().cpu()[0]).item() | |
| return probability | |
| def example_csv_click(patient_id: int): | |
| print(f"*** Predict patient {patient_id} (Example CSV)") | |
| patient = correct_preds[patient_id] if patient_id in correct_preds else wrong_preds[patient_id] | |
| probability = _predict(patient['file_name']) | |
| return [{"dead": probability, "alive": 1-probability}, | |
| patient['label']] | |
| def user_csv_upload(temp_csv_file_path): | |
| print(f"*** Predict patient (User CSV Upload)") | |
| matches = _extract_patient_data_from_name(temp_csv_file_path) | |
| probability = _predict(temp_csv_file_path) | |
| return [{"dead": probability, "alive": 1-probability}, | |
| "(Not Available)" if matches is None else matches[1]] | |
| def do_query(query_str, query_type): | |
| if query_type in ["Diagnosis", "Procedure"]: | |
| str_to_search = f"ICD-9 {query_type} Code " + query_str | |
| return gr.HTML(value=f'<a href="https://www.google.com/search?q={str_to_search}" target="_blank">Google</a>', | |
| visible=True) | |
| else: # Lab Code | |
| query_str = query_str.strip() | |
| if (index := query_str.rfind("_")) >= 0: | |
| query_str = query_str[0:index] | |
| res = D_LABITEMS[D_LABITEMS["ITEMID"] == query_str] | |
| if res.shape[0] == 0: | |
| answer = "(Something wrong. No definition found.)" | |
| elif res.shape[0] == 1: | |
| answer = f"{res['LABEL'].values[0]}-{res['FLUID'].values[0]}-{res['CATEGORY'].values[0]}" | |
| else: | |
| answer=f"(Something wrong. Too many definitions, given code [{query_str}].)" | |
| return gr.HTML(value=answer, | |
| visible=True) | |
| def query_input_change_event(query_str, query_type): | |
| if (query_str is not None and len(query_str.strip())>0 and\ | |
| query_type is not None): | |
| return [gr.Button(interactive=True), gr.HTML(visible=False)] | |
| else: | |
| return [gr.Button(interactive=False), gr.HTML(visible=False)] | |
| resDispPartFuncs = [] | |
| css = \ | |
| """ | |
| #selectFileToUpload {max-height: 180px} | |
| .gradio-container { | |
| background: url(https://www.kindpng.com/picc/m/207-2075829_transparent-healthcare-clipart-medical-report-icon-hd-png.png); | |
| background-position: 80% 85%; | |
| background-repeat: no-repeat; | |
| background-size: 200px; | |
| } | |
| #label-label { | |
| height: 50px !important; | |
| } | |
| #label-label > .container { | |
| height: 50px !important; | |
| } | |
| #label-label > .container > h2 { | |
| //height: 50px !important; | |
| padding: 0 !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| ## Input: | |
| (See examples for file structure) | |
| """ | |
| ) | |
| patient_upload_file = gr.File(label="Upload A Patient", | |
| file_types = ['.csv'], | |
| file_count = "single", | |
| elem_id="selectFileToUpload") | |
| patient_upload_file.upload(fn=user_csv_upload, inputs=patient_upload_file, outputs=None) | |
| gr.Markdown( | |
| """ | |
| ## Examples - Correct Prediction: | |
| """ | |
| ) | |
| with gr.Row(): | |
| for patient_id in correct_preds.keys(): | |
| with gr.Column(variant='panel', | |
| min_width=100): | |
| patient_input_btn = gr.Button(f"Patient {patient_id}", | |
| size="sm") | |
| patient_download_btn = gr.DownloadButton(label="Download", | |
| value=f"{correct_preds[patient_id]['file_name']}", | |
| size="sm") | |
| patient_id_num = gr.Number(value=patient_id, | |
| visible=False) | |
| partFunc = partial(patient_input_btn.click, | |
| fn=example_csv_click, | |
| inputs=patient_id_num, | |
| api_name="predict") | |
| resDispPartFuncs.append(partFunc) | |
| gr.Markdown( | |
| """ | |
| ## Examples - Wrong Prediction: | |
| """ | |
| ) | |
| with gr.Row(): | |
| for patient_id in wrong_preds.keys(): | |
| with gr.Column(variant='panel', | |
| min_width=100): | |
| patient_input_btn = gr.Button(f"Patient {patient_id}", | |
| size="sm") | |
| patient_download_btn = gr.DownloadButton(label="Download", | |
| value=f"{wrong_preds[patient_id]['file_name']}", | |
| size="sm") | |
| patient_id_num = gr.Number(value=patient_id, | |
| visible=False) | |
| partFunc = partial(patient_input_btn.click, | |
| fn=example_csv_click, | |
| inputs=patient_id_num, | |
| api_name="predict") | |
| resDispPartFuncs.append(partFunc) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| ## Mortality Prediction: | |
| In 24 hours after ICU admission. | |
| """ | |
| ) | |
| result_pred = gr.Label(num_top_classes=2, label="Predicted") | |
| result_label = gr.Label(label="Label", elem_id="label-label") | |
| with gr.Accordion("More on Patient Conditions...", open=False): | |
| query_tbx = gr.Textbox(label="Enter one ICD-9 Diagnosis/Procedure Code or Lab Value:", | |
| lines=1, | |
| max_lines=1, placeholder="00869 for 'Other viral intes infec' (Diagnosis)") | |
| query_type = gr.Radio(["Diagnosis", "Procedure", "Lab Value"], show_label=False) | |
| query_btn = gr.Button(value="Query", size="sm", interactive=False) | |
| html = gr.HTML("", visible=False) | |
| query_tbx.change(fn=query_input_change_event, inputs=[query_tbx, query_type], outputs=[query_btn, html]) | |
| query_type.change(fn=query_input_change_event, inputs=[query_tbx, query_type], outputs=[query_btn, html]) | |
| query_btn.click(fn=do_query, inputs=[query_tbx, query_type], outputs=html) | |
| with gr.Accordion("More on Technical Details...", open=False): | |
| gr.Markdown( | |
| """ | |
| - Paper: [Variationally Regularized Graph-based Representation Learning for Electronic Health Records (Zhu et al, 2021)](https://arxiv.org/abs/1912.03761) | |
| - Dataset: [MIMIC-III](https://physionet.org/content/mimiciii/1.4/) | |
| - 50,314 records, 10,591 features | |
| - 5,315 positive, 44,999 negative (11.8%) | |
| - Split: 80% training, 10% validation, 10% testing | |
| - Notable points: | |
| - Result: AUPRC 0.7027 (Baseline: 0.118) on Val split | |
| - Variational Regularization, inspired by [Kipf et al., 2016](https://arxiv.org/abs/1611.07308) | |
| - Trained on NVIDIA A100 with PyTorch 2.4.0 | |
| - Code on GitHub: [pytorch-variational-gcn-ehr-public](https://github.com/ThachNgocTran/pytorch-variational-gcn-ehr-public) | |
| """ | |
| ) | |
| with gr.Accordion("More on Training...", open=False): | |
| gr.HTML(""" | |
| <img src="/file=images/AUPRC_Training_Graph.png" alt=""> | |
| """) | |
| for partialFunc in resDispPartFuncs: | |
| partialFunc(outputs=[result_pred, result_label]) | |
| demo.launch(debug=True, allowed_paths=["images/."]) | |