Commit
Β·
06257c8
1
Parent(s):
8b6bb53
first commit
Browse files- README.md +34 -5
- app.py +131 -0
- descriptors.py +204 -0
- examples/edema.jpg +0 -0
- examples/enlarged_cardiomediastinum.jpg +0 -0
- examples/support_devices.jpg +0 -0
- flagged/image_path/tmp4gt9tvuq.png +0 -0
- flagged/log.csv +2 -0
- flagged/output/tmpy60f1_o9.png +0 -0
- inference.py +116 -0
- model.py +158 -0
- plot.png +0 -0
- pre-requirements.txt +1 -0
- requirements.txt +5 -0
- utils.py +40 -0
README.md
CHANGED
@@ -1,12 +1,41 @@
|
|
1 |
---
|
2 |
title: Xplainer
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: Xplainer
|
3 |
+
emoji: π
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.34.0
|
8 |
+
python_version: 3.7.16
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
+
license: mit
|
12 |
---
|
13 |
|
14 |
+
This is the official demo for the paper "Xplainer: From X-Ray Observations to Explainable Zero-Shot Diagnosis" (https://arxiv.org/pdf/2303.13391.pdf), which was accepted for publication at MICCAI 2023.
|
15 |
+
|
16 |
+
We propose a new way of explainability for zero-shot diagnosis prediction in the clinical domain. Instead of directly predicting a diagnosis, we prompt the model to classify the existence of descriptive observations, which a radiologist would look for on an X-Ray scan, and use the descriptor probabilities to estimate the likelihood of a diagnosis, making our model explainable by design. For this we leverage BioVil, a pretrained CLIP model for X-rays and apply contrastive observation-based prompting. We evaluate Xplainer on two chest X-ray
|
17 |
+
datasets, CheXpert and ChestX-ray14, and demonstrate its effectiveness
|
18 |
+
in improving the performance and explainability of zero-shot diagnosis.
|
19 |
+
**Authors**: [Chantal Pellegrini][cp], [Matthias Keicher][mk], [Ege Γzsoy][eo], [Petra Jiraskova][pj], [Rickmer Braren][rb], [Nassir Navab][nn]
|
20 |
+
|
21 |
+
[cp]:https://www.cs.cit.tum.de/camp/members/chantal-pellegrini/
|
22 |
+
[eo]:https://www.cs.cit.tum.de/camp/members/ege-oezsoy/
|
23 |
+
[mk]:https://www.cs.cit.tum.de/camp/members/matthias-keicher/
|
24 |
+
[pj]:https://campus.tum.de/tumonline/ee/ui/ca2/app/desktop/#/pl/ui/$ctx/visitenkarte.show_vcard?$ctx=design=ca2;header=max;lang=de&pPersonenGruppe=3&pPersonenId=46F3A857F258DEE6
|
25 |
+
[rb]:https://radiologie.mri.tum.de/de/person/prof-dr-rickmer-f-braren
|
26 |
+
[nn]:https://www.cs.cit.tum.de/camp/members/cv-nassir-navab/nassir-navab/
|
27 |
+
|
28 |
+
Github: https://github.com/ChantalMP/Xplainer/tree/master
|
29 |
+
|
30 |
+
```
|
31 |
+
@inproceedings{pellegrini2023xplainer,
|
32 |
+
title={Xplainer: From X-Ray Observations to Explainable Zero-Shot Diagnosis},
|
33 |
+
author={Pellegrini, Chantal and Keicher, Matthias and Γzsoy, Ege and Jiraskova, Petra and Braren, Rickmer and Navab, Nassir},
|
34 |
+
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
|
35 |
+
year={2023},
|
36 |
+
organization={Springer}
|
37 |
+
}
|
38 |
+
```
|
39 |
+
|
40 |
+
### Intended Use
|
41 |
+
This model is intended to be used solely for (I) future research on visual-language processing and (II) reproducibility of the experimental results reported in the reference paper.
|
app.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
|
7 |
+
from descriptors import disease_descriptors_chexpert, disease_descriptors_chestxray14
|
8 |
+
from model import InferenceModel
|
9 |
+
|
10 |
+
|
11 |
+
def plot_bars(model_output):
|
12 |
+
# sort model_output by overall_probability
|
13 |
+
model_output = {k: v for k, v in sorted(model_output.items(), key=lambda item: item[1]['overall_probability'], reverse=True)}
|
14 |
+
|
15 |
+
# Create a figure with as many subplots as there are diseases, arranged vertically
|
16 |
+
fig, axs = plt.subplots(len(model_output), 1, figsize=(10, 5 * len(model_output)))
|
17 |
+
# axs is not iterable if only one subplot is created, so make it a list
|
18 |
+
if len(model_output) == 1:
|
19 |
+
axs = [axs]
|
20 |
+
|
21 |
+
for ax, (disease, data) in zip(axs, model_output.items()):
|
22 |
+
desc_probs = list(data['descriptor_probabilities'].items())
|
23 |
+
# sort descending
|
24 |
+
desc_probs = sorted(desc_probs, key=lambda item: item[1], reverse=True)
|
25 |
+
|
26 |
+
my_probs = [p[1] for p in desc_probs]
|
27 |
+
min_prob = min(my_probs)
|
28 |
+
max_prob = max(my_probs)
|
29 |
+
my_labels = [p[0] for p in desc_probs]
|
30 |
+
|
31 |
+
# Convert probabilities to differences from 0.5
|
32 |
+
diffs = np.abs(np.array(my_probs) - 0.5)
|
33 |
+
|
34 |
+
# Set colors based on sign of difference
|
35 |
+
colors = ['red' if p < 0.5 else 'forestgreen' for p in my_probs]
|
36 |
+
|
37 |
+
# Plot bars with appropriate colors and left offsets
|
38 |
+
left = [p if p < 0.5 else 0.5 for p in my_probs]
|
39 |
+
bars = ax.barh(my_labels, diffs, left=left, color=colors, alpha=0.3)
|
40 |
+
|
41 |
+
for i, bar in enumerate(bars):
|
42 |
+
ax.text(min_prob - 0.04, bar.get_y() + bar.get_height() / 2, my_labels[i], ha='left', va='center', color='black', fontsize=15)
|
43 |
+
|
44 |
+
ax.set_xlim(min(min_prob - 0.05, 0.49), max(max_prob + 0.05, 0.51))
|
45 |
+
|
46 |
+
# Invert the y-axis to show bars with values less than 0.5 to the left of the center
|
47 |
+
ax.invert_yaxis()
|
48 |
+
|
49 |
+
ax.set_yticks([])
|
50 |
+
|
51 |
+
# Add a title for the disease
|
52 |
+
if data['overall_probability'] >= 0.5:
|
53 |
+
ax.set_title(f"{disease} : score of {data['overall_probability']:.2f}")
|
54 |
+
else:
|
55 |
+
ax.set_title(f"No {disease} : score of {data['overall_probability']:.2f}")
|
56 |
+
|
57 |
+
# make title larger and bold
|
58 |
+
ax.title.set_fontsize(15)
|
59 |
+
ax.title.set_fontweight(600)
|
60 |
+
|
61 |
+
# Save the plot
|
62 |
+
plt.tight_layout() # Adjust subplot parameters to give specified padding
|
63 |
+
file_path = 'plot.png'
|
64 |
+
plt.savefig(file_path)
|
65 |
+
plt.close(fig)
|
66 |
+
|
67 |
+
return file_path
|
68 |
+
|
69 |
+
|
70 |
+
def classify_image(inference_model, image_path, diseases_to_predict):
|
71 |
+
descriptors_with_indication = [d + " indicating " + disease for disease, descriptors in diseases_to_predict.items() for d in descriptors]
|
72 |
+
probs, negative_probs = inference_model.get_descriptor_probs(image_path=Path(image_path), descriptors=descriptors_with_indication,
|
73 |
+
do_negative_prompting=True, demo=True)
|
74 |
+
|
75 |
+
disease_probs, negative_disease_probs = inference_model.get_diseases_probs(diseases_to_predict, pos_probs=probs, negative_probs=negative_probs)
|
76 |
+
|
77 |
+
model_output = {}
|
78 |
+
for idx, disease in enumerate(diseases_to_predict.keys()):
|
79 |
+
model_output[disease] = {
|
80 |
+
'overall_probability': disease_probs[disease],
|
81 |
+
'descriptor_probabilities': {descriptor: probs[f'{descriptor} indicating {disease}'].item() for descriptor in
|
82 |
+
diseases_to_predict[disease]}
|
83 |
+
}
|
84 |
+
|
85 |
+
file_path = plot_bars(model_output)
|
86 |
+
return file_path
|
87 |
+
|
88 |
+
|
89 |
+
# Define the function you want to wrap
|
90 |
+
def process_input(image_path, prompt_names: list, disease_name: str, descriptors: str):
|
91 |
+
diseases_to_predict = {}
|
92 |
+
|
93 |
+
for prompt in prompt_names:
|
94 |
+
if prompt == 'Custom':
|
95 |
+
diseases_to_predict[disease_name] = descriptors.split('\n')
|
96 |
+
else:
|
97 |
+
if prompt in disease_descriptors_chexpert:
|
98 |
+
diseases_to_predict[prompt] = disease_descriptors_chexpert[prompt]
|
99 |
+
else: # only chestxray14
|
100 |
+
diseases_to_predict[prompt] = disease_descriptors_chestxray14[prompt]
|
101 |
+
|
102 |
+
# classify
|
103 |
+
model = InferenceModel()
|
104 |
+
output = classify_image(model, image_path, diseases_to_predict)
|
105 |
+
|
106 |
+
return output
|
107 |
+
|
108 |
+
|
109 |
+
# Define the Gradio interface
|
110 |
+
iface = gr.Interface(
|
111 |
+
fn=process_input,
|
112 |
+
examples = [['examples/enlarged_cardiomediastinum.jpg', ['Enlarged Cardiomediastinum'], '', ''],['examples/edema.jpg', ['Edema'], '', ''],
|
113 |
+
['examples/support_devices.jpg', ['Custom'], 'Pacemaker', 'metalic object\nimplant on the left side of the chest\nimplanted cardiac device']],
|
114 |
+
inputs=[gr.inputs.Image(type="filepath"), gr.inputs.CheckboxGroup(
|
115 |
+
choices=['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
|
116 |
+
'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices',
|
117 |
+
'Infiltration', 'Mass', 'Nodule', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia',
|
118 |
+
'Custom'],
|
119 |
+
default=['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
|
120 |
+
'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'],
|
121 |
+
label='Selct to use predefined disease descriptors. Select "Custom" to define your own observations.'),
|
122 |
+
gr.inputs.Textbox(lines=2, placeholder="Name of pathology for which you want to define custom observations", label='Pathology:'),
|
123 |
+
gr.inputs.Textbox(lines=2, placeholder="Add your custom (positive) observations separated by a new line"
|
124 |
+
"\n Note: Each descriptor will automatically be embedded into our prompt format: There is/are (no) <observation> indicating <pathology>"
|
125 |
+
"\n Example:\n\n Opacity\nPleural Effusion\nConsolidation"
|
126 |
+
, label='Custom Observations:')],
|
127 |
+
outputs=gr.outputs.Image(type="filepath")
|
128 |
+
)
|
129 |
+
|
130 |
+
# Launch the interface
|
131 |
+
iface.launch()
|
descriptors.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
disease_descriptors_chexpert = {
|
2 |
+
"No Finding": [
|
3 |
+
"Clear lung fields",
|
4 |
+
"Normal heart size and shape",
|
5 |
+
"No Abnormal fluid buildup",
|
6 |
+
"No Visible tumors or masses",
|
7 |
+
"No Signs of bone fractures or dislocations"
|
8 |
+
],
|
9 |
+
"Enlarged Cardiomediastinum": [
|
10 |
+
"Increased width of the heart shadow",
|
11 |
+
"Widened mediastinum",
|
12 |
+
"Abnormal contour of the heart border",
|
13 |
+
"Fluid or air within the pericardium",
|
14 |
+
"Mass within the mediastinum",
|
15 |
+
],
|
16 |
+
"Cardiomegaly": [
|
17 |
+
"Increased size of the heart shadow",
|
18 |
+
"Enlargement of the heart silhouette",
|
19 |
+
"Increased diameter of the heart border",
|
20 |
+
"Increased cardiothoracic ratio",
|
21 |
+
],
|
22 |
+
"Lung Opacity": [
|
23 |
+
"Increased density in the lung field",
|
24 |
+
"Whitish or grayish area in the lung field",
|
25 |
+
"Obscured or blurred margins of the lung field",
|
26 |
+
"Loss of normal lung markings within the opacity",
|
27 |
+
"Air bronchograms within the opacity",
|
28 |
+
"Fluid levels within the opacity",
|
29 |
+
"Silhouette sign loss with adjacent structures",
|
30 |
+
|
31 |
+
],
|
32 |
+
"Lung Lesion": [
|
33 |
+
"Consolidation of lung tissue",
|
34 |
+
"Pleural effusion",
|
35 |
+
"Cavities or abscesses in the lung",
|
36 |
+
"Abnormal opacity or shadow in the lung",
|
37 |
+
"Irregular or blurred margins of the lung",
|
38 |
+
|
39 |
+
],
|
40 |
+
"Edema": [
|
41 |
+
"Blurry vascular markings in the lungs",
|
42 |
+
"Enlarged heart",
|
43 |
+
"Kerley B lines",
|
44 |
+
"Increased interstitial markings in the lungs",
|
45 |
+
"Widening of interstitial spaces",
|
46 |
+
],
|
47 |
+
"Consolidation": [
|
48 |
+
"Loss of lung volume",
|
49 |
+
"Increased density of lung tissue",
|
50 |
+
"Obliteration of the diaphragmatic silhouette",
|
51 |
+
"Presence of opacities",
|
52 |
+
],
|
53 |
+
"Pneumonia": [
|
54 |
+
"Consolidation of lung tissue",
|
55 |
+
"Air bronchograms",
|
56 |
+
"Cavitation",
|
57 |
+
"Interstitial opacities",
|
58 |
+
],
|
59 |
+
"Atelectasis": [
|
60 |
+
"Increased opacity",
|
61 |
+
"Volume loss of the affected lung region",
|
62 |
+
"Blunting of the costophrenic angle",
|
63 |
+
"Shifting of the mediastinum",
|
64 |
+
],
|
65 |
+
"Pneumothorax": [
|
66 |
+
"Tracheal deviation",
|
67 |
+
"Deep sulcus sign",
|
68 |
+
"Increased radiolucency",
|
69 |
+
"Flattening of the hemidiaphragm",
|
70 |
+
"Absence of lung markings",
|
71 |
+
"Shifting of the mediastinum"
|
72 |
+
],
|
73 |
+
"Pleural Effusion": [
|
74 |
+
"Blunting of costophrenic angles",
|
75 |
+
"Opacity in the lower lung fields",
|
76 |
+
"Mediastinal shift",
|
77 |
+
"Reduced lung volume",
|
78 |
+
"Presence of meniscus sign or veil-like appearance"
|
79 |
+
],
|
80 |
+
"Pleural Other": [
|
81 |
+
"Pleural thickening",
|
82 |
+
"Pleural calcification",
|
83 |
+
"Pleural masses or nodules",
|
84 |
+
"Pleural empyema",
|
85 |
+
"Pleural fibrosis",
|
86 |
+
"Pleural adhesions"
|
87 |
+
],
|
88 |
+
"Fracture": [
|
89 |
+
"Visible breaks in the continuity of the bone",
|
90 |
+
"Misalignments of bone fragments",
|
91 |
+
"Displacements of bone fragments",
|
92 |
+
"Disruptions of the cortex or outer layer of the bone",
|
93 |
+
"Visible callus or healing tissue",
|
94 |
+
"Fracture lines that are jagged or irregular in shape",
|
95 |
+
"Multiple fracture lines that intersect at different angles"
|
96 |
+
],
|
97 |
+
"Support Devices": [
|
98 |
+
"Artificial joints or implants",
|
99 |
+
"Pacemakers or cardiac devices",
|
100 |
+
"Stents or other vascular devices",
|
101 |
+
"Prosthetic devices or limbs",
|
102 |
+
"Breast implants",
|
103 |
+
"Radiotherapy markers or seeds"
|
104 |
+
]
|
105 |
+
}
|
106 |
+
|
107 |
+
disease_descriptors_chestxray14 = {
|
108 |
+
|
109 |
+
"No Finding": ["No Finding"],
|
110 |
+
"Cardiomegaly": [
|
111 |
+
"Increased size of the heart shadow",
|
112 |
+
"Enlargement of the heart silhouette",
|
113 |
+
"Increased diameter of the heart border",
|
114 |
+
"Increased cardiothoracic ratio"
|
115 |
+
],
|
116 |
+
"Edema": [
|
117 |
+
"Blurry vascular markings in the lungs",
|
118 |
+
"Kerley B lines",
|
119 |
+
"Increased interstitial markings in the lungs",
|
120 |
+
"Widening of interstitial spaces"
|
121 |
+
],
|
122 |
+
"Consolidation": [
|
123 |
+
"Loss of lung volume",
|
124 |
+
"Increased density of lung tissue",
|
125 |
+
"Obliteration of the diaphragmatic silhouette",
|
126 |
+
"Presence of opacities"
|
127 |
+
],
|
128 |
+
"Pneumonia": [
|
129 |
+
"Consolidation of lung tissue",
|
130 |
+
"Air bronchograms",
|
131 |
+
"Cavitation",
|
132 |
+
"Interstitial opacities"
|
133 |
+
],
|
134 |
+
"Atelectasis": [
|
135 |
+
"Increased opacity",
|
136 |
+
"Volume loss of the affected lung region",
|
137 |
+
"Displacement of the diaphragm",
|
138 |
+
"Blunting of the costophrenic angle",
|
139 |
+
"Shifting of the mediastinum"
|
140 |
+
],
|
141 |
+
"Pneumothorax": [
|
142 |
+
"Tracheal deviation",
|
143 |
+
"Deep sulcus sign",
|
144 |
+
"Increased radiolucency",
|
145 |
+
"Flattening of the hemidiaphragm",
|
146 |
+
"Absence of lung markings",
|
147 |
+
"Shifting of the mediastinum"
|
148 |
+
],
|
149 |
+
"Pleural Effusion": [
|
150 |
+
"Blunting of costophrenic angles",
|
151 |
+
"Opacity in the lower lung fields",
|
152 |
+
"Mediastinal shift",
|
153 |
+
"Reduced lung volume",
|
154 |
+
"Meniscus sign or veil-like appearance"
|
155 |
+
],
|
156 |
+
"Infiltration": [
|
157 |
+
"Irregular or fuzzy borders around white areas",
|
158 |
+
"Blurring",
|
159 |
+
"Hazy or cloudy areas",
|
160 |
+
"Increased density or opacity of lung tissue",
|
161 |
+
"Air bronchograms",
|
162 |
+
],
|
163 |
+
"Mass": [
|
164 |
+
"Calcifications or mineralizations",
|
165 |
+
"Shadowing",
|
166 |
+
"Distortion or compression of tissues",
|
167 |
+
"Anomalous structure or irregularity in shape"
|
168 |
+
],
|
169 |
+
"Nodule": [
|
170 |
+
"Nodular shape that protrudes into a cavity or airway",
|
171 |
+
"Distinct edges or borders",
|
172 |
+
"Calcifications or speckled areas",
|
173 |
+
"Small round oral shaped spots",
|
174 |
+
"White shadows"
|
175 |
+
],
|
176 |
+
"Emphysema": [
|
177 |
+
"Flattened hemidiaphragm",
|
178 |
+
"Pulmonary bullae",
|
179 |
+
"Hyperlucent lungs",
|
180 |
+
"Horizontalisation of ribs",
|
181 |
+
"Barrel Chest",
|
182 |
+
],
|
183 |
+
"Fibrosis": [
|
184 |
+
"Reticular shadowing of the lung peripheries",
|
185 |
+
"Volume loss",
|
186 |
+
"Thickened and irregular interstitial markings",
|
187 |
+
"Bronchial dilation",
|
188 |
+
"Shaggy heart borders"
|
189 |
+
],
|
190 |
+
"Pleural Thickening": [
|
191 |
+
"Thickened pleural line",
|
192 |
+
"Loss of sharpness of the mediastinal border",
|
193 |
+
"Calcifications on the pleura",
|
194 |
+
"Lobulated peripheral shadowing",
|
195 |
+
"Loss of lung volume",
|
196 |
+
],
|
197 |
+
"Hernia": [
|
198 |
+
"Bulge or swelling in the abdominal wall",
|
199 |
+
"Protrusion of intestine or other abdominal tissue",
|
200 |
+
"Swelling or enlargement of the herniated sac or surrounding tissues",
|
201 |
+
"Retro-cardiac air-fluid level",
|
202 |
+
"Thickening of intestinal folds"
|
203 |
+
]
|
204 |
+
}
|
examples/edema.jpg
ADDED
![]() |
examples/enlarged_cardiomediastinum.jpg
ADDED
![]() |
examples/support_devices.jpg
ADDED
![]() |
flagged/image_path/tmp4gt9tvuq.png
ADDED
![]() |
flagged/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
image_path,"Selct to use predefined disease descriptors. Select ""Custom"" to define your own observations.",Pathology:,Custom Observations:,output,flag,username,timestamp
|
2 |
+
/Users/chantal/Documents/programmieren_gitbackuped/Xplainer/flagged/image_path/tmp4gt9tvuq.png,['Cardiomegaly'],,,/Users/chantal/Documents/programmieren_gitbackuped/Xplainer/flagged/output/tmpy60f1_o9.png,,,2023-06-27 18:09:42.017441
|
flagged/output/tmpy60f1_o9.png
ADDED
![]() |
inference.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import gc
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from chestxray14 import ChestXray14Dataset
|
10 |
+
from chexpert import CheXpertDataset
|
11 |
+
from descriptors import disease_descriptors_chexpert, disease_descriptors_chestxray14
|
12 |
+
from model import InferenceModel
|
13 |
+
from utils import calculate_auroc
|
14 |
+
|
15 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
16 |
+
|
17 |
+
|
18 |
+
def inference_chexpert():
|
19 |
+
split = 'test'
|
20 |
+
dataset = CheXpertDataset(f'data/chexpert/{split}_labels.csv') # also do test
|
21 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x, num_workers=0)
|
22 |
+
inference_model = InferenceModel()
|
23 |
+
all_descriptors = inference_model.get_all_descriptors(disease_descriptors_chexpert)
|
24 |
+
|
25 |
+
all_labels = []
|
26 |
+
all_probs_neg = []
|
27 |
+
|
28 |
+
for batch in tqdm(dataloader):
|
29 |
+
batch = batch[0]
|
30 |
+
image_paths, labels, keys = batch
|
31 |
+
image_paths = [Path(image_path) for image_path in image_paths]
|
32 |
+
agg_probs = []
|
33 |
+
agg_negative_probs = []
|
34 |
+
for image_path in image_paths:
|
35 |
+
probs, negative_probs = inference_model.get_descriptor_probs(image_path, descriptors=all_descriptors)
|
36 |
+
agg_probs.append(probs)
|
37 |
+
agg_negative_probs.append(negative_probs)
|
38 |
+
probs = {} # Aggregated
|
39 |
+
negative_probs = {} # Aggregated
|
40 |
+
for key in agg_probs[0].keys():
|
41 |
+
probs[key] = sum([p[key] for p in agg_probs]) / len(agg_probs) # Mean Aggregation
|
42 |
+
|
43 |
+
for key in agg_negative_probs[0].keys():
|
44 |
+
negative_probs[key] = sum([p[key] for p in agg_negative_probs]) / len(agg_negative_probs) # Mean Aggregation
|
45 |
+
|
46 |
+
disease_probs, negative_disease_probs = inference_model.get_diseases_probs(disease_descriptors_chexpert, pos_probs=probs,
|
47 |
+
negative_probs=negative_probs)
|
48 |
+
predicted_diseases, prob_vector_neg_prompt = inference_model.get_predictions_bin_prompting(disease_descriptors_chexpert,
|
49 |
+
disease_probs=disease_probs,
|
50 |
+
negative_disease_probs=negative_disease_probs,
|
51 |
+
keys=keys)
|
52 |
+
all_labels.append(labels)
|
53 |
+
all_probs_neg.append(prob_vector_neg_prompt)
|
54 |
+
|
55 |
+
all_labels = torch.stack(all_labels)
|
56 |
+
all_probs_neg = torch.stack(all_probs_neg)
|
57 |
+
|
58 |
+
# evaluation
|
59 |
+
existing_mask = sum(all_labels, 0) > 0
|
60 |
+
all_labels_clean = all_labels[:, existing_mask]
|
61 |
+
all_probs_neg_clean = all_probs_neg[:, existing_mask]
|
62 |
+
all_keys_clean = [key for idx, key in enumerate(keys) if existing_mask[idx]]
|
63 |
+
|
64 |
+
overall_auroc, per_disease_auroc = calculate_auroc(all_probs_neg_clean, all_labels_clean)
|
65 |
+
print(f"AUROC: {overall_auroc:.5f}\n")
|
66 |
+
for idx, key in enumerate(all_keys_clean):
|
67 |
+
print(f'{key}: {per_disease_auroc[idx]:.5f}')
|
68 |
+
|
69 |
+
|
70 |
+
def inference_chestxray14():
|
71 |
+
dataset = ChestXray14Dataset(f'data/chestxray14/Data_Entry_2017_v2020_modified.csv')
|
72 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x, num_workers=1)
|
73 |
+
inference_model = InferenceModel()
|
74 |
+
all_descriptors = inference_model.get_all_descriptors(disease_descriptors_chestxray14)
|
75 |
+
|
76 |
+
all_labels = []
|
77 |
+
all_probs_neg = []
|
78 |
+
for batch in tqdm(dataloader):
|
79 |
+
batch = batch[0]
|
80 |
+
image_path, labels, keys = batch
|
81 |
+
image_path = Path(image_path)
|
82 |
+
probs, negative_probs = inference_model.get_descriptor_probs(image_path, descriptors=all_descriptors)
|
83 |
+
disease_probs, negative_disease_probs = inference_model.get_diseases_probs(disease_descriptors_chestxray14, pos_probs=probs,
|
84 |
+
negative_probs=negative_probs)
|
85 |
+
predicted_diseases, prob_vector_neg_prompt = inference_model.get_predictions_bin_prompting(disease_descriptors_chestxray14,
|
86 |
+
disease_probs=disease_probs,
|
87 |
+
negative_disease_probs=negative_disease_probs,
|
88 |
+
keys=keys)
|
89 |
+
all_labels.append(labels)
|
90 |
+
all_probs_neg.append(prob_vector_neg_prompt)
|
91 |
+
gc.collect()
|
92 |
+
|
93 |
+
all_labels = torch.stack(all_labels)
|
94 |
+
all_probs_neg = torch.stack(all_probs_neg)
|
95 |
+
|
96 |
+
existing_mask = sum(all_labels, 0) > 0
|
97 |
+
all_labels_clean = all_labels[:, existing_mask]
|
98 |
+
all_probs_neg_clean = all_probs_neg[:, existing_mask]
|
99 |
+
all_keys_clean = [key for idx, key in enumerate(keys) if existing_mask[idx]]
|
100 |
+
|
101 |
+
overall_auroc, per_disease_auroc = calculate_auroc(all_probs_neg_clean[:, 1:], all_labels_clean[:, 1:])
|
102 |
+
print(f"AUROC: {overall_auroc:.5f}\n")
|
103 |
+
for idx, key in enumerate(all_keys_clean[1:]):
|
104 |
+
print(f'{key}: {per_disease_auroc[idx]:.5f}')
|
105 |
+
|
106 |
+
|
107 |
+
if __name__ == '__main__':
|
108 |
+
# add argument parser
|
109 |
+
parser = argparse.ArgumentParser()
|
110 |
+
parser.add_argument('--dataset', type=str, default='chexpert', help='chexpert or chestxray14')
|
111 |
+
args = parser.parse_args()
|
112 |
+
|
113 |
+
if args.dataset == 'chexpert':
|
114 |
+
inference_chexpert()
|
115 |
+
elif args.dataset == 'chestxray14':
|
116 |
+
inference_chestxray14()
|
model.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from health_multimodal.image import get_biovil_resnet_inference
|
7 |
+
from health_multimodal.text import get_cxr_bert_inference
|
8 |
+
from health_multimodal.vlp import ImageTextInferenceEngine
|
9 |
+
|
10 |
+
from utils import cos_sim_to_prob, prob_to_log_prob, log_prob_to_prob
|
11 |
+
|
12 |
+
|
13 |
+
class InferenceModel():
|
14 |
+
def __init__(self):
|
15 |
+
self.text_inference = get_cxr_bert_inference()
|
16 |
+
self.image_inference = get_biovil_resnet_inference()
|
17 |
+
self.image_text_inference = ImageTextInferenceEngine(
|
18 |
+
image_inference_engine=self.image_inference,
|
19 |
+
text_inference_engine=self.text_inference,
|
20 |
+
)
|
21 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
self.image_text_inference.to(self.device)
|
23 |
+
|
24 |
+
# caches for faster inference
|
25 |
+
self.text_embedding_cache = {}
|
26 |
+
self.image_embedding_cache = {}
|
27 |
+
|
28 |
+
self.transform = self.image_inference.transform
|
29 |
+
|
30 |
+
def get_similarity_score_from_raw_data(self, image_embedding, query_text: str) -> float:
|
31 |
+
"""Compute the cosine similarity score between an image and one or more strings.
|
32 |
+
If multiple strings are passed, their embeddings are averaged before L2-normalization.
|
33 |
+
:param image_path: Path to the input chest X-ray, either a DICOM or JPEG file.
|
34 |
+
:param query_text: Input radiology text phrase.
|
35 |
+
:return: The similarity score between the image and the text.
|
36 |
+
"""
|
37 |
+
assert not self.image_text_inference.image_inference_engine.model.training
|
38 |
+
assert not self.image_text_inference.text_inference_engine.model.training
|
39 |
+
if query_text in self.text_embedding_cache:
|
40 |
+
text_embedding = self.text_embedding_cache[query_text]
|
41 |
+
else:
|
42 |
+
text_embedding = self.image_text_inference.text_inference_engine.get_embeddings_from_prompt([query_text], normalize=False)
|
43 |
+
text_embedding = text_embedding.mean(dim=0)
|
44 |
+
text_embedding = F.normalize(text_embedding, dim=0, p=2)
|
45 |
+
self.text_embedding_cache[query_text] = text_embedding
|
46 |
+
|
47 |
+
cos_similarity = image_embedding @ text_embedding.t()
|
48 |
+
|
49 |
+
return cos_similarity.item()
|
50 |
+
|
51 |
+
def process_image(self, image):
|
52 |
+
''' same code as in image_text_inference.image_inference_engine.get_projected_global_embedding() but adapted to deal with image instances instead of path'''
|
53 |
+
|
54 |
+
transformed_image = self.transform(image)
|
55 |
+
projected_img_emb = self.image_inference.model.forward(transformed_image).projected_global_embedding
|
56 |
+
projected_img_emb = F.normalize(projected_img_emb, dim=-1)
|
57 |
+
assert projected_img_emb.shape[0] == 1
|
58 |
+
assert projected_img_emb.ndim == 2
|
59 |
+
return projected_img_emb[0]
|
60 |
+
|
61 |
+
def get_descriptor_probs(self, image_path: Path, descriptors: List[str], do_negative_prompting=True, demo=False):
|
62 |
+
probs = {}
|
63 |
+
negative_probs = {}
|
64 |
+
if image_path in self.image_embedding_cache:
|
65 |
+
image_embedding = self.image_embedding_cache[image_path]
|
66 |
+
else:
|
67 |
+
image_embedding = self.image_text_inference.image_inference_engine.get_projected_global_embedding(image_path)
|
68 |
+
if not demo:
|
69 |
+
self.image_embedding_cache[image_path] = image_embedding
|
70 |
+
|
71 |
+
# Default get_similarity_score_from_raw_data would load the image every time. Instead we only load once.
|
72 |
+
for desc in descriptors:
|
73 |
+
prompt = f'There are {desc}'
|
74 |
+
score = self.get_similarity_score_from_raw_data(image_embedding, prompt)
|
75 |
+
if do_negative_prompting:
|
76 |
+
neg_prompt = f'There are no {desc}'
|
77 |
+
neg_score = self.get_similarity_score_from_raw_data(image_embedding, neg_prompt)
|
78 |
+
|
79 |
+
pos_prob = cos_sim_to_prob(score)
|
80 |
+
|
81 |
+
if do_negative_prompting:
|
82 |
+
pos_prob, neg_prob = torch.softmax((torch.tensor([score, neg_score]) / 0.5), dim=0)
|
83 |
+
negative_probs[desc] = neg_prob
|
84 |
+
|
85 |
+
probs[desc] = pos_prob
|
86 |
+
|
87 |
+
return probs, negative_probs
|
88 |
+
|
89 |
+
def get_all_descriptors(self, disease_descriptors):
|
90 |
+
all_descriptors = set()
|
91 |
+
for disease, descs in disease_descriptors.items():
|
92 |
+
all_descriptors.update([f"{desc} indicating {disease}" for desc in descs])
|
93 |
+
all_descriptors = sorted(all_descriptors)
|
94 |
+
return all_descriptors
|
95 |
+
|
96 |
+
def get_all_descriptors_only_disease(self, disease_descriptors):
|
97 |
+
all_descriptors = set()
|
98 |
+
for disease, descs in disease_descriptors.items():
|
99 |
+
all_descriptors.update([f"{desc}" for desc in descs])
|
100 |
+
all_descriptors = sorted(all_descriptors)
|
101 |
+
return all_descriptors
|
102 |
+
|
103 |
+
def get_diseases_probs(self, disease_descriptors, pos_probs, negative_probs, prior_probs=None, do_negative_prompting=True):
|
104 |
+
disease_probs = {}
|
105 |
+
disease_neg_probs = {}
|
106 |
+
for disease, descriptors in disease_descriptors.items():
|
107 |
+
desc_log_probs = []
|
108 |
+
desc_neg_log_probs = []
|
109 |
+
for desc in descriptors:
|
110 |
+
desc = f"{desc} indicating {disease}"
|
111 |
+
desc_log_probs.append(prob_to_log_prob(pos_probs[desc]))
|
112 |
+
if do_negative_prompting:
|
113 |
+
desc_neg_log_probs.append(prob_to_log_prob(negative_probs[desc]))
|
114 |
+
disease_log_prob = sum(sorted(desc_log_probs, reverse=True)) / len(desc_log_probs)
|
115 |
+
if do_negative_prompting:
|
116 |
+
disease_neg_log_prob = sum(desc_neg_log_probs) / len(desc_neg_log_probs)
|
117 |
+
disease_probs[disease] = log_prob_to_prob(disease_log_prob)
|
118 |
+
if do_negative_prompting:
|
119 |
+
disease_neg_probs[disease] = log_prob_to_prob(disease_neg_log_prob)
|
120 |
+
|
121 |
+
return disease_probs, disease_neg_probs
|
122 |
+
|
123 |
+
# Threshold Based
|
124 |
+
def get_predictions(self, disease_descriptors, threshold, disease_probs, keys):
|
125 |
+
predicted_diseases = []
|
126 |
+
prob_vector = torch.zeros(len(keys), dtype=torch.float) # num of diseases
|
127 |
+
for idx, disease in enumerate(disease_descriptors):
|
128 |
+
if disease == 'No Finding':
|
129 |
+
continue
|
130 |
+
prob_vector[keys.index(disease)] = disease_probs[disease]
|
131 |
+
if disease_probs[disease] > threshold:
|
132 |
+
predicted_diseases.append(disease)
|
133 |
+
|
134 |
+
if len(predicted_diseases) == 0: # No finding rule based
|
135 |
+
prob_vector[0] = 1.0 - max(prob_vector)
|
136 |
+
else:
|
137 |
+
prob_vector[0] = 1.0 - max(prob_vector)
|
138 |
+
|
139 |
+
return predicted_diseases, prob_vector
|
140 |
+
|
141 |
+
# Negative vs Positive Prompting
|
142 |
+
def get_predictions_bin_prompting(self, disease_descriptors, disease_probs, negative_disease_probs, keys):
|
143 |
+
predicted_diseases = []
|
144 |
+
prob_vector = torch.zeros(len(keys), dtype=torch.float) # num of diseases
|
145 |
+
for idx, disease in enumerate(disease_descriptors):
|
146 |
+
if disease == 'No Finding':
|
147 |
+
continue
|
148 |
+
pos_neg_scores = torch.tensor([disease_probs[disease], negative_disease_probs[disease]])
|
149 |
+
prob_vector[keys.index(disease)] = pos_neg_scores[0]
|
150 |
+
if torch.argmax(pos_neg_scores) == 0: # Positive is More likely
|
151 |
+
predicted_diseases.append(disease)
|
152 |
+
|
153 |
+
if len(predicted_diseases) == 0: # No finding rule based
|
154 |
+
prob_vector[0] = 1.0 - max(prob_vector)
|
155 |
+
else:
|
156 |
+
prob_vector[0] = 1.0 - max(prob_vector)
|
157 |
+
|
158 |
+
return predicted_diseases, prob_vector
|
plot.png
ADDED
![]() |
pre-requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
hi-ml-multimodal==0.1.2
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scikit-learn==1.0.2
|
2 |
+
transformers==4.17.0
|
3 |
+
gradio==3.34.0
|
4 |
+
pandas==1.3.5
|
5 |
+
torch==1.13.0
|
utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import log, exp
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from sklearn.metrics import roc_auc_score
|
5 |
+
|
6 |
+
|
7 |
+
def cos_sim_to_prob(sim):
|
8 |
+
return (sim + 1) / 2 # linear transformation to 0 and 1
|
9 |
+
|
10 |
+
|
11 |
+
def log_prob_to_prob(log_prob):
|
12 |
+
return exp(log_prob)
|
13 |
+
|
14 |
+
|
15 |
+
def prob_to_log_prob(prob):
|
16 |
+
return log(prob)
|
17 |
+
|
18 |
+
|
19 |
+
def calculate_auroc(all_disease_probs, gt_diseases):
|
20 |
+
'''
|
21 |
+
Calculates the AUROC (Area Under the Receiver Operating Characteristic curve) for multiple diseases.
|
22 |
+
|
23 |
+
Parameters:
|
24 |
+
all_disease_probs (numpy array): predicted disease labels, a multi-hot vector of shape (N_samples, 14)
|
25 |
+
gt_diseases (numpy array): ground truth disease labels, a multi-hot vector of shape (N_samples, 14)
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
overall_auroc (float): the overall AUROC score
|
29 |
+
per_disease_auroc (numpy array): an array of shape (14,) containing the AUROC score for each disease
|
30 |
+
'''
|
31 |
+
|
32 |
+
per_disease_auroc = np.zeros((gt_diseases.shape[1],)) # num of diseases
|
33 |
+
for i in range(gt_diseases.shape[1]):
|
34 |
+
# Compute the AUROC score for each disease
|
35 |
+
per_disease_auroc[i] = roc_auc_score(gt_diseases[:, i], all_disease_probs[:, i])
|
36 |
+
|
37 |
+
# Compute the overall AUROC score
|
38 |
+
overall_auroc = roc_auc_score(gt_diseases, all_disease_probs, average='macro')
|
39 |
+
|
40 |
+
return overall_auroc, per_disease_auroc
|