Xplainer / inference.py
ChantalPellegrini's picture
first commit
06257c8
import argparse
import gc
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from chestxray14 import ChestXray14Dataset
from chexpert import CheXpertDataset
from descriptors import disease_descriptors_chexpert, disease_descriptors_chestxray14
from model import InferenceModel
from utils import calculate_auroc
torch.multiprocessing.set_sharing_strategy('file_system')
def inference_chexpert():
split = 'test'
dataset = CheXpertDataset(f'data/chexpert/{split}_labels.csv') # also do test
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x, num_workers=0)
inference_model = InferenceModel()
all_descriptors = inference_model.get_all_descriptors(disease_descriptors_chexpert)
all_labels = []
all_probs_neg = []
for batch in tqdm(dataloader):
batch = batch[0]
image_paths, labels, keys = batch
image_paths = [Path(image_path) for image_path in image_paths]
agg_probs = []
agg_negative_probs = []
for image_path in image_paths:
probs, negative_probs = inference_model.get_descriptor_probs(image_path, descriptors=all_descriptors)
agg_probs.append(probs)
agg_negative_probs.append(negative_probs)
probs = {} # Aggregated
negative_probs = {} # Aggregated
for key in agg_probs[0].keys():
probs[key] = sum([p[key] for p in agg_probs]) / len(agg_probs) # Mean Aggregation
for key in agg_negative_probs[0].keys():
negative_probs[key] = sum([p[key] for p in agg_negative_probs]) / len(agg_negative_probs) # Mean Aggregation
disease_probs, negative_disease_probs = inference_model.get_diseases_probs(disease_descriptors_chexpert, pos_probs=probs,
negative_probs=negative_probs)
predicted_diseases, prob_vector_neg_prompt = inference_model.get_predictions_bin_prompting(disease_descriptors_chexpert,
disease_probs=disease_probs,
negative_disease_probs=negative_disease_probs,
keys=keys)
all_labels.append(labels)
all_probs_neg.append(prob_vector_neg_prompt)
all_labels = torch.stack(all_labels)
all_probs_neg = torch.stack(all_probs_neg)
# evaluation
existing_mask = sum(all_labels, 0) > 0
all_labels_clean = all_labels[:, existing_mask]
all_probs_neg_clean = all_probs_neg[:, existing_mask]
all_keys_clean = [key for idx, key in enumerate(keys) if existing_mask[idx]]
overall_auroc, per_disease_auroc = calculate_auroc(all_probs_neg_clean, all_labels_clean)
print(f"AUROC: {overall_auroc:.5f}\n")
for idx, key in enumerate(all_keys_clean):
print(f'{key}: {per_disease_auroc[idx]:.5f}')
def inference_chestxray14():
dataset = ChestXray14Dataset(f'data/chestxray14/Data_Entry_2017_v2020_modified.csv')
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x, num_workers=1)
inference_model = InferenceModel()
all_descriptors = inference_model.get_all_descriptors(disease_descriptors_chestxray14)
all_labels = []
all_probs_neg = []
for batch in tqdm(dataloader):
batch = batch[0]
image_path, labels, keys = batch
image_path = Path(image_path)
probs, negative_probs = inference_model.get_descriptor_probs(image_path, descriptors=all_descriptors)
disease_probs, negative_disease_probs = inference_model.get_diseases_probs(disease_descriptors_chestxray14, pos_probs=probs,
negative_probs=negative_probs)
predicted_diseases, prob_vector_neg_prompt = inference_model.get_predictions_bin_prompting(disease_descriptors_chestxray14,
disease_probs=disease_probs,
negative_disease_probs=negative_disease_probs,
keys=keys)
all_labels.append(labels)
all_probs_neg.append(prob_vector_neg_prompt)
gc.collect()
all_labels = torch.stack(all_labels)
all_probs_neg = torch.stack(all_probs_neg)
existing_mask = sum(all_labels, 0) > 0
all_labels_clean = all_labels[:, existing_mask]
all_probs_neg_clean = all_probs_neg[:, existing_mask]
all_keys_clean = [key for idx, key in enumerate(keys) if existing_mask[idx]]
overall_auroc, per_disease_auroc = calculate_auroc(all_probs_neg_clean[:, 1:], all_labels_clean[:, 1:])
print(f"AUROC: {overall_auroc:.5f}\n")
for idx, key in enumerate(all_keys_clean[1:]):
print(f'{key}: {per_disease_auroc[idx]:.5f}')
if __name__ == '__main__':
# add argument parser
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='chexpert', help='chexpert or chestxray14')
args = parser.parse_args()
if args.dataset == 'chexpert':
inference_chexpert()
elif args.dataset == 'chestxray14':
inference_chestxray14()