"""refactored from `main` in `eval_zeroshot.py` (SLIP) for clarity. |
""" |
import random |
import torch |
import json |
import os |
from tqdm import tqdm |
from sklearn import metrics |
from constants import RSNA_CLASS_PROMPTS_webdataset, modality_indices_radimagenet_test_set |
from collections import defaultdict |
def load_metadata(metadir="clipeval"): |
with open(os.path.join(metadir, 'dataset_catalog.json')) as f: |
catalog = json.load(f) |
with open(os.path.join(metadir, 'templates.json')) as f: |
all_templates = json.load(f) |
with open(os.path.join(metadir, 'labels.json')) as f: |
all_labels = json.load(f) |
return catalog, all_templates, all_labels |
def evaluate(d, val_loader, templates, labels, model, tokenizer, classnorm=False): |
print('Evaluating {}'.format(d)) |
is_acc = d not in ['FGVCAircraft', 'OxfordPets', 'Caltech101', 'Flowers102', 'Kinetics700', 'HatefulMemes'] |
if d == 'radimagenet': |
acc, us_acc, mri_acc, ct_acc = validate_zeroshot(val_loader, templates, labels, model, tokenizer, |
is_acc, d, classnorm) |
else: |
acc_or_outputs = validate_zeroshot(val_loader, templates, labels, model, tokenizer, is_acc, d, classnorm) |
if d in ['FGVCAircraft', 'OxfordPets', 'Caltech101', 'Flowers102']: |
metric = mean_per_class(*acc_or_outputs) |
elif d == 'Kinetics700': |
top1, top5 = accuracy(*acc_or_outputs, topk=(1, 5)) |
metric = (top1 + top5) / 2 |
metric = metric.item() |
elif d == 'HatefulMemes': |
metric = roc_auc(*acc_or_outputs) |
elif d == 'radimagenet': |
metric = {"acc": acc, "US acc": us_acc, "MRI acc": mri_acc, "CT acc": ct_acc} |
else: |
metric = acc_or_outputs |
return metric |
@torch.no_grad() |
def build_text_features(templates, labels, model, tokenizer, skip_text_projection=False, classnorm=False): |
text_features = [] |
if type(templates) == dict: |
class_similarities = [] |
class_names = [] |
for cls_name, cls_text in templates.items(): |
texts = tokenizer(cls_text).to(next(model.parameters()).device, non_blocking=True) |
class_embeddings = model.encode_text(texts) |
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) |
if True: |
cls_sim = class_embeddings.mean(dim=0) |
else: |
cls_sim = class_embeddings[0] |
class_similarities.append(cls_sim) |
class_names.append(cls_name) |
text_features = torch.stack(class_similarities, dim=0) |
elif type(templates) == list and templates[0] == "Meniscal abnormality detected in MRI imaging of the knee.": |
print("Encoding captions for RadImageNet dataset") |
for single_template in templates: |
texts = tokenizer(single_template).to(next(model.parameters()).device, non_blocking=True) |
class_embeddings = model.encode_text(texts) |
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) |
text_features.append(class_embeddings) |
text_features = torch.stack(text_features, dim=0).squeeze(1) |
else: |
for label in labels: |
if isinstance(label, list): |
texts = [t.format(l) for t in templates for l in label] |
else: |
texts = [t.format(label) for t in templates] |
texts = tokenizer(texts).to(next(model.parameters()).device, non_blocking=True) |
class_embeddings = model.encode_text(texts) |
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) |
class_embeddings = class_embeddings.mean(dim=0) |
text_features.append(class_embeddings) |
text_features = torch.stack(text_features, dim=0) |
mean, std = None, None |
if classnorm: |
mean, std = text_features.mean(dim=0)[None, :], text_features.std(dim=0)[None, :] |
text_features = (text_features - mean) / std |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
return text_features, mean, std |
def generate_chexpert_class_prompts(class_prompts, n=None): |
"""Generate text prompts for each CheXpert classification task |
Parameters |
---------- |
n: int |
number of prompts per class |
Returns |
------- |
class prompts : dict |
dictionary of class to prompts |
""" |
prompts = {} |
for k, v in class_prompts.items(): |
cls_prompts = [] |
keys = list(v.keys()) |
for k0 in v[keys[0]]: |
for k1 in v[keys[1]]: |
for k2 in v[keys[2]]: |
cls_prompts.append(f"{k0} {k1} {k2}") |
if n is not None and n < len(cls_prompts): |
prompts[k] = random.sample(cls_prompts, n) |
else: |
prompts[k] = cls_prompts |
print(f'sample {len(prompts[k])} num of prompts for {k} from total {len(cls_prompts)}') |
return prompts |
def generate_rsna_class_prompts(class_prompts, n=None): |
prompts = {} |
for k, v in class_prompts.items(): |
cls_prompts = [] |
keys = list(v.keys()) |
for k0 in v[keys[0]]: |
for k1 in v[keys[1]]: |
for k2 in v[keys[2]]: |
cls_prompts.append(f"{k0} {k1} {k2}") |
if n is not None and n < len(cls_prompts): |
prompts[k] = random.sample(cls_prompts, n) |
else: |
prompts[k] = cls_prompts |
print(f'sample {len(prompts[k])} num of prompts for {k} from total {len(cls_prompts)}') |
return prompts |
@torch.no_grad() |
def validate_zeroshot(val_loader, templates, labels, model, tokenizer, is_acc, name, classnorm=False): |
model.cuda() |
model.eval() |
total_top1 = 0 |
total_images = 0 |
all_outputs = [] |
all_targets = [] |
text_features = None |
class_correct = defaultdict(int) |
class_total = defaultdict(int) |
for samples in tqdm(val_loader): |
if text_features is None: |
print('=> encoding captions') |
if name == "chexpert-5x200": |
if not type(templates[list(templates.keys())[0]]) == list: |
prompted_templates = generate_chexpert_class_prompts(templates, 10) |
else: |
k = 11 |
print(f"Using {k - 1} templates for the ensembling at test time") |
for single_key in templates.keys(): |
templates[single_key] = templates[single_key][0:k] |
prompted_templates = templates |
text_features, mean, std = build_text_features(prompted_templates, None, model, tokenizer, |
classnorm=classnorm) |
elif name == "rsna_pneumonia": |
if not type(templates[list(templates.keys())[0]]) == list: |
temp = generate_rsna_class_prompts(templates, 10) |
prompted_templates = {'normal': RSNA_CLASS_PROMPTS_webdataset['Normal'], |
'pneumonia': temp['Pneumonia']} |
else: |
k = 1 |
print(f"Using {k - 1} templates for the ensembling at test time") |
for single_key in templates.keys(): |
templates[single_key] = templates[single_key][0:k] |
prompted_templates = templates |
text_features, mean, std = build_text_features(prompted_templates, None, model, tokenizer, |
classnorm=classnorm) |
else: |
if type(templates) == dict: |
k = 11 |
print(f"Using {k - 1} templates for the ensembling at test time") |
for single_key in templates.keys(): |
length = len(templates[single_key]) |
templates[single_key] = templates[single_key][0:length] |
prompted_templates = templates |
else: |
prompted_templates = templates |
text_features, mean, std = build_text_features(prompted_templates, labels, model, tokenizer, |
classnorm=classnorm) |
if isinstance(samples, tuple) or isinstance(samples, list): |
images, target = samples[0], samples[1] |
elif isinstance(samples, dict): |
images, target = samples["pixel_values"], samples["targets"] |
else: |
raise ValueError("unknown sample type", type(samples)) |
images = images.cuda(non_blocking=True) |
target = target.cuda(non_blocking=True) |
image_features = model.encode_image(images) |
if classnorm: |
image_features = (image_features - mean) / std |
print("no normalizing this time)") |
logits_per_image = image_features @ text_features.t() |
logits_per_image = logits_per_image.cpu() |
target = target.cpu() |
if name == "chexpert-5x200": |
target = torch.argmax(target, axis=1) |
if is_acc: |
pred = logits_per_image.argmax(dim=1) |
correct = pred.eq(target).sum() |
total_top1 += correct.item() |
total_images += images.size(0) |
if name == "radimagenet": |
for t, p in zip(target, pred): |
class_correct[t.item()] += p.eq(t).item() |
class_total[t.item()] += 1 |
all_outputs.append(logits_per_image) |
all_targets.append(target) |
else: |
all_outputs.append(logits_per_image) |
all_targets.append(target) |
if is_acc: |
if name == "radimagenet": |
US_all_class_correct = 0 |
MRI_all_class_correct = 0 |
CT_all_class_correct = 0 |
US_all_class_total = 0 |
MRI_all_class_total = 0 |
CT_all_class_total = 0 |
for single_us_index in modality_indices_radimagenet_test_set['US']: |
US_all_class_correct += class_correct[single_us_index] |
US_all_class_total += class_total[single_us_index] |
for single_mri_index in modality_indices_radimagenet_test_set['MRI']: |
MRI_all_class_correct += class_correct[single_mri_index] |
MRI_all_class_total += class_total[single_mri_index] |
for single_ct_index in modality_indices_radimagenet_test_set['CT']: |
CT_all_class_correct += class_correct[single_ct_index] |
CT_all_class_total += class_total[single_ct_index] |
return 100 * total_top1 / total_images, \ |
100 * US_all_class_correct / US_all_class_total, \ |
100 * MRI_all_class_correct / MRI_all_class_total, \ |
100 * CT_all_class_correct / CT_all_class_total |
if name == 'radimagenet' or name == 'chexpert-5x200' or name == 'CT_sagittal' or name == 'CT_axial' \ |
or name == 'CT_coronal' or name == 'dr_uwf' or name == 'dr_regular' \ |
or name == 'PCAM' or name == 'LC25000_lung' or name == 'LC25000_colon' \ |
or name == "NCK_CRC" or name == 'BACH' or name == 'Osteo' \ |
or name == 'skin_cancer' or name == 'skin_tumor' or name == 'SICAPv2' \ |
or name == 'five_retina' or name == 'odir_retina': |
return 100 * total_top1 / total_images |
else: |
all_outputs = torch.cat(all_outputs) |
all_targets = torch.cat(all_targets) |
accuracy = 100 * total_top1 / total_images |
auc_roc = roc_auc(all_outputs, all_targets) |
f1_score = F1_score(all_outputs, all_targets) |
precision_score = Precision_score(all_outputs, all_targets) |
recall_score = Recall_score(all_outputs, all_targets) |
return {"acc": accuracy, "auc_roc": auc_roc, "f1_score": f1_score, |
"precision_score": precision_score, "recall_score": recall_score} |
else: |
return torch.cat(all_outputs), torch.cat(all_targets) |
def accuracy(output, target, topk=(1,)): |
"""Computes the accuracy over the k top predictions for the specified values of k""" |
with torch.no_grad(): |
maxk = max(topk) |
batch_size = target.size(0) |
_, pred = output.topk(maxk, 1, True, True) |
pred = pred.t() |
correct = pred.eq(target.reshape(1, -1).expand_as(pred)) |
res = [] |
for k in topk: |
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) |
res.append(correct_k.mul_(100.0 / batch_size)) |
return res |
def Recall_score(outputs, targets): |
pred = outputs.argmax(1) |
Recall_score = metrics.recall_score(targets, pred) |
return 100 * Recall_score |
def F1_score(outputs, targets): |
pred = outputs.argmax(1) |
F1_score = metrics.f1_score(targets, pred) |
return 100 * F1_score |
def Precision_score(outputs, targets): |
pred = outputs.argmax(1) |
Precision_score = metrics.precision_score(targets, pred) |
return 100 * Precision_score |
def mean_per_class(outputs, targets): |
pred = outputs.argmax(1) |
confusion_matrix = metrics.confusion_matrix(targets, pred) |
per_classes = confusion_matrix.diagonal() / confusion_matrix.sum(axis=1) |
return 100 * per_classes.mean() |
def roc_auc(outputs, targets): |
pos_score = outputs[:, 1] - outputs[:, 0] |
metric = metrics.roc_auc_score(targets, pos_score) |
return 100 * metric |
if __name__ == '__main__': |
logits = torch.randn(128, 10) |
targets = torch.randint(size=(128,), low=0, high=10) |
evaluate("imagenet", logits, targets) |