berkaygkv54's picture
first push
19759e2
raw
history blame contribute delete
No virus
3.4 kB
# NOTE: This script is currently not supported for CLAP.
import logging
from contextlib import suppress
import torch
import torch.nn.functional as F
from tqdm import tqdm
from clap_module import tokenize
from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
def zero_shot_classifier(model, classnames, templates, args):
with torch.no_grad():
zeroshot_weights = []
for classname in tqdm(classnames):
texts = [template(classname) for template in templates] # format with class
texts = tokenize(texts).to(args.device) # tokenize
if args.distributed and not args.horovod:
class_embeddings = model.module.encode_text(texts)
else:
class_embeddings = model.encode_text(texts)
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device)
return zeroshot_weights
def accuracy(output, target, topk=(1,)):
pred = output.topk(max(topk), 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
def run(model, classifier, dataloader, args):
autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress
with torch.no_grad():
top1, top5, n = 0., 0., 0.
for images, target in tqdm(dataloader, unit_scale=args.batch_size):
images = images.to(args.device)
target = target.to(args.device)
with autocast():
# predict
if args.distributed and not args.horovod:
image_features = model.module.encode_image(images)
else:
image_features = model.encode_image(images)
image_features = F.normalize(image_features, dim=-1)
logits = 100. * image_features @ classifier
# measure accuracy
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
top1 += acc1
top5 += acc5
n += images.size(0)
top1 = (top1 / n)
top5 = (top5 / n)
return top1, top5
def zero_shot_eval(model, data, epoch, args):
if 'imagenet-val' not in data and 'imagenet-v2' not in data:
return {}
if args.zeroshot_frequency == 0:
return {}
if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
return {}
logging.info('Starting zero-shot imagenet.')
logging.info('Building zero-shot classifier')
classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args)
logging.info('Using classifier')
results = {}
if 'imagenet-val' in data:
top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
results['imagenet-zeroshot-val-top1'] = top1
results['imagenet-zeroshot-val-top5'] = top5
if 'imagenet-v2' in data:
top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args)
results['imagenetv2-zeroshot-val-top1'] = top1
results['imagenetv2-zeroshot-val-top5'] = top5
logging.info('Finished zero-shot imagenet.')
return results