diff --git "a/RobustViT.ipynb" "b/RobustViT.ipynb"
new file mode 100644--- /dev/null
+++ "b/RobustViT.ipynb"
@@ -0,0 +1,1454 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "RobustViT.ipynb",
+ "provenance": [],
+ "collapsed_sections": [],
+ "authorship_tag": "ABX9TyNP00yXydKk0stZEJQyT5pO",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "hzG96yZskJSy",
+ "outputId": "3eab22fa-e246-4cfb-d4a9-c35878cf75f2"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Cloning into 'RobustViT'...\n",
+ "remote: Enumerating objects: 139, done.\u001b[K\n",
+ "remote: Counting objects: 100% (139/139), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (119/119), done.\u001b[K\n",
+ "remote: Total 139 (delta 54), reused 84 (delta 18), pack-reused 0\u001b[K\n",
+ "Receiving objects: 100% (139/139), 4.50 MiB | 17.11 MiB/s, done.\n",
+ "Resolving deltas: 100% (54/54), done.\n"
+ ]
+ }
+ ],
+ "source": [
+ "!git clone https://github.com/hila-chefer/RobustViT.git\n",
+ "\n",
+ "import os\n",
+ "os.chdir(f'./RobustViT')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install timm\n",
+ "!pip install einops"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "hZK84BL3mZQg",
+ "outputId": "f9273be0-3410-47cf-f52e-2a45989e9b1f"
+ },
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting timm\n",
+ " Downloading timm-0.5.4-py3-none-any.whl (431 kB)\n",
+ "\u001b[K |████████████████████████████████| 431 kB 4.4 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: torch>=1.4 in /usr/local/lib/python3.7/dist-packages (from timm) (1.11.0+cu113)\n",
+ "Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from timm) (0.12.0+cu113)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.4->timm) (4.2.0)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torchvision->timm) (1.21.6)\n",
+ "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->timm) (7.1.2)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from torchvision->timm) (2.23.0)\n",
+ "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (3.0.4)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (2022.5.18.1)\n",
+ "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (1.24.3)\n",
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (2.10)\n",
+ "Installing collected packages: timm\n",
+ "Successfully installed timm-0.5.4\n",
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting einops\n",
+ " Downloading einops-0.4.1-py3-none-any.whl (28 kB)\n",
+ "Installing collected packages: einops\n",
+ "Successfully installed einops-0.4.1\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from PIL import Image\n",
+ "import torchvision.transforms as transforms\n",
+ "import matplotlib.pyplot as plt\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "import cv2\n",
+ "from CLS2IDX import CLS2IDX\n",
+ "\n",
+ "normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
+ "transform = transforms.Compose([\n",
+ " transforms.Resize(256),\n",
+ " transforms.CenterCrop(224),\n",
+ " transforms.ToTensor(),\n",
+ " normalize,\n",
+ "])\n",
+ "transform_224 = transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " normalize,\n",
+ "])\n",
+ "\n",
+ "# create heatmap from mask on image\n",
+ "def show_cam_on_image(img, mask):\n",
+ " heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)\n",
+ " heatmap = np.float32(heatmap) / 255\n",
+ " cam = heatmap + np.float32(img)\n",
+ " cam = cam / np.max(cam)\n",
+ " return cam"
+ ],
+ "metadata": {
+ "id": "uqDTsTS2k8pl"
+ },
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from pydrive.auth import GoogleAuth\n",
+ "from pydrive.drive import GoogleDrive\n",
+ "from google.colab import auth\n",
+ "from oauth2client.client import GoogleCredentials\n",
+ "\n",
+ "# Authenticate and create the PyDrive client.\n",
+ "auth.authenticate_user()\n",
+ "gauth = GoogleAuth()\n",
+ "gauth.credentials = GoogleCredentials.get_application_default()\n",
+ "drive = GoogleDrive(gauth)\n",
+ "\n",
+ "# downloads weights\n",
+ "ids = ['1jbWiuBrL4sKpAjG3x4oGbs3WOC2UdbIb', '1DHKX_s8rVCDiX4pwnuCCZdGWsOl4SFMn', '1vDmuvbdLbYVAqWz6yVM4vT1Wdzt8KV-g']\n",
+ "for file_id in ids:\n",
+ " downloaded = drive.CreateFile({'id':file_id})\n",
+ " downloaded.FetchMetadata(fetch_all=True)\n",
+ " downloaded.GetContentFile(downloaded.metadata['title'])"
+ ],
+ "metadata": {
+ "id": "eImo3TAenFbo"
+ },
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model_name = 'ar_base' #@param ['ar_base','vit_base', 'deit_base']\n",
+ "\n",
+ "if model_name == 'ar_base':\n",
+ " from ViT.ViT_new import vit_base_patch16_224 as vit\n",
+ "\n",
+ " # initialize ViT pretrained\n",
+ " model = vit(pretrained=True).cuda()\n",
+ " model.eval()\n",
+ "\n",
+ " model_finetuned = vit().cuda()\n",
+ " checkpoint = torch.load('ar_base.tar')\n",
+ "\n",
+ "if model_name == 'vit_base':\n",
+ " from ViT.ViT import vit_base_patch16_224 as vit\n",
+ "\n",
+ " # initialize ViT pretrained\n",
+ " model = vit(pretrained=True).cuda()\n",
+ " model.eval()\n",
+ "\n",
+ " model_finetuned = vit().cuda()\n",
+ " checkpoint = torch.load('vit_base.tar')\n",
+ "\n",
+ "if model_name == 'deit_base':\n",
+ " from ViT.ViT import deit_base_patch16_224 as vit\n",
+ "\n",
+ " # initialize ViT pretrained\n",
+ " model = vit(pretrained=True).cuda()\n",
+ " model.eval()\n",
+ "\n",
+ " model_finetuned = vit().cuda()\n",
+ " checkpoint = torch.load('deit_base.tar')\n",
+ "\n",
+ "model_finetuned.load_state_dict(checkpoint['state_dict'])\n",
+ "model_finetuned.eval()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "cellView": "form",
+ "id": "tayM9OIWlLOT",
+ "outputId": "ac6f819a-ef56-4bbf-c25c-5f9cc9b2bb5c"
+ },
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "VisionTransformer(\n",
+ " (patch_embed): PatchEmbed(\n",
+ " (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n",
+ " (norm): Identity()\n",
+ " )\n",
+ " (pos_drop): Dropout(p=0.0, inplace=False)\n",
+ " (blocks): ModuleList(\n",
+ " (0): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (1): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (2): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (3): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (4): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (5): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (6): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (7): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (8): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (9): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (10): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (11): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (pre_logits): Identity()\n",
+ " (head): Linear(in_features=768, out_features=1000, bias=True)\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 7
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "start_layer = 0\n",
+ "\n",
+ "# rule 5 from paper\n",
+ "def avg_heads(cam, grad):\n",
+ " cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])\n",
+ " grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])\n",
+ " cam = grad * cam\n",
+ " cam = cam.clamp(min=0).mean(dim=0)\n",
+ " return cam\n",
+ "\n",
+ "# rule 6 from paper\n",
+ "def apply_self_attention_rules(R_ss, cam_ss):\n",
+ " R_ss_addition = torch.matmul(cam_ss, R_ss)\n",
+ " return R_ss_addition\n",
+ "\n",
+ "def generate_relevance(model, input, index=None):\n",
+ " output = model(input, register_hook=True)\n",
+ " if index == None:\n",
+ " index = np.argmax(output.cpu().data.numpy(), axis=-1)\n",
+ "\n",
+ " one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)\n",
+ " one_hot[0, index] = 1\n",
+ " one_hot_vector = one_hot\n",
+ " one_hot = torch.from_numpy(one_hot).requires_grad_(True)\n",
+ " one_hot = torch.sum(one_hot.cuda() * output)\n",
+ " model.zero_grad()\n",
+ " one_hot.backward(retain_graph=True)\n",
+ "\n",
+ " num_tokens = model.blocks[0].attn.get_attention_map().shape[-1]\n",
+ " R = torch.eye(num_tokens, num_tokens).cuda()\n",
+ " for i,blk in enumerate(model.blocks):\n",
+ " if i < start_layer:\n",
+ " continue\n",
+ " grad = blk.attn.get_attn_gradients()\n",
+ " cam = blk.attn.get_attention_map()\n",
+ " cam = avg_heads(cam, grad)\n",
+ " R += apply_self_attention_rules(R.cuda(), cam.cuda())\n",
+ " return R[0, 1:]"
+ ],
+ "metadata": {
+ "id": "tKp64OSWlC7w"
+ },
+ "execution_count": 8,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def generate_visualization(model, original_image, class_index=None):\n",
+ " with torch.enable_grad():\n",
+ " transformer_attribution = generate_relevance(model, original_image.unsqueeze(0).cuda(), index=class_index).detach()\n",
+ " transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)\n",
+ " transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')\n",
+ " transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()\n",
+ " transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())\n",
+ " \n",
+ " image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()\n",
+ " image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())\n",
+ " vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)\n",
+ " vis = np.uint8(255 * vis)\n",
+ " vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)\n",
+ " return vis\n",
+ "\n",
+ "def print_top_classes(predictions, **kwargs): \n",
+ " # Print Top-5 predictions\n",
+ " prob = torch.softmax(predictions, dim=1)\n",
+ " class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()\n",
+ " max_str_len = 0\n",
+ " class_names = []\n",
+ " for cls_idx in class_indices:\n",
+ " class_names.append(CLS2IDX[cls_idx])\n",
+ " if len(CLS2IDX[cls_idx]) > max_str_len:\n",
+ " max_str_len = len(CLS2IDX[cls_idx])\n",
+ " \n",
+ " print('Top 5 classes:')\n",
+ " for cls_idx in class_indices:\n",
+ " output_string = '\\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])\n",
+ " output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\\t\\t'\n",
+ " output_string += 'value = {:.3f}\\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])\n",
+ " print(output_string)"
+ ],
+ "metadata": {
+ "id": "rmQ9pacLoGze"
+ },
+ "execution_count": 9,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# ImageNet-A"
+ ],
+ "metadata": {
+ "id": "5o-p6euNMibE"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "with torch.no_grad():\n",
+ " image = Image.open(f'samples/{model_name}/a.png')\n",
+ " dog_cat_image = transform_224(image)\n",
+ "\n",
+ " fig, axs = plt.subplots(1, 2)\n",
+ " fig.set_size_inches(10, 7)\n",
+ " axs[0].imshow(image);\n",
+ " axs[0].axis('off');\n",
+ "\n",
+ " output = model(dog_cat_image.unsqueeze(0).cuda())\n",
+ " print(\"original model\")\n",
+ " print_top_classes(output)\n",
+ "\n",
+ " out = generate_visualization(model, dog_cat_image)\n",
+ "\n",
+ " fig.suptitle('original model',y=0.8)\n",
+ " axs[1].imshow(out);\n",
+ " axs[1].axis('off');\n",
+ "\n",
+ " fig, axs = plt.subplots(1, 2)\n",
+ " fig.set_size_inches(10, 7)\n",
+ " axs[0].imshow(image);\n",
+ " axs[0].axis('off');\n",
+ " output = model_finetuned(dog_cat_image.unsqueeze(0).cuda())\n",
+ " print(\"finetuned model\")\n",
+ " print_top_classes(output)\n",
+ "\n",
+ " out = generate_visualization(model_finetuned, dog_cat_image)\n",
+ "\n",
+ " fig.suptitle('finetuned model',y=0.8)\n",
+ " axs[1].imshow(out);\n",
+ " axs[1].axis('off');"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 842
+ },
+ "id": "q8hsWi_WMlkb",
+ "outputId": "2a89083c-53b6-4b5e-c7b0-c01d9c90446c"
+ },
+ "execution_count": 10,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "original model\n",
+ "Top 5 classes:\n",
+ "\t829 : streetcar, tram, tramcar, trolley, trolley car\t\tvalue = 10.911\t prob = 57.2%\n",
+ "\t874 : trolleybus, trolley coach, trackless trolley \t\tvalue = 10.221\t prob = 28.7%\n",
+ "\t466 : bullet train, bullet \t\tvalue = 6.897\t prob = 1.0%\n",
+ "\t733 : pole \t\tvalue = 6.878\t prob = 1.0%\n",
+ "\t547 : electric locomotive \t\tvalue = 6.626\t prob = 0.8%\n",
+ "finetuned model\n",
+ "Top 5 classes:\n",
+ "\t847 : tank, army tank, armored combat vehicle, armoured combat vehicle\t\tvalue = 11.573\t prob = 60.1%\n",
+ "\t408 : amphibian, amphibious vehicle \t\tvalue = 10.085\t prob = 13.6%\n",
+ "\t874 : trolleybus, trolley coach, trackless trolley \t\tvalue = 9.585\t prob = 8.2%\n",
+ "\t829 : streetcar, tram, tramcar, trolley, trolley car \t\tvalue = 9.583\t prob = 8.2%\n",
+ "\t586 : half track \t\tvalue = 7.935\t prob = 1.6%\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "