ChantalPellegrini commited on
Commit
06257c8
Β·
1 Parent(s): 8b6bb53

first commit

Browse files
README.md CHANGED
@@ -1,12 +1,41 @@
1
  ---
2
  title: Xplainer
3
- emoji: πŸ”₯
4
- colorFrom: blue
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.35.2
 
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Xplainer
3
+ emoji: πŸ“Š
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.34.0
8
+ python_version: 3.7.16
9
  app_file: app.py
10
  pinned: false
11
+ license: mit
12
  ---
13
 
14
+ This is the official demo for the paper "Xplainer: From X-Ray Observations to Explainable Zero-Shot Diagnosis" (https://arxiv.org/pdf/2303.13391.pdf), which was accepted for publication at MICCAI 2023.
15
+
16
+ We propose a new way of explainability for zero-shot diagnosis prediction in the clinical domain. Instead of directly predicting a diagnosis, we prompt the model to classify the existence of descriptive observations, which a radiologist would look for on an X-Ray scan, and use the descriptor probabilities to estimate the likelihood of a diagnosis, making our model explainable by design. For this we leverage BioVil, a pretrained CLIP model for X-rays and apply contrastive observation-based prompting. We evaluate Xplainer on two chest X-ray
17
+ datasets, CheXpert and ChestX-ray14, and demonstrate its effectiveness
18
+ in improving the performance and explainability of zero-shot diagnosis.
19
+ **Authors**: [Chantal Pellegrini][cp], [Matthias Keicher][mk], [Ege Γ–zsoy][eo], [Petra Jiraskova][pj], [Rickmer Braren][rb], [Nassir Navab][nn]
20
+
21
+ [cp]:https://www.cs.cit.tum.de/camp/members/chantal-pellegrini/
22
+ [eo]:https://www.cs.cit.tum.de/camp/members/ege-oezsoy/
23
+ [mk]:https://www.cs.cit.tum.de/camp/members/matthias-keicher/
24
+ [pj]:https://campus.tum.de/tumonline/ee/ui/ca2/app/desktop/#/pl/ui/$ctx/visitenkarte.show_vcard?$ctx=design=ca2;header=max;lang=de&pPersonenGruppe=3&pPersonenId=46F3A857F258DEE6
25
+ [rb]:https://radiologie.mri.tum.de/de/person/prof-dr-rickmer-f-braren
26
+ [nn]:https://www.cs.cit.tum.de/camp/members/cv-nassir-navab/nassir-navab/
27
+
28
+ Github: https://github.com/ChantalMP/Xplainer/tree/master
29
+
30
+ ```
31
+ @inproceedings{pellegrini2023xplainer,
32
+ title={Xplainer: From X-Ray Observations to Explainable Zero-Shot Diagnosis},
33
+ author={Pellegrini, Chantal and Keicher, Matthias and Γ–zsoy, Ege and Jiraskova, Petra and Braren, Rickmer and Navab, Nassir},
34
+ booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
35
+ year={2023},
36
+ organization={Springer}
37
+ }
38
+ ```
39
+
40
+ ### Intended Use
41
+ This model is intended to be used solely for (I) future research on visual-language processing and (II) reproducibility of the experimental results reported in the reference paper.
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ from matplotlib import pyplot as plt
6
+
7
+ from descriptors import disease_descriptors_chexpert, disease_descriptors_chestxray14
8
+ from model import InferenceModel
9
+
10
+
11
+ def plot_bars(model_output):
12
+ # sort model_output by overall_probability
13
+ model_output = {k: v for k, v in sorted(model_output.items(), key=lambda item: item[1]['overall_probability'], reverse=True)}
14
+
15
+ # Create a figure with as many subplots as there are diseases, arranged vertically
16
+ fig, axs = plt.subplots(len(model_output), 1, figsize=(10, 5 * len(model_output)))
17
+ # axs is not iterable if only one subplot is created, so make it a list
18
+ if len(model_output) == 1:
19
+ axs = [axs]
20
+
21
+ for ax, (disease, data) in zip(axs, model_output.items()):
22
+ desc_probs = list(data['descriptor_probabilities'].items())
23
+ # sort descending
24
+ desc_probs = sorted(desc_probs, key=lambda item: item[1], reverse=True)
25
+
26
+ my_probs = [p[1] for p in desc_probs]
27
+ min_prob = min(my_probs)
28
+ max_prob = max(my_probs)
29
+ my_labels = [p[0] for p in desc_probs]
30
+
31
+ # Convert probabilities to differences from 0.5
32
+ diffs = np.abs(np.array(my_probs) - 0.5)
33
+
34
+ # Set colors based on sign of difference
35
+ colors = ['red' if p < 0.5 else 'forestgreen' for p in my_probs]
36
+
37
+ # Plot bars with appropriate colors and left offsets
38
+ left = [p if p < 0.5 else 0.5 for p in my_probs]
39
+ bars = ax.barh(my_labels, diffs, left=left, color=colors, alpha=0.3)
40
+
41
+ for i, bar in enumerate(bars):
42
+ ax.text(min_prob - 0.04, bar.get_y() + bar.get_height() / 2, my_labels[i], ha='left', va='center', color='black', fontsize=15)
43
+
44
+ ax.set_xlim(min(min_prob - 0.05, 0.49), max(max_prob + 0.05, 0.51))
45
+
46
+ # Invert the y-axis to show bars with values less than 0.5 to the left of the center
47
+ ax.invert_yaxis()
48
+
49
+ ax.set_yticks([])
50
+
51
+ # Add a title for the disease
52
+ if data['overall_probability'] >= 0.5:
53
+ ax.set_title(f"{disease} : score of {data['overall_probability']:.2f}")
54
+ else:
55
+ ax.set_title(f"No {disease} : score of {data['overall_probability']:.2f}")
56
+
57
+ # make title larger and bold
58
+ ax.title.set_fontsize(15)
59
+ ax.title.set_fontweight(600)
60
+
61
+ # Save the plot
62
+ plt.tight_layout() # Adjust subplot parameters to give specified padding
63
+ file_path = 'plot.png'
64
+ plt.savefig(file_path)
65
+ plt.close(fig)
66
+
67
+ return file_path
68
+
69
+
70
+ def classify_image(inference_model, image_path, diseases_to_predict):
71
+ descriptors_with_indication = [d + " indicating " + disease for disease, descriptors in diseases_to_predict.items() for d in descriptors]
72
+ probs, negative_probs = inference_model.get_descriptor_probs(image_path=Path(image_path), descriptors=descriptors_with_indication,
73
+ do_negative_prompting=True, demo=True)
74
+
75
+ disease_probs, negative_disease_probs = inference_model.get_diseases_probs(diseases_to_predict, pos_probs=probs, negative_probs=negative_probs)
76
+
77
+ model_output = {}
78
+ for idx, disease in enumerate(diseases_to_predict.keys()):
79
+ model_output[disease] = {
80
+ 'overall_probability': disease_probs[disease],
81
+ 'descriptor_probabilities': {descriptor: probs[f'{descriptor} indicating {disease}'].item() for descriptor in
82
+ diseases_to_predict[disease]}
83
+ }
84
+
85
+ file_path = plot_bars(model_output)
86
+ return file_path
87
+
88
+
89
+ # Define the function you want to wrap
90
+ def process_input(image_path, prompt_names: list, disease_name: str, descriptors: str):
91
+ diseases_to_predict = {}
92
+
93
+ for prompt in prompt_names:
94
+ if prompt == 'Custom':
95
+ diseases_to_predict[disease_name] = descriptors.split('\n')
96
+ else:
97
+ if prompt in disease_descriptors_chexpert:
98
+ diseases_to_predict[prompt] = disease_descriptors_chexpert[prompt]
99
+ else: # only chestxray14
100
+ diseases_to_predict[prompt] = disease_descriptors_chestxray14[prompt]
101
+
102
+ # classify
103
+ model = InferenceModel()
104
+ output = classify_image(model, image_path, diseases_to_predict)
105
+
106
+ return output
107
+
108
+
109
+ # Define the Gradio interface
110
+ iface = gr.Interface(
111
+ fn=process_input,
112
+ examples = [['examples/enlarged_cardiomediastinum.jpg', ['Enlarged Cardiomediastinum'], '', ''],['examples/edema.jpg', ['Edema'], '', ''],
113
+ ['examples/support_devices.jpg', ['Custom'], 'Pacemaker', 'metalic object\nimplant on the left side of the chest\nimplanted cardiac device']],
114
+ inputs=[gr.inputs.Image(type="filepath"), gr.inputs.CheckboxGroup(
115
+ choices=['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
116
+ 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices',
117
+ 'Infiltration', 'Mass', 'Nodule', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia',
118
+ 'Custom'],
119
+ default=['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
120
+ 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'],
121
+ label='Selct to use predefined disease descriptors. Select "Custom" to define your own observations.'),
122
+ gr.inputs.Textbox(lines=2, placeholder="Name of pathology for which you want to define custom observations", label='Pathology:'),
123
+ gr.inputs.Textbox(lines=2, placeholder="Add your custom (positive) observations separated by a new line"
124
+ "\n Note: Each descriptor will automatically be embedded into our prompt format: There is/are (no) <observation> indicating <pathology>"
125
+ "\n Example:\n\n Opacity\nPleural Effusion\nConsolidation"
126
+ , label='Custom Observations:')],
127
+ outputs=gr.outputs.Image(type="filepath")
128
+ )
129
+
130
+ # Launch the interface
131
+ iface.launch()
descriptors.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ disease_descriptors_chexpert = {
2
+ "No Finding": [
3
+ "Clear lung fields",
4
+ "Normal heart size and shape",
5
+ "No Abnormal fluid buildup",
6
+ "No Visible tumors or masses",
7
+ "No Signs of bone fractures or dislocations"
8
+ ],
9
+ "Enlarged Cardiomediastinum": [
10
+ "Increased width of the heart shadow",
11
+ "Widened mediastinum",
12
+ "Abnormal contour of the heart border",
13
+ "Fluid or air within the pericardium",
14
+ "Mass within the mediastinum",
15
+ ],
16
+ "Cardiomegaly": [
17
+ "Increased size of the heart shadow",
18
+ "Enlargement of the heart silhouette",
19
+ "Increased diameter of the heart border",
20
+ "Increased cardiothoracic ratio",
21
+ ],
22
+ "Lung Opacity": [
23
+ "Increased density in the lung field",
24
+ "Whitish or grayish area in the lung field",
25
+ "Obscured or blurred margins of the lung field",
26
+ "Loss of normal lung markings within the opacity",
27
+ "Air bronchograms within the opacity",
28
+ "Fluid levels within the opacity",
29
+ "Silhouette sign loss with adjacent structures",
30
+
31
+ ],
32
+ "Lung Lesion": [
33
+ "Consolidation of lung tissue",
34
+ "Pleural effusion",
35
+ "Cavities or abscesses in the lung",
36
+ "Abnormal opacity or shadow in the lung",
37
+ "Irregular or blurred margins of the lung",
38
+
39
+ ],
40
+ "Edema": [
41
+ "Blurry vascular markings in the lungs",
42
+ "Enlarged heart",
43
+ "Kerley B lines",
44
+ "Increased interstitial markings in the lungs",
45
+ "Widening of interstitial spaces",
46
+ ],
47
+ "Consolidation": [
48
+ "Loss of lung volume",
49
+ "Increased density of lung tissue",
50
+ "Obliteration of the diaphragmatic silhouette",
51
+ "Presence of opacities",
52
+ ],
53
+ "Pneumonia": [
54
+ "Consolidation of lung tissue",
55
+ "Air bronchograms",
56
+ "Cavitation",
57
+ "Interstitial opacities",
58
+ ],
59
+ "Atelectasis": [
60
+ "Increased opacity",
61
+ "Volume loss of the affected lung region",
62
+ "Blunting of the costophrenic angle",
63
+ "Shifting of the mediastinum",
64
+ ],
65
+ "Pneumothorax": [
66
+ "Tracheal deviation",
67
+ "Deep sulcus sign",
68
+ "Increased radiolucency",
69
+ "Flattening of the hemidiaphragm",
70
+ "Absence of lung markings",
71
+ "Shifting of the mediastinum"
72
+ ],
73
+ "Pleural Effusion": [
74
+ "Blunting of costophrenic angles",
75
+ "Opacity in the lower lung fields",
76
+ "Mediastinal shift",
77
+ "Reduced lung volume",
78
+ "Presence of meniscus sign or veil-like appearance"
79
+ ],
80
+ "Pleural Other": [
81
+ "Pleural thickening",
82
+ "Pleural calcification",
83
+ "Pleural masses or nodules",
84
+ "Pleural empyema",
85
+ "Pleural fibrosis",
86
+ "Pleural adhesions"
87
+ ],
88
+ "Fracture": [
89
+ "Visible breaks in the continuity of the bone",
90
+ "Misalignments of bone fragments",
91
+ "Displacements of bone fragments",
92
+ "Disruptions of the cortex or outer layer of the bone",
93
+ "Visible callus or healing tissue",
94
+ "Fracture lines that are jagged or irregular in shape",
95
+ "Multiple fracture lines that intersect at different angles"
96
+ ],
97
+ "Support Devices": [
98
+ "Artificial joints or implants",
99
+ "Pacemakers or cardiac devices",
100
+ "Stents or other vascular devices",
101
+ "Prosthetic devices or limbs",
102
+ "Breast implants",
103
+ "Radiotherapy markers or seeds"
104
+ ]
105
+ }
106
+
107
+ disease_descriptors_chestxray14 = {
108
+
109
+ "No Finding": ["No Finding"],
110
+ "Cardiomegaly": [
111
+ "Increased size of the heart shadow",
112
+ "Enlargement of the heart silhouette",
113
+ "Increased diameter of the heart border",
114
+ "Increased cardiothoracic ratio"
115
+ ],
116
+ "Edema": [
117
+ "Blurry vascular markings in the lungs",
118
+ "Kerley B lines",
119
+ "Increased interstitial markings in the lungs",
120
+ "Widening of interstitial spaces"
121
+ ],
122
+ "Consolidation": [
123
+ "Loss of lung volume",
124
+ "Increased density of lung tissue",
125
+ "Obliteration of the diaphragmatic silhouette",
126
+ "Presence of opacities"
127
+ ],
128
+ "Pneumonia": [
129
+ "Consolidation of lung tissue",
130
+ "Air bronchograms",
131
+ "Cavitation",
132
+ "Interstitial opacities"
133
+ ],
134
+ "Atelectasis": [
135
+ "Increased opacity",
136
+ "Volume loss of the affected lung region",
137
+ "Displacement of the diaphragm",
138
+ "Blunting of the costophrenic angle",
139
+ "Shifting of the mediastinum"
140
+ ],
141
+ "Pneumothorax": [
142
+ "Tracheal deviation",
143
+ "Deep sulcus sign",
144
+ "Increased radiolucency",
145
+ "Flattening of the hemidiaphragm",
146
+ "Absence of lung markings",
147
+ "Shifting of the mediastinum"
148
+ ],
149
+ "Pleural Effusion": [
150
+ "Blunting of costophrenic angles",
151
+ "Opacity in the lower lung fields",
152
+ "Mediastinal shift",
153
+ "Reduced lung volume",
154
+ "Meniscus sign or veil-like appearance"
155
+ ],
156
+ "Infiltration": [
157
+ "Irregular or fuzzy borders around white areas",
158
+ "Blurring",
159
+ "Hazy or cloudy areas",
160
+ "Increased density or opacity of lung tissue",
161
+ "Air bronchograms",
162
+ ],
163
+ "Mass": [
164
+ "Calcifications or mineralizations",
165
+ "Shadowing",
166
+ "Distortion or compression of tissues",
167
+ "Anomalous structure or irregularity in shape"
168
+ ],
169
+ "Nodule": [
170
+ "Nodular shape that protrudes into a cavity or airway",
171
+ "Distinct edges or borders",
172
+ "Calcifications or speckled areas",
173
+ "Small round oral shaped spots",
174
+ "White shadows"
175
+ ],
176
+ "Emphysema": [
177
+ "Flattened hemidiaphragm",
178
+ "Pulmonary bullae",
179
+ "Hyperlucent lungs",
180
+ "Horizontalisation of ribs",
181
+ "Barrel Chest",
182
+ ],
183
+ "Fibrosis": [
184
+ "Reticular shadowing of the lung peripheries",
185
+ "Volume loss",
186
+ "Thickened and irregular interstitial markings",
187
+ "Bronchial dilation",
188
+ "Shaggy heart borders"
189
+ ],
190
+ "Pleural Thickening": [
191
+ "Thickened pleural line",
192
+ "Loss of sharpness of the mediastinal border",
193
+ "Calcifications on the pleura",
194
+ "Lobulated peripheral shadowing",
195
+ "Loss of lung volume",
196
+ ],
197
+ "Hernia": [
198
+ "Bulge or swelling in the abdominal wall",
199
+ "Protrusion of intestine or other abdominal tissue",
200
+ "Swelling or enlargement of the herniated sac or surrounding tissues",
201
+ "Retro-cardiac air-fluid level",
202
+ "Thickening of intestinal folds"
203
+ ]
204
+ }
examples/edema.jpg ADDED
examples/enlarged_cardiomediastinum.jpg ADDED
examples/support_devices.jpg ADDED
flagged/image_path/tmp4gt9tvuq.png ADDED
flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ image_path,"Selct to use predefined disease descriptors. Select ""Custom"" to define your own observations.",Pathology:,Custom Observations:,output,flag,username,timestamp
2
+ /Users/chantal/Documents/programmieren_gitbackuped/Xplainer/flagged/image_path/tmp4gt9tvuq.png,['Cardiomegaly'],,,/Users/chantal/Documents/programmieren_gitbackuped/Xplainer/flagged/output/tmpy60f1_o9.png,,,2023-06-27 18:09:42.017441
flagged/output/tmpy60f1_o9.png ADDED
inference.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ from tqdm import tqdm
8
+
9
+ from chestxray14 import ChestXray14Dataset
10
+ from chexpert import CheXpertDataset
11
+ from descriptors import disease_descriptors_chexpert, disease_descriptors_chestxray14
12
+ from model import InferenceModel
13
+ from utils import calculate_auroc
14
+
15
+ torch.multiprocessing.set_sharing_strategy('file_system')
16
+
17
+
18
+ def inference_chexpert():
19
+ split = 'test'
20
+ dataset = CheXpertDataset(f'data/chexpert/{split}_labels.csv') # also do test
21
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x, num_workers=0)
22
+ inference_model = InferenceModel()
23
+ all_descriptors = inference_model.get_all_descriptors(disease_descriptors_chexpert)
24
+
25
+ all_labels = []
26
+ all_probs_neg = []
27
+
28
+ for batch in tqdm(dataloader):
29
+ batch = batch[0]
30
+ image_paths, labels, keys = batch
31
+ image_paths = [Path(image_path) for image_path in image_paths]
32
+ agg_probs = []
33
+ agg_negative_probs = []
34
+ for image_path in image_paths:
35
+ probs, negative_probs = inference_model.get_descriptor_probs(image_path, descriptors=all_descriptors)
36
+ agg_probs.append(probs)
37
+ agg_negative_probs.append(negative_probs)
38
+ probs = {} # Aggregated
39
+ negative_probs = {} # Aggregated
40
+ for key in agg_probs[0].keys():
41
+ probs[key] = sum([p[key] for p in agg_probs]) / len(agg_probs) # Mean Aggregation
42
+
43
+ for key in agg_negative_probs[0].keys():
44
+ negative_probs[key] = sum([p[key] for p in agg_negative_probs]) / len(agg_negative_probs) # Mean Aggregation
45
+
46
+ disease_probs, negative_disease_probs = inference_model.get_diseases_probs(disease_descriptors_chexpert, pos_probs=probs,
47
+ negative_probs=negative_probs)
48
+ predicted_diseases, prob_vector_neg_prompt = inference_model.get_predictions_bin_prompting(disease_descriptors_chexpert,
49
+ disease_probs=disease_probs,
50
+ negative_disease_probs=negative_disease_probs,
51
+ keys=keys)
52
+ all_labels.append(labels)
53
+ all_probs_neg.append(prob_vector_neg_prompt)
54
+
55
+ all_labels = torch.stack(all_labels)
56
+ all_probs_neg = torch.stack(all_probs_neg)
57
+
58
+ # evaluation
59
+ existing_mask = sum(all_labels, 0) > 0
60
+ all_labels_clean = all_labels[:, existing_mask]
61
+ all_probs_neg_clean = all_probs_neg[:, existing_mask]
62
+ all_keys_clean = [key for idx, key in enumerate(keys) if existing_mask[idx]]
63
+
64
+ overall_auroc, per_disease_auroc = calculate_auroc(all_probs_neg_clean, all_labels_clean)
65
+ print(f"AUROC: {overall_auroc:.5f}\n")
66
+ for idx, key in enumerate(all_keys_clean):
67
+ print(f'{key}: {per_disease_auroc[idx]:.5f}')
68
+
69
+
70
+ def inference_chestxray14():
71
+ dataset = ChestXray14Dataset(f'data/chestxray14/Data_Entry_2017_v2020_modified.csv')
72
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x, num_workers=1)
73
+ inference_model = InferenceModel()
74
+ all_descriptors = inference_model.get_all_descriptors(disease_descriptors_chestxray14)
75
+
76
+ all_labels = []
77
+ all_probs_neg = []
78
+ for batch in tqdm(dataloader):
79
+ batch = batch[0]
80
+ image_path, labels, keys = batch
81
+ image_path = Path(image_path)
82
+ probs, negative_probs = inference_model.get_descriptor_probs(image_path, descriptors=all_descriptors)
83
+ disease_probs, negative_disease_probs = inference_model.get_diseases_probs(disease_descriptors_chestxray14, pos_probs=probs,
84
+ negative_probs=negative_probs)
85
+ predicted_diseases, prob_vector_neg_prompt = inference_model.get_predictions_bin_prompting(disease_descriptors_chestxray14,
86
+ disease_probs=disease_probs,
87
+ negative_disease_probs=negative_disease_probs,
88
+ keys=keys)
89
+ all_labels.append(labels)
90
+ all_probs_neg.append(prob_vector_neg_prompt)
91
+ gc.collect()
92
+
93
+ all_labels = torch.stack(all_labels)
94
+ all_probs_neg = torch.stack(all_probs_neg)
95
+
96
+ existing_mask = sum(all_labels, 0) > 0
97
+ all_labels_clean = all_labels[:, existing_mask]
98
+ all_probs_neg_clean = all_probs_neg[:, existing_mask]
99
+ all_keys_clean = [key for idx, key in enumerate(keys) if existing_mask[idx]]
100
+
101
+ overall_auroc, per_disease_auroc = calculate_auroc(all_probs_neg_clean[:, 1:], all_labels_clean[:, 1:])
102
+ print(f"AUROC: {overall_auroc:.5f}\n")
103
+ for idx, key in enumerate(all_keys_clean[1:]):
104
+ print(f'{key}: {per_disease_auroc[idx]:.5f}')
105
+
106
+
107
+ if __name__ == '__main__':
108
+ # add argument parser
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument('--dataset', type=str, default='chexpert', help='chexpert or chestxray14')
111
+ args = parser.parse_args()
112
+
113
+ if args.dataset == 'chexpert':
114
+ inference_chexpert()
115
+ elif args.dataset == 'chestxray14':
116
+ inference_chestxray14()
model.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from health_multimodal.image import get_biovil_resnet_inference
7
+ from health_multimodal.text import get_cxr_bert_inference
8
+ from health_multimodal.vlp import ImageTextInferenceEngine
9
+
10
+ from utils import cos_sim_to_prob, prob_to_log_prob, log_prob_to_prob
11
+
12
+
13
+ class InferenceModel():
14
+ def __init__(self):
15
+ self.text_inference = get_cxr_bert_inference()
16
+ self.image_inference = get_biovil_resnet_inference()
17
+ self.image_text_inference = ImageTextInferenceEngine(
18
+ image_inference_engine=self.image_inference,
19
+ text_inference_engine=self.text_inference,
20
+ )
21
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ self.image_text_inference.to(self.device)
23
+
24
+ # caches for faster inference
25
+ self.text_embedding_cache = {}
26
+ self.image_embedding_cache = {}
27
+
28
+ self.transform = self.image_inference.transform
29
+
30
+ def get_similarity_score_from_raw_data(self, image_embedding, query_text: str) -> float:
31
+ """Compute the cosine similarity score between an image and one or more strings.
32
+ If multiple strings are passed, their embeddings are averaged before L2-normalization.
33
+ :param image_path: Path to the input chest X-ray, either a DICOM or JPEG file.
34
+ :param query_text: Input radiology text phrase.
35
+ :return: The similarity score between the image and the text.
36
+ """
37
+ assert not self.image_text_inference.image_inference_engine.model.training
38
+ assert not self.image_text_inference.text_inference_engine.model.training
39
+ if query_text in self.text_embedding_cache:
40
+ text_embedding = self.text_embedding_cache[query_text]
41
+ else:
42
+ text_embedding = self.image_text_inference.text_inference_engine.get_embeddings_from_prompt([query_text], normalize=False)
43
+ text_embedding = text_embedding.mean(dim=0)
44
+ text_embedding = F.normalize(text_embedding, dim=0, p=2)
45
+ self.text_embedding_cache[query_text] = text_embedding
46
+
47
+ cos_similarity = image_embedding @ text_embedding.t()
48
+
49
+ return cos_similarity.item()
50
+
51
+ def process_image(self, image):
52
+ ''' same code as in image_text_inference.image_inference_engine.get_projected_global_embedding() but adapted to deal with image instances instead of path'''
53
+
54
+ transformed_image = self.transform(image)
55
+ projected_img_emb = self.image_inference.model.forward(transformed_image).projected_global_embedding
56
+ projected_img_emb = F.normalize(projected_img_emb, dim=-1)
57
+ assert projected_img_emb.shape[0] == 1
58
+ assert projected_img_emb.ndim == 2
59
+ return projected_img_emb[0]
60
+
61
+ def get_descriptor_probs(self, image_path: Path, descriptors: List[str], do_negative_prompting=True, demo=False):
62
+ probs = {}
63
+ negative_probs = {}
64
+ if image_path in self.image_embedding_cache:
65
+ image_embedding = self.image_embedding_cache[image_path]
66
+ else:
67
+ image_embedding = self.image_text_inference.image_inference_engine.get_projected_global_embedding(image_path)
68
+ if not demo:
69
+ self.image_embedding_cache[image_path] = image_embedding
70
+
71
+ # Default get_similarity_score_from_raw_data would load the image every time. Instead we only load once.
72
+ for desc in descriptors:
73
+ prompt = f'There are {desc}'
74
+ score = self.get_similarity_score_from_raw_data(image_embedding, prompt)
75
+ if do_negative_prompting:
76
+ neg_prompt = f'There are no {desc}'
77
+ neg_score = self.get_similarity_score_from_raw_data(image_embedding, neg_prompt)
78
+
79
+ pos_prob = cos_sim_to_prob(score)
80
+
81
+ if do_negative_prompting:
82
+ pos_prob, neg_prob = torch.softmax((torch.tensor([score, neg_score]) / 0.5), dim=0)
83
+ negative_probs[desc] = neg_prob
84
+
85
+ probs[desc] = pos_prob
86
+
87
+ return probs, negative_probs
88
+
89
+ def get_all_descriptors(self, disease_descriptors):
90
+ all_descriptors = set()
91
+ for disease, descs in disease_descriptors.items():
92
+ all_descriptors.update([f"{desc} indicating {disease}" for desc in descs])
93
+ all_descriptors = sorted(all_descriptors)
94
+ return all_descriptors
95
+
96
+ def get_all_descriptors_only_disease(self, disease_descriptors):
97
+ all_descriptors = set()
98
+ for disease, descs in disease_descriptors.items():
99
+ all_descriptors.update([f"{desc}" for desc in descs])
100
+ all_descriptors = sorted(all_descriptors)
101
+ return all_descriptors
102
+
103
+ def get_diseases_probs(self, disease_descriptors, pos_probs, negative_probs, prior_probs=None, do_negative_prompting=True):
104
+ disease_probs = {}
105
+ disease_neg_probs = {}
106
+ for disease, descriptors in disease_descriptors.items():
107
+ desc_log_probs = []
108
+ desc_neg_log_probs = []
109
+ for desc in descriptors:
110
+ desc = f"{desc} indicating {disease}"
111
+ desc_log_probs.append(prob_to_log_prob(pos_probs[desc]))
112
+ if do_negative_prompting:
113
+ desc_neg_log_probs.append(prob_to_log_prob(negative_probs[desc]))
114
+ disease_log_prob = sum(sorted(desc_log_probs, reverse=True)) / len(desc_log_probs)
115
+ if do_negative_prompting:
116
+ disease_neg_log_prob = sum(desc_neg_log_probs) / len(desc_neg_log_probs)
117
+ disease_probs[disease] = log_prob_to_prob(disease_log_prob)
118
+ if do_negative_prompting:
119
+ disease_neg_probs[disease] = log_prob_to_prob(disease_neg_log_prob)
120
+
121
+ return disease_probs, disease_neg_probs
122
+
123
+ # Threshold Based
124
+ def get_predictions(self, disease_descriptors, threshold, disease_probs, keys):
125
+ predicted_diseases = []
126
+ prob_vector = torch.zeros(len(keys), dtype=torch.float) # num of diseases
127
+ for idx, disease in enumerate(disease_descriptors):
128
+ if disease == 'No Finding':
129
+ continue
130
+ prob_vector[keys.index(disease)] = disease_probs[disease]
131
+ if disease_probs[disease] > threshold:
132
+ predicted_diseases.append(disease)
133
+
134
+ if len(predicted_diseases) == 0: # No finding rule based
135
+ prob_vector[0] = 1.0 - max(prob_vector)
136
+ else:
137
+ prob_vector[0] = 1.0 - max(prob_vector)
138
+
139
+ return predicted_diseases, prob_vector
140
+
141
+ # Negative vs Positive Prompting
142
+ def get_predictions_bin_prompting(self, disease_descriptors, disease_probs, negative_disease_probs, keys):
143
+ predicted_diseases = []
144
+ prob_vector = torch.zeros(len(keys), dtype=torch.float) # num of diseases
145
+ for idx, disease in enumerate(disease_descriptors):
146
+ if disease == 'No Finding':
147
+ continue
148
+ pos_neg_scores = torch.tensor([disease_probs[disease], negative_disease_probs[disease]])
149
+ prob_vector[keys.index(disease)] = pos_neg_scores[0]
150
+ if torch.argmax(pos_neg_scores) == 0: # Positive is More likely
151
+ predicted_diseases.append(disease)
152
+
153
+ if len(predicted_diseases) == 0: # No finding rule based
154
+ prob_vector[0] = 1.0 - max(prob_vector)
155
+ else:
156
+ prob_vector[0] = 1.0 - max(prob_vector)
157
+
158
+ return predicted_diseases, prob_vector
plot.png ADDED
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ hi-ml-multimodal==0.1.2
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ scikit-learn==1.0.2
2
+ transformers==4.17.0
3
+ gradio==3.34.0
4
+ pandas==1.3.5
5
+ torch==1.13.0
utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import log, exp
2
+
3
+ import numpy as np
4
+ from sklearn.metrics import roc_auc_score
5
+
6
+
7
+ def cos_sim_to_prob(sim):
8
+ return (sim + 1) / 2 # linear transformation to 0 and 1
9
+
10
+
11
+ def log_prob_to_prob(log_prob):
12
+ return exp(log_prob)
13
+
14
+
15
+ def prob_to_log_prob(prob):
16
+ return log(prob)
17
+
18
+
19
+ def calculate_auroc(all_disease_probs, gt_diseases):
20
+ '''
21
+ Calculates the AUROC (Area Under the Receiver Operating Characteristic curve) for multiple diseases.
22
+
23
+ Parameters:
24
+ all_disease_probs (numpy array): predicted disease labels, a multi-hot vector of shape (N_samples, 14)
25
+ gt_diseases (numpy array): ground truth disease labels, a multi-hot vector of shape (N_samples, 14)
26
+
27
+ Returns:
28
+ overall_auroc (float): the overall AUROC score
29
+ per_disease_auroc (numpy array): an array of shape (14,) containing the AUROC score for each disease
30
+ '''
31
+
32
+ per_disease_auroc = np.zeros((gt_diseases.shape[1],)) # num of diseases
33
+ for i in range(gt_diseases.shape[1]):
34
+ # Compute the AUROC score for each disease
35
+ per_disease_auroc[i] = roc_auc_score(gt_diseases[:, i], all_disease_probs[:, i])
36
+
37
+ # Compute the overall AUROC score
38
+ overall_auroc = roc_auc_score(gt_diseases, all_disease_probs, average='macro')
39
+
40
+ return overall_auroc, per_disease_auroc