|
import functools |
|
import os |
|
import re |
|
from collections import OrderedDict |
|
from typing import Dict, List, Optional |
|
|
|
import duckdb |
|
import pandas as pd |
|
import torch |
|
|
|
from .tables import ed_cxr_token_type_ids, ed_module_tables, mimic_cxr_tables |
|
|
|
|
|
def mimic_cxr_text_path(dir, subject_id, study_id, ext='txt'): |
|
return os.path.join(dir, 'p' + str(subject_id)[:2], 'p' + str(subject_id), |
|
's' + str(study_id) + '.' + ext) |
|
|
|
def format(text): |
|
|
|
text = re.sub(r'\n|\t', ' ', text) |
|
text = re.sub(r'\s+', ' ', text) |
|
text = text.strip() |
|
return text |
|
|
|
|
|
def rgetattr(obj, attr, *args): |
|
def _getattr(obj, attr): |
|
return getattr(obj, attr, *args) |
|
return functools.reduce(_getattr, [obj] + attr.split('.')) |
|
|
|
|
|
def df_to_tensor_index_columns( |
|
df: pd.DataFrame, |
|
tensor: torch.Tensor, |
|
group_idx_to_y_idx: Dict, |
|
groupby: str, |
|
index_columns: List[str], |
|
): |
|
""" |
|
Converts a dataframe with index columns to a tensor, where each index of the y-axis is determined by the |
|
'groupby' column. |
|
""" |
|
assert len(group_idx_to_y_idx) == tensor.shape[0] |
|
all_columns = index_columns + [groupby] |
|
y_indices = [group_idx_to_y_idx[row[groupby]] for _, row in df[all_columns].iterrows() for i in index_columns if row[i] == row[i]] |
|
x_indices = [row[i] for _, row in df[all_columns].iterrows() for i in index_columns if row[i] == row[i]] |
|
tensor[y_indices, x_indices] = 1.0 |
|
return tensor |
|
|
|
|
|
def df_to_tensor_value_columns( |
|
df: pd.DataFrame, |
|
tensor: torch.Tensor, |
|
group_idx_to_y_idx: Dict, |
|
groupby: str, |
|
value_columns: List[str], |
|
value_column_to_idx: Dict, |
|
): |
|
""" |
|
Converts a dataframe with value columns to a tensor, where each index of the y-axis is determined by the |
|
'groupby' column. The x-index is determined by a dictionary using the column name. |
|
""" |
|
assert len(group_idx_to_y_idx) == tensor.shape[0] |
|
all_columns = value_columns + [groupby] |
|
y_indices = [group_idx_to_y_idx[row[groupby]] for _, row in df[all_columns].iterrows() for i in value_columns if row[i] == row[i]] |
|
x_indices = [value_column_to_idx[i] for _, row in df[all_columns].iterrows() for i in value_columns if row[i] == row[i]] |
|
element_value = [row[i] for _, row in df[all_columns].iterrows() for i in value_columns if row[i] == row[i]] |
|
tensor[y_indices, x_indices] = torch.tensor(element_value, dtype=tensor.dtype) |
|
return tensor |
|
|
|
|
|
class EDCXRSubjectRecords: |
|
def __init__( |
|
self, |
|
database_path: str, |
|
dataset_dir: Optional[str] = None, |
|
reports_dir: Optional[str] = None, |
|
token_type_ids_starting_idx: Optional[int] = None, |
|
time_delta_map = lambda x: x, |
|
debug: bool = False |
|
): |
|
|
|
self.database_path = database_path |
|
self.dataset_dir = dataset_dir |
|
self.reports_dir = reports_dir |
|
self.time_delta_map = time_delta_map |
|
self.debug = debug |
|
|
|
self.connect = duckdb.connect(self.database_path, read_only=True) |
|
|
|
self.streamlit_flag = False |
|
|
|
self.clear_start_end_times() |
|
|
|
|
|
self.ed_module_tables = ed_module_tables |
|
self.mimic_cxr_tables = mimic_cxr_tables |
|
|
|
lut_info = self.connect.sql("FROM lut_info").df() |
|
|
|
for k, v in (self.ed_module_tables | self.mimic_cxr_tables).items(): |
|
if v.load and (v.value_columns or v.index_columns): |
|
v.value_column_to_idx = {} |
|
if v.index_columns: |
|
v.total_indices = lut_info[lut_info['table_name'] == k]['end_index'].item() + 1 |
|
else: |
|
v.total_indices = 0 |
|
for i in v.value_columns: |
|
v.value_column_to_idx[i] = v.total_indices |
|
v.total_indices += 1 |
|
|
|
|
|
self.token_type_to_token_type_id = ed_cxr_token_type_ids |
|
if token_type_ids_starting_idx is not None: |
|
self.token_type_to_token_type_id = {k: v + token_type_ids_starting_idx for k, v in self.token_type_to_token_type_id.items()} |
|
|
|
def __len__(self): |
|
return len(self.subject_ids) |
|
|
|
def clear_start_end_times(self): |
|
self.start_time, self.end_time = None, None |
|
|
|
def admission_ed_stay_ids(self, hadm_id): |
|
if hadm_id: |
|
return self.connect.sql(f'SELECT stay_id FROM edstays WHERE subject_id = {self.subject_id} AND hadm_id = {hadm_id}').df()['stay_id'].tolist() |
|
else: |
|
return self.connect.sql(f'SELECT stay_id FROM edstays WHERE subject_id = {self.subject_id}').df()['stay_id'].tolist() |
|
|
|
def subject_study_ids(self): |
|
mimic_cxr = self.connect.sql( |
|
f'SELECT study_id, study_datetime FROM mimic_cxr WHERE subject_id = {self.subject_id}', |
|
).df() |
|
if self.start_time and self.end_time: |
|
mimic_cxr = self.filter_admissions_by_time_span(mimic_cxr, 'study_datetime') |
|
mimic_cxr = mimic_cxr.drop_duplicates(subset=['study_id']).sort_values(by='study_datetime') |
|
return dict(zip(mimic_cxr['study_id'], mimic_cxr['study_datetime'])) |
|
|
|
def load_ed_module(self, hadm_id=None, stay_id=None, reference_time=None): |
|
if not self.start_time and stay_id is not None: |
|
edstay = self.connect.sql( |
|
f""" |
|
SELECT intime, outtime |
|
FROM edstays |
|
WHERE stay_id = {stay_id} |
|
""" |
|
).df() |
|
self.start_time = edstay['intime'].item() |
|
self.end_time = edstay['outtime'].item() |
|
self.load_module(self.ed_module_tables, hadm_id=hadm_id, stay_id=stay_id, reference_time=reference_time) |
|
|
|
def load_mimic_cxr(self, study_id, reference_time=None): |
|
self.load_module(self.mimic_cxr_tables, study_id=study_id, reference_time=reference_time) |
|
if self.streamlit_flag: |
|
self.report_path = mimic_cxr_text_path(self.reports_dir, self.subject_id, study_id, 'txt') |
|
|
|
def load_module(self, module_dict, hadm_id=None, stay_id=None, study_id=None, reference_time=None): |
|
for k, v in module_dict.items(): |
|
|
|
if self.streamlit_flag or v.load: |
|
|
|
query = f"FROM {k}" |
|
|
|
conditions = [] |
|
if hasattr(self, 'subject_id') and v.subject_id_filter: |
|
conditions.append(f"subject_id={self.subject_id}") |
|
if v.hadm_id_filter: |
|
assert hadm_id is not None |
|
conditions.append(f"hadm_id={hadm_id}") |
|
if v.stay_id_filter: |
|
assert stay_id is not None |
|
conditions.append(f"stay_id={stay_id}") |
|
if v.study_id_filter: |
|
assert study_id is not None |
|
conditions.append(f"study_id={study_id}") |
|
if v.mimic_cxr_sectioned: |
|
assert study_id is not None |
|
conditions.append(f"study='s{study_id}'") |
|
ands = ['AND'] * (len(conditions) * 2 - 1) |
|
ands[0::2] = conditions |
|
|
|
if conditions: |
|
query += " WHERE (" |
|
query += ' '.join(ands) |
|
query += ")" |
|
|
|
df = self.connect.sql(query).df() |
|
|
|
if v.load: |
|
|
|
columns = [v.groupby] + v.time_columns + v.index_columns + v.text_columns + v.value_columns + v.target_sections |
|
|
|
|
|
if v.use_start_time: |
|
df['start_time'] = self.start_time |
|
columns += ['start_time'] |
|
|
|
if reference_time is not None: |
|
time_column = v.time_columns[-1] if not v.use_start_time else 'start_time' |
|
|
|
|
|
df = df[df[time_column] < reference_time] |
|
|
|
if self.streamlit_flag: |
|
setattr(self, k, df) |
|
|
|
if v.load: |
|
columns = list(dict.fromkeys(columns)) |
|
df = df.drop(columns=df.columns.difference(columns), axis=1) |
|
setattr(self, f'{k}_feats', df) |
|
|
|
def return_ed_module_features(self, stay_id, reference_time=None): |
|
example_dict = {} |
|
if stay_id is not None: |
|
self.load_ed_module(stay_id=stay_id, reference_time=reference_time) |
|
for k, v in self.ed_module_tables.items(): |
|
if v.load: |
|
|
|
df = getattr(self, f'{k}_feats') |
|
|
|
if self.debug: |
|
example_dict.setdefault('ed_tables', []).append(k) |
|
|
|
if not df.empty: |
|
|
|
assert f'{k}_index_value_feats' not in example_dict |
|
|
|
|
|
time_column = v.time_columns[-1] if not v.use_start_time else 'start_time' |
|
group_idx_to_y_idx, group_idx_to_datetime = OrderedDict(), OrderedDict() |
|
groups = df.dropna(subset=v.index_columns + v.value_columns + v.text_columns, axis=0, how='all') |
|
groups = groups.drop_duplicates(subset=[v.groupby])[list(dict.fromkeys([v.groupby, time_column]))] |
|
groups = groups.reset_index(drop=True) |
|
for i, row in groups.iterrows(): |
|
group_idx_to_y_idx[row[v.groupby]] = i |
|
group_idx_to_datetime[row[v.groupby]] = row[time_column] |
|
|
|
if (v.index_columns or v.value_columns) and group_idx_to_y_idx: |
|
example_dict[f'{k}_index_value_feats'] = torch.zeros(len(group_idx_to_y_idx), v.total_indices) |
|
if v.index_columns: |
|
example_dict[f'{k}_index_value_feats'] = df_to_tensor_index_columns( |
|
df=df, |
|
tensor=example_dict[f'{k}_index_value_feats'], |
|
group_idx_to_y_idx=group_idx_to_y_idx, |
|
groupby=v.groupby, |
|
index_columns=v.index_columns, |
|
) |
|
if v.value_columns: |
|
example_dict[f'{k}_index_value_feats'] = df_to_tensor_value_columns( |
|
df=df, |
|
tensor=example_dict[f'{k}_index_value_feats'], |
|
group_idx_to_y_idx=group_idx_to_y_idx, |
|
groupby=v.groupby, |
|
value_columns=v.value_columns, |
|
value_column_to_idx=v.value_column_to_idx |
|
) |
|
|
|
example_dict[f'{k}_index_value_token_type_ids'] = torch.full( |
|
[example_dict[f'{k}_index_value_feats'].shape[0]], |
|
self.token_type_to_token_type_id[k], |
|
dtype=torch.long, |
|
) |
|
|
|
event_times = list(group_idx_to_datetime.values()) |
|
assert all([i == i for i in event_times]) |
|
time_delta = [self.compute_time_delta(i, reference_time) for i in event_times] |
|
example_dict[f'{k}_index_value_time_delta'] = torch.tensor(time_delta)[:, None] |
|
|
|
assert example_dict[f'{k}_index_value_feats'].shape[0] == example_dict[f'{k}_index_value_time_delta'].shape[0] |
|
|
|
if v.text_columns: |
|
for j in group_idx_to_y_idx.keys(): |
|
group_text = df[df[v.groupby] == j] |
|
for i in v.text_columns: |
|
|
|
column_text = [format(k) for k in list(dict.fromkeys(group_text[i].tolist())) if k is not None] |
|
|
|
if column_text: |
|
|
|
example_dict.setdefault(f'{k}_{i}', []).append(f"{', '.join(column_text)}.") |
|
|
|
event_times = group_text[time_column].iloc[0] |
|
time_delta = self.compute_time_delta(event_times, reference_time, to_tensor=False) |
|
example_dict.setdefault(f'{k}_{i}_time_delta', []).append(time_delta) |
|
|
|
return example_dict |
|
|
|
def return_mimic_cxr_features(self, study_id, reference_time=None): |
|
example_dict = {} |
|
if study_id is not None: |
|
self.load_mimic_cxr(study_id=study_id, reference_time=reference_time) |
|
for k, v in self.mimic_cxr_tables.items(): |
|
if v.load: |
|
|
|
df = getattr(self, f'{k}_feats') |
|
|
|
if self.debug: |
|
example_dict.setdefault('mimic_cxr_inputs', []).append(k) |
|
|
|
if not df.empty: |
|
|
|
|
|
group_idx_to_y_idx = OrderedDict() |
|
groups = df.dropna( |
|
subset=v.index_columns + v.value_columns + v.text_columns + v.target_sections, |
|
axis=0, |
|
how='all' |
|
) |
|
groups = groups.drop_duplicates(subset=[v.groupby])[[v.groupby]] |
|
groups = groups.reset_index(drop=True) |
|
for i, row in groups.iterrows(): |
|
group_idx_to_y_idx[row[v.groupby]] = i |
|
|
|
if v.index_columns and group_idx_to_y_idx: |
|
|
|
example_dict[f'{k}_index_value_feats'] = torch.zeros(len(group_idx_to_y_idx), v.total_indices) |
|
if v.index_columns: |
|
example_dict[f'{k}_index_value_feats'] = df_to_tensor_index_columns( |
|
df=df, |
|
tensor=example_dict[f'{k}_index_value_feats'], |
|
group_idx_to_y_idx=group_idx_to_y_idx, |
|
groupby=v.groupby, |
|
index_columns=v.index_columns, |
|
) |
|
|
|
example_dict[f'{k}_index_value_token_type_ids'] = torch.full( |
|
[example_dict[f'{k}_index_value_feats'].shape[0]], |
|
self.token_type_to_token_type_id[k], |
|
dtype=torch.long, |
|
) |
|
|
|
if v.text_columns: |
|
for j in group_idx_to_y_idx.keys(): |
|
group_text = df[df[v.groupby] == j] |
|
for i in v.text_columns: |
|
column_text = [format(k) for k in list(dict.fromkeys(group_text[i].tolist())) if k is not None] |
|
if column_text: |
|
example_dict.setdefault(f'{i}', []).append(f"{', '.join(column_text)}.") |
|
|
|
if v.target_sections: |
|
for j in group_idx_to_y_idx.keys(): |
|
group_text = df[df[v.groupby] == j] |
|
for i in v.target_sections: |
|
column_text = [format(k) for k in list(dict.fromkeys(group_text[i].tolist())) if k is not None] |
|
assert len(column_text) == 1 |
|
example_dict[i] = column_text[-1] |
|
|
|
return example_dict |
|
|
|
def compute_time_delta(self, event_time, reference_time, denominator = 3600, to_tensor=True): |
|
""" |
|
How to we transform time-delta inputs? It appears that minutes are used as the input to |
|
a weight matrix in "Self-Supervised Transformer for Sparse and Irregularly Sampled Multivariate |
|
Clinical Time-Series". This is almost confirmed by the CVE class defined here: |
|
https://github.com/sindhura97/STraTS/blob/main/strats_notebook.ipynb, where the input has |
|
a size of one. |
|
""" |
|
time_delta = reference_time - event_time |
|
time_delta = time_delta.total_seconds() / (denominator) |
|
assert isinstance(time_delta, float), f'time_delta should be float, not {type(time_delta)}.' |
|
if time_delta < 0: |
|
raise ValueError(f'time_delta should be greater than or equal to zero, not {time_delta}.') |
|
time_delta = self.time_delta_map(time_delta) |
|
if to_tensor: |
|
time_delta = torch.tensor(time_delta) |
|
return time_delta |
|
|
|
def filter_admissions_by_time_span(self, df, time_column): |
|
return df[(df[time_column] > self.start_time) & (df[time_column] <= self.end_time)] |
|
|