import argparse import os import torch import clip import os from tqdm import tqdm def zeroshot_classifier(model, classnames, templates, device): print('Building zero-shot classifier.') with torch.no_grad(): zeroshot_weights = [] for classname in tqdm(classnames): texts = [template(classname) for template in templates] #format with class texts = clip.tokenize(texts).to(device) #tokenize class_embeddings = model.encode_text(texts) class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding /= class_embedding.norm() zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) return 100*zeroshot_weights.t()