{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import open_clip\n", "import torch\n", "from tqdm.notebook import tqdm\n", "import pandas as pd\n", "import os\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "PROMPTS = [\n", " '{0}',\n", " 'an image of {0}',\n", " 'a photo of {0}',\n", " '{0} on a photo',\n", " 'a photo of a person named {0}',\n", " 'a person named {0}',\n", " 'a man named {0}',\n", " 'a woman named {0}',\n", " 'the name of the person is {0}',\n", " 'a photo of a person with the name {0}',\n", " '{0} at a gala',\n", " 'a photo of the celebrity {0}',\n", " 'actor {0}',\n", " 'actress {0}',\n", " 'a colored photo of {0}',\n", " 'a black and white photo of {0}',\n", " 'a cool photo of {0}',\n", " 'a cropped photo of {0}',\n", " 'a cropped image of {0}',\n", " '{0} in a suit',\n", " '{0} in a dress'\n", "]\n", "OPEN_CLIP_LAION400M_MODEL_NAMES = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14']\n", "OPEN_CLIP_LAION2B_MODEL_NAMES = [('ViT-B-32', 'laion2b_s34b_b79k') , ('ViT-L-14', 'laion2b_s32b_b82k')]\n", "OPEN_AI_MODELS = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14']\n", "SEED = 42" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "MODELS = {}\n", "for model_name in OPEN_CLIP_LAION400M_MODEL_NAMES:\n", " dataset = 'LAION400M'\n", " model, _, preprocess = open_clip.create_model_and_transforms(\n", " model_name,\n", " pretrained=f'{dataset.lower()}_e32'\n", " )\n", " model = model.eval()\n", " MODELS[(model_name, dataset.lower())] = {\n", " 'model_instance': model,\n", " 'preprocessing': preprocess,\n", " 'model_name': model_name,\n", " 'tokenizer': open_clip.get_tokenizer(model_name),\n", " }\n", "\n", "for model_name, dataset_name in OPEN_CLIP_LAION2B_MODEL_NAMES:\n", " dataset = 'LAION2B'\n", " model, _, preprocess = open_clip.create_model_and_transforms(\n", " model_name,\n", " pretrained = dataset_name\n", " )\n", " model = model.eval()\n", " MODELS[(model_name, dataset.lower())] = {\n", " 'model_instance': model,\n", " 'preprocessing': preprocess,\n", " 'model_name': model_name,\n", " 'tokenizer': open_clip.get_tokenizer(model_name)\n", " }\n", "\n", "for model_name in OPEN_AI_MODELS:\n", " dataset = 'OpenAI'\n", " model, _, preprocess = open_clip.create_model_and_transforms(\n", " model_name,\n", " pretrained=dataset.lower()\n", " )\n", " model = model.eval()\n", " MODELS[(model_name, dataset.lower())] = {\n", " 'model_instance': model,\n", " 'preprocessing': preprocess,\n", " 'model_name': model_name,\n", " 'tokenizer': open_clip.get_tokenizer(model_name)\n", " }" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# define a function to get the predictions for an actor/actress\n", "@torch.no_grad()\n", "def get_text_embeddings(model, context, context_batchsize=1_000, use_tqdm=False):\n", " context_batchsize = context_batchsize * torch.cuda.device_count()\n", " # if there is not batches for the context unsqueeze it\n", " if context.dim() < 3:\n", " context = context.unsqueeze(0)\n", "\n", " # get the batch size, the number of labels and the sequence length\n", " seq_len = context.shape[-1]\n", " viewed_context = context.view(-1, seq_len)\n", "\n", " text_features = []\n", " for context_batch_idx in tqdm(range(0, len(viewed_context), context_batchsize), desc=\"Calculating Text Embeddings\",\n", " disable=not use_tqdm):\n", " context_batch = viewed_context[context_batch_idx:context_batch_idx + context_batchsize]\n", " batch_text_features = model.encode_text(context_batch, normalize=True).cpu()\n", "\n", " text_features.append(batch_text_features)\n", " text_features = torch.cat(text_features).view(list(context.shape[:-1]) + [-1])\n", "\n", " return text_features" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# load the possible names\n", "possible_names = pd.read_csv('./full_names.csv', index_col=0)\n", "possible_names\n", "# possible_names_list = (possible_names['first_name'] + ' ' + possible_names['last_name']).tolist()\n", "# possible_names_list[:5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# populate the prompts with the possible names\n", "prompts = []\n", "for idx, row in possible_names.iterrows():\n", " df_dict = row.to_dict()\n", " name = f'{row[\"first_name\"]} {row[\"last_name\"]}'\n", " for prompt_idx, prompt in enumerate(PROMPTS):\n", " df_dict[f'prompt_{prompt_idx}'] = prompt.format(name)\n", " prompts.append(df_dict)\n", "prompts = pd.DataFrame(prompts)\n", "prompts" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "label_context_vecs_per_model = {}\n", "for dict_key, model_dict in MODELS.items():\n", " label_context_vecs = []\n", " for i in range(len(PROMPTS)):\n", " context = model_dict['tokenizer'](prompts[f'prompt_{i}'].to_numpy())\n", " label_context_vecs.append(context)\n", " label_context_vecs = torch.stack(label_context_vecs)\n", " label_context_vecs_per_model[dict_key] = label_context_vecs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "text_embeddings_per_model = {}\n", "for dict_key, model_dict in MODELS.items():\n", " label_context_vecs = label_context_vecs_per_model[dict_key].to(device)\n", " model = model_dict['model_instance']\n", " model = model.to(device)\n", " text_embeddings = get_text_embeddings(model, label_context_vecs, use_tqdm=True, context_batchsize=5_000)\n", " text_embeddings_per_model[dict_key] = text_embeddings\n", " model = model.cpu()\n", " label_context_vecs = label_context_vecs.cpu()\n", "\n", "label_context_vecs = label_context_vecs.cpu()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# save the calculated embeddings to a file\n", "if not os.path.exists('./prompt_text_embeddings'):\n", " os.makedirs('./prompt_text_embeddings')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "for (model_name, dataset_name), model_dict in MODELS.items():\n", " torch.save(\n", " text_embeddings_per_model[(model_name, dataset_name)],\n", " f'./prompt_text_embeddings/{model_name}_{dataset_name}_prompt_text_embeddings.pt'\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 0 }