{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Systematic" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import clip\n", "from evaluation_utils import norm, denorm\n", "from general_utils import *\n", "from datasets.lvis_oneshot3 import LVIS_OneShot3\n", "\n", "clip_device = 'cuda'\n", "clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n", "clip_model.eval();\n", "\n", "from models.clipseg import CLIPDensePredTMasked\n", "\n", "clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n", "clip_mask_model.eval();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n", " text_class_labels=True, image_size=352, min_area=0.1,\n", " min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_data(lvis)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "import json\n", "\n", "lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n", "lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n", "\n", "objects_per_image = defaultdict(lambda : set())\n", "for ann in lvis_raw['annotations']:\n", " objects_per_image[ann['image_id']].add(ann['category_id'])\n", " \n", "for ann in lvis_val_raw['annotations']:\n", " objects_per_image[ann['image_id']].add(ann['category_id']) \n", " \n", "objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n", "\n", "del lvis_raw, lvis_val_raw" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#bs = 32\n", "#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from general_utils import get_batch\n", "from functools import partial\n", "from evaluation_utils import img_preprocess\n", "import torch\n", "\n", "def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n", "\n", " # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n", "\n", " all_prompts = []\n", " \n", " with torch.no_grad():\n", " valid_sims = []\n", " torch.manual_seed(571)\n", " \n", " if type(batches_or_dataset) == list:\n", " loader = batches_or_dataset # already loaded\n", " max_iter = float('inf')\n", " else:\n", " loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n", " max_iter = 50\n", " \n", " global batch\n", " for i_batch, (batch, batch_y) in enumerate(loader):\n", " \n", " if i_batch >= max_iter: break\n", " \n", " processed_batch = process(batch)\n", " if type(processed_batch) == dict:\n", " \n", " # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n", " image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n", " else:\n", " processed_batch = process(batch).to(clip_device)\n", " processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n", " #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n", " image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n", " \n", " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n", " bs = len(batch[0])\n", " for j in range(bs):\n", " \n", " c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n", " support_image = basename(lvis.samples[c][sid])\n", " \n", " img_objs = [o for o in objects_per_image[int(support_image)]]\n", " img_objs = [o.replace('_', ' ') for o in img_objs]\n", " \n", " other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n", " if o != batch_y[2][j]]\n", " \n", " prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n", " all_prompts += [prompts]\n", " \n", " text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n", " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n", "\n", " global logits\n", " logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n", "\n", " global sim\n", " sim = torch.softmax(logits, dim=-1)\n", " \n", " valid_sims += [sim]\n", " \n", " #valid_sims = torch.stack(valid_sims)\n", " return valid_sims, all_prompts\n", " \n", "\n", "def new_img_preprocess(x):\n", " return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n", " \n", "#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n", "get_similarities(lvis, lambda x: x[1]);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "preprocessing_functions = [\n", "# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n", "# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n", "# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n", "# ['colorize object red', partial(img_preprocess, colorize=True)],\n", "# ['add red outline', partial(img_preprocess, outline=True)],\n", " \n", "# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n", "# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n", "# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n", "# ['BG blur', partial(img_preprocess, blur=3)],\n", "# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n", " \n", "# ['crop large context', partial(img_preprocess, center_context=0.5)],\n", "# ['crop small context', partial(img_preprocess, center_context=0.1)],\n", " ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n", " ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n", "# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n", "]\n", "\n", "preprocessing_functions = preprocessing_functions\n", "\n", "base, base_p = get_similarities(lvis, lambda x: x[1])\n", "outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for j in range(1):\n", " print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pandas import DataFrame\n", "tab = dict()\n", "for j, (name, _) in enumerate(preprocessing_functions):\n", " tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n", " \n", " \n", "print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Visual" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from evaluation_utils import denorm, norm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def load_sample(filename, filename2):\n", " from os.path import join\n", " bp = expanduser('~/cloud/resources/sample_images')\n", " tf = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " transforms.Resize(224),\n", " transforms.CenterCrop(224)\n", " ])\n", " tf2 = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Resize(224),\n", " transforms.CenterCrop(224)\n", " ])\n", " inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n", " inp1[1] = inp1[1].unsqueeze(0)\n", " inp1[2] = inp1[2][:1] \n", " return inp1\n", "\n", "def all_preprocessing(inp1):\n", " return [\n", " img_preprocess(inp1),\n", " img_preprocess(inp1, colorize=True),\n", " img_preprocess(inp1, outline=True), \n", " img_preprocess(inp1, blur=3),\n", " img_preprocess(inp1, bg_fac=0.1),\n", " #img_preprocess(inp1, bg_fac=0.5),\n", " #img_preprocess(inp1, blur=3, bg_fac=0.5), \n", " img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n", " ]\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torchvision import transforms\n", "from PIL import Image\n", "from matplotlib import pyplot as plt\n", "from evaluation_utils import img_preprocess\n", "import clip\n", "\n", "images_queries = [\n", " [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n", " [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n", "]\n", "\n", "\n", "_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n", "\n", "for j, (images, objects) in enumerate(images_queries):\n", " \n", " joint_image = all_preprocessing(images)\n", " \n", " joint_image = torch.stack(joint_image)[:,0]\n", " clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n", " image_features = clip_model.encode_image(joint_image)\n", " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n", " \n", " prompts = [f'a photo of a {obj}'for obj in objects]\n", " text_cond = clip_model.encode_text(clip.tokenize(prompts))\n", " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n", " logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n", " sim = torch.softmax(logits, dim=-1).detach().cpu()\n", "\n", " for i, img in enumerate(joint_image):\n", " ax[2*j, i].axis('off')\n", " \n", " ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n", " ax[2*j+ 1, i].grid(True)\n", " \n", " ax[2*j + 1, i].set_ylim(0,1)\n", " ax[2*j + 1, i].set_yticklabels([])\n", " ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n", "# ax[1, i].set_xticklabels(objects, rotation=90)\n", " for k in range(len(sim[i])):\n", " ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n", " ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n", "\n", "plt.tight_layout()\n", "plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')" ] } ], "metadata": { "kernelspec": { "display_name": "env2", "language": "python", "name": "env2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 4 }