anicolson commited on
Commit
6f7f115
1 Parent(s): a1f73d2

Upload model

Browse files
Files changed (9) hide show
  1. README.md +199 -0
  2. config.json +237 -0
  3. dataset.py +382 -0
  4. generation_config.json +7 -0
  5. model.safetensors +3 -0
  6. modelling_cxrmate_ed.py +1129 -0
  7. modelling_uniformer.py +412 -0
  8. records.py +369 -0
  9. tables.py +159 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MIMICIVEDCXRMultimodalModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoModel": "modelling_cxrmate_ed.MIMICIVEDCXRMultimodalModel"
7
+ },
8
+ "decoder": {
9
+ "_name_or_path": "",
10
+ "add_cross_attention": false,
11
+ "add_time_deltas": true,
12
+ "architectures": null,
13
+ "attention_bias": false,
14
+ "attention_dropout": 0.0,
15
+ "bad_words_ids": null,
16
+ "begin_suppress_tokens": null,
17
+ "bos_token_id": 1,
18
+ "chunk_size_feed_forward": 0,
19
+ "cross_attention_hidden_size": null,
20
+ "decoder_start_token_id": null,
21
+ "diversity_penalty": 0.0,
22
+ "do_sample": false,
23
+ "early_stopping": false,
24
+ "ed_module_columns": [
25
+ "triage_chiefcomplaint",
26
+ "triage_pain",
27
+ "vitalsign_pain"
28
+ ],
29
+ "encoder_no_repeat_ngram_size": 0,
30
+ "eos_token_id": 2,
31
+ "exponential_decay_length_penalty": null,
32
+ "finetuning_task": null,
33
+ "forced_bos_token_id": null,
34
+ "forced_eos_token_id": null,
35
+ "hidden_act": "silu",
36
+ "hidden_size": 768,
37
+ "id2label": {
38
+ "0": "LABEL_0",
39
+ "1": "LABEL_1"
40
+ },
41
+ "include_time_delta": true,
42
+ "index_value_encoder_config": {
43
+ "edstays": 40,
44
+ "triage": 7,
45
+ "vitalsign": 1177
46
+ },
47
+ "index_value_encoder_intermediate_size": 2048,
48
+ "initializer_range": 0.02,
49
+ "intermediate_size": 3072,
50
+ "is_decoder": true,
51
+ "is_encoder_decoder": false,
52
+ "label2id": {
53
+ "LABEL_0": 0,
54
+ "LABEL_1": 1
55
+ },
56
+ "length_penalty": 1.0,
57
+ "max_length": 20,
58
+ "max_position_embeddings": 2048,
59
+ "mimic_cxr_columns": [
60
+ "indication",
61
+ "history"
62
+ ],
63
+ "min_length": 0,
64
+ "model_type": "llama",
65
+ "no_repeat_ngram_size": 0,
66
+ "num_attention_heads": 12,
67
+ "num_beam_groups": 1,
68
+ "num_beams": 1,
69
+ "num_hidden_layers": 6,
70
+ "num_key_value_heads": 12,
71
+ "num_return_sequences": 1,
72
+ "num_token_types": 19,
73
+ "output_attentions": false,
74
+ "output_hidden_states": false,
75
+ "output_scores": false,
76
+ "pad_token_id": 4,
77
+ "prefix": null,
78
+ "pretraining_tp": 1,
79
+ "problem_type": null,
80
+ "pruned_heads": {},
81
+ "remove_invalid_values": false,
82
+ "repetition_penalty": 1.0,
83
+ "return_dict": true,
84
+ "return_dict_in_generate": false,
85
+ "rms_norm_eps": 1e-06,
86
+ "rope_scaling": null,
87
+ "rope_theta": 10000.0,
88
+ "sep_token_id": null,
89
+ "suppress_tokens": null,
90
+ "task_specific_params": null,
91
+ "temperature": 1.0,
92
+ "tf_legacy_loss": false,
93
+ "tie_encoder_decoder": false,
94
+ "tie_word_embeddings": false,
95
+ "time_delta_monotonic_inversion": true,
96
+ "token_type_to_token_type_id": {
97
+ "comparison": 15,
98
+ "edstays": 1,
99
+ "findings": 12,
100
+ "history": 11,
101
+ "image": 14,
102
+ "impression": 13,
103
+ "indication": 10,
104
+ "medrecon": 0,
105
+ "medrecon_name": 6,
106
+ "mimic_cxr_2_0_0_metadata": 5,
107
+ "previous_findings": 16,
108
+ "previous_image": 18,
109
+ "previous_impression": 17,
110
+ "pyxis": 4,
111
+ "triage": 2,
112
+ "triage_chiefcomplaint": 7,
113
+ "triage_pain": 8,
114
+ "vitalsign": 3,
115
+ "vitalsign_pain": 9
116
+ },
117
+ "tokenizer_class": null,
118
+ "top_k": 50,
119
+ "top_p": 1.0,
120
+ "torch_dtype": null,
121
+ "torchscript": false,
122
+ "typical_p": 1.0,
123
+ "use_bfloat16": false,
124
+ "use_cache": true,
125
+ "vocab_size": 30000,
126
+ "zero_time_delta_value": 1.0
127
+ },
128
+ "encoder": {
129
+ "_name_or_path": "",
130
+ "add_cross_attention": false,
131
+ "architectures": null,
132
+ "attention_probs_dropout_prob": 0.0,
133
+ "attn_drop_rate": 0.0,
134
+ "bad_words_ids": null,
135
+ "begin_suppress_tokens": null,
136
+ "bos_token_id": null,
137
+ "chunk_size_feed_forward": 0,
138
+ "conv_stem": false,
139
+ "cross_attention_hidden_size": null,
140
+ "decoder_start_token_id": null,
141
+ "depth": [
142
+ 5,
143
+ 8,
144
+ 20,
145
+ 7
146
+ ],
147
+ "diversity_penalty": 0.0,
148
+ "do_sample": false,
149
+ "drop_path_rate": 0.3,
150
+ "drop_rate": 0.0,
151
+ "early_stopping": false,
152
+ "embed_dim": [
153
+ 64,
154
+ 128,
155
+ 320,
156
+ 512
157
+ ],
158
+ "encoder_no_repeat_ngram_size": 0,
159
+ "encoder_stride": 16,
160
+ "eos_token_id": null,
161
+ "exponential_decay_length_penalty": null,
162
+ "finetuning_task": null,
163
+ "forced_bos_token_id": null,
164
+ "forced_eos_token_id": null,
165
+ "head_dim": 64,
166
+ "hidden_act": "gelu",
167
+ "hidden_dropout_prob": 0.0,
168
+ "hidden_size": 768,
169
+ "id2label": {
170
+ "0": "LABEL_0",
171
+ "1": "LABEL_1"
172
+ },
173
+ "image_size": 384,
174
+ "in_chans": 3,
175
+ "initializer_range": 0.02,
176
+ "intermediate_size": 3072,
177
+ "is_decoder": false,
178
+ "is_encoder_decoder": false,
179
+ "label2id": {
180
+ "LABEL_0": 0,
181
+ "LABEL_1": 1
182
+ },
183
+ "layer_norm_eps": 1e-06,
184
+ "length_penalty": 1.0,
185
+ "max_length": 20,
186
+ "min_length": 0,
187
+ "mlp_ratio": 4,
188
+ "model_type": "vit",
189
+ "no_repeat_ngram_size": 0,
190
+ "num_attention_heads": 12,
191
+ "num_beam_groups": 1,
192
+ "num_beams": 1,
193
+ "num_channels": 3,
194
+ "num_classes": 1000,
195
+ "num_hidden_layers": 12,
196
+ "num_return_sequences": 1,
197
+ "output_attentions": false,
198
+ "output_hidden_states": false,
199
+ "output_scores": false,
200
+ "pad_token_id": null,
201
+ "patch_size": [
202
+ 4,
203
+ 2,
204
+ 2,
205
+ 2
206
+ ],
207
+ "prefix": null,
208
+ "problem_type": null,
209
+ "projection_size": 768,
210
+ "pruned_heads": {},
211
+ "qk_scale": null,
212
+ "qkv_bias": true,
213
+ "remove_invalid_values": false,
214
+ "repetition_penalty": 1.0,
215
+ "representation_size": null,
216
+ "return_dict": true,
217
+ "return_dict_in_generate": false,
218
+ "sep_token_id": null,
219
+ "suppress_tokens": null,
220
+ "task_specific_params": null,
221
+ "temperature": 1.0,
222
+ "tf_legacy_loss": false,
223
+ "tie_encoder_decoder": false,
224
+ "tie_word_embeddings": true,
225
+ "tokenizer_class": null,
226
+ "top_k": 50,
227
+ "top_p": 1.0,
228
+ "torch_dtype": null,
229
+ "torchscript": false,
230
+ "typical_p": 1.0,
231
+ "use_bfloat16": false
232
+ },
233
+ "model_type": "vision-encoder-decoder",
234
+ "tie_word_embeddings": false,
235
+ "torch_dtype": "float32",
236
+ "transformers_version": "4.39.0"
237
+ }
dataset.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import struct
3
+
4
+ import lmdb
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+ from torchvision.io import decode_image, read_image
10
+
11
+ from data.mimic_cxr.dcm_processing import load_and_preprocess_dcm_uint16
12
+ from tools.mimic_iv.ed_cxr.records import EDCXRSubjectRecords
13
+ from tools.utils import mimic_cxr_image_path
14
+
15
+ # Ordered by oblique, lateral, AP, and then PA views so that PA views are closest in position to the generated tokens (and oblique is furtherest).
16
+ VIEW_ORDER = ['LPO', 'RAO', 'LAO', 'SWIMMERS', 'XTABLE LATERAL', 'LL', 'LATERAL', 'AP AXIAL', 'AP RLD', 'AP LLD', 'AP', 'PA RLD', 'PA LLD', 'PA']
17
+
18
+
19
+ class StudyIDEDStayIDSubset(Dataset):
20
+ """
21
+ Study ID & ED stay ID subset. Examples are indexed by the study identifier.
22
+ Information from the ED module is added by finding the study_id that is within
23
+ the timespan of the stay_id for the subject_id. The history and indication
24
+ sections are also included.
25
+ """
26
+ def __init__(
27
+ self,
28
+ mimic_iv_duckdb_path,
29
+ split,
30
+ dataset_dir=None,
31
+ max_images_per_study=None,
32
+ transforms=None,
33
+ images=True,
34
+ columns='study_id, dicom_id, subject_id, findings, impression',
35
+ and_condition='',
36
+ records=None,
37
+ study_id_inclusion_list=None,
38
+ return_images=True,
39
+ ed_module=True,
40
+ extension='jpg',
41
+ images_rocksdb_path=None,
42
+ jpg_lmdb_path=None,
43
+ jpg_rocksdb_path=None,
44
+ ):
45
+ """
46
+ Argument/s:
47
+ mimic_iv_duckdb_path - Path to MIMIC-IV DuckDB database.
48
+ split - 'train', 'validate', or 'test'.
49
+ dataset_dir - Dataset directory.
50
+ max_images_per_study - the maximum number of images per study.
51
+ transforms - torchvision transformations.
52
+ colour_space - PIL target colour space.
53
+ images - flag to return processed images.
54
+ columns - which columns to query on.
55
+ and_condition - AND condition to add to the SQL query.
56
+ records - MIMIC-IV records class instance.
57
+ study_id_inclusion_list - studies not in this list are excluded.
58
+ return_images - return CXR images for the study as tensors.
59
+ ed_module - use the ED module.
60
+ extension - 'jpg' or 'dcm'.
61
+ images_rocksdb_path - path to image RocksDB database.
62
+ jpg_lmdb_path - path to LMDB .jpg database.
63
+ jpg_rocksdb_path - path to RocksDB .jpg database.
64
+ """
65
+ super(StudyIDEDStayIDSubset, self).__init__()
66
+ self.split = split
67
+ self.dataset_dir = dataset_dir
68
+ self.max_images_per_study = max_images_per_study
69
+ self.transforms = transforms
70
+ self.images = images
71
+ self.columns = columns
72
+ self.and_condition = and_condition
73
+ self.return_images = return_images
74
+ self.ed_module = ed_module
75
+ self.extension = extension
76
+ self.images_rocksdb_path = images_rocksdb_path
77
+ self.jpg_lmdb_path = jpg_lmdb_path
78
+ self.jpg_rocksdb_path = jpg_rocksdb_path
79
+
80
+ # If max images per study is not set:
81
+ self.max_images_per_study = float('inf') if self.max_images_per_study is None else self.max_images_per_study
82
+
83
+ assert self.extension == 'jpg' or self.extension == 'dcm'
84
+
85
+ if self.dataset_dir is not None and self.images_rocksdb_path is None:
86
+ if self.extension == 'jpg':
87
+ if 'physionet.org/files/mimic-cxr-jpg/2.0.0/files' not in self.dataset_dir:
88
+ self.dataset_dir = os.path.join(self.dataset_dir, 'physionet.org/files/mimic-cxr-jpg/2.0.0/files')
89
+ elif self.extension == 'dcm':
90
+ if 'physionet.org/files/mimic-cxr/2.0.0/files' not in self.dataset_dir:
91
+ self.dataset_dir = os.path.join(self.dataset_dir, 'physionet.org/files/mimic-cxr/2.0.0/files')
92
+
93
+ # Open the RocksDB images database:
94
+ if self.images_rocksdb_path is not None:
95
+ import rocksdb
96
+
97
+ # Define the column families:
98
+ column_families = {
99
+ b'shape': rocksdb.ColumnFamilyOptions(),
100
+ b'image': rocksdb.ColumnFamilyOptions(),
101
+ }
102
+
103
+ opts = rocksdb.Options()
104
+ opts.max_open_files = 1e+5
105
+ self.images_db = rocksdb.DB(self.images_rocksdb_path, opts, column_families=column_families, read_only=True)
106
+
107
+ self.shape_handle = self.images_db.get_column_family(b'shape')
108
+ self.image_handle = self.images_db.get_column_family(b'image')
109
+
110
+ self.shape_dtype = np.int32
111
+ self.image_dtype = np.uint16
112
+
113
+ # Prepare the RocksDB .jpg database:
114
+ if self.jpg_rocksdb_path is not None:
115
+ import rocksdb
116
+
117
+ opts = rocksdb.Options()
118
+ opts.max_open_files = 1e+5
119
+
120
+ self.images_db = rocksdb.DB(self.jpg_rocksdb_path, opts, read_only=True)
121
+
122
+ # Prepare the LMDB .jpg database:
123
+ if self.jpg_lmdb_path is not None:
124
+
125
+ print('Loading images using LMDB.')
126
+
127
+ # Map size:
128
+ map_size = int(0.65 * (1024 ** 4))
129
+ assert isinstance(map_size, int)
130
+
131
+ self.env = lmdb.open(self.jpg_lmdb_path, map_size=map_size, lock=False, readonly=True)
132
+ self.txn = self.env.begin(write=False)
133
+
134
+ self.records = EDCXRSubjectRecords(database_path=mimic_iv_duckdb_path) if records is None else records
135
+
136
+ query = f"""
137
+ SELECT {columns}
138
+ FROM mimic_cxr
139
+ WHERE split = '{split}'
140
+ {and_condition}
141
+ ORDER BY study_id
142
+ """
143
+
144
+ # For multi-image, the study identifiers make up the training examples:
145
+ df = self.records.connect.sql(query).df()
146
+
147
+ # Drop studies that don't have a findings or impression section:
148
+ df = df.dropna(subset=['findings', 'impression'], how='any')
149
+
150
+ # This study has two rows in edstays (removed as it causes issues):
151
+ if self.ed_module:
152
+ df = df[df['study_id'] != 59128861]
153
+
154
+ # Exclude studies not in list:
155
+ if study_id_inclusion_list is not None:
156
+ df = df[df['study_id'].isin(study_id_inclusion_list)]
157
+
158
+ # Example study identifiers for the subset:
159
+ self.examples = df['study_id'].unique().tolist()
160
+
161
+ # Record statistics:
162
+ self.num_study_ids = len(self.examples)
163
+ self.num_dicom_ids = len(df['dicom_id'].unique().tolist())
164
+ self.num_subject_ids = len(df['subject_id'].unique().tolist())
165
+
166
+ def __len__(self):
167
+ return self.num_study_ids
168
+
169
+ def __getitem__(self, index):
170
+
171
+ study_id = self.examples[index]
172
+
173
+ # Get the study:
174
+ study = self.records.connect.sql(
175
+ f"""
176
+ SELECT dicom_id, study_id, subject_id, study_datetime, ViewPosition
177
+ FROM mimic_cxr
178
+ WHERE (study_id = {study_id});
179
+ """
180
+ ).df()
181
+ subject_id = study.iloc[0, study.columns.get_loc('subject_id')]
182
+ study_id = study.iloc[0, study.columns.get_loc('study_id')]
183
+ study_datetime = study['study_datetime'].max()
184
+
185
+ example_dict = {
186
+ 'study_ids': study_id,
187
+ 'subject_id': subject_id,
188
+ 'index': index,
189
+ }
190
+
191
+ example_dict.update(self.records.return_mimic_cxr_features(study_id))
192
+
193
+ if self.ed_module:
194
+ edstays = self.records.connect.sql(
195
+ f"""
196
+ SELECT stay_id, intime, outtime
197
+ FROM edstays
198
+ WHERE (subject_id = {subject_id})
199
+ AND intime < '{study_datetime}'
200
+ AND outtime > '{study_datetime}';
201
+ """
202
+ ).df()
203
+
204
+ assert len(edstays) <= 1
205
+ stay_id = edstays.iloc[0, edstays.columns.get_loc('stay_id')] if not edstays.empty else None
206
+ self.records.clear_start_end_times()
207
+ example_dict.update(self.records.return_ed_module_features(stay_id, study_datetime))
208
+
209
+ example_dict['stay_ids'] = stay_id
210
+
211
+ if self.return_images:
212
+ example_dict['images'], example_dict['image_time_deltas'] = self.get_images(study, study_datetime)
213
+
214
+ return example_dict
215
+
216
+ def get_images(self, example, reference_time):
217
+ """
218
+ Get the image/s for a given example.
219
+
220
+ Argument/s:
221
+ example - dataframe for the example.
222
+ reference_time - reference_time for time delta.
223
+
224
+ Returns:
225
+ The image/s for the example
226
+ """
227
+
228
+ # Sample if over max_images_per_study. Only allowed during training:
229
+ if len(example) > self.max_images_per_study:
230
+ assert self.split == 'train'
231
+ example = example.sample(n=self.max_images_per_study, axis=0)
232
+
233
+ # Order by ViewPostion:
234
+ example['ViewPosition'] = example['ViewPosition'].astype(pd.CategoricalDtype(categories=VIEW_ORDER, ordered=True))
235
+
236
+ # Sort the DataFrame based on the categorical column
237
+ example = example.sort_values(by=['study_datetime', 'ViewPosition'])
238
+
239
+ # Load and pre-process each CXR:
240
+ images, time_deltas = [], []
241
+ for _, row in example.iterrows():
242
+ images.append(
243
+ self.load_and_preprocess_image(
244
+ row['subject_id'],
245
+ row['study_id'],
246
+ row['dicom_id'],
247
+ ),
248
+ )
249
+ time_deltas.append(self.records.compute_time_delta(row['study_datetime'], reference_time, to_tensor=False))
250
+
251
+ if self.transforms is not None:
252
+ images = torch.stack(images, 0)
253
+ return images, time_deltas
254
+
255
+ def load_and_preprocess_image(self, subject_id, study_id, dicom_id):
256
+ """
257
+ Load and preprocess an image using torchvision.transforms.v2:
258
+ https://pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_getting_started.html#sphx-glr-auto-examples-transforms-plot-transforms-getting-started-py
259
+
260
+ Argument/s:
261
+ subject_id - subject identifier.
262
+ study_id - study identifier.
263
+ dicom_id - DICOM identifier.
264
+
265
+ Returns:
266
+ image - Tensor of the CXR.
267
+ """
268
+
269
+ if self.extension == 'jpg':
270
+
271
+ if self.jpg_rocksdb_path is not None:
272
+
273
+ # Convert to bytes:
274
+ key = bytes(dicom_id, 'utf-8')
275
+
276
+ # Retrieve image:
277
+ image = bytearray(self.images_db.get(key))
278
+ image = torch.frombuffer(image, dtype=torch.uint8)
279
+ image = decode_image(image)
280
+
281
+ elif self.jpg_lmdb_path is not None:
282
+
283
+ # Convert to bytes:
284
+ key = bytes(dicom_id, 'utf-8')
285
+
286
+ # Retrieve image:
287
+ image = bytearray(self.txn.get(key))
288
+ image = torch.frombuffer(image, dtype=torch.uint8)
289
+ image = decode_image(image)
290
+
291
+ else:
292
+ image_file_path = mimic_cxr_image_path(self.dataset_dir, subject_id, study_id, dicom_id, self.extension)
293
+ image = read_image(image_file_path)
294
+
295
+ elif self.extension == 'dcm':
296
+ if self.images_rocksdb_path is not None:
297
+
298
+ key = dicom_id.encode('utf-8')
299
+
300
+ # Retrieve the serialized image shape associated with the key:
301
+ shape_bytes = self.images_db.get((self.shape_handle, key), key)
302
+ shape = struct.unpack('iii', shape_bytes)
303
+
304
+ np.frombuffer(shape_bytes, dtype=self.shape_dtype).reshape(3)
305
+
306
+ # Retrieve the serialized image data associated with the key:
307
+ image_bytes = self.images_db.get((self.image_handle, key), key)
308
+ image = np.frombuffer(image_bytes, dtype=self.image_dtype).reshape(*shape)
309
+
310
+ else:
311
+ image_file_path = mimic_cxr_image_path(self.dataset_dir, subject_id, study_id, dicom_id, self.extension)
312
+ image = load_and_preprocess_dcm_uint16(image_file_path)
313
+
314
+ # Convert to a torch tensor:
315
+ image = torch.from_numpy(image)
316
+
317
+ if self.transforms is not None:
318
+ image = self.transforms(image)
319
+
320
+ return image
321
+
322
+
323
+ if __name__ == '__main__':
324
+ import time
325
+
326
+ from tqdm import tqdm
327
+
328
+ num_samples = 20
329
+
330
+ datasets = []
331
+ datasets.append(
332
+ StudyIDEDStayIDSubset(
333
+ dataset_dir='/datasets/work/hb-mlaifsp-mm/work/archive',
334
+ mimic_iv_duckdb_path='/scratch3/nic261/database/mimic_iv_duckdb_rev_b.db',
335
+ split='train',
336
+ extension='jpg',
337
+ ed_module=False,
338
+ ),
339
+ )
340
+
341
+ datasets.append(
342
+ StudyIDEDStayIDSubset(
343
+ dataset_dir='/scratch3/nic261/datasets',
344
+ mimic_iv_duckdb_path='/scratch3/nic261/database/mimic_iv_duckdb_rev_b.db',
345
+ split='train',
346
+ extension='jpg',
347
+ ed_module=False,
348
+ ),
349
+ )
350
+
351
+ datasets.append(
352
+ StudyIDEDStayIDSubset(
353
+ jpg_lmdb_path='/scratch3/nic261/database/mimic_cxr_jpg_lmdb_rev_a.db',
354
+ mimic_iv_duckdb_path='/scratch3/nic261/database/mimic_iv_duckdb_rev_b.db',
355
+ split='train',
356
+ extension='jpg',
357
+ ed_module=False,
358
+ ),
359
+ )
360
+
361
+ datasets.append(
362
+ StudyIDEDStayIDSubset(
363
+ jpg_rocksdb_path='/scratch3/nic261/database/mimic_cxr_jpg_rocksdb.db',
364
+ mimic_iv_duckdb_path='/scratch3/nic261/database/mimic_iv_duckdb_rev_b.db',
365
+ split='train',
366
+ extension='jpg',
367
+ ed_module=False,
368
+ )
369
+ )
370
+
371
+ assert (datasets[1][0]['images'][0] == datasets[2][0]['images'][0]).all().item()
372
+ assert (datasets[1][5]['images'][0] == datasets[2][5]['images'][0]).all().item()
373
+
374
+ for d in datasets:
375
+ start_time = time.time()
376
+ indices = torch.randperm(len(d))[:num_samples] # Get random indices.
377
+ for i in tqdm(indices):
378
+ _ = d[i]
379
+ end_time = time.time()
380
+ elapsed_time = end_time - start_time
381
+ print(f"Elapsed time: {elapsed_time} seconds")
382
+
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 4,
6
+ "transformers_version": "4.39.0"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4b1ed2a5298bb8999cb91a9b905ace6733e5c66ebdef9702baa4d421428fad3
3
+ size 644854104
modelling_cxrmate_ed.py ADDED
@@ -0,0 +1,1129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import functools
3
+ import math
4
+ import os
5
+ import re
6
+ from collections import OrderedDict
7
+ from glob import glob
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+
11
+ import duckdb
12
+ import pandas as pd
13
+ import streamlit as st
14
+ import torch
15
+ import transformers
16
+ from torch.nn import CrossEntropyLoss
17
+ from tqdm import tqdm
18
+ from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.modeling_outputs import Seq2SeqLMOutput
21
+ from transformers.modeling_utils import PreTrainedModel
22
+ from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import (
23
+ VisionEncoderDecoderConfig,
24
+ )
25
+ from transformers.utils import logging
26
+
27
+ from .dataset import StudyIDEDStayIDSubset
28
+ from .modelling_uniformer import MultiUniFormerWithProjectionHead
29
+ from .records import EDCXRSubjectRecords
30
+ from .tables import ed_module_tables
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ def create_lookup_table(df, columns, start_idx):
36
+ df = df.groupby(columns).head(1)[columns].sort_values(by=columns)
37
+ indices = range(start_idx, start_idx + len(df))
38
+ df['index'] = indices
39
+ return df, indices[-1]
40
+
41
+
42
+ class FNNEncoder(torch.nn.Module):
43
+ def __init__(self, num_features, intermediate_size, decoder_hidden_size):
44
+ super().__init__()
45
+ self.up_proj = torch.nn.Linear(num_features, intermediate_size, bias=False)
46
+ self.down_proj = torch.nn.Linear(intermediate_size, decoder_hidden_size, bias=False)
47
+ self.act_fn = torch.nn.SiLU()
48
+
49
+ def forward(self, x):
50
+ return self.down_proj(self.act_fn(self.up_proj(x)))
51
+
52
+
53
+ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
54
+
55
+ config_class = VisionEncoderDecoderConfig
56
+ base_model_prefix = "vision_encoder_decoder"
57
+ main_input_name = "input_ids"
58
+ supports_gradient_checkpointing = True
59
+
60
+ def __init__(
61
+ self,
62
+ config: Optional[PretrainedConfig] = None,
63
+ encoder: Optional[PreTrainedModel] = None,
64
+ decoder: Optional[PreTrainedModel] = None,
65
+ DefaultEncoderClass = MultiUniFormerWithProjectionHead,
66
+ DefaultDecoderClass = transformers.LlamaForCausalLM,
67
+ ):
68
+
69
+ if decoder:
70
+ assert not decoder.config.add_cross_attention, '"add_cross_attention" must be False for the given decoder'
71
+ assert decoder.config.is_decoder, '"is_decoder" must be True for the given decoder'
72
+
73
+ if config is None and (encoder is None or decoder is None):
74
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
75
+ if config is None:
76
+ config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
77
+ else:
78
+ if not isinstance(config, self.config_class):
79
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
80
+
81
+ config.tie_word_embeddings = False
82
+
83
+ # Initialize with config:
84
+ PreTrainedModel.__init__(self, config)
85
+
86
+ # Encoder:
87
+ if encoder is None:
88
+ encoder = DefaultEncoderClass(config=config.encoder)
89
+
90
+ # Decoder:
91
+ if decoder is None:
92
+ assert not config.decoder.add_cross_attention
93
+ decoder = DefaultDecoderClass(config=config.decoder)
94
+
95
+ self.encoder = encoder
96
+ self.decoder = decoder
97
+
98
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
99
+ logger.warning(
100
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
101
+ f" {self.config.encoder}"
102
+ )
103
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
104
+ logger.warning(
105
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
106
+ f" {self.config.decoder}"
107
+ )
108
+
109
+ self.encoder.config = self.config.encoder
110
+ self.decoder.config = self.config.decoder
111
+
112
+ assert config.decoder.is_decoder
113
+ assert not config.decoder.is_encoder_decoder
114
+ assert 'pad_token_id' in self.decoder.config.__dict__
115
+ assert 'time_delta_monotonic_inversion' in self.decoder.config.__dict__
116
+ assert 'zero_time_delta_value' in self.decoder.config.__dict__
117
+ assert 'add_time_deltas' in self.decoder.config.__dict__
118
+
119
+ assert isinstance(self.decoder.config.time_delta_monotonic_inversion, bool)
120
+ assert isinstance(self.decoder.config.zero_time_delta_value, float)
121
+
122
+ for k, v in self.decoder.config.index_value_encoder_config.items():
123
+ setattr(
124
+ self,
125
+ f'{k}_index_value_encoder',
126
+ FNNEncoder(
127
+ num_features=v,
128
+ intermediate_size=self.decoder.config.index_value_encoder_intermediate_size,
129
+ decoder_hidden_size=self.decoder.config.hidden_size,
130
+ ),
131
+ )
132
+ if self.decoder.config.add_time_deltas:
133
+ self.time_delta_encoder = FNNEncoder(
134
+ num_features=1,
135
+ intermediate_size=self.decoder.config.index_value_encoder_intermediate_size,
136
+ decoder_hidden_size=self.decoder.config.hidden_size,
137
+ )
138
+ self.token_type_embeddings = torch.nn.Embedding(self.decoder.config.num_token_types, self.decoder.config.hidden_size)
139
+
140
+ @classmethod
141
+ def from_encoder_decoder_pretrained(
142
+ cls,
143
+ encoder_pretrained_model_name_or_path: str = None,
144
+ decoder_pretrained_model_name_or_path: str = None,
145
+ *model_args,
146
+ **kwargs,
147
+ ) -> PreTrainedModel:
148
+ r"""
149
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
150
+ checkpoints.
151
+
152
+
153
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
154
+ the model, you need to first set it back in training mode with `model.train()`.
155
+
156
+ Params:
157
+ encoder_pretrained_model_name_or_path (`str`, *optional*):
158
+ Information necessary to initiate the image encoder. Can be either:
159
+
160
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. An
161
+ example is `google/vit-base-patch16-224-in21k`.
162
+ - A path to a *directory* containing model weights saved using
163
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
164
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
165
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
166
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
167
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
168
+
169
+ decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
170
+ Information necessary to initiate the text decoder. Can be either:
171
+
172
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
173
+ - A path to a *directory* containing model weights saved using
174
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
175
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
176
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
177
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
178
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
179
+
180
+ model_args (remaining positional arguments, *optional*):
181
+ All remaning positional arguments will be passed to the underlying model's `__init__` method.
182
+
183
+ kwargs (remaining dictionary of keyword arguments, *optional*):
184
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
185
+ `output_attentions=True`).
186
+
187
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
188
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
189
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
190
+
191
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
192
+
193
+ Example:
194
+
195
+ ```python
196
+ >>> from transformers import VisionEncoderDecoderModel
197
+
198
+ >>> # initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized
199
+ >>> model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
200
+ ... "google/vit-base-patch16-224-in21k", "google-bert/bert-base-uncased"
201
+ ... )
202
+ >>> # saving model after fine-tuning
203
+ >>> model.save_pretrained("./vit-bert")
204
+ >>> # load fine-tuned model
205
+ >>> model = VisionEncoderDecoderModel.from_pretrained("./vit-bert")
206
+ ```"""
207
+
208
+ kwargs_encoder = {
209
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
210
+ }
211
+
212
+ kwargs_decoder = {
213
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
214
+ }
215
+
216
+ # remove encoder, decoder kwargs from kwargs
217
+ for key in kwargs_encoder.keys():
218
+ del kwargs["encoder_" + key]
219
+ for key in kwargs_decoder.keys():
220
+ del kwargs["decoder_" + key]
221
+
222
+ # Load and initialize the encoder and decoder
223
+ # The distinction between encoder and decoder at the model level is made
224
+ # by the value of the flag `is_decoder` that we need to set correctly.
225
+ encoder = kwargs_encoder.pop("model", None)
226
+ if encoder is None:
227
+ if encoder_pretrained_model_name_or_path is None:
228
+ raise ValueError(
229
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
230
+ "to be defined."
231
+ )
232
+
233
+ if "config" not in kwargs_encoder:
234
+ encoder_config, kwargs_encoder = transformers.AutoConfig.from_pretrained(
235
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
236
+ )
237
+
238
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
239
+ logger.info(
240
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
241
+ "from a decoder model. Cross-attention and casual mask are disabled."
242
+ )
243
+ encoder_config.is_decoder = False
244
+ encoder_config.add_cross_attention = False
245
+
246
+ kwargs_encoder["config"] = encoder_config
247
+
248
+ encoder = transformers.AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
249
+
250
+ decoder = kwargs_decoder.pop("model", None)
251
+ if decoder is None:
252
+ if decoder_pretrained_model_name_or_path is None:
253
+ raise ValueError(
254
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
255
+ "to be defined."
256
+ )
257
+
258
+ if "config" not in kwargs_decoder:
259
+ decoder_config, kwargs_decoder = transformers.AutoConfig.from_pretrained(
260
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
261
+ )
262
+
263
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
264
+ logger.info(
265
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
266
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
267
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
268
+ )
269
+ decoder_config.is_decoder = True
270
+ decoder_config.add_cross_attention = False
271
+
272
+ kwargs_decoder["config"] = decoder_config
273
+
274
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
275
+ logger.warning(
276
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
277
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
278
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
279
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
280
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
281
+ )
282
+
283
+ decoder = transformers.AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
284
+
285
+ # instantiate config with corresponding kwargs
286
+ config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
287
+
288
+ # make sure input & output embeddings is not tied
289
+ config.tie_word_embeddings = False
290
+ return cls(encoder=encoder, decoder=decoder, config=config)
291
+
292
+ def forward(
293
+ self,
294
+ decoder_input_ids: Optional[torch.LongTensor] = None,
295
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
296
+ decoder_token_type_ids: Optional[torch.LongTensor] = None,
297
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
298
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
299
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
300
+ decoder_position_ids: Optional[torch.LongTensor] = None,
301
+ labels: Optional[torch.LongTensor] = None,
302
+ use_cache: Optional[bool] = None,
303
+ output_attentions: Optional[bool] = None,
304
+ output_hidden_states: Optional[bool] = None,
305
+ return_dict: Optional[bool] = None,
306
+ **kwargs,
307
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
308
+
309
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
310
+
311
+ kwargs_decoder = {
312
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
313
+ }
314
+
315
+ assert decoder_position_ids is not None
316
+ assert decoder_attention_mask is not None
317
+ assert decoder_attention_mask.dtype == torch.long, f'The dtype for {decoder_attention_mask} was {decoder_attention_mask.dtype}. It should be torch.long'
318
+ assert decoder_token_type_ids is not None
319
+
320
+ if decoder_inputs_embeds is None:
321
+ decoder_inputs_embeds = self.decoder.get_input_embeddings()(decoder_input_ids)
322
+ decoder_inputs_embeds += self.token_type_embeddings(decoder_token_type_ids)
323
+
324
+ # Generation:
325
+ decoder_outputs = self.decoder(
326
+ inputs_embeds=decoder_inputs_embeds,
327
+ attention_mask=decoder_attention_mask,
328
+ position_ids=decoder_position_ids,
329
+ output_attentions=output_attentions,
330
+ output_hidden_states=output_hidden_states,
331
+ use_cache=use_cache,
332
+ past_key_values=past_key_values,
333
+ return_dict=return_dict,
334
+ **kwargs_decoder,
335
+ )
336
+
337
+ # Loss:
338
+ loss = None
339
+ if labels is not None:
340
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
341
+ loss_fct = CrossEntropyLoss()
342
+ loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
343
+
344
+ if not return_dict:
345
+ if loss is not None:
346
+ return (loss,) + decoder_outputs + encoder_outputs
347
+ else:
348
+ return decoder_outputs + encoder_outputs
349
+
350
+ return Seq2SeqLMOutput(
351
+ loss=loss,
352
+ logits=decoder_outputs.logits,
353
+ past_key_values=decoder_outputs.past_key_values,
354
+ decoder_hidden_states=decoder_outputs.hidden_states,
355
+ decoder_attentions=decoder_outputs.attentions,
356
+ )
357
+
358
+ def prepare_inputs_for_generation(
359
+ self,
360
+ input_ids,
361
+ special_token_ids,
362
+ prompt_attention_mask,
363
+ prompt_position_ids,
364
+ token_type_id_sections=None,
365
+ past_key_values=None,
366
+ use_cache=None,
367
+ **kwargs,
368
+ ):
369
+ """
370
+ Modification of:
371
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
372
+ """
373
+
374
+ report_attention_mask = (input_ids != self.decoder.config.pad_token_id).long()
375
+
376
+ if past_key_values is None:
377
+
378
+ # 4D attention mask:
379
+ decoder_attention_mask = self.create_4d_attention_mask_mixed_causality(prompt_attention_mask, report_attention_mask)
380
+
381
+ # Position identifiers accounting for padding:
382
+ report_position_ids = report_attention_mask.cumsum(-1) + prompt_position_ids.max(dim=1).values[:, None]
383
+ report_position_ids.masked_fill_(report_attention_mask == 0, 1)
384
+ decoder_position_ids = torch.cat([prompt_position_ids, report_position_ids], dim=1)
385
+
386
+ # `inputs_embeds` are only to be used in the 1st generation step:
387
+ inputs_embeds = torch.cat([kwargs['decoder_inputs_embeds'], self.decoder.get_input_embeddings()(input_ids)], dim=1)
388
+
389
+ decoder_token_type_ids = self.token_ids_to_token_type_ids(input_ids, special_token_ids, token_type_id_sections)
390
+ decoder_token_type_ids = torch.cat(
391
+ [
392
+ kwargs['decoder_token_type_ids'],
393
+ decoder_token_type_ids,
394
+ ],
395
+ dim=1,
396
+ ) # Add image token type identifiers.
397
+
398
+ input_dict = {
399
+ 'decoder_input_ids': input_ids,
400
+ 'decoder_inputs_embeds': inputs_embeds,
401
+ 'decoder_token_type_ids': decoder_token_type_ids,
402
+ }
403
+ else:
404
+
405
+ # 4D attention mask:
406
+ decoder_attention_mask = self.create_4d_attention_mask_mixed_causality_past_key_values(prompt_attention_mask, report_attention_mask)
407
+
408
+ # Position identifiers accounting for padding:
409
+ decoder_position_ids = report_attention_mask.cumsum(-1) + prompt_position_ids.max(dim=1).values[:, None]
410
+ decoder_position_ids.masked_fill_(report_attention_mask == 0, 1)
411
+
412
+ # Always place token_ids_to_token_type_ids_past_key_values before input_ids = input_ids[:, remove_prefix_length:]:
413
+ decoder_token_type_ids = self.token_ids_to_token_type_ids_past_key_values(input_ids, special_token_ids, token_type_id_sections)
414
+ decoder_position_ids = decoder_position_ids[:, -1:]
415
+
416
+ past_length = past_key_values[0][0].shape[2]
417
+
418
+ # Some generation methods only pass the last input ID:
419
+ if input_ids.shape[1] > past_length:
420
+ remove_prefix_length = past_length
421
+ else:
422
+ # Keep only the final ID:
423
+ remove_prefix_length = input_ids.shape[1] - 1
424
+
425
+ input_ids = input_ids[:, remove_prefix_length:]
426
+
427
+ input_dict = {'decoder_input_ids': input_ids, 'decoder_token_type_ids': decoder_token_type_ids}
428
+
429
+ input_dict.update(
430
+ {
431
+ 'decoder_attention_mask': decoder_attention_mask,
432
+ 'decoder_position_ids': decoder_position_ids,
433
+ 'past_key_values': past_key_values,
434
+ 'use_cache': use_cache,
435
+ }
436
+ )
437
+ return input_dict
438
+
439
+ def token_ids_to_token_type_ids(self, token_ids, special_token_ids, token_type_id_sections=None):
440
+ """
441
+ Extract token type identifiers from the token identifiers.
442
+
443
+ Argument/s:
444
+ token_ids - token identifiers.
445
+ special_token_ids - special token identifiers that indicate the separation between sections.
446
+ token_type_id_section - token type identifier for each section.
447
+
448
+ Returns:
449
+ token_type_ids - token type identifiers.
450
+ """
451
+
452
+ token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
453
+
454
+ mbatch_size, seq_len = token_ids.shape
455
+ token_type_ids = torch.full_like(token_ids, token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
456
+
457
+ for i, j in enumerate(special_token_ids):
458
+ # Find first occurrence of special tokens that indicate the boundary between sections:
459
+ cols = (token_ids == j).int().argmax(dim=1)
460
+ rows = torch.arange(mbatch_size, device=token_ids.device)
461
+
462
+ # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
463
+ cols += 1
464
+
465
+ # Ensure that the column index is not out of bounds. If 0, then token_id not present.
466
+ # This is safe as index 0 is always a special token (now equal to 1 due to +1):
467
+ rows = rows[torch.logical_and(cols != 1, cols < seq_len)]
468
+ cols = cols[torch.logical_and(cols != 1, cols < seq_len)]
469
+
470
+ # Indices to that correspond to the second sequence:
471
+ if rows.nelement() != 0:
472
+ ids = torch.stack([
473
+ torch.stack([x, z]) for (x, y) in zip(rows, cols) for z in torch.arange(
474
+ y, seq_len, device=token_ids.device,
475
+ )
476
+ ])
477
+
478
+ token_type_ids[ids[:, 0], ids[:, 1]] = token_type_id_sections[i + 1]
479
+
480
+ return token_type_ids
481
+
482
+ def token_ids_to_token_type_ids_past_key_values(self, token_ids, special_token_ids, token_type_id_sections=None):
483
+ """
484
+ Extract token type identifiers from the token identifiers if past != None. Make sure to input all the
485
+ token_ids (e.g., do not input input_ids = input_ids[:, remove_prefix_length:] from prepare_inputs_for_generation).
486
+
487
+ Argument/s:
488
+ token_ids - token identifiers.
489
+ special_token_ids - special token identifiers that indicate the separation between sections.
490
+
491
+ Returns:
492
+ token_type_ids - token type identifiers.
493
+ """
494
+
495
+ token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
496
+ token_type_ids = torch.full([token_ids.shape[0], 1], token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
497
+
498
+ # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
499
+ token_ids = token_ids[:, :-1]
500
+
501
+ for i, j in enumerate(special_token_ids):
502
+
503
+ # Find first occurrence of special token, which indicates the boundary between sections:
504
+ exists = torch.any(token_ids == j, dim=1, keepdim=True)
505
+ token_type_ids[exists] = token_type_id_sections[i + 1]
506
+
507
+ return token_type_ids
508
+
509
+ def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int):
510
+ """
511
+ Tokenize the reports and creates the inputs and targets for teacher forcing.
512
+
513
+ Argument/s:
514
+ findings - findings sections.
515
+ impression - impression sections.
516
+ return_token_type_ids - return the token type identifiers.
517
+ tokenizer - Hugging Face tokenizer.
518
+ max_len - maximum number of tokens.
519
+
520
+ Returns:
521
+ decoder_input_ids - the token identifiers for the input of the decoder.
522
+ decoder_attention_mask - the attention mask for the decoder_input_ids.
523
+ label_ids - the label token identifiers for the decoder.
524
+ """
525
+
526
+ # Prepare the sections for the tokenizer by placing special tokens between each section:
527
+ reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
528
+ zip(findings, impression)]
529
+
530
+ # Tokenize the report:
531
+ tokenized = tokenizer(
532
+ reports,
533
+ padding='longest',
534
+ truncation=True,
535
+ max_length=max_len + 1, # +1 to account for the bias between input and target.
536
+ return_tensors='pt',
537
+ return_token_type_ids=False,
538
+ add_special_tokens=False,
539
+ ).to(self.device)
540
+
541
+ # Modify for language modelling:
542
+ batch_dict = {
543
+
544
+ # Labels for the decoder (shifted right by one for autoregression):
545
+ 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
546
+
547
+ # Remove last token identifier to match the sequence length of the labels:
548
+ 'decoder_input_ids': tokenized['input_ids'][:, :-1],
549
+
550
+ # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
551
+ 'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
552
+ }
553
+
554
+ return batch_dict
555
+
556
+ def tokenize_report_teacher_forcing_rev_a(self, tokenizer: PreTrainedTokenizerFast, max_len: int, findings: Optional[str] = None, impression: Optional[str] = None, reports: Optional[str] = None):
557
+ """
558
+ Tokenize the reports and creates the inputs and targets for teacher forcing.
559
+
560
+ Argument/s:
561
+ tokenizer - Hugging Face tokenizer.
562
+ max_len - maximum number of tokens.
563
+ findings - findings sections.
564
+ impression - impression sections.
565
+ reports - prepared reports, with special tokens and report sections.
566
+
567
+ Returns:
568
+ decoder_input_ids - the token identifiers for the input of the decoder.
569
+ decoder_attention_mask - the attention mask for the decoder_input_ids.
570
+ label_ids - the label token identifiers for the decoder.
571
+ """
572
+
573
+ # Prepare the sections for the tokenizer by placing special tokens between each section:
574
+ if reports is None:
575
+ assert findings and impression, "If 'reports' is not defined, 'findings' and 'impression' need to be defined."
576
+ reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
577
+ zip(findings, impression)]
578
+
579
+ # Tokenize the report:
580
+ tokenized = tokenizer(
581
+ reports,
582
+ padding='longest',
583
+ truncation=True,
584
+ max_length=max_len + 1, # +1 to account for the bias between input and target.
585
+ return_tensors='pt',
586
+ return_token_type_ids=False,
587
+ add_special_tokens=False,
588
+ ).to(self.device)
589
+
590
+ # Modify for language modelling:
591
+ batch_dict = {
592
+
593
+ # Labels for the decoder (shifted right by one for autoregression):
594
+ 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
595
+
596
+ # Remove last token identifier to match the sequence length of the labels:
597
+ 'decoder_input_ids': tokenized['input_ids'][:, :-1],
598
+
599
+ # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
600
+ 'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
601
+ }
602
+
603
+ return batch_dict
604
+
605
+ def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer: PreTrainedTokenizerFast):
606
+ """
607
+ Split the token identifiers into sections, then convert the token identifiers into strings.
608
+
609
+ Argument/s:
610
+ token_ids - token identifiers.
611
+ special_token_ids - special token identifiers that indicate the end of each section.
612
+ tokenizer - Hugging Face tokenizer.
613
+
614
+ Returns:
615
+ token_type_ids - token type identifiers.
616
+ """
617
+
618
+ _, seq_len = token_ids.shape
619
+
620
+ # The number of sections is the same as the number of special_token_ids:
621
+ num_sections = len(special_token_ids)
622
+
623
+ sections = {k: [] for k in range(num_sections)}
624
+
625
+ for i in token_ids:
626
+ prev_col = 0
627
+ for j, k in enumerate(special_token_ids):
628
+
629
+ # The maximum sequence length was exceeded, thus no more tokens:
630
+ if prev_col >= seq_len:
631
+ sections[j].append('')
632
+ continue
633
+
634
+ # Find first occurrence of special tokens that indicate the boundary between sections:
635
+ col = (i == k).int().argmax().item()
636
+
637
+ # If equal to 0, token was not found, set the column to the sequence length (as the decoder exceeded
638
+ # the maximum sequence length):
639
+ if col == 0:
640
+ col = seq_len
641
+
642
+ # Extract section token identifiers:
643
+ section_token_ids = i[prev_col:col]
644
+ prev_col = col
645
+ section_string = tokenizer.decode(section_token_ids, skip_special_tokens=True)
646
+
647
+ sections[j].append(section_string)
648
+
649
+ return tuple(sections.values())
650
+
651
+ def tokenize_text_columns(self, tokenizer: PreTrainedTokenizerFast, **kwargs):
652
+ """
653
+ Tokenize the text columns from MIMIC-IV ED and MIMIC-CXR (excluding the findings and impression sections).
654
+ Time deltas for the input_ids are also prepared here.
655
+
656
+ Argument/s:
657
+ tokenizer - Hugging Face tokenizer.
658
+
659
+ Returns:
660
+ ed - dictionary containing the input_ids, token_type_ids, attention_mask and time_deltas for the ED module columns.
661
+ cxr - dictionary containing the input_ids, token_type_ids, and attention_mask for MIMIC-CXR columns.
662
+ """
663
+
664
+ batch_size = len(kwargs['index'])
665
+
666
+ tokenized = {
667
+ 'input_ids': {i: [] for i in range(batch_size)},
668
+ 'token_type_ids': {i: [] for i in range(batch_size)},
669
+ 'time_delta': {i: [] for i in range(batch_size)},
670
+ 'attention_mask': torch.empty(batch_size, 0, 1, device=self.device),
671
+ }
672
+
673
+ for i in self.decoder.config.ed_module_columns + self.decoder.config.mimic_cxr_columns + ['previous_findings', 'previous_impression']:
674
+ if i in kwargs:
675
+ if f'{i}_time_delta' not in kwargs:
676
+ kwargs[f'{i}_time_delta'] = [[self.decoder.config.zero_time_delta_value for _ in j] if j is not None else None for j in kwargs[i]]
677
+ for x, (y, z) in enumerate(zip(kwargs[i], kwargs[f'{i}_time_delta'])):
678
+ if y is not None:
679
+ assert isinstance(y, list)
680
+ assert isinstance(z, list)
681
+ for text, time_delta in zip(y, z):
682
+ tokenized['input_ids'][x].append(
683
+ tokenizer(text, add_special_tokens=False, return_tensors='pt')['input_ids'].to(device=self.device)
684
+ )
685
+ tokenized['token_type_ids'][x].append(
686
+ torch.full(
687
+ (1, tokenized['input_ids'][x][-1].shape[-1]),
688
+ self.decoder.config.token_type_to_token_type_id[i],
689
+ dtype=torch.long,
690
+ device=self.device,
691
+ )
692
+ )
693
+ tokenized['time_delta'][x].append(
694
+ torch.full(
695
+ (1, tokenized['input_ids'][x][-1].shape[-1]),
696
+ time_delta,
697
+ dtype=torch.float32,
698
+ device=self.device,
699
+ )
700
+ )
701
+
702
+ tokenized['input_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['input_ids'].values()]
703
+ tokenized['token_type_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['token_type_ids'].values()]
704
+ tokenized['time_delta'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, device=self.device) for j in tokenized['time_delta'].values()]
705
+
706
+ tokenized['input_ids'] = torch.nn.utils.rnn.pad_sequence(
707
+ tokenized['input_ids'], batch_first=True, padding_value=tokenizer.pad_token_id
708
+ )[:, :, 0]
709
+ tokenized['token_type_ids'] = torch.nn.utils.rnn.pad_sequence(
710
+ tokenized['token_type_ids'], batch_first=True, padding_value=0,
711
+ )[:, :, 0]
712
+
713
+ tokenized['attention_mask'] = (tokenized['input_ids'] != tokenizer.pad_token_id).int()
714
+
715
+ tokenized['time_delta'] = torch.nn.utils.rnn.pad_sequence(
716
+ tokenized['time_delta'], batch_first=True, padding_value=0,
717
+ )
718
+
719
+ return tokenized
720
+
721
+ def prepare_inputs(
722
+ self,
723
+ images,
724
+ tokenizer: PreTrainedTokenizerFast,
725
+ tokenized_report=None,
726
+ sep_token_id=None,
727
+ section_ids=None,
728
+ **batch,
729
+ ):
730
+ """
731
+ Tokenize the text columns from MIMIC-IV ED and MIMIC-CXR (excluding the findings and impression sections).
732
+
733
+ Argument/s:
734
+ images - images.
735
+ tokenizer - Hugging Face tokenizer.
736
+ tokenized_report - if training/teacher forcing, input the tokenized_report dict to include it in the prepared inputs.
737
+ separator_token_id - separator token identifier.
738
+ section_ids - section identifiers for the findings and impression sections.
739
+
740
+ Returns:
741
+ inputs_embeds - input embeddings.
742
+ attention_mask - attention mask.
743
+ token_type_ids - token type identifiers.
744
+ position_ids - position identifiers.
745
+ bos_token_ids - bos_token_ids for generation.
746
+ """
747
+
748
+ input_ids = []
749
+ inputs_embeds = []
750
+ token_type_ids = []
751
+ attention_mask = []
752
+ time_delta = []
753
+ position_ids = None
754
+ bos_token_ids = None
755
+
756
+ # Index and value columns:
757
+ batch_size = len(batch['index'])
758
+ for k in self.decoder.config.index_value_encoder_config.keys():
759
+ if f'{k}_index_value_feats' not in batch:
760
+ batch[f'{k}_index_value_feats'] = torch.empty(batch_size, 0, self.decoder.config.index_value_encoder_config[k], device=self.device)
761
+ inputs_embeds.append(
762
+ getattr(self, f'{k}_index_value_encoder')(batch[f'{k}_index_value_feats'])
763
+ )
764
+ token_type_ids.append(batch[f'{k}_index_value_token_type_ids'] if f'{k}_index_value_token_type_ids' in batch else torch.empty(batch_size, 0, dtype=torch.long, device=self.device))
765
+ attention_mask.append(batch[f'{k}_index_value_mask'] if f'{k}_index_value_mask' in batch else torch.empty(batch_size, 0, dtype=torch.long, device=self.device))
766
+ if f'{k}_time_delta' in batch:
767
+ time_delta.append(batch[f'{k}_time_delta'])
768
+ else:
769
+ time_delta_index_value = torch.zeros(*batch[f'{k}_index_value_mask'].shape, 1, device=self.device) if f'{k}_index_value_mask' in batch else torch.empty(batch_size, 0, 1, device=self.device)
770
+ time_delta.append(time_delta_index_value)
771
+
772
+ # Tokenize text columns for prompt:
773
+ tokenized = self.tokenize_text_columns(tokenizer, **batch)
774
+ input_ids.append(tokenized['input_ids'])
775
+ token_type_ids.append(tokenized['token_type_ids'])
776
+ attention_mask.append(tokenized['attention_mask'])
777
+ time_delta.append(tokenized['time_delta'])
778
+
779
+ # Image encoder:
780
+ encoder_outputs = self.encoder(images)
781
+ inputs_embeds.append(encoder_outputs[0])
782
+ inputs_per_image = encoder_outputs[0].shape[-2] // images.shape[1]
783
+ padded_image_time_deltas = [i + [self.decoder.config.zero_time_delta_value] * (images.shape[1] - len(i)) for i in batch['image_time_deltas']]
784
+ time_delta_image_features = torch.tensor(padded_image_time_deltas, device=self.device).repeat_interleave(inputs_per_image, dim=1)
785
+ token_type_ids.append(
786
+ torch.where(
787
+ time_delta_image_features == self.decoder.config.zero_time_delta_value,
788
+ self.decoder.config.token_type_to_token_type_id['image'],
789
+ self.decoder.config.token_type_to_token_type_id['previous_image'],
790
+ ),
791
+ )
792
+ attention_mask.append(encoder_outputs[1])
793
+ time_delta.append(time_delta_image_features[:, :, None])
794
+
795
+ # Compute embeddings from token identifiers:
796
+ input_ids = torch.cat(input_ids, dim=1)
797
+ inputs_embeds.append(self.decoder.get_input_embeddings()(input_ids))
798
+
799
+ # Concatentate time deltas and input embeddings before adding time delta embedding to prompt:
800
+ time_delta = torch.cat(time_delta, dim=1)
801
+ inputs_embeds = torch.cat(inputs_embeds, dim=1)
802
+
803
+ # Add time delta embeddings to prompt:
804
+ if time_delta.shape[1] > 0 and self.decoder.config.add_time_deltas:
805
+ time_delta = time_delta.to(dtype=inputs_embeds.dtype)
806
+ inputs_embeds += self.time_delta_encoder(time_delta)
807
+
808
+ # Concatentate the attention mask:
809
+ attention_mask = torch.cat(attention_mask, dim=1)
810
+
811
+ # Position identifiers:
812
+ position_ids = self.position_ids_from_time_deltas_and_attention_mask(time_delta, attention_mask)
813
+
814
+ # Tokenize report:
815
+ if tokenized_report is not None:
816
+ inputs_embeds = torch.cat([inputs_embeds, self.decoder.get_input_embeddings()(tokenized_report['decoder_input_ids'])], dim=1)
817
+
818
+ report_token_type_ids = self.token_ids_to_token_type_ids(
819
+ token_ids=tokenized_report['decoder_input_ids'],
820
+ special_token_ids=[sep_token_id],
821
+ token_type_id_sections=section_ids,
822
+ )
823
+ token_type_ids.append(report_token_type_ids)
824
+
825
+ # Position identifiers accounting for padding:
826
+ report_position_ids = tokenized_report['decoder_attention_mask'].cumsum(-1) + position_ids.max(dim=1).values[:, None]
827
+ report_position_ids.masked_fill_(tokenized_report['decoder_attention_mask'] == 0, 1)
828
+ position_ids = torch.cat([position_ids, report_position_ids], dim=1)
829
+
830
+ # 4D attention mask:
831
+ attention_mask = self.create_4d_attention_mask_mixed_causality(attention_mask, tokenized_report['decoder_attention_mask'])
832
+ # attention_mask_diagonal = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2)
833
+
834
+ else:
835
+
836
+ # BOS token identifiers for inference/generation:
837
+ bos_token_ids = torch.full((encoder_outputs[0].shape[0], 1), tokenizer.bos_token_id, dtype=torch.long, device=self.device)
838
+
839
+ # Concatentate the token type identifiers:
840
+ token_type_ids = torch.cat(token_type_ids, dim=1)
841
+
842
+ assert inputs_embeds.shape[1] == attention_mask.shape[-1]
843
+ assert inputs_embeds.shape[1] == token_type_ids.shape[1]
844
+
845
+ return inputs_embeds, attention_mask, token_type_ids, position_ids, bos_token_ids
846
+
847
+ @staticmethod
848
+ def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask):
849
+
850
+ prompt_seq_len = non_causal_2d_attention_mask.shape[-1]
851
+ report_seq_len = causal_2d_attention_mask.shape[-1]
852
+
853
+ non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
854
+ causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
855
+
856
+ # Upper left of attention matrix:
857
+ upper_left = non_causal_2d_attention_mask.expand(-1, -1, prompt_seq_len, -1)
858
+ upper_left = upper_left * non_causal_2d_attention_mask
859
+ upper_left = upper_left * non_causal_2d_attention_mask.permute(0, 1, 3, 2)
860
+
861
+ causal_mask = torch.tril(
862
+ torch.ones(
863
+ (
864
+ report_seq_len,
865
+ report_seq_len,
866
+ ),
867
+ dtype=torch.long,
868
+ device=causal_2d_attention_mask.device,
869
+ ),
870
+ )
871
+
872
+ # Lower right of attention matrix:
873
+ lower_right = causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
874
+ lower_right = lower_right * causal_2d_attention_mask.permute(0, 1, 3, 2)
875
+ lower_right = lower_right * causal_mask
876
+
877
+ # Upper right of attention matrix:
878
+ upper_right = torch.zeros(
879
+ causal_2d_attention_mask.shape[0],
880
+ 1,
881
+ prompt_seq_len,
882
+ report_seq_len,
883
+ dtype=torch.long,
884
+ device=causal_2d_attention_mask.device,
885
+ )
886
+
887
+ # Lower left of attention matrix:
888
+ lower_left = non_causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
889
+ lower_left = lower_left * causal_2d_attention_mask.permute(0, 1, 3, 2)
890
+
891
+ left = torch.cat((upper_left, lower_left), dim=2)
892
+ right = torch.cat((upper_right, lower_right), dim=2)
893
+
894
+ mixed_causality_4d_attention_mask = torch.cat((left, right), dim=-1)
895
+ return mixed_causality_4d_attention_mask
896
+
897
+ @staticmethod
898
+ def create_4d_attention_mask_mixed_causality_past_key_values(non_causal_2d_attention_mask, causal_2d_attention_mask):
899
+
900
+ non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
901
+ causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
902
+
903
+ mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
904
+ return mixed_causality_4d_attention_mask
905
+
906
+ def position_ids_from_time_deltas_and_attention_mask(self, time_deltas, attention_mask):
907
+ _, col_indices = torch.sort(torch.where(attention_mask == 1, time_deltas[:, :, 0], torch.finfo(time_deltas.dtype).min), descending=not self.decoder.config.time_delta_monotonic_inversion)
908
+
909
+ num_rows, num_cols, _ = time_deltas.shape
910
+
911
+ row_indices = torch.arange(num_rows, device=time_deltas.device).view(-1, 1).repeat(1, num_cols).view(-1)
912
+ position_ids = torch.zeros_like(col_indices, device=time_deltas.device)
913
+ position_ids[row_indices, col_indices.flatten()] = torch.arange(num_cols, device=time_deltas.device)[None, :].expand(num_rows, -1).flatten()
914
+ position_ids.masked_fill_(attention_mask == 0, 1) # Following: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L1285
915
+
916
+ return position_ids
917
+
918
+ @staticmethod
919
+ def prepare_data(physionet_dir, database_path, dataset_dir=None):
920
+
921
+ dataset_dir = physionet_dir if dataset_dir is None else dataset_dir
922
+
923
+ sectioned_dir = os.path.join(dataset_dir, 'mimic_cxr_sectioned')
924
+
925
+ mimic_cxr_sectioned_path = os.path.join(sectioned_dir, 'mimic_cxr_sectioned.csv')
926
+ if not os.path.exists(mimic_cxr_sectioned_path):
927
+ print(f'{mimic_cxr_sectioned_path} does not exist, creating...')
928
+
929
+ # Check if reports exist. Reports for the first and last patients are checked only for speed, this comprimises comprehensiveness for speed:
930
+ report_paths = [
931
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p10/p10000032/s50414267.txt'),
932
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p10/p10000032/s53189527.txt'),
933
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p10/p10000032/s53911762.txt'),
934
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p10/p10000032/s56699142.txt'),
935
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p19/p19999987/s55368167.txt'),
936
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p19/p19999987/s58621812.txt'),
937
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p19/p19999987/s58971208.txt'),
938
+ ]
939
+ assert all([os.path.isfile(i) for i in report_paths]), f"""The reports do not exist with the following regex: {os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p1*/p1*/s*.txt')}.
940
+ "Please download them using wget -r -N -c -np --reject dcm --user <username> --ask-password https://physionet.org/files/mimic-cxr/2.0.0/"""
941
+
942
+ print('Extracting sections from reports...')
943
+ create_sectioned_files(
944
+ reports_path=os.path.join(physionet_dir, 'mimic-cxr', '2.0.0', 'files'),
945
+ output_path=sectioned_dir,
946
+ no_split=True,
947
+ )
948
+
949
+ if not os.path.exists(database_path):
950
+
951
+ connect = duckdb.connect(database_path)
952
+
953
+ csv_paths = []
954
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'edstays.csv.gz'))[0])
955
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'medrecon.csv.gz'))[0])
956
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'pyxis.csv.gz'))[0])
957
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'triage.csv.gz'))[0])
958
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'vitalsign.csv.gz'))[0])
959
+
960
+ base_names = [os.path.basename(i) for i in csv_paths]
961
+
962
+ for i in ['edstays.csv.gz', 'medrecon.csv.gz', 'pyxis.csv.gz', 'triage.csv.gz', 'vitalsign.csv.gz']:
963
+ assert i in base_names, f"""Table {i} is missing from MIMIC-IV-ED.
964
+ Please download the tables from https://physionet.org/content/mimic-iv-ed. Do not decompress them."""
965
+
966
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-cxr-jpg', '*', 'mimic-cxr-2.0.0-metadata.csv.gz'))[0])
967
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-cxr-jpg', '*', 'mimic-cxr-2.0.0-chexpert.csv.gz'))[0])
968
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-cxr-jpg', '*', 'mimic-cxr-2.0.0-split.csv.gz'))[0])
969
+
970
+ base_names = [os.path.basename(i) for i in csv_paths[-3:]]
971
+
972
+ for i in ['mimic-cxr-2.0.0-metadata.csv.gz', 'mimic-cxr-2.0.0-chexpert.csv.gz', 'mimic-cxr-2.0.0-split.csv.gz']:
973
+ assert i in base_names, f"""CSV file {i} is missing from MIMIC-IV-ED.
974
+ Please download the tables from https://physionet.org/content/mimic-cxr-jpg. Do not decompress them."""
975
+
976
+ for i in csv_paths:
977
+ name = Path(i).stem.replace('.csv', '').replace('.gz', '').replace('-', '_').replace('.', '_')
978
+ print(f'Copying {name} into database...')
979
+ connect.sql(f"CREATE OR REPLACE TABLE {name} AS FROM '{i}';")
980
+
981
+ # MIMIC-CXR report sections:
982
+ print(f'Copying mimic_cxr_sectioned into database...')
983
+ connect.sql(f"CREATE OR REPLACE TABLE mimic_cxr_sectioned AS FROM '{mimic_cxr_sectioned_path}';")
984
+ connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column0 TO study;")
985
+ connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column1 TO impression;")
986
+ connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column2 TO findings;")
987
+ connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column3 TO indication;")
988
+ connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column4 TO history;")
989
+ connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column5 TO last_paragraph;")
990
+ connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column6 TO comparison;")
991
+ connect.sql("DELETE FROM mimic_cxr_sectioned WHERE study='study';")
992
+
993
+ splits = connect.sql("FROM mimic_cxr_2_0_0_split").df()
994
+ reports = connect.sql("FROM mimic_cxr_sectioned").df()
995
+ metadata = connect.sql("FROM mimic_cxr_2_0_0_metadata").df()
996
+ chexpert = connect.sql("FROM mimic_cxr_2_0_0_chexpert").df()
997
+
998
+ # Create datetime column:
999
+ metadata['StudyTime'] = metadata['StudyTime'].astype(int)
1000
+ metadata['study_datetime'] = pd.to_datetime(
1001
+ metadata.apply(lambda x: f'{x["StudyDate"]} {x["StudyTime"]:06}', axis=1),
1002
+ format='%Y%m%d %H%M%S',
1003
+ )
1004
+ reports.rename(columns={'study': 'study_id'}, inplace=True)
1005
+ reports.study_id = reports.study_id.str[1:].astype('int32')
1006
+ df = pd.merge(splits, reports, on='study_id')
1007
+ df = pd.merge(df, metadata, on=['dicom_id', 'study_id', 'subject_id'])
1008
+ df = pd.merge(df, chexpert, on=['study_id', 'subject_id'])
1009
+
1010
+ connect.sql(f"CREATE OR REPLACE TABLE mimic_cxr AS SELECT * FROM df")
1011
+
1012
+ # Create lookup tables (do this only for ED tables, as the MIMIC-CXR metadata table is not useful):
1013
+ for k, v in ed_module_tables.items():
1014
+ if v.load and v.index_columns:
1015
+ start_idx = 0
1016
+ for i in v.index_columns_source:
1017
+ lut_name = f'{k}_{i}_lut'
1018
+ table = k
1019
+ lut, end_idx = create_lookup_table(connect.sql(f"SELECT {i} FROM {table}").df(), [i], start_idx)
1020
+ start_idx = end_idx + 1
1021
+ lut = lut.rename(columns={'index': f'{i}_index'})
1022
+
1023
+ print(f'Creating {lut_name}...')
1024
+
1025
+ connect.sql(f"CREATE OR REPLACE TABLE {lut_name} AS SELECT * FROM lut")
1026
+
1027
+ if f'{i}_index' in connect.sql(f"FROM {k} LIMIT 0").df().columns:
1028
+ connect.sql(
1029
+ f"""
1030
+ ALTER TABLE {k}
1031
+ DROP COLUMN {i}_index;
1032
+ """
1033
+ )
1034
+
1035
+ connect.sql(
1036
+ f"""
1037
+ CREATE OR REPLACE TABLE {k} AS
1038
+ SELECT {k}.*, {lut_name}.{i}_index
1039
+ FROM {k} LEFT JOIN {lut_name}
1040
+ ON {k}.{i} = {lut_name}.{i}
1041
+ """
1042
+ )
1043
+
1044
+ connect.sql(
1045
+ f"""
1046
+ CREATE TABLE IF NOT EXISTS lut_info (table_name VARCHAR PRIMARY KEY, start_index INT, end_index INT);
1047
+ INSERT OR REPLACE INTO lut_info VALUES ('{k}', {0}, {end_idx});
1048
+ """
1049
+ )
1050
+
1051
+ table_studies = {
1052
+ 'edstays': [],
1053
+ 'triage': [],
1054
+ 'medrecon': [],
1055
+ 'vitalsign': [],
1056
+ 'pyxis': [],
1057
+ }
1058
+ stay_id_tables = ['triage']
1059
+ stay_id_charttime_tables = ['medrecon', 'vitalsign', 'pyxis']
1060
+
1061
+ df = connect.sql(f"FROM mimic_cxr").df()
1062
+
1063
+ # DICOM identifiers can have different datetimes, so use most recent datetime for the study:
1064
+ df = df.sort_values(by='study_datetime', ascending=False)
1065
+ df = df.groupby('study_id').first().reset_index()
1066
+
1067
+ for _, row in tqdm(df.iterrows(), total=df.shape[0]):
1068
+ edstays = connect.sql(
1069
+ f"""
1070
+ SELECT stay_id, intime, outtime
1071
+ FROM edstays
1072
+ WHERE (subject_id = {row['subject_id']})
1073
+ AND intime < '{row['study_datetime']}'
1074
+ AND outtime > '{row['study_datetime']}';
1075
+ """
1076
+ ).df()
1077
+
1078
+ if len(edstays) > 0:
1079
+
1080
+ for i in edstays['stay_id'].to_list():
1081
+ table_studies['edstays'].append({'study_id': row['study_id'], 'stay_id': i})
1082
+ for j in stay_id_tables:
1083
+ table = connect.sql(
1084
+ f"""
1085
+ SELECT stay_id
1086
+ FROM {j}
1087
+ WHERE (stay_id = {i});
1088
+ """
1089
+ ).df()
1090
+
1091
+ for k in table['stay_id'].to_list():
1092
+ table_studies[j].append({'study_id': row['study_id'], 'stay_id': k})
1093
+
1094
+ for j in stay_id_charttime_tables:
1095
+ table = connect.sql(
1096
+ f"""
1097
+ SELECT stay_id
1098
+ FROM {j}
1099
+ WHERE (stay_id = {i})
1100
+ AND charttime < '{row['study_datetime']}';
1101
+ """
1102
+ ).df()
1103
+
1104
+ for k in table['stay_id'].to_list():
1105
+ table_studies[j].append({'study_id': row['study_id'], 'stay_id': k})
1106
+
1107
+ for k, v in table_studies.items():
1108
+ df = pd.DataFrame(v)
1109
+ df = df.drop_duplicates(subset=['study_id', 'stay_id'])
1110
+ connect.sql(f"CREATE TABLE {k}_study_ids AS SELECT * FROM df")
1111
+
1112
+ @staticmethod
1113
+ def get_dataset(split, transforms, database_path, mimic_cxr_jpg_dir, max_images_per_study=5):
1114
+
1115
+ records = EDCXRSubjectRecords(database_path=database_path, time_delta_map=lambda x: 1 / math.sqrt(x + 1))
1116
+
1117
+ dataset = StudyIDEDStayIDSubset(
1118
+ mimic_iv_duckdb_path=database_path,
1119
+ dataset_dir=mimic_cxr_jpg_dir,
1120
+ transforms=transforms,
1121
+ split=split,
1122
+ max_images_per_study=max_images_per_study,
1123
+ records=records,
1124
+ )
1125
+ print(f'No. of examples: {dataset.__len__()}.')
1126
+ print(
1127
+ f'No. of training dicom_ids, study_ids, & subject_ids: {dataset.num_dicom_ids},',
1128
+ f'{dataset.num_study_ids}, & {dataset.num_subject_ids}.',
1129
+ )
modelling_uniformer.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from functools import partial
3
+ from typing import Optional, Tuple, Union
4
+ from math import isqrt
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
9
+ from transformers import ViTConfig
10
+ from transformers.modeling_outputs import ModelOutput
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.utils import logging
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+
17
+ layer_scale = False
18
+ init_value = 1e-6
19
+
20
+
21
+ class Mlp(nn.Module):
22
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.fc1 = nn.Linear(in_features, hidden_features)
27
+ self.act = act_layer()
28
+ self.fc2 = nn.Linear(hidden_features, out_features)
29
+ self.drop = nn.Dropout(drop)
30
+
31
+ def forward(self, x):
32
+ x = self.fc1(x)
33
+ x = self.act(x)
34
+ x = self.drop(x)
35
+ x = self.fc2(x)
36
+ x = self.drop(x)
37
+ return x
38
+
39
+
40
+ class CMlp(nn.Module):
41
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
42
+ super().__init__()
43
+ out_features = out_features or in_features
44
+ hidden_features = hidden_features or in_features
45
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
46
+ self.act = act_layer()
47
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
48
+ self.drop = nn.Dropout(drop)
49
+
50
+ def forward(self, x):
51
+ x = self.fc1(x)
52
+ x = self.act(x)
53
+ x = self.drop(x)
54
+ x = self.fc2(x)
55
+ x = self.drop(x)
56
+ return x
57
+
58
+
59
+ class Attention(nn.Module):
60
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
61
+ super().__init__()
62
+ self.num_heads = num_heads
63
+ head_dim = dim // num_heads
64
+ self.scale = qk_scale or head_dim ** -0.5
65
+
66
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
67
+ self.attn_drop = nn.Dropout(attn_drop)
68
+ self.proj = nn.Linear(dim, dim)
69
+ self.proj_drop = nn.Dropout(proj_drop)
70
+
71
+ def forward(self, x):
72
+ B, N, C = x.shape
73
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
74
+ q, k, v = qkv[0], qkv[1], qkv[2]
75
+
76
+ attn = (q @ k.transpose(-2, -1)) * self.scale
77
+ attn = attn.softmax(dim=-1)
78
+ attn = self.attn_drop(attn)
79
+
80
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
81
+ x = self.proj(x)
82
+ x = self.proj_drop(x)
83
+ return x
84
+
85
+
86
+ class CBlock(nn.Module):
87
+ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU):
88
+ super().__init__()
89
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
90
+ self.norm1 = nn.BatchNorm2d(dim)
91
+ self.conv1 = nn.Conv2d(dim, dim, 1)
92
+ self.conv2 = nn.Conv2d(dim, dim, 1)
93
+ self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
94
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
95
+ self.norm2 = nn.BatchNorm2d(dim)
96
+ mlp_hidden_dim = int(dim * mlp_ratio)
97
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
98
+
99
+ def forward(self, x):
100
+ x = x + self.pos_embed(x)
101
+ x = x + self.module_1(x)
102
+ x = x + self.module_2(x)
103
+ return x
104
+
105
+ def module_1(self, x):
106
+ x = self.norm1(x.to(dtype=self.norm1.weight.dtype)) # Won't autocast to the dtype of the parameters of nn.BatchNorm2d.
107
+ x = self.conv1(x)
108
+ x = self.attn(x)
109
+ x = self.conv2(x)
110
+ x = self.drop_path(x)
111
+ return x
112
+
113
+ def module_2(self, x):
114
+ x = self.norm2(x.to(dtype=self.norm2.weight.dtype)) # Won't autocast to the dtype of the parameters of nn.BatchNorm2d.
115
+ x = self.mlp(x)
116
+ x = self.drop_path(x)
117
+ return x
118
+
119
+ class SABlock(nn.Module):
120
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
121
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
122
+ super().__init__()
123
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
124
+ self.norm1 = norm_layer(dim)
125
+ self.attn = Attention(
126
+ dim,
127
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
128
+ attn_drop=attn_drop, proj_drop=drop)
129
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
130
+ self.norm2 = norm_layer(dim)
131
+ mlp_hidden_dim = int(dim * mlp_ratio)
132
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
133
+ global layer_scale
134
+ self.ls = layer_scale
135
+ if self.ls:
136
+ global init_value
137
+ print(f"Use layer_scale: {layer_scale}, init_values: {init_value}")
138
+ self.gamma_1 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
139
+ self.gamma_2 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
140
+
141
+ def forward(self, x):
142
+ x = x + self.pos_embed(x)
143
+ B, N, H, W = x.shape
144
+ x = x.flatten(2).transpose(1, 2)
145
+ if self.ls:
146
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
147
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
148
+ else:
149
+ x = x + self.drop_path(self.attn(self.norm1(x)))
150
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
151
+ x = x.transpose(1, 2).reshape(B, N, H, W)
152
+ return x
153
+
154
+
155
+ class HeadEmbedding(nn.Module):
156
+ def __init__(self, in_channels, out_channels):
157
+ super(HeadEmbedding, self).__init__()
158
+
159
+ self.proj = nn.Sequential(
160
+ nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
161
+ nn.BatchNorm2d(out_channels // 2),
162
+ nn.GELU(),
163
+ nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
164
+ nn.BatchNorm2d(out_channels),
165
+ )
166
+
167
+ def forward(self, x):
168
+ x = self.proj(x)
169
+ return x
170
+
171
+
172
+ class MiddleEmbedding(nn.Module):
173
+ def __init__(self, in_channels, out_channels):
174
+ super(MiddleEmbedding, self).__init__()
175
+
176
+ self.proj = nn.Sequential(
177
+ nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
178
+ nn.BatchNorm2d(out_channels),
179
+ )
180
+
181
+ def forward(self, x):
182
+ x = self.proj(x)
183
+ return x
184
+
185
+
186
+ class PatchEmbed(nn.Module):
187
+ def __init__(self, image_size=224, patch_size=16, in_chans=3, embed_dim=768):
188
+ super().__init__()
189
+ image_size = to_2tuple(image_size)
190
+ patch_size = to_2tuple(patch_size)
191
+ num_patches_height = image_size[0] // patch_size[0]
192
+ num_patches_width = image_size[1] // patch_size[1]
193
+ num_patches = num_patches_height * num_patches_width
194
+ self.image_size = image_size
195
+ self.patch_size = patch_size
196
+ self.num_patches = num_patches
197
+
198
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
199
+ self.norm = nn.LayerNorm(embed_dim)
200
+
201
+ def forward(self, x):
202
+ _, _, H, W = x.shape
203
+ assert H == self.image_size[0] and W == self.image_size[1], \
204
+ f"Input image size ({H}*{W}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
205
+ x = self.proj(x)
206
+ B, _, H, W = x.shape
207
+ x = x.flatten(2).transpose(1, 2)
208
+ x = self.norm(x)
209
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
210
+ return x
211
+
212
+
213
+ class UniFormer(nn.Module):
214
+ def __init__(self, depth=[3, 4, 8, 3], image_size=224, in_chans=3, num_classes=1000, embed_dim=[64, 128, 320, 512],
215
+ head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, patch_size=[4, 2, 2, 2],
216
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., conv_stem=False, layer_norm_eps=1e-6, **kwargs):
217
+ super().__init__()
218
+ self.num_classes = num_classes
219
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
220
+ norm_layer = partial(nn.LayerNorm, eps=layer_norm_eps)
221
+ if conv_stem:
222
+ self.patch_embed1 = HeadEmbedding(in_channels=in_chans, out_channels=embed_dim[0])
223
+ self.patch_embed2 = MiddleEmbedding(in_channels=embed_dim[0], out_channels=embed_dim[1])
224
+ self.patch_embed3 = MiddleEmbedding(in_channels=embed_dim[1], out_channels=embed_dim[2])
225
+ self.patch_embed4 = MiddleEmbedding(in_channels=embed_dim[2], out_channels=embed_dim[3])
226
+ else:
227
+ self.patch_embed1 = PatchEmbed(
228
+ image_size=image_size, patch_size=patch_size[0], in_chans=in_chans, embed_dim=embed_dim[0])
229
+ self.patch_embed2 = PatchEmbed(
230
+ image_size=image_size // patch_size[0], patch_size=patch_size[1], in_chans=embed_dim[0], embed_dim=embed_dim[1])
231
+ self.patch_embed3 = PatchEmbed(
232
+ image_size=image_size // (patch_size[0]*patch_size[1]), patch_size=patch_size[2], in_chans=embed_dim[1], embed_dim=embed_dim[2])
233
+ self.patch_embed4 = PatchEmbed(
234
+ image_size=image_size // (patch_size[0]*patch_size[1]*patch_size[2]), patch_size=patch_size[3], in_chans=embed_dim[2], embed_dim=embed_dim[3])
235
+
236
+ self.pos_drop = nn.Dropout(p=drop_rate)
237
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] # stochastic depth decay rule
238
+ num_heads = [dim // head_dim for dim in embed_dim]
239
+ self.blocks1 = nn.ModuleList([
240
+ CBlock(dim=embed_dim[0], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i])
241
+ for i in range(depth[0])])
242
+ self.blocks2 = nn.ModuleList([
243
+ CBlock(dim=embed_dim[1], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i+depth[0]])
244
+ for i in range(depth[1])])
245
+ self.blocks3 = nn.ModuleList([
246
+ SABlock(
247
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
248
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer)
249
+ for i in range(depth[2])])
250
+ self.blocks4 = nn.ModuleList([
251
+ SABlock(
252
+ dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
253
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer)
254
+ for i in range(depth[3])])
255
+ self.norm = nn.BatchNorm2d(embed_dim[-1])
256
+
257
+ # Representation layer
258
+ if representation_size:
259
+ self.num_features = representation_size
260
+ self.pre_logits = nn.Sequential(OrderedDict([
261
+ ('fc', nn.Linear(embed_dim, representation_size)),
262
+ ('act', nn.Tanh())
263
+ ]))
264
+ else:
265
+ self.pre_logits = nn.Identity()
266
+
267
+ def forward_features(self, x):
268
+ x = self.patch_embed1(x)
269
+ x = self.pos_drop(x)
270
+ for blk in self.blocks1:
271
+ x = blk(x)
272
+ x = self.patch_embed2(x)
273
+ for blk in self.blocks2:
274
+ x = blk(x)
275
+ x = self.patch_embed3(x)
276
+ for blk in self.blocks3:
277
+ x = blk(x)
278
+ x = self.patch_embed4(x)
279
+ for blk in self.blocks4:
280
+ x = blk(x)
281
+ x = self.norm(x.to(dtype=self.norm.weight.dtype)) # Won't autocast to the dtype of the parameters of nn.BatchNorm2d.
282
+ x = self.pre_logits(x)
283
+ return x
284
+
285
+ def forward(self, x):
286
+ x = self.forward_features(x)
287
+ return x
288
+
289
+
290
+ class UniFormerPreTrainedModel(PreTrainedModel):
291
+ """
292
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
293
+ models.
294
+ """
295
+
296
+ config_class = ViTConfig
297
+ base_model_prefix = "vit"
298
+ main_input_name = "pixel_values"
299
+
300
+ def _init_weights(self, m):
301
+ if isinstance(m, nn.Linear):
302
+ trunc_normal_(m.weight, std=.02)
303
+ if isinstance(m, nn.Linear) and m.bias is not None:
304
+ nn.init.constant_(m.bias, 0)
305
+ elif isinstance(m, nn.LayerNorm):
306
+ nn.init.constant_(m.bias, 0)
307
+ nn.init.constant_(m.weight, 1.0)
308
+
309
+
310
+ class UniFormerProjectionHead(torch.nn.Module):
311
+
312
+ def __init__(self, config) -> None:
313
+ super().__init__()
314
+
315
+ # Layer normalisation before projection:
316
+ self.layer_norm = torch.nn.LayerNorm(config.embed_dim[-1], eps=config.layer_norm_eps)
317
+
318
+ # No bias as following layer normalisation with bias:
319
+ self.projection = torch.nn.Linear(config.embed_dim[-1], config.projection_size, bias=False)
320
+
321
+
322
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
323
+ x = self.layer_norm(x)
324
+ x = self.projection(x)
325
+ return x
326
+
327
+
328
+ class UniFormerModel(UniFormerPreTrainedModel):
329
+ def __init__(self, config):
330
+ super().__init__(config)
331
+
332
+ self.uniformer = UniFormer(**vars(config))
333
+
334
+ # Initialize weights and apply final processing:
335
+ self.post_init()
336
+
337
+ def forward(
338
+ self,
339
+ pixel_values: Optional[torch.Tensor] = None,
340
+ output_hidden_states: Optional[bool] = None,
341
+ return_dict: Optional[bool] = None,
342
+ ) -> Union[Tuple, ModelOutput]:
343
+
344
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
345
+
346
+ last_hidden_state = self.uniformer(pixel_values)
347
+
348
+ # Flatten h x w:
349
+ last_hidden_state = torch.flatten(last_hidden_state, 2)
350
+
351
+ # Permute last hidden state:
352
+ last_hidden_state = torch.permute(last_hidden_state, [0, 2, 1])
353
+
354
+ # return last_hidden_state
355
+ if not return_dict:
356
+ return last_hidden_state
357
+
358
+ return ModelOutput(last_hidden_state=last_hidden_state)
359
+
360
+
361
+ class MultiUniFormerWithProjectionHead(UniFormerPreTrainedModel):
362
+ def __init__(self, config):
363
+ super().__init__(config)
364
+
365
+ self.uniformer = UniFormer(**vars(config))
366
+ self.projection_head = UniFormerProjectionHead(config)
367
+
368
+ # Initialize weights and apply final processing:
369
+ self.post_init()
370
+
371
+ def forward(
372
+ self,
373
+ pixel_values: Optional[torch.Tensor] = None,
374
+ output_hidden_states: Optional[bool] = None,
375
+ return_dict: Optional[bool] = None,
376
+ ) -> Union[Tuple, ModelOutput]:
377
+
378
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
379
+
380
+ # Flatten the batch and study_id dimensions:
381
+ assert len(pixel_values.shape) == 5, 'pixel_values must be B, S, C, H, W, where S is the max number of images for a study in the batch.'
382
+ last_hidden_state = self.uniformer(pixel_values.view(-1, *pixel_values.shape[2:]))
383
+ # last_hidden_state = self.uniformer(pixel_values.flatten(start_dim=0, end_dim=1))
384
+
385
+ # Flatten h x w:
386
+ last_hidden_state = torch.flatten(last_hidden_state, 2)
387
+
388
+ # Project the features for each spatial position to the decoder's hidden size:
389
+ projection = self.projection_head(torch.permute(last_hidden_state, [0, 2, 1]))
390
+
391
+ # Concatenate the features for each chest X-ray:
392
+ projection = projection.view(pixel_values.shape[0], -1, projection.shape[-1])
393
+
394
+ # Derive the attention mask from the pixel values:
395
+ mask = (pixel_values[:, :, 0, 0, 0] != 0.0)[:, :, None]
396
+ attention_mask = torch.ones(
397
+ [projection.shape[0], pixel_values.shape[1], projection.shape[1] // pixel_values.shape[1]],
398
+ dtype=torch.long,
399
+ device=mask.device,
400
+ )
401
+ attention_mask = attention_mask * mask
402
+ attention_mask = attention_mask.view(attention_mask.shape[0], -1)
403
+
404
+ if not return_dict:
405
+ return projection
406
+
407
+ return ModelOutput(last_hidden_state=projection, attention_mask=attention_mask)
408
+
409
+
410
+ if __name__ == '__main__':
411
+ y = PatchEmbed()
412
+ y(torch.randn(2, 3, 224, 224))
records.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import re
4
+ from collections import OrderedDict
5
+ from typing import Dict, List, Optional
6
+
7
+ import duckdb
8
+ import pandas as pd
9
+ import torch
10
+
11
+ from .tables import ed_cxr_token_type_ids, ed_module_tables, mimic_cxr_tables
12
+
13
+
14
+ def mimic_cxr_text_path(dir, subject_id, study_id, ext='txt'):
15
+ return os.path.join(dir, 'p' + str(subject_id)[:2], 'p' + str(subject_id),
16
+ 's' + str(study_id) + '.' + ext)
17
+
18
+ def format(text):
19
+ # Remove newline, tab, repeated whitespaces, and leading and trailing whitespaces:
20
+ text = re.sub(r'\n|\t', ' ', text)
21
+ text = re.sub(r'\s+', ' ', text)
22
+ text = text.strip()
23
+ return text
24
+
25
+
26
+ def rgetattr(obj, attr, *args):
27
+ def _getattr(obj, attr):
28
+ return getattr(obj, attr, *args)
29
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
30
+
31
+
32
+ def df_to_tensor_index_columns(
33
+ df: pd.DataFrame,
34
+ tensor: torch.Tensor,
35
+ group_idx_to_y_idx: Dict,
36
+ groupby: str,
37
+ index_columns: List[str],
38
+ ):
39
+ """
40
+ Converts a dataframe with index columns to a tensor, where each index of the y-axis is determined by the
41
+ 'groupby' column.
42
+ """
43
+ assert len(group_idx_to_y_idx) == tensor.shape[0]
44
+ all_columns = index_columns + [groupby]
45
+ y_indices = [group_idx_to_y_idx[row[groupby]] for _, row in df[all_columns].iterrows() for i in index_columns if row[i] == row[i]]
46
+ x_indices = [row[i] for _, row in df[all_columns].iterrows() for i in index_columns if row[i] == row[i]]
47
+ tensor[y_indices, x_indices] = 1.0
48
+ return tensor
49
+
50
+
51
+ def df_to_tensor_value_columns(
52
+ df: pd.DataFrame,
53
+ tensor: torch.Tensor,
54
+ group_idx_to_y_idx: Dict,
55
+ groupby: str,
56
+ value_columns: List[str],
57
+ value_column_to_idx: Dict,
58
+ ):
59
+ """
60
+ Converts a dataframe with value columns to a tensor, where each index of the y-axis is determined by the
61
+ 'groupby' column. The x-index is determined by a dictionary using the column name.
62
+ """
63
+ assert len(group_idx_to_y_idx) == tensor.shape[0]
64
+ all_columns = value_columns + [groupby]
65
+ y_indices = [group_idx_to_y_idx[row[groupby]] for _, row in df[all_columns].iterrows() for i in value_columns if row[i] == row[i]]
66
+ x_indices = [value_column_to_idx[i] for _, row in df[all_columns].iterrows() for i in value_columns if row[i] == row[i]]
67
+ element_value = [row[i] for _, row in df[all_columns].iterrows() for i in value_columns if row[i] == row[i]]
68
+ tensor[y_indices, x_indices] = torch.tensor(element_value, dtype=tensor.dtype)
69
+ return tensor
70
+
71
+
72
+ class EDCXRSubjectRecords:
73
+ def __init__(
74
+ self,
75
+ database_path: str,
76
+ dataset_dir: Optional[str] = None,
77
+ reports_dir: Optional[str] = None,
78
+ token_type_ids_starting_idx: Optional[int] = None,
79
+ time_delta_map = lambda x: x,
80
+ debug: bool = False
81
+ ):
82
+
83
+ self.database_path = database_path
84
+ self.dataset_dir = dataset_dir
85
+ self.reports_dir = reports_dir
86
+ self.time_delta_map = time_delta_map
87
+ self.debug = debug
88
+
89
+ self.connect = duckdb.connect(self.database_path, read_only=True)
90
+
91
+ self.streamlit_flag = False
92
+
93
+ self.clear_start_end_times()
94
+
95
+ # Module configurations:
96
+ self.ed_module_tables = ed_module_tables
97
+ self.mimic_cxr_tables = mimic_cxr_tables
98
+
99
+ lut_info = self.connect.sql("FROM lut_info").df()
100
+
101
+ for k, v in (self.ed_module_tables | self.mimic_cxr_tables).items():
102
+ if v.load and (v.value_columns or v.index_columns):
103
+ v.value_column_to_idx = {}
104
+ if v.index_columns:
105
+ v.total_indices = lut_info[lut_info['table_name'] == k]['end_index'].item() + 1
106
+ else:
107
+ v.total_indices = 0
108
+ for i in v.value_columns:
109
+ v.value_column_to_idx[i] = v.total_indices
110
+ v.total_indices += 1
111
+
112
+ # Token type identifiers:
113
+ self.token_type_to_token_type_id = ed_cxr_token_type_ids
114
+ if token_type_ids_starting_idx is not None:
115
+ self.token_type_to_token_type_id = {k: v + token_type_ids_starting_idx for k, v in self.token_type_to_token_type_id.items()}
116
+
117
+ def __len__(self):
118
+ return len(self.subject_ids)
119
+
120
+ def clear_start_end_times(self):
121
+ self.start_time, self.end_time = None, None
122
+
123
+ def admission_ed_stay_ids(self, hadm_id):
124
+ if hadm_id:
125
+ return self.connect.sql(f'SELECT stay_id FROM edstays WHERE subject_id = {self.subject_id} AND hadm_id = {hadm_id}').df()['stay_id'].tolist()
126
+ else:
127
+ return self.connect.sql(f'SELECT stay_id FROM edstays WHERE subject_id = {self.subject_id}').df()['stay_id'].tolist()
128
+
129
+ def subject_study_ids(self):
130
+ mimic_cxr = self.connect.sql(
131
+ f'SELECT study_id, study_datetime FROM mimic_cxr WHERE subject_id = {self.subject_id}',
132
+ ).df()
133
+ if self.start_time and self.end_time:
134
+ mimic_cxr = self.filter_admissions_by_time_span(mimic_cxr, 'study_datetime')
135
+ mimic_cxr = mimic_cxr.drop_duplicates(subset=['study_id']).sort_values(by='study_datetime')
136
+ return dict(zip(mimic_cxr['study_id'], mimic_cxr['study_datetime']))
137
+
138
+ def load_ed_module(self, hadm_id=None, stay_id=None, reference_time=None):
139
+ if not self.start_time and stay_id is not None:
140
+ edstay = self.connect.sql(
141
+ f"""
142
+ SELECT intime, outtime
143
+ FROM edstays
144
+ WHERE stay_id = {stay_id}
145
+ """
146
+ ).df()
147
+ self.start_time = edstay['intime'].item()
148
+ self.end_time = edstay['outtime'].item()
149
+ self.load_module(self.ed_module_tables, hadm_id=hadm_id, stay_id=stay_id, reference_time=reference_time)
150
+
151
+ def load_mimic_cxr(self, study_id, reference_time=None):
152
+ self.load_module(self.mimic_cxr_tables, study_id=study_id, reference_time=reference_time)
153
+ if self.streamlit_flag:
154
+ self.report_path = mimic_cxr_text_path(self.reports_dir, self.subject_id, study_id, 'txt')
155
+
156
+ def load_module(self, module_dict, hadm_id=None, stay_id=None, study_id=None, reference_time=None):
157
+ for k, v in module_dict.items():
158
+
159
+ if self.streamlit_flag or v.load:
160
+
161
+ query = f"FROM {k}"
162
+
163
+ conditions = []
164
+ if hasattr(self, 'subject_id') and v.subject_id_filter:
165
+ conditions.append(f"subject_id={self.subject_id}")
166
+ if v.hadm_id_filter:
167
+ assert hadm_id is not None
168
+ conditions.append(f"hadm_id={hadm_id}")
169
+ if v.stay_id_filter:
170
+ assert stay_id is not None
171
+ conditions.append(f"stay_id={stay_id}")
172
+ if v.study_id_filter:
173
+ assert study_id is not None
174
+ conditions.append(f"study_id={study_id}")
175
+ if v.mimic_cxr_sectioned:
176
+ assert study_id is not None
177
+ conditions.append(f"study='s{study_id}'")
178
+ ands = ['AND'] * (len(conditions) * 2 - 1)
179
+ ands[0::2] = conditions
180
+
181
+ if conditions:
182
+ query += " WHERE ("
183
+ query += ' '.join(ands)
184
+ query += ")"
185
+
186
+ df = self.connect.sql(query).df()
187
+
188
+ if v.load:
189
+
190
+ columns = [v.groupby] + v.time_columns + v.index_columns + v.text_columns + v.value_columns + v.target_sections
191
+
192
+ # Use the starting time of the stay/admission as the time:
193
+ if v.use_start_time:
194
+ df['start_time'] = self.start_time
195
+ columns += ['start_time']
196
+
197
+ if reference_time is not None:
198
+ time_column = v.time_columns[-1] if not v.use_start_time else 'start_time'
199
+
200
+ # Remove rows that are after the reference time to maintain causality:
201
+ df = df[df[time_column] < reference_time]
202
+
203
+ if self.streamlit_flag:
204
+ setattr(self, k, df)
205
+
206
+ if v.load:
207
+ columns = list(dict.fromkeys(columns)) # remove repetitions.
208
+ df = df.drop(columns=df.columns.difference(columns), axis=1)
209
+ setattr(self, f'{k}_feats', df)
210
+
211
+ def return_ed_module_features(self, stay_id, reference_time=None):
212
+ example_dict = {}
213
+ if stay_id is not None:
214
+ self.load_ed_module(stay_id=stay_id, reference_time=reference_time)
215
+ for k, v in self.ed_module_tables.items():
216
+ if v.load:
217
+
218
+ df = getattr(self, f'{k}_feats')
219
+
220
+ if self.debug:
221
+ example_dict.setdefault('ed_tables', []).append(k)
222
+
223
+ if not df.empty:
224
+
225
+ assert f'{k}_index_value_feats' not in example_dict
226
+
227
+ # The y-index and the time for each group:
228
+ time_column = v.time_columns[-1] if not v.use_start_time else 'start_time'
229
+ group_idx_to_y_idx, group_idx_to_datetime = OrderedDict(), OrderedDict()
230
+ groups = df.dropna(subset=v.index_columns + v.value_columns + v.text_columns, axis=0, how='all')
231
+ groups = groups.drop_duplicates(subset=[v.groupby])[list(dict.fromkeys([v.groupby, time_column]))]
232
+ groups = groups.reset_index(drop=True)
233
+ for i, row in groups.iterrows():
234
+ group_idx_to_y_idx[row[v.groupby]] = i
235
+ group_idx_to_datetime[row[v.groupby]] = row[time_column]
236
+
237
+ if (v.index_columns or v.value_columns) and group_idx_to_y_idx:
238
+ example_dict[f'{k}_index_value_feats'] = torch.zeros(len(group_idx_to_y_idx), v.total_indices)
239
+ if v.index_columns:
240
+ example_dict[f'{k}_index_value_feats'] = df_to_tensor_index_columns(
241
+ df=df,
242
+ tensor=example_dict[f'{k}_index_value_feats'],
243
+ group_idx_to_y_idx=group_idx_to_y_idx,
244
+ groupby=v.groupby,
245
+ index_columns=v.index_columns,
246
+ )
247
+ if v.value_columns:
248
+ example_dict[f'{k}_index_value_feats'] = df_to_tensor_value_columns(
249
+ df=df,
250
+ tensor=example_dict[f'{k}_index_value_feats'],
251
+ group_idx_to_y_idx=group_idx_to_y_idx,
252
+ groupby=v.groupby,
253
+ value_columns=v.value_columns,
254
+ value_column_to_idx=v.value_column_to_idx
255
+ )
256
+
257
+ example_dict[f'{k}_index_value_token_type_ids'] = torch.full(
258
+ [example_dict[f'{k}_index_value_feats'].shape[0]],
259
+ self.token_type_to_token_type_id[k],
260
+ dtype=torch.long,
261
+ )
262
+
263
+ event_times = list(group_idx_to_datetime.values())
264
+ assert all([i == i for i in event_times])
265
+ time_delta = [self.compute_time_delta(i, reference_time) for i in event_times]
266
+ example_dict[f'{k}_index_value_time_delta'] = torch.tensor(time_delta)[:, None]
267
+
268
+ assert example_dict[f'{k}_index_value_feats'].shape[0] == example_dict[f'{k}_index_value_time_delta'].shape[0]
269
+
270
+ if v.text_columns:
271
+ for j in group_idx_to_y_idx.keys():
272
+ group_text = df[df[v.groupby] == j]
273
+ for i in v.text_columns:
274
+
275
+ column_text = [format(k) for k in list(dict.fromkeys(group_text[i].tolist())) if k is not None]
276
+
277
+ if column_text:
278
+
279
+ example_dict.setdefault(f'{k}_{i}', []).append(f"{', '.join(column_text)}.")
280
+
281
+ event_times = group_text[time_column].iloc[0]
282
+ time_delta = self.compute_time_delta(event_times, reference_time, to_tensor=False)
283
+ example_dict.setdefault(f'{k}_{i}_time_delta', []).append(time_delta)
284
+
285
+ return example_dict
286
+
287
+ def return_mimic_cxr_features(self, study_id, reference_time=None):
288
+ example_dict = {}
289
+ if study_id is not None:
290
+ self.load_mimic_cxr(study_id=study_id, reference_time=reference_time)
291
+ for k, v in self.mimic_cxr_tables.items():
292
+ if v.load:
293
+
294
+ df = getattr(self, f'{k}_feats')
295
+
296
+ if self.debug:
297
+ example_dict.setdefault('mimic_cxr_inputs', []).append(k)
298
+
299
+ if not df.empty:
300
+
301
+ # The y-index for each group:
302
+ group_idx_to_y_idx = OrderedDict()
303
+ groups = df.dropna(
304
+ subset=v.index_columns + v.value_columns + v.text_columns + v.target_sections,
305
+ axis=0,
306
+ how='all'
307
+ )
308
+ groups = groups.drop_duplicates(subset=[v.groupby])[[v.groupby]]
309
+ groups = groups.reset_index(drop=True)
310
+ for i, row in groups.iterrows():
311
+ group_idx_to_y_idx[row[v.groupby]] = i
312
+
313
+ if v.index_columns and group_idx_to_y_idx:
314
+
315
+ example_dict[f'{k}_index_value_feats'] = torch.zeros(len(group_idx_to_y_idx), v.total_indices)
316
+ if v.index_columns:
317
+ example_dict[f'{k}_index_value_feats'] = df_to_tensor_index_columns(
318
+ df=df,
319
+ tensor=example_dict[f'{k}_index_value_feats'],
320
+ group_idx_to_y_idx=group_idx_to_y_idx,
321
+ groupby=v.groupby,
322
+ index_columns=v.index_columns,
323
+ )
324
+
325
+ example_dict[f'{k}_index_value_token_type_ids'] = torch.full(
326
+ [example_dict[f'{k}_index_value_feats'].shape[0]],
327
+ self.token_type_to_token_type_id[k],
328
+ dtype=torch.long,
329
+ )
330
+
331
+ if v.text_columns:
332
+ for j in group_idx_to_y_idx.keys():
333
+ group_text = df[df[v.groupby] == j]
334
+ for i in v.text_columns:
335
+ column_text = [format(k) for k in list(dict.fromkeys(group_text[i].tolist())) if k is not None]
336
+ if column_text:
337
+ example_dict.setdefault(f'{i}', []).append(f"{', '.join(column_text)}.")
338
+
339
+ if v.target_sections:
340
+ for j in group_idx_to_y_idx.keys():
341
+ group_text = df[df[v.groupby] == j]
342
+ for i in v.target_sections:
343
+ column_text = [format(k) for k in list(dict.fromkeys(group_text[i].tolist())) if k is not None]
344
+ assert len(column_text) == 1
345
+ example_dict[i] = column_text[-1]
346
+
347
+ return example_dict
348
+
349
+ def compute_time_delta(self, event_time, reference_time, denominator = 3600, to_tensor=True):
350
+ """
351
+ How to we transform time-delta inputs? It appears that minutes are used as the input to
352
+ a weight matrix in "Self-Supervised Transformer for Sparse and Irregularly Sampled Multivariate
353
+ Clinical Time-Series". This is almost confirmed by the CVE class defined here:
354
+ https://github.com/sindhura97/STraTS/blob/main/strats_notebook.ipynb, where the input has
355
+ a size of one.
356
+ """
357
+ time_delta = reference_time - event_time
358
+ time_delta = time_delta.total_seconds() / (denominator)
359
+ assert isinstance(time_delta, float), f'time_delta should be float, not {type(time_delta)}.'
360
+ if time_delta < 0:
361
+ raise ValueError(f'time_delta should be greater than or equal to zero, not {time_delta}.')
362
+ time_delta = self.time_delta_map(time_delta)
363
+ if to_tensor:
364
+ time_delta = torch.tensor(time_delta)
365
+ return time_delta
366
+
367
+ def filter_admissions_by_time_span(self, df, time_column):
368
+ return df[(df[time_column] > self.start_time) & (df[time_column] <= self.end_time)]
369
+
tables.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Optional
3
+
4
+ ed_cxr_token_type_ids = {
5
+ 'medrecon': 0,
6
+ 'edstays': 1,
7
+ 'triage': 2,
8
+ 'vitalsign': 3,
9
+ 'pyxis': 4,
10
+ 'mimic_cxr_2_0_0_metadata': 5,
11
+ 'medrecon_name': 6,
12
+ 'triage_chiefcomplaint': 7,
13
+ 'triage_pain': 8,
14
+ 'vitalsign_pain': 9,
15
+ 'indication': 10,
16
+ 'history': 11,
17
+ 'findings': 12,
18
+ 'impression': 13,
19
+ 'image': 14,
20
+ 'comparison': 15,
21
+ 'previous_findings': 16,
22
+ 'previous_impression': 17,
23
+ 'previous_image': 18,
24
+ }
25
+
26
+ NUM_ED_CXR_TOKEN_TYPE_IDS = max(ed_cxr_token_type_ids.values()) + 1
27
+
28
+
29
+ class TableConfig:
30
+ def __init__(
31
+ self,
32
+ name: str,
33
+ hadm_id_filter: bool = False,
34
+ stay_id_filter: bool = False,
35
+ study_id_filter: bool = False,
36
+ subject_id_filter: bool = True,
37
+ load: Optional[bool] = None,
38
+ groupby: Optional[str] = None,
39
+ index_columns: list = [],
40
+ text_columns: list = [],
41
+ value_columns: list = [],
42
+ time_columns: list = [],
43
+ target_sections: list = [],
44
+ use_start_time: bool = False,
45
+ mimic_cxr_sectioned: bool = False,
46
+ ):
47
+ self.name = name
48
+ self.hadm_id_filter = hadm_id_filter
49
+ self.stay_id_filter = stay_id_filter
50
+ self.study_id_filter = study_id_filter
51
+ self.subject_id_filter = subject_id_filter
52
+ self.load = load
53
+ self.groupby = groupby
54
+ self.index_columns_source = [index_columns] if isinstance(index_columns, str) else index_columns
55
+ self.index_columns = [f'{i}_index' for i in self.index_columns_source]
56
+ self.text_columns = [text_columns] if isinstance(text_columns, str) else text_columns
57
+ self.value_columns = [value_columns] if isinstance(value_columns, str) else value_columns
58
+ self.time_columns = [time_columns] if isinstance(time_columns, str) else time_columns
59
+ self.target_sections = [target_sections] if isinstance(target_sections, str) else target_sections
60
+ self.use_start_time = use_start_time
61
+ self.mimic_cxr_sectioned = mimic_cxr_sectioned
62
+
63
+ assert self.time_columns is None or isinstance(self.time_columns, list)
64
+
65
+ self.value_column_to_idx = {}
66
+ self.total_indices = None
67
+
68
+
69
+ # ed module:
70
+ """
71
+ Order the tables for position_ids based on their order of occurance (for cases where their time deltas are matching).
72
+ The way that they are ordered here is the way that they will be ordered as input.
73
+
74
+ 1. medrecon - the medications which the patient was taking prior to their ED stay.
75
+ 2. edstays - patient stays are tracked in the edstays table.
76
+ 3. triage - information collected from the patient at the time of triage.
77
+ 4. vitalsign - aperiodic vital signs documented for patients during their stay.
78
+ 5. pyxis - dispensation information for medications provided by the BD Pyxis MedStation (position is interchangable with 4).
79
+ """
80
+ ed_module_tables = OrderedDict(
81
+ {
82
+ 'medrecon': TableConfig(
83
+ 'Medicine reconciliation',
84
+ stay_id_filter=True,
85
+ load=True,
86
+ index_columns=['gsn', 'ndc', 'etc_rn', 'etccode'],
87
+ text_columns='name',
88
+ groupby='stay_id',
89
+ use_start_time=True,
90
+ ),
91
+ 'edstays': TableConfig(
92
+ 'ED admissions',
93
+ stay_id_filter=True,
94
+ load=True,
95
+ index_columns=['gender', 'race', 'arrival_transport'],
96
+ groupby='stay_id',
97
+ time_columns='intime',
98
+ ),
99
+ 'triage': TableConfig(
100
+ 'Triage',
101
+ stay_id_filter=True,
102
+ load=True,
103
+ text_columns=['chiefcomplaint', 'pain'],
104
+ value_columns=['temperature', 'heartrate', 'resprate', 'o2sat', 'sbp', 'dbp', 'acuity'],
105
+ groupby='stay_id',
106
+ use_start_time=True,
107
+ ),
108
+ 'vitalsign': TableConfig(
109
+ 'Aperiodic vital signs',
110
+ stay_id_filter=True,
111
+ load=True,
112
+ index_columns=['rhythm'],
113
+ text_columns=['pain'],
114
+ value_columns=['temperature', 'heartrate', 'resprate', 'o2sat', 'sbp', 'dbp'],
115
+ groupby='charttime',
116
+ time_columns='charttime',
117
+ ),
118
+ 'pyxis': TableConfig(
119
+ 'Dispensation information for medications provided by the BD Pyxis MedStation',
120
+ stay_id_filter=True,
121
+ load=True,
122
+ index_columns=['med_rn', 'name', 'gsn_rn', 'gsn'],
123
+ groupby='charttime',
124
+ time_columns='charttime',
125
+ ),
126
+ 'diagnosis': TableConfig('Diagnosis', stay_id_filter=True, hadm_id_filter=False),
127
+ }
128
+ )
129
+
130
+ # MIMIC-CXR module:
131
+ mimic_cxr_tables = OrderedDict(
132
+ {
133
+ 'mimic_cxr_2_0_0_metadata': TableConfig(
134
+ 'Metadata',
135
+ study_id_filter=True,
136
+ load=True,
137
+ index_columns=[
138
+ 'PerformedProcedureStepDescription',
139
+ 'ViewPosition',
140
+ 'ProcedureCodeSequence_CodeMeaning',
141
+ 'ViewCodeSequence_CodeMeaning',
142
+ 'PatientOrientationCodeSequence_CodeMeaning',
143
+ ],
144
+ groupby='study_id',
145
+ ),
146
+ 'mimic_cxr_sectioned': TableConfig(
147
+ 'Report sections',
148
+ mimic_cxr_sectioned=True,
149
+ subject_id_filter=False,
150
+ load=True,
151
+ groupby='study',
152
+ text_columns=['indication', 'history', 'comparison'],
153
+ target_sections=['findings', 'impression'],
154
+ ),
155
+ 'mimic_cxr_2_0_0_chexpert': TableConfig('CheXpert', study_id_filter=True),
156
+ 'mimic_cxr_2_0_0_split': TableConfig('Split', study_id_filter=True),
157
+ }
158
+ )
159
+