Spaces:
Sleeping
Sleeping
from pathlib import Path | |
from glob import glob | |
from functools import partial | |
import numpy as np | |
import torch | |
import gradio as gr | |
import pandas as pd | |
import re | |
from model import VariationalGNN | |
examples_path = "examples" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
correct_preds, wrong_preds = {}, {} | |
condition_lst = pd.read_csv("data/feature.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str) | |
D_LABITEMS = pd.read_csv("data/D_LABITEMS.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str) | |
def load_model(): | |
path = r"models/final_model.pt" | |
kwargs, state = torch.load(path, weights_only=False, map_location=device) | |
model = VariationalGNN(**kwargs).to(device) | |
model.load_state_dict(state) | |
return model | |
model = load_model() | |
def _check_patient_csv_format(df: pd.DataFrame): | |
if not (list(df.columns)[0:2] == ["condition", "value"]): | |
raise gr.Error(f"Column set [{list(df.columns)}]: not expected.", duration=None) | |
if condition_lst["condition"].to_list() != df["condition"].to_list(): | |
raise gr.Error(f"Condition set: not expected.", duration=None) | |
vals = np.sort(df["value"].unique()) | |
if not (vals.ndim == 1 and len(vals) == 2 and all(vals == np.array([0.0, 1.0]))): | |
raise gr.Error(f"Column 'value': contain invalid values.", duration=None) | |
def _extract_patient_data_from_name(csv_file_name: str): | |
patient_file_pat = r"^Patient_(\d+)_\(Label-(alive|dead)\)_\(Predicted-(dead|alive)\).csv$" | |
csv_name = Path(csv_file_name).name | |
matches = re.search(patient_file_pat, csv_name) | |
if matches is None: | |
return None | |
else: | |
return (matches.group(1), matches.group(2), matches.group(3)) | |
def _find_example_csv_files() -> None: | |
all_csv_files = glob(f'{examples_path}/*.csv', recursive=True) | |
if len(all_csv_files) == 0: | |
print("*** No csv files found.") | |
else: | |
for one_csv_file in all_csv_files: | |
matches = _extract_patient_data_from_name(one_csv_file) | |
if matches: | |
pat_id, pat_label, pat_predicted = matches | |
if pat_id in correct_preds or pat_id in wrong_preds: | |
print(f"*** File [{one_csv_file}]: already processed! How come?") | |
else: | |
if pat_label == pat_predicted: | |
correct_preds[pat_id] = {"label": pat_label, | |
"predicted": pat_predicted, | |
"file_name": one_csv_file} | |
else: | |
wrong_preds[pat_id] = {"label": pat_label, | |
"predicted": pat_predicted, | |
"file_name": one_csv_file} | |
else: | |
print(f"*** File [{one_csv_file}]: wrong name.") | |
_find_example_csv_files() | |
def _predict(file_path: str): | |
df = pd.read_csv(f"{file_path}", | |
header="infer", | |
sep=",", | |
encoding="utf-8", | |
dtype={'condition': 'str', 'value': 'float32'}, | |
keep_default_na=False) | |
_check_patient_csv_format(df) | |
patient_data = torch.from_numpy(df["value"].to_numpy()).unsqueeze(dim=0).to(device) | |
model.eval() | |
with torch.inference_mode(): | |
probability, _ = model(patient_data) | |
probability = torch.sigmoid(probability.detach().cpu()[0]).item() | |
return probability | |
def example_csv_click(patient_id: int): | |
print(f"*** Predict patient {patient_id} (Example CSV)") | |
patient = correct_preds[patient_id] if patient_id in correct_preds else wrong_preds[patient_id] | |
probability = _predict(patient['file_name']) | |
return [{"dead": probability, "alive": 1-probability}, | |
patient['label']] | |
def user_csv_upload(temp_csv_file_path): | |
print(f"*** Predict patient (User CSV Upload)") | |
matches = _extract_patient_data_from_name(temp_csv_file_path) | |
probability = _predict(temp_csv_file_path) | |
return [{"dead": probability, "alive": 1-probability}, | |
"(Not Available)" if matches is None else matches[1]] | |
def do_query(query_str, query_type): | |
if query_type in ["Diagnosis", "Procedure"]: | |
str_to_search = f"ICD-9 {query_type} Code " + query_str | |
return gr.HTML(value=f'<a href="https://www.google.com/search?q={str_to_search}" target="_blank">Google</a>', | |
visible=True) | |
else: # Lab Code | |
query_str = query_str.strip() | |
if (index := query_str.rfind("_")) >= 0: | |
query_str = query_str[0:index] | |
res = D_LABITEMS[D_LABITEMS["ITEMID"] == query_str] | |
if res.shape[0] == 0: | |
answer = "(Something wrong. No definition found.)" | |
elif res.shape[0] == 1: | |
answer = f"{res['LABEL'].values[0]}-{res['FLUID'].values[0]}-{res['CATEGORY'].values[0]}" | |
else: | |
answer=f"(Something wrong. Too many definitions, given code [{query_str}].)" | |
return gr.HTML(value=answer, | |
visible=True) | |
def query_input_change_event(query_str, query_type): | |
if (query_str is not None and len(query_str.strip())>0 and\ | |
query_type is not None): | |
return [gr.Button(interactive=True), gr.HTML(visible=False)] | |
else: | |
return [gr.Button(interactive=False), gr.HTML(visible=False)] | |
resDispPartFuncs = [] | |
css = \ | |
""" | |
#selectFileToUpload {max-height: 180px} | |
.gradio-container { | |
background: url(https://www.kindpng.com/picc/m/207-2075829_transparent-healthcare-clipart-medical-report-icon-hd-png.png); | |
background-position: 80% 85%; | |
background-repeat: no-repeat; | |
background-size: 200px; | |
} | |
#label-label { | |
height: 50px !important; | |
} | |
#label-label > .container { | |
height: 50px !important; | |
} | |
#label-label > .container > h2 { | |
//height: 50px !important; | |
padding: 0 !important; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
## Input: | |
(See examples for file structure) | |
""" | |
) | |
patient_upload_file = gr.File(label="Upload A Patient", | |
file_types = ['.csv'], | |
file_count = "single", | |
elem_id="selectFileToUpload") | |
patient_upload_file.upload(fn=user_csv_upload, inputs=patient_upload_file, outputs=None) | |
gr.Markdown( | |
""" | |
## Examples - Correct Prediction: | |
""" | |
) | |
with gr.Row(): | |
for patient_id in correct_preds.keys(): | |
with gr.Column(variant='panel', | |
min_width=100): | |
patient_input_btn = gr.Button(f"Patient {patient_id}", | |
size="sm") | |
patient_download_btn = gr.DownloadButton(label="Download", | |
value=f"{correct_preds[patient_id]['file_name']}", | |
size="sm") | |
patient_id_num = gr.Number(value=patient_id, | |
visible=False) | |
partFunc = partial(patient_input_btn.click, | |
fn=example_csv_click, | |
inputs=patient_id_num, | |
api_name="predict") | |
resDispPartFuncs.append(partFunc) | |
gr.Markdown( | |
""" | |
## Examples - Wrong Prediction: | |
""" | |
) | |
with gr.Row(): | |
for patient_id in wrong_preds.keys(): | |
with gr.Column(variant='panel', | |
min_width=100): | |
patient_input_btn = gr.Button(f"Patient {patient_id}", | |
size="sm") | |
patient_download_btn = gr.DownloadButton(label="Download", | |
value=f"{wrong_preds[patient_id]['file_name']}", | |
size="sm") | |
patient_id_num = gr.Number(value=patient_id, | |
visible=False) | |
partFunc = partial(patient_input_btn.click, | |
fn=example_csv_click, | |
inputs=patient_id_num, | |
api_name="predict") | |
resDispPartFuncs.append(partFunc) | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
## Mortality Prediction: | |
In 24 hours after ICU admission. | |
""" | |
) | |
result_pred = gr.Label(num_top_classes=2, label="Predicted") | |
result_label = gr.Label(label="Label", elem_id="label-label") | |
with gr.Accordion("More on Patient Conditions...", open=False): | |
query_tbx = gr.Textbox(label="Enter one ICD-9 Diagnosis/Procedure Code or Lab Value:", | |
lines=1, | |
max_lines=1, placeholder="00869 for 'Other viral intes infec' (Diagnosis)") | |
query_type = gr.Radio(["Diagnosis", "Procedure", "Lab Value"], show_label=False) | |
query_btn = gr.Button(value="Query", size="sm", interactive=False) | |
html = gr.HTML("", visible=False) | |
query_tbx.change(fn=query_input_change_event, inputs=[query_tbx, query_type], outputs=[query_btn, html]) | |
query_type.change(fn=query_input_change_event, inputs=[query_tbx, query_type], outputs=[query_btn, html]) | |
query_btn.click(fn=do_query, inputs=[query_tbx, query_type], outputs=html) | |
with gr.Accordion("More on Technical Details...", open=False): | |
gr.Markdown( | |
""" | |
- Paper: [Variationally Regularized Graph-based Representation Learning for Electronic Health Records (Zhu et al, 2021)](https://arxiv.org/abs/1912.03761) | |
- Dataset: [MIMIC-III](https://physionet.org/content/mimiciii/1.4/) | |
- 50,314 records, 10,591 features | |
- 5,315 positive, 44,999 negative (11.8%) | |
- Split: 80% training, 10% validation, 10% testing | |
- Notable points: | |
- Result: AUPRC 0.7027 (Baseline: 0.118) on Val split | |
- Variational Regularization, inspired by [Kipf et al., 2016](https://arxiv.org/abs/1611.07308) | |
- Trained on NVIDIA A100 with PyTorch 2.4.0 | |
- Code on GitHub: [pytorch-variational-gcn-ehr-public](https://github.com/ThachNgocTran/pytorch-variational-gcn-ehr-public) | |
""" | |
) | |
with gr.Accordion("More on Training...", open=False): | |
gr.HTML(""" | |
<img src="/file=images/AUPRC_Training_Graph.png" alt=""> | |
""") | |
for partialFunc in resDispPartFuncs: | |
partialFunc(outputs=[result_pred, result_label]) | |
demo.launch(debug=True, allowed_paths=["images/."]) | |