Spaces:
Running
on
Zero
Running
on
Zero
| # NOTE: This script is currently not supported for CLAP. | |
| import logging | |
| from contextlib import suppress | |
| import torch | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| from open_clip import tokenize | |
| from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template | |
| def zero_shot_classifier(model, classnames, templates, args): | |
| with torch.no_grad(): | |
| zeroshot_weights = [] | |
| for classname in tqdm(classnames): | |
| texts = [template(classname) for template in templates] # format with class | |
| texts = tokenize(texts).to(args.device) # tokenize | |
| if args.distributed and not args.horovod: | |
| class_embeddings = model.module.encode_text(texts) | |
| else: | |
| class_embeddings = model.encode_text(texts) | |
| class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) | |
| class_embedding /= class_embedding.norm() | |
| zeroshot_weights.append(class_embedding) | |
| zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) | |
| return zeroshot_weights | |
| def accuracy(output, target, topk=(1,)): | |
| pred = output.topk(max(topk), 1, True, True)[1].t() | |
| correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
| return [ | |
| float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) | |
| for k in topk | |
| ] | |
| def run(model, classifier, dataloader, args): | |
| autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress | |
| with torch.no_grad(): | |
| top1, top5, n = 0.0, 0.0, 0.0 | |
| for images, target in tqdm(dataloader, unit_scale=args.batch_size): | |
| images = images.to(args.device) | |
| target = target.to(args.device) | |
| with autocast(): | |
| # predict | |
| if args.distributed and not args.horovod: | |
| image_features = model.module.encode_image(images) | |
| else: | |
| image_features = model.encode_image(images) | |
| image_features = F.normalize(image_features, dim=-1) | |
| logits = 100.0 * image_features @ classifier | |
| # measure accuracy | |
| acc1, acc5 = accuracy(logits, target, topk=(1, 5)) | |
| top1 += acc1 | |
| top5 += acc5 | |
| n += images.size(0) | |
| top1 = top1 / n | |
| top5 = top5 / n | |
| return top1, top5 | |
| def zero_shot_eval(model, data, epoch, args): | |
| if "imagenet-val" not in data and "imagenet-v2" not in data: | |
| return {} | |
| if args.zeroshot_frequency == 0: | |
| return {} | |
| if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: | |
| return {} | |
| logging.info("Starting zero-shot imagenet.") | |
| logging.info("Building zero-shot classifier") | |
| classifier = zero_shot_classifier( | |
| model, imagenet_classnames, openai_imagenet_template, args | |
| ) | |
| logging.info("Using classifier") | |
| results = {} | |
| if "imagenet-val" in data: | |
| top1, top5 = run(model, classifier, data["imagenet-val"].dataloader, args) | |
| results["imagenet-zeroshot-val-top1"] = top1 | |
| results["imagenet-zeroshot-val-top5"] = top5 | |
| if "imagenet-v2" in data: | |
| top1, top5 = run(model, classifier, data["imagenet-v2"].dataloader, args) | |
| results["imagenetv2-zeroshot-val-top1"] = top1 | |
| results["imagenetv2-zeroshot-val-top5"] = top5 | |
| logging.info("Finished zero-shot imagenet.") | |
| return results | |