|
from pathlib import Path |
|
|
|
import gradio as gr |
|
import numpy as np |
|
from matplotlib import pyplot as plt |
|
|
|
from descriptors import disease_descriptors_chexpert, disease_descriptors_chestxray14 |
|
from model import InferenceModel |
|
|
|
|
|
def plot_bars(model_output): |
|
|
|
model_output = {k: v for k, v in sorted(model_output.items(), key=lambda item: item[1]['overall_probability'], reverse=True)} |
|
|
|
|
|
fig, axs = plt.subplots(len(model_output), 1, figsize=(10, 5 * len(model_output))) |
|
|
|
if len(model_output) == 1: |
|
axs = [axs] |
|
|
|
for ax, (disease, data) in zip(axs, model_output.items()): |
|
desc_probs = list(data['descriptor_probabilities'].items()) |
|
|
|
desc_probs = sorted(desc_probs, key=lambda item: item[1], reverse=True) |
|
|
|
my_probs = [p[1] for p in desc_probs] |
|
min_prob = min(my_probs) |
|
max_prob = max(my_probs) |
|
my_labels = [p[0] for p in desc_probs] |
|
|
|
|
|
diffs = np.abs(np.array(my_probs) - 0.5) |
|
|
|
|
|
colors = ['red' if p < 0.5 else 'forestgreen' for p in my_probs] |
|
|
|
|
|
left = [p if p < 0.5 else 0.5 for p in my_probs] |
|
bars = ax.barh(my_labels, diffs, left=left, color=colors, alpha=0.3) |
|
|
|
for i, bar in enumerate(bars): |
|
ax.text(min_prob - 0.04, bar.get_y() + bar.get_height() / 2, my_labels[i], ha='left', va='center', color='black', fontsize=15) |
|
|
|
ax.set_xlim(min(min_prob - 0.05, 0.49), max(max_prob + 0.05, 0.51)) |
|
|
|
|
|
ax.invert_yaxis() |
|
|
|
ax.set_yticks([]) |
|
|
|
|
|
if data['overall_probability'] >= 0.5: |
|
ax.set_title(f"{disease} : score of {data['overall_probability']:.2f}") |
|
else: |
|
ax.set_title(f"No {disease} : score of {data['overall_probability']:.2f}") |
|
|
|
|
|
ax.title.set_fontsize(15) |
|
ax.title.set_fontweight(600) |
|
|
|
|
|
plt.tight_layout() |
|
file_path = 'plot.png' |
|
plt.savefig(file_path) |
|
plt.close(fig) |
|
|
|
return file_path |
|
|
|
|
|
def classify_image(inference_model, image_path, diseases_to_predict): |
|
descriptors_with_indication = [d + " indicating " + disease for disease, descriptors in diseases_to_predict.items() for d in descriptors] |
|
probs, negative_probs = inference_model.get_descriptor_probs(image_path=Path(image_path), descriptors=descriptors_with_indication, |
|
do_negative_prompting=True, demo=True) |
|
|
|
disease_probs, negative_disease_probs = inference_model.get_diseases_probs(diseases_to_predict, pos_probs=probs, negative_probs=negative_probs) |
|
|
|
model_output = {} |
|
for idx, disease in enumerate(diseases_to_predict.keys()): |
|
model_output[disease] = { |
|
'overall_probability': disease_probs[disease], |
|
'descriptor_probabilities': {descriptor: probs[f'{descriptor} indicating {disease}'].item() for descriptor in |
|
diseases_to_predict[disease]} |
|
} |
|
|
|
file_path = plot_bars(model_output) |
|
return file_path |
|
|
|
|
|
|
|
def process_input(image_path, prompt_names: list, disease_name: str, descriptors: str): |
|
diseases_to_predict = {} |
|
|
|
for prompt in prompt_names: |
|
if prompt == 'Custom': |
|
diseases_to_predict[disease_name] = descriptors.split('\n') |
|
else: |
|
if prompt in disease_descriptors_chexpert: |
|
diseases_to_predict[prompt] = disease_descriptors_chexpert[prompt] |
|
else: |
|
diseases_to_predict[prompt] = disease_descriptors_chestxray14[prompt] |
|
|
|
|
|
model = InferenceModel() |
|
output = classify_image(model, image_path, diseases_to_predict) |
|
|
|
return output |
|
|
|
with open("article.md", "r") as f: |
|
article = f.read() |
|
with open("description.md", "r") as f: |
|
description = f.read() |
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_input, |
|
examples = [['examples/enlarged_cardiomediastinum.jpg', ['Enlarged Cardiomediastinum'], '', ''],['examples/edema.jpg', ['Edema'], '', ''], |
|
['examples/support_devices.jpg', ['Custom'], 'Pacemaker', 'metalic object\nimplant on the left side of the chest\nimplanted cardiac device']], |
|
inputs=[gr.inputs.Image(type="filepath"), gr.inputs.CheckboxGroup( |
|
choices=['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', |
|
'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices', |
|
'Infiltration', 'Mass', 'Nodule', 'Emphysema', 'Fibrosis', 'Pleural Thickening', 'Hernia', |
|
'Custom'], |
|
default=['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', |
|
'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'], |
|
label='Select to use predefined disease descriptors. Select "Custom" to define your own observations.'), |
|
gr.inputs.Textbox(lines=2, placeholder="Name of pathology for which you want to define custom observations", label='Pathology:'), |
|
gr.inputs.Textbox(lines=2, placeholder="Add your custom (positive) observations separated by a new line" |
|
"\n Note: Each descriptor will automatically be embedded into our prompt format: There is/are (no) <observation> indicating <pathology>" |
|
"\n Example:\n\n Opacity\nPleural Effusion\nConsolidation" |
|
, label='Custom Observations:')], |
|
article=article, |
|
description=description, |
|
outputs=gr.outputs.Image(type="filepath") |
|
) |
|
|
|
|
|
iface.launch() |
|
|