|
import argparse |
|
import gc |
|
from pathlib import Path |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from chestxray14 import ChestXray14Dataset |
|
from chexpert import CheXpertDataset |
|
from descriptors import disease_descriptors_chexpert, disease_descriptors_chestxray14 |
|
from model import InferenceModel |
|
from utils import calculate_auroc |
|
|
|
torch.multiprocessing.set_sharing_strategy('file_system') |
|
|
|
|
|
def inference_chexpert(): |
|
split = 'test' |
|
dataset = CheXpertDataset(f'data/chexpert/{split}_labels.csv') |
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x, num_workers=0) |
|
inference_model = InferenceModel() |
|
all_descriptors = inference_model.get_all_descriptors(disease_descriptors_chexpert) |
|
|
|
all_labels = [] |
|
all_probs_neg = [] |
|
|
|
for batch in tqdm(dataloader): |
|
batch = batch[0] |
|
image_paths, labels, keys = batch |
|
image_paths = [Path(image_path) for image_path in image_paths] |
|
agg_probs = [] |
|
agg_negative_probs = [] |
|
for image_path in image_paths: |
|
probs, negative_probs = inference_model.get_descriptor_probs(image_path, descriptors=all_descriptors) |
|
agg_probs.append(probs) |
|
agg_negative_probs.append(negative_probs) |
|
probs = {} |
|
negative_probs = {} |
|
for key in agg_probs[0].keys(): |
|
probs[key] = sum([p[key] for p in agg_probs]) / len(agg_probs) |
|
|
|
for key in agg_negative_probs[0].keys(): |
|
negative_probs[key] = sum([p[key] for p in agg_negative_probs]) / len(agg_negative_probs) |
|
|
|
disease_probs, negative_disease_probs = inference_model.get_diseases_probs(disease_descriptors_chexpert, pos_probs=probs, |
|
negative_probs=negative_probs) |
|
predicted_diseases, prob_vector_neg_prompt = inference_model.get_predictions_bin_prompting(disease_descriptors_chexpert, |
|
disease_probs=disease_probs, |
|
negative_disease_probs=negative_disease_probs, |
|
keys=keys) |
|
all_labels.append(labels) |
|
all_probs_neg.append(prob_vector_neg_prompt) |
|
|
|
all_labels = torch.stack(all_labels) |
|
all_probs_neg = torch.stack(all_probs_neg) |
|
|
|
|
|
existing_mask = sum(all_labels, 0) > 0 |
|
all_labels_clean = all_labels[:, existing_mask] |
|
all_probs_neg_clean = all_probs_neg[:, existing_mask] |
|
all_keys_clean = [key for idx, key in enumerate(keys) if existing_mask[idx]] |
|
|
|
overall_auroc, per_disease_auroc = calculate_auroc(all_probs_neg_clean, all_labels_clean) |
|
print(f"AUROC: {overall_auroc:.5f}\n") |
|
for idx, key in enumerate(all_keys_clean): |
|
print(f'{key}: {per_disease_auroc[idx]:.5f}') |
|
|
|
|
|
def inference_chestxray14(): |
|
dataset = ChestXray14Dataset(f'data/chestxray14/Data_Entry_2017_v2020_modified.csv') |
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x, num_workers=1) |
|
inference_model = InferenceModel() |
|
all_descriptors = inference_model.get_all_descriptors(disease_descriptors_chestxray14) |
|
|
|
all_labels = [] |
|
all_probs_neg = [] |
|
for batch in tqdm(dataloader): |
|
batch = batch[0] |
|
image_path, labels, keys = batch |
|
image_path = Path(image_path) |
|
probs, negative_probs = inference_model.get_descriptor_probs(image_path, descriptors=all_descriptors) |
|
disease_probs, negative_disease_probs = inference_model.get_diseases_probs(disease_descriptors_chestxray14, pos_probs=probs, |
|
negative_probs=negative_probs) |
|
predicted_diseases, prob_vector_neg_prompt = inference_model.get_predictions_bin_prompting(disease_descriptors_chestxray14, |
|
disease_probs=disease_probs, |
|
negative_disease_probs=negative_disease_probs, |
|
keys=keys) |
|
all_labels.append(labels) |
|
all_probs_neg.append(prob_vector_neg_prompt) |
|
gc.collect() |
|
|
|
all_labels = torch.stack(all_labels) |
|
all_probs_neg = torch.stack(all_probs_neg) |
|
|
|
existing_mask = sum(all_labels, 0) > 0 |
|
all_labels_clean = all_labels[:, existing_mask] |
|
all_probs_neg_clean = all_probs_neg[:, existing_mask] |
|
all_keys_clean = [key for idx, key in enumerate(keys) if existing_mask[idx]] |
|
|
|
overall_auroc, per_disease_auroc = calculate_auroc(all_probs_neg_clean[:, 1:], all_labels_clean[:, 1:]) |
|
print(f"AUROC: {overall_auroc:.5f}\n") |
|
for idx, key in enumerate(all_keys_clean[1:]): |
|
print(f'{key}: {per_disease_auroc[idx]:.5f}') |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--dataset', type=str, default='chexpert', help='chexpert or chestxray14') |
|
args = parser.parse_args() |
|
|
|
if args.dataset == 'chexpert': |
|
inference_chexpert() |
|
elif args.dataset == 'chestxray14': |
|
inference_chestxray14() |
|
|