|
|
|
import logging |
|
from contextlib import suppress |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from tqdm import tqdm |
|
|
|
from open_clip import tokenize |
|
from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template |
|
|
|
|
|
def zero_shot_classifier(model, classnames, templates, args): |
|
with torch.no_grad(): |
|
zeroshot_weights = [] |
|
for classname in tqdm(classnames): |
|
texts = [template(classname) for template in templates] |
|
texts = tokenize(texts).to(args.device) |
|
if args.distributed and not args.horovod: |
|
class_embeddings = model.module.encode_text(texts) |
|
else: |
|
class_embeddings = model.encode_text(texts) |
|
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) |
|
class_embedding /= class_embedding.norm() |
|
zeroshot_weights.append(class_embedding) |
|
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) |
|
return zeroshot_weights |
|
|
|
|
|
def accuracy(output, target, topk=(1,)): |
|
pred = output.topk(max(topk), 1, True, True)[1].t() |
|
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
|
return [ |
|
float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) |
|
for k in topk |
|
] |
|
|
|
|
|
def run(model, classifier, dataloader, args): |
|
autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress |
|
with torch.no_grad(): |
|
top1, top5, n = 0.0, 0.0, 0.0 |
|
for images, target in tqdm(dataloader, unit_scale=args.batch_size): |
|
images = images.to(args.device) |
|
target = target.to(args.device) |
|
|
|
with autocast(): |
|
|
|
if args.distributed and not args.horovod: |
|
image_features = model.module.encode_image(images) |
|
else: |
|
image_features = model.encode_image(images) |
|
image_features = F.normalize(image_features, dim=-1) |
|
logits = 100.0 * image_features @ classifier |
|
|
|
|
|
acc1, acc5 = accuracy(logits, target, topk=(1, 5)) |
|
top1 += acc1 |
|
top5 += acc5 |
|
n += images.size(0) |
|
|
|
top1 = top1 / n |
|
top5 = top5 / n |
|
return top1, top5 |
|
|
|
|
|
def zero_shot_eval(model, data, epoch, args): |
|
if "imagenet-val" not in data and "imagenet-v2" not in data: |
|
return {} |
|
if args.zeroshot_frequency == 0: |
|
return {} |
|
if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: |
|
return {} |
|
|
|
logging.info("Starting zero-shot imagenet.") |
|
|
|
logging.info("Building zero-shot classifier") |
|
classifier = zero_shot_classifier( |
|
model, imagenet_classnames, openai_imagenet_template, args |
|
) |
|
|
|
logging.info("Using classifier") |
|
results = {} |
|
if "imagenet-val" in data: |
|
top1, top5 = run(model, classifier, data["imagenet-val"].dataloader, args) |
|
results["imagenet-zeroshot-val-top1"] = top1 |
|
results["imagenet-zeroshot-val-top5"] = top5 |
|
if "imagenet-v2" in data: |
|
top1, top5 = run(model, classifier, data["imagenet-v2"].dataloader, args) |
|
results["imagenetv2-zeroshot-val-top1"] = top1 |
|
results["imagenetv2-zeroshot-val-top5"] = top5 |
|
|
|
logging.info("Finished zero-shot imagenet.") |
|
|
|
return results |
|
|