{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import json\n", "from PIL import Image\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "from pathlib import Path\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "\n", "from src.data.embs import ImageDataset\n", "from src.model.blip_embs import blip_embs" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def get_blip_config(model=\"base\"):\n", " config = dict()\n", " if model == \"base\":\n", " config[\n", " \"pretrained\"\n", " ] = \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth \"\n", " config[\"vit\"] = \"base\"\n", " config[\"batch_size_train\"] = 32\n", " config[\"batch_size_test\"] = 16\n", " config[\"vit_grad_ckpt\"] = True\n", " config[\"vit_ckpt_layer\"] = 4\n", " config[\"init_lr\"] = 1e-5\n", " elif model == \"large\":\n", " config[\n", " \"pretrained\"\n", " ] = \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth\"\n", " config[\"vit\"] = \"large\"\n", " config[\"batch_size_train\"] = 16\n", " config[\"batch_size_test\"] = 32\n", " config[\"vit_grad_ckpt\"] = True\n", " config[\"vit_ckpt_layer\"] = 12\n", " config[\"init_lr\"] = 5e-6\n", "\n", " config[\"image_size\"] = 384\n", " config[\"queue_size\"] = 57600\n", " config[\"alpha\"] = 0.4\n", " config[\"k_test\"] = 256\n", " config[\"negative_all_rank\"] = True\n", "\n", " return config" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating model\n", "load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth\n", "missing keys:\n", "[]\n" ] }, { "data": { "text/plain": [ "BLIPEmbs(\n", " (visual_encoder): VisionTransformer(\n", " (patch_embed): PatchEmbed(\n", " (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))\n", " (norm): Identity()\n", " )\n", " (pos_drop): Dropout(p=0.0, inplace=False)\n", " (blocks): ModuleList(\n", " (0): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (1): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.004)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (2): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.009)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (3): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.013)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (4): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.017)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (5): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.022)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (6): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.026)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (7): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.030)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (8): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.035)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (9): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.039)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (10): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.043)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (11): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.048)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (12): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.052)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (13): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.057)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (14): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.061)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (15): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.065)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (16): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.070)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (17): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.074)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (18): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.078)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (19): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.083)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (20): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.087)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (21): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.091)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (22): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.096)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (23): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.100)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " )\n", " (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " )\n", " (text_encoder): BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(30524, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0-11): 12 x BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (crossattention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=1024, out_features=768, bias=True)\n", " (value): Linear(in_features=1024, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " )\n", " (vision_proj): Linear(in_features=1024, out_features=256, bias=True)\n", " (text_proj): Linear(in_features=768, out_features=256, bias=True)\n", ")" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"Creating model\")\n", "config = get_blip_config(\"large\")\n", "\n", "model = blip_embs(\n", " pretrained=config[\"pretrained\"],\n", " image_size=config[\"image_size\"],\n", " vit=config[\"vit\"],\n", " vit_grad_ckpt=config[\"vit_grad_ckpt\"],\n", " vit_ckpt_layer=config[\"vit_ckpt_layer\"],\n", " queue_size=config[\"queue_size\"],\n", " negative_all_rank=config[\"negative_all_rank\"],\n", " )\n", "\n", "model = model.to(device)\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "BLIPEmbs(\n", " (visual_encoder): VisionTransformer(\n", " (patch_embed): PatchEmbed(\n", " (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))\n", " (norm): Identity()\n", " )\n", " (pos_drop): Dropout(p=0.0, inplace=False)\n", " (blocks): ModuleList(\n", " (0): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (1): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.004)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (2): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.009)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (3): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.013)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (4): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.017)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (5): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.022)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (6): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.026)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (7): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.030)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (8): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.035)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (9): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.039)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (10): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.043)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (11): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.048)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (12): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.052)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (13): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.057)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (14): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.061)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (15): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.065)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (16): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.070)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (17): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.074)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (18): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.078)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (19): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.083)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (20): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.087)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (21): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.091)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (22): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.096)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (23): Block(\n", " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): DropPath(drop_prob=0.100)\n", " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " )\n", " (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n", " )\n", " (text_encoder): BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(30524, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0-11): 12 x BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (crossattention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=1024, out_features=768, bias=True)\n", " (value): Linear(in_features=1024, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " )\n", " (vision_proj): Linear(in_features=1024, out_features=256, bias=True)\n", " (text_proj): Linear(in_features=768, out_features=256, bias=True)\n", ")" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Read all database image features and create a list" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "df = pd.read_json(\"datasets/sidechef/my_recipes.json\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Index(['recipe_name', 'recipe_time', 'recipe_yields', 'recipe_ingredients',\n", " 'recipe_instructions', 'recipe_image', 'blogger', 'recipe_nutrients',\n", " 'tags', 'id_'],\n", " dtype='object')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.columns" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading Target Embedding\n" ] } ], "source": [ "print(\"Loading Target Embedding\")\n", "tar_img_feats = []\n", "for _id in df[\"id_\"].tolist(): \n", " tar_img_feats.append(torch.load(\"datasets/sidechef/blip-embs-large/{:07d}.pth\".format(_id)).unsqueeze(0))\n", "\n", "tar_img_feats = torch.cat(tar_img_feats, dim=0)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from src.data.transforms import transform_test\n", "\n", "transform = transform_test(384)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "image = Image.open(\"datasets/sidechef/images/{:07d}.png\".format(3)).convert(\"RGB\")" ] }, { "cell_type": "code", "execution_count": 134, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "img = transform(image).unsqueeze(0)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "img = img.to(device)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "img_embs = model.visual_encoder(img)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 577, 1024])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img_embs.shape" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 256])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img_feats.shape" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 256])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tar_img_feats[0].shape" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "tar_img_feats = torch.cat(tar_img_feats, dim=0)" ] }, { "cell_type": "code", "execution_count": 159, "metadata": {}, "outputs": [], "source": [ "score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy()" ] }, { "cell_type": "code", "execution_count": 165, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2" ] }, "execution_count": 165, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.argsort(score)[::-1][0]" ] }, { "cell_type": "code", "execution_count": 168, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "recipe_name Farmers Market Breakfast Pizza\n", "recipe_time 0\n", "recipe_yields 2 servings\n", "recipe_ingredients [1/2 Pizza Dough, 1/2 cup Kale, 1/2 cup Onion,...\n", "recipe_instructions For homemade pizza sauce, finely chop the Swee...\n", "recipe_image https://www.sidechef.com/recipe/1cd15944-9411-...\n", "blogger sidechef.com\n", "recipe_nutrients {'calories': '315 calories', 'proteinContent':...\n", "tags [Breakfast, Brunch, Main Dish, Budget-Friendly...\n", "id_ 4\n", "Name: 3, dtype: object" ] }, "execution_count": 168, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.iloc[2+1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "class StoppingCriteriaSub(StoppingCriteria):\n", "\n", " def __init__(self, stops=[], encounters=1):\n", " super().__init__()\n", " self.stops = stops\n", "\n", " def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):\n", " for stop in self.stops:\n", " if torch.all(input_ids[:, -len(stop):] == stop).item():\n", " return True\n", "\n", " return False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image_path = \"datasets/sidechef/images/{:07d}.png\".format(3)" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "class Chat:\n", "\n", " def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None):\n", " self.device = device\n", " self.model = model\n", " self.transform = transform\n", " self.df = dataframe\n", " self.tar_img_feats = tar_img_feats\n", " self.img_feats = None\n", " self.target_recipe = None\n", " self.messages = []\n", "\n", " if stopping_criteria is not None:\n", " self.stopping_criteria = stopping_criteria\n", " else:\n", " stop_words_ids = [torch.tensor([2]).to(self.device)]\n", " self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])\n", "\n", " def encode_image(self, image_path):\n", " img = Image.fromarray(image_path).convert(\"RGB\")\n", " img = self.transform(img).unsqueeze(0)\n", " img = img.to(self.device)\n", " img_embs = model.visual_encoder(img)\n", " img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu()\n", "\n", " self.img_feats = img_feats \n", "\n", " self.get_target(self.img_feats, self.tar_img_feats)\n", "\n", " def get_target(self, img_feats, tar_img_feats) : \n", " score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy()\n", " index = np.argsort(score)[::-1][0] + 1\n", " self.target_recipe = df.iloc[index]\n", "\n", " def ask(self, msg):\n", " if \"nutrition\" in msg or \"nutrients\" in msg : \n", " return json.dumps(self.target_recipe[\"recipe_nutrients\"], indent=4)\n", " elif \"instruction\" in msg :\n", " return json.dumps(self.target_recipe[\"recipe_instructions\"], indent=4)\n", " elif \"ingredients\" in msg :\n", " return json.dumps(self.target_recipe[\"recipe_ingredients\"], indent=4)\n", " elif \"tag\" in msg or \"class\" in msg :\n", " return json.dumps(self.target_recipe[\"tags\"], indent=4)\n", " else:\n", " return \"Conversational capabilities will be included later.\"\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [], "source": [ "chat = Chat(model,transform,df,tar_img_feats)" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [], "source": [ "import gradio as gr " ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7874\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "example_images = gr.Dataset(components=[image], label=\"Food Examples\",\n", " samples=[\n", " [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000018.png\")],\n", " [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000021.png\")],\n", " [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000035.png\")],\n", " [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000038.png\")],\n", " [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000090.png\")],\n", " [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000122.png\")],\n", " ])\n", "\n", "example_texts = gr.Dataset(components=[gr.Textbox(visible=False)],\n", " label=\"Prompt Examples\",\n", " samples=[\n", " [\"Describe the given chest x-ray image in detail.\"],\n", " [\"Take a look at this chest x-ray and describe the findings and impression.\"],\n", " [\"Could you provide a detailed description of the given x-ray image?\"],\n", " [\"Describe the given chest x-ray image as detailed as possible.\"],\n", " [\"What are the key findings in this chest x-ray image?\"],\n", " [\"Could you highlight any abnormalities or concerns in this chest x-ray image?\"],\n", " [\"What specific features of the lungs and heart are visible in this chest x-ray image?\"],\n", " [\"What is the most prominent feature visible in this chest x-ray image, and how is it indicative of the patient's health?\"],\n", " [\"Based on the findings in this chest x-ray image, what is the overall impression?\"],\n", " ],)\n", "\n", "\n", "\n", "def respond_to_user(image, message):\n", " # Process the image and message here\n", " # For demonstration, I'll just return a simple text response\n", " chat = Chat(model,transform,df,tar_img_feats)\n", " chat.encode_image(image)\n", " response = chat.ask(message)\n", " return response\n", "\n", "iface = gr.Interface(\n", " fn=respond_to_user,\n", " inputs=[gr.Image(height=\"70%\"), gr.Textbox(label=\"Ask Query\"),],\n", " outputs=[gr.Textbox(label=\"Nutrition-GPT\")],\n", " title=\"Nutrition-GPT Demo\",\n", " description=\"Upload an food image and ask queries!\",\n", " css=\".component-12 {background-color: red}\",\n", " \n", ")\n", "\n", "iface.launch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7876\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 82, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "def respond_to_user(image, message):\n", " # Process the image and message here\n", " # For demonstration, I'll just return a simple text response\n", " chat = Chat(model,transform,df,tar_img_feats)\n", " chat.encode_image(image)\n", " response = chat.ask(message)\n", " return response\n", "\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"Nutrition-GPT Demo\")\n", "\n", " with gr.Row():\n", " with gr.Column():\n", " image = gr.Image()\n", " text_input = gr.Textbox(label='Ask Query')\n", " submit_button = gr.Button(value=\"Upload Food Image and Submit Query\", interactive=True, variant=\"primary\")\n", " clear = gr.Button(\"Reset\")\n", " \n", "\n", " with gr.Column():\n", " text_output = gr.Textbox(label=\"Nutrition-GPT\")\n", "\n", " \n", " submit_button.click(respond_to_user, inputs=[image, text_input], outputs=text_output)\n", " \n", "\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def set_example_image(example: list) -> dict:\n", " return gr.Image.update(value=example[0])\n", "\n", "def upload_img(gr_img, text_input, chat_state):\n", " if gr_img is None:\n", " return None, None, gr.update(interactive=True), chat_state, None\n", " chat_state = CONV_VISION.copy()\n", " img_list = []\n", " llm_message = chat.upload_img(gr_img, chat_state, img_list)\n", " return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value=\"Start Chatting\", interactive=False), chat_state, img_list\n", "\n", "\n", "\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"Nutrition-GPT Demo\")\n", "\n", " with gr.Row():\n", " with gr.Column(scale=0.5):\n", " image = gr.Image(type=\"pil\")\n", " upload_button = gr.Button(value=\"Upload Food Image and Ask Queries\", interactive=True, variant=\"primary\")\n", " clear = gr.Button(\"Reset\")\n", " \n", "\n", " with gr.Column():\n", " chat_state = gr.State()\n", " img_list = gr.State()\n", " chatbot = gr.Chatbot(label='Nutrition-GPT')\n", " text_input = gr.Textbox(label='User', placeholder='Please upload food image.', interactive=False)\n", "\n", "\n", " with gr.Row():\n", " example_images = gr.Dataset(components=[image], label=\"X-Ray Examples\",\n", " samples=[\n", " [os.path.join(os.path.dirname(__file__), \"./datasets/sidechef/sample_images/0000018.png\")],\n", " [os.path.join(os.path.dirname(__file__), \"./datasets/sidechef/sample_images/0000021.png\")],\n", " [os.path.join(os.path.dirname(__file__), \"./datasets/sidechef/sample_images/0000035.png\")],\n", " [os.path.join(os.path.dirname(__file__), \"./datasets/sidechef/sample_images/0000038.png\")],\n", " [os.path.join(os.path.dirname(__file__), \"./datasets/sidechef/sample_images/0000090.png\")],\n", " [os.path.join(os.path.dirname(__file__), \"./datasets/sidechef/sample_images/0000122.png\")],\n", " ])\n", " \n", "\n", " with gr.Row():\n", " example_texts = gr.Dataset(components=[gr.Textbox(visible=False)],\n", " label=\"Prompt Examples\",\n", " samples=[\n", " [\"Describe the given chest x-ray image in detail.\"],\n", " [\"Take a look at this chest x-ray and describe the findings and impression.\"],\n", " [\"Could you provide a detailed description of the given x-ray image?\"],\n", " [\"Describe the given chest x-ray image as detailed as possible.\"],\n", " [\"What are the key findings in this chest x-ray image?\"],\n", " [\"Could you highlight any abnormalities or concerns in this chest x-ray image?\"],\n", " [\"What specific features of the lungs and heart are visible in this chest x-ray image?\"],\n", " [\"What is the most prominent feature visible in this chest x-ray image, and how is it indicative of the patient's health?\"],\n", " [\"Based on the findings in this chest x-ray image, what is the overall impression?\"],\n", " ],)\n", " \n", " example_images.click(fn=set_example_image, inputs=example_images, outputs=example_images.components)\n", "\n", " upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])\n", " \n", " example_texts.click(set_example_text_input, inputs=example_texts, outputs=text_input).then(\n", " gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(\n", " gradio_answer, [chatbot, chat_state, img_list], [chatbot, chat_state, img_list]\n", " )\n", " \n", " text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(\n", " gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]\n", " )\n", " clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)\n", " \n", " gr.Markdown(disclaimer)\n", "\n", "demo.launch(share=True, enable_queue=True)" ] } ], "metadata": { "kernelspec": { "display_name": "covr", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 2 }