In [None]:
import open_clip
import torch
from tqdm.notebook import tqdm
import pandas as pd
import os

device = "cuda" if torch.cuda.is_available() else "cpu"

PROMPTS = [
 '{0}',
 'an image of {0}',
 'a photo of {0}',
 '{0} on a photo',
 'a photo of a person named {0}',
 'a person named {0}',
 'a man named {0}',
 'a woman named {0}',
 'the name of the person is {0}',
 'a photo of a person with the name {0}',
 '{0} at a gala',
 'a photo of the celebrity {0}',
 'actor {0}',
 'actress {0}',
 'a colored photo of {0}',
 'a black and white photo of {0}',
 'a cool photo of {0}',
 'a cropped photo of {0}',
 'a cropped image of {0}',
 '{0} in a suit',
 '{0} in a dress'
]
OPEN_CLIP_LAION400M_MODEL_NAMES = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14']
OPEN_CLIP_LAION2B_MODEL_NAMES = [('ViT-B-32', 'laion2b_s34b_b79k') , ('ViT-L-14', 'laion2b_s32b_b82k')]
OPEN_AI_MODELS = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14']
SEED = 42

In [None]:
MODELS = {}
for model_name in OPEN_CLIP_LAION400M_MODEL_NAMES:
 dataset = 'LAION400M'
 model, _, preprocess = open_clip.create_model_and_transforms(
 model_name,
 pretrained=f'{dataset.lower()}_e32'
 )
 model = model.eval()
 MODELS[(model_name, dataset.lower())] = {
 'model_instance': model,
 'preprocessing': preprocess,
 'model_name': model_name,
 'tokenizer': open_clip.get_tokenizer(model_name),
 }

for model_name, dataset_name in OPEN_CLIP_LAION2B_MODEL_NAMES:
 dataset = 'LAION2B'
 model, _, preprocess = open_clip.create_model_and_transforms(
 model_name,
 pretrained = dataset_name
 )
 model = model.eval()
 MODELS[(model_name, dataset.lower())] = {
 'model_instance': model,
 'preprocessing': preprocess,
 'model_name': model_name,
 'tokenizer': open_clip.get_tokenizer(model_name)
 }

for model_name in OPEN_AI_MODELS:
 dataset = 'OpenAI'
 model, _, preprocess = open_clip.create_model_and_transforms(
 model_name,
 pretrained=dataset.lower()
 )
 model = model.eval()
 MODELS[(model_name, dataset.lower())] = {
 'model_instance': model,
 'preprocessing': preprocess,
 'model_name': model_name,
 'tokenizer': open_clip.get_tokenizer(model_name)
 }

In [None]:
# define a function to get the predictions for an actor/actress
@torch.no_grad()
def get_text_embeddings(model, context, context_batchsize=1_000, use_tqdm=False):
 context_batchsize = context_batchsize * torch.cuda.device_count()
 # if there is not batches for the context unsqueeze it
 if context.dim() < 3:
 context = context.unsqueeze(0)

 # get the batch size, the number of labels and the sequence length
 seq_len = context.shape[-1]
 viewed_context = context.view(-1, seq_len)

 text_features = []
 for context_batch_idx in tqdm(range(0, len(viewed_context), context_batchsize), desc="Calculating Text Embeddings",
 disable=not use_tqdm):
 context_batch = viewed_context[context_batch_idx:context_batch_idx + context_batchsize]
 batch_text_features = model.encode_text(context_batch, normalize=True).cpu()

 text_features.append(batch_text_features)
 text_features = torch.cat(text_features).view(list(context.shape[:-1]) + [-1])

 return text_features

In [None]:
# load the possible names
possible_names = pd.read_csv('./full_names.csv', index_col=0)
possible_names
# possible_names_list = (possible_names['first_name'] + ' ' + possible_names['last_name']).tolist()
# possible_names_list[:5]

In [None]:
# populate the prompts with the possible names
prompts = []
for idx, row in possible_names.iterrows():
 df_dict = row.to_dict()
 name = f'{row["first_name"]} {row["last_name"]}'
 for prompt_idx, prompt in enumerate(PROMPTS):
 df_dict[f'prompt_{prompt_idx}'] = prompt.format(name)
 prompts.append(df_dict)
prompts = pd.DataFrame(prompts)
prompts

In [None]:
label_context_vecs_per_model = {}
for dict_key, model_dict in MODELS.items():
 label_context_vecs = []
 for i in range(len(PROMPTS)):
 context = model_dict['tokenizer'](prompts[f'prompt_{i}'].to_numpy())
 label_context_vecs.append(context)
 label_context_vecs = torch.stack(label_context_vecs)
 label_context_vecs_per_model[dict_key] = label_context_vecs

In [None]:
text_embeddings_per_model = {}
for dict_key, model_dict in MODELS.items():
 label_context_vecs = label_context_vecs_per_model[dict_key].to(device)
 model = model_dict['model_instance']
 model = model.to(device)
 text_embeddings = get_text_embeddings(model, label_context_vecs, use_tqdm=True, context_batchsize=5_000)
 text_embeddings_per_model[dict_key] = text_embeddings
 model = model.cpu()
 label_context_vecs = label_context_vecs.cpu()

label_context_vecs = label_context_vecs.cpu()

In [None]:
# save the calculated embeddings to a file
if not os.path.exists('./prompt_text_embeddings'):
 os.makedirs('./prompt_text_embeddings')

In [None]:
for (model_name, dataset_name), model_dict in MODELS.items():
 torch.save(
 text_embeddings_per_model[(model_name, dataset_name)],
 f'./prompt_text_embeddings/{model_name}_{dataset_name}_prompt_text_embeddings.pt'
 )