{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{},"source":["# Specific Test V. Exploring Transformers\n","\n","> This is a copy of the training notebook used for Specific Test 5 of the DeepLense tests. The original notebook, datasets and results can be found [here](https://github.com/gauthamk02/deeplense-test-2023)\n","\n","**Task:** Use a vision transformer method of your choice to build a robust and efficient model for binary classification or unsupervised anomaly detection on the provided dataset. In the case of unsupervised anomaly detection, train your model to learn the distribution of the provided strong lensing images with no substructure. Please implement your approach in PyTorch or Keras and discuss your strategy.\n","\n","**Dataset Description:** A set of simulated strong gravitational lensing images with and without substructure. \n","\n","**Evaluation Metrics:** ROC curve (Receiver Operating Characteristic curve) and AUC score (Area Under the ROC Curve)"]},{"cell_type":"code","execution_count":27,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T09:43:14.076073Z","iopub.status.busy":"2023-03-11T09:43:14.075565Z","iopub.status.idle":"2023-03-11T09:43:24.878632Z","shell.execute_reply":"2023-03-11T09:43:24.877344Z","shell.execute_reply.started":"2023-03-11T09:43:14.076037Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["Requirement already satisfied: einops in /opt/conda/lib/python3.7/site-packages (0.6.0)\n","\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n","\u001b[0m"]}],"source":["!pip install einops"]},{"cell_type":"code","execution_count":28,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T09:43:24.883148Z","iopub.status.busy":"2023-03-11T09:43:24.882826Z","iopub.status.idle":"2023-03-11T09:43:24.890557Z","shell.execute_reply":"2023-03-11T09:43:24.889077Z","shell.execute_reply.started":"2023-03-11T09:43:24.883115Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torchvision.transforms as transforms\n","from torch.utils.data import DataLoader, Dataset\n","import matplotlib.pyplot as plt\n","from sklearn.model_selection import train_test_split\n","import numpy as np\n","import pandas as pd\n","import os\n","from PIL import Image\n","from tqdm import tqdm\n","import seaborn as sns\n","\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"]},{"cell_type":"markdown","metadata":{},"source":["## Loading Data"]},{"cell_type":"code","execution_count":29,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T09:43:24.892682Z","iopub.status.busy":"2023-03-11T09:43:24.892290Z","iopub.status.idle":"2023-03-11T09:43:24.942546Z","shell.execute_reply":"2023-03-11T09:43:24.941484Z","shell.execute_reply.started":"2023-03-11T09:43:24.892644Z"},"trusted":true},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
pathlabel
0/kaggle/input/ml4sci-test5/lenses/no_sub/image...no_sub
1/kaggle/input/ml4sci-test5/lenses/no_sub/image...no_sub
2/kaggle/input/ml4sci-test5/lenses/no_sub/image...no_sub
3/kaggle/input/ml4sci-test5/lenses/no_sub/image...no_sub
4/kaggle/input/ml4sci-test5/lenses/no_sub/image...no_sub
\n","
"],"text/plain":[" path label\n","0 /kaggle/input/ml4sci-test5/lenses/no_sub/image... no_sub\n","1 /kaggle/input/ml4sci-test5/lenses/no_sub/image... no_sub\n","2 /kaggle/input/ml4sci-test5/lenses/no_sub/image... no_sub\n","3 /kaggle/input/ml4sci-test5/lenses/no_sub/image... no_sub\n","4 /kaggle/input/ml4sci-test5/lenses/no_sub/image... no_sub"]},"execution_count":29,"metadata":{},"output_type":"execute_result"}],"source":["base_dir = '../datasets/test5/lenses'\n","\n","train_df = pd.DataFrame(columns = ['path', 'label'])\n","\n","label_map = {'no_sub':0, 'sub':1}\n","val_label_map = {0:'no_sub', 1:'sub'}\n","\n","for i in label_map.keys():\n"," entries = [os.path.join(base_dir, i, j) for j in os.listdir(os.path.join(base_dir, i))]\n"," temp_df = pd.DataFrame({'path':entries, 'label':i})\n"," train_df = pd.concat([train_df, temp_df], ignore_index=True)\n","\n","train_df.head()"]},{"cell_type":"code","execution_count":30,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T09:43:24.945833Z","iopub.status.busy":"2023-03-11T09:43:24.945428Z","iopub.status.idle":"2023-03-11T09:43:24.964359Z","shell.execute_reply":"2023-03-11T09:43:24.963191Z","shell.execute_reply.started":"2023-03-11T09:43:24.945794Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["Train split:\n","sub 4500\n","no_sub 4500\n","Name: label, dtype: int64\n","\n","Test split:\n","sub 500\n","no_sub 500\n","Name: label, dtype: int64\n"]}],"source":["train_df, test_df = train_test_split(train_df, test_size=0.1, random_state=42, stratify=train_df['label'])\n","\n","print(f\"Train split:\\n{train_df['label'].value_counts()}\\n\")\n","print(f\"Test split:\\n{test_df['label'].value_counts()}\")"]},{"cell_type":"code","execution_count":31,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T09:43:24.967780Z","iopub.status.busy":"2023-03-11T09:43:24.967441Z","iopub.status.idle":"2023-03-11T09:43:24.975704Z","shell.execute_reply":"2023-03-11T09:43:24.974493Z","shell.execute_reply.started":"2023-03-11T09:43:24.967752Z"},"trusted":true},"outputs":[],"source":["class Dataset(Dataset):\n"," def __init__(self, df, transform=None):\n"," self.df = df\n"," self.transform = transform\n","\n"," def __len__(self):\n"," return len(self.df)\n","\n"," def __getitem__(self, idx):\n"," img_path = self.df.iloc[idx]['path']\n"," img = Image.open(img_path)\n"," \n"," if self.transform:\n"," img = self.transform(img)\n","\n"," label = self.df.iloc[idx]['label']\n"," label = label_map[label]\n"," \n"," return img, label"]},{"cell_type":"code","execution_count":32,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T09:43:24.978749Z","iopub.status.busy":"2023-03-11T09:43:24.977569Z","iopub.status.idle":"2023-03-11T09:43:24.990137Z","shell.execute_reply":"2023-03-11T09:43:24.989173Z","shell.execute_reply.started":"2023-03-11T09:43:24.978711Z"},"trusted":true},"outputs":[],"source":["train_transforms = transforms.Compose([\n"," transforms.ColorJitter(brightness=0.2, contrast=0.2),\n"," transforms.RandomRotation(180),\n"," transforms.RandomHorizontalFlip(),\n"," transforms.Resize(256),\n"," transforms.ToTensor()\n","])\n","\n","test_transforms = transforms.Compose([\n"," transforms.Resize(256),\n"," transforms.ToTensor()\n","])\n","\n","train_dataset = Dataset(train_df, transform=train_transforms)\n","test_dataset = Dataset(test_df, transform=test_transforms)\n","\n","train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)\n","test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True, num_workers=2)"]},{"cell_type":"code","execution_count":33,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T09:43:24.992209Z","iopub.status.busy":"2023-03-11T09:43:24.991353Z","iopub.status.idle":"2023-03-11T09:43:25.550563Z","shell.execute_reply":"2023-03-11T09:43:25.549252Z","shell.execute_reply.started":"2023-03-11T09:43:24.992172Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["torch.Size([64, 1, 256, 256])\n"]}],"source":["img, label = next(iter(train_loader))\n","print(img.shape)"]},{"cell_type":"code","execution_count":34,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T09:43:25.554842Z","iopub.status.busy":"2023-03-11T09:43:25.552310Z","iopub.status.idle":"2023-03-11T09:43:26.475677Z","shell.execute_reply":"2023-03-11T09:43:26.474519Z","shell.execute_reply.started":"2023-03-11T09:43:25.554796Z"},"trusted":true},"outputs":[{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["def show_batch(images, labels, class_map):\n"," fig, ax = plt.subplots(3, 3, figsize=(6, 6))\n"," for i in range(3):\n"," for j in range(3):\n"," image = images[i*3 + j]\n"," label = torch.argmax(labels[i*3 + j])\n"," ax[i][j].imshow(image.permute(1, 2, 0))\n"," title = [k for k, v in class_map.items() if v == label][0]\n"," ax[i][j].set_title(title)\n"," ax[i][j].axis('off')\n"," ax[i][j].title.set_fontsize(10)\n","\n"," plt.show()\n","\n","images, labels = next(iter(train_loader))\n","show_batch(images, labels, label_map)"]},{"cell_type":"markdown","metadata":{},"source":["## Defining Model"]},{"cell_type":"code","execution_count":35,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T09:43:26.478507Z","iopub.status.busy":"2023-03-11T09:43:26.477831Z","iopub.status.idle":"2023-03-11T09:43:26.510319Z","shell.execute_reply":"2023-03-11T09:43:26.509443Z","shell.execute_reply.started":"2023-03-11T09:43:26.478461Z"},"trusted":true},"outputs":[],"source":["import torch\n","from torch import nn, einsum\n","import torch.nn.functional as F\n","\n","from einops import rearrange\n","from einops.layers.torch import Rearrange\n","\n","# helper methods\n","\n","def group_dict_by_key(cond, d):\n"," return_val = [dict(), dict()]\n"," for key in d.keys():\n"," match = bool(cond(key))\n"," ind = int(not match)\n"," return_val[ind][key] = d[key]\n"," return (*return_val,)\n","\n","def group_by_key_prefix_and_remove_prefix(prefix, d):\n"," kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)\n"," kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))\n"," return kwargs_without_prefix, kwargs\n","\n","# classes\n","\n","class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1\n"," def __init__(self, dim, eps = 1e-5):\n"," super().__init__()\n"," self.eps = eps\n"," self.g = nn.Parameter(torch.ones(1, dim, 1, 1))\n"," self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))\n","\n"," def forward(self, x):\n"," var = torch.var(x, dim = 1, unbiased = False, keepdim = True)\n"," mean = torch.mean(x, dim = 1, keepdim = True)\n"," return (x - mean) / (var + self.eps).sqrt() * self.g + self.b\n","\n","class PreNorm(nn.Module):\n"," def __init__(self, dim, fn):\n"," super().__init__()\n"," self.norm = LayerNorm(dim)\n"," self.fn = fn\n"," def forward(self, x, **kwargs):\n"," x = self.norm(x)\n"," return self.fn(x, **kwargs)\n","\n","class FeedForward(nn.Module):\n"," def __init__(self, dim, mult = 4, dropout = 0.):\n"," super().__init__()\n"," self.net = nn.Sequential(\n"," nn.Conv2d(dim, dim * mult, 1),\n"," nn.GELU(),\n"," nn.Dropout(dropout),\n"," nn.Conv2d(dim * mult, dim, 1),\n"," nn.Dropout(dropout)\n"," )\n"," def forward(self, x):\n"," return self.net(x)\n","\n","class DepthWiseConv2d(nn.Module):\n"," def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):\n"," super().__init__()\n"," self.net = nn.Sequential(\n"," nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),\n"," nn.BatchNorm2d(dim_in),\n"," nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)\n"," )\n"," def forward(self, x):\n"," return self.net(x)\n","\n","class Attention(nn.Module):\n"," def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):\n"," super().__init__()\n"," inner_dim = dim_head * heads\n"," padding = proj_kernel // 2\n"," self.heads = heads\n"," self.scale = dim_head ** -0.5\n","\n"," self.attend = nn.Softmax(dim = -1)\n"," self.dropout = nn.Dropout(dropout)\n","\n"," self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)\n"," self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)\n","\n"," self.to_out = nn.Sequential(\n"," nn.Conv2d(inner_dim, dim, 1),\n"," nn.Dropout(dropout)\n"," )\n","\n"," def forward(self, x):\n"," shape = x.shape\n"," b, n, _, y, h = *shape, self.heads\n"," q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))\n"," q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))\n","\n"," dots = einsum('b i d, b j d -> b i j', q, k) * self.scale\n","\n"," attn = self.attend(dots)\n"," attn = self.dropout(attn)\n","\n"," out = einsum('b i j, b j d -> b i d', attn, v)\n"," out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)\n"," return self.to_out(out)\n","\n","class Transformer(nn.Module):\n"," def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.):\n"," super().__init__()\n"," self.layers = nn.ModuleList([])\n"," for _ in range(depth):\n"," self.layers.append(nn.ModuleList([\n"," PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),\n"," PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))\n"," ]))\n"," def forward(self, x):\n"," for attn, ff in self.layers:\n"," x = attn(x) + x\n"," x = ff(x) + x\n"," return x\n","\n","class CvT(nn.Module):\n"," def __init__(\n"," self,\n"," *,\n"," num_classes,\n"," s1_emb_dim = 64,\n"," s1_emb_kernel = 7,\n"," s1_emb_stride = 4,\n"," s1_proj_kernel = 3,\n"," s1_kv_proj_stride = 2,\n"," s1_heads = 1,\n"," s1_depth = 1,\n"," s1_mlp_mult = 4,\n"," s2_emb_dim = 192,\n"," s2_emb_kernel = 3,\n"," s2_emb_stride = 2,\n"," s2_proj_kernel = 3,\n"," s2_kv_proj_stride = 2,\n"," s2_heads = 3,\n"," s2_depth = 2,\n"," s2_mlp_mult = 4,\n"," s3_emb_dim = 384,\n"," s3_emb_kernel = 3,\n"," s3_emb_stride = 2,\n"," s3_proj_kernel = 3,\n"," s3_kv_proj_stride = 2,\n"," s3_heads = 6,\n"," s3_depth = 10,\n"," s3_mlp_mult = 4,\n"," dropout = 0.\n"," ):\n"," super().__init__()\n"," kwargs = dict(locals())\n","\n"," dim = 1\n"," layers = []\n","\n"," for prefix in ('s1', 's2', 's3'):\n"," config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)\n","\n"," layers.append(nn.Sequential(\n"," nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),\n"," LayerNorm(config['emb_dim']),\n"," Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)\n"," ))\n","\n"," dim = config['emb_dim']\n","\n"," self.layers = nn.Sequential(*layers)\n","\n"," self.to_logits = nn.Sequential(\n"," nn.AdaptiveAvgPool2d(1),\n"," Rearrange('... () () -> ...'),\n"," nn.Linear(dim, num_classes)\n"," )\n","\n"," def forward(self, x):\n"," latents = self.layers(x)\n"," return self.to_logits(latents)"]},{"cell_type":"markdown","metadata":{},"source":["## Hyper-parameters and Training"]},{"cell_type":"code","execution_count":36,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T09:43:26.515399Z","iopub.status.busy":"2023-03-11T09:43:26.514765Z","iopub.status.idle":"2023-03-11T09:43:26.529464Z","shell.execute_reply":"2023-03-11T09:43:26.528682Z","shell.execute_reply.started":"2023-03-11T09:43:26.515359Z"},"trusted":true},"outputs":[],"source":["def train(model, epochs, optimizer, criterion, scheduler, device, trainloader, testloader):\n","\n"," train_losses = []\n"," train_acc = []\n"," test_losses = []\n"," test_acc = []\n"," \n"," for i in range(epochs):\n"," running_loss = 0.0\n"," running_correct = 0\n"," total = 0\n"," best_acc = 0.0\n"," \n"," print(f\"Epoch: {i + 1}\")\n"," \n"," for images, targets in tqdm(trainloader, desc= \"Train\\t\"):\n"," \n"," images, targets = images.to(device), targets.to(device)\n"," optimizer.zero_grad()\n"," output = model(images)\n"," loss = criterion(output, targets)\n"," loss.backward()\n"," optimizer.step()\n","\n"," running_loss += loss.item()\n"," pred = torch.argmax(output, dim=1)\n"," \n"," running_correct += (pred == targets).sum().item()\n"," total += targets.size(0)\n","\n"," scheduler.step()\n"," \n"," train_losses.append(running_loss / len(trainloader))\n"," train_acc.append(running_correct / total)\n","\n"," running_val_loss = 0.0\n"," correct = 0\n"," total = 0\n"," \n"," with torch.no_grad():\n","\n"," for images, targets in tqdm(testloader, desc= \"Test\\t\"):\n"," images, targets = images.to(device), targets.to(device)\n","\n"," output = model(images)\n"," preds = torch.argmax(output, dim=1)\n","\n"," correct += (preds == targets).sum().item()\n"," running_val_loss += criterion(output, targets).item()\n"," total += targets.size(0)\n","\n"," acc = correct / total\n"," test_acc.append(acc)\n"," test_losses.append(running_val_loss / len(testloader))\n","\n"," if test_acc[-1] > best_acc:\n"," best_acc = test_acc[-1]\n"," torch.save(model.state_dict(), '../models/test-5-cvt-model.pth') \n","\n"," print(f\"Train Loss: {train_losses[-1]:.3f}, Train Acc: {train_acc[-1]:.3f}, Test Loss: {test_losses[-1]:.3f}, Test Acc: {test_acc[-1]:.3f}\\n\")\n","\n"," return train_losses, train_acc, test_losses, test_acc, best_acc"]},{"cell_type":"code","execution_count":37,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T09:43:26.531845Z","iopub.status.busy":"2023-03-11T09:43:26.531003Z","iopub.status.idle":"2023-03-11T09:43:26.780820Z","shell.execute_reply":"2023-03-11T09:43:26.779756Z","shell.execute_reply.started":"2023-03-11T09:43:26.531807Z"},"trusted":true},"outputs":[],"source":["model = CvT(num_classes= 2, dropout= 0.1).to(device)\n","\n","epochs = 30\n","lr = 0.0001\n","loss_function = torch.nn.CrossEntropyLoss()\n","optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n","scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.000001)"]},{"cell_type":"code","execution_count":38,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T09:43:26.782804Z","iopub.status.busy":"2023-03-11T09:43:26.782330Z","iopub.status.idle":"2023-03-11T10:40:25.471349Z","shell.execute_reply":"2023-03-11T10:40:25.470223Z","shell.execute_reply.started":"2023-03-11T09:43:26.782763Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["Epoch: 1\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.72it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.692, Train Acc: 0.623, Test Loss: 0.377, Test Acc: 0.782\n","\n","Epoch: 2\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.72it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.368, Train Acc: 0.806, Test Loss: 0.261, Test Acc: 0.866\n","\n","Epoch: 3\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.72it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.304, Train Acc: 0.843, Test Loss: 0.257, Test Acc: 0.880\n","\n","Epoch: 4\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.69it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.258, Train Acc: 0.877, Test Loss: 0.201, Test Acc: 0.905\n","\n","Epoch: 5\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.59it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.229, Train Acc: 0.893, Test Loss: 0.188, Test Acc: 0.915\n","\n","Epoch: 6\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.72it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.177, Train Acc: 0.920, Test Loss: 0.143, Test Acc: 0.942\n","\n","Epoch: 7\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.72it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.161, Train Acc: 0.929, Test Loss: 0.142, Test Acc: 0.939\n","\n","Epoch: 8\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.69it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.138, Train Acc: 0.939, Test Loss: 0.190, Test Acc: 0.907\n","\n","Epoch: 9\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.71it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.129, Train Acc: 0.945, Test Loss: 0.093, Test Acc: 0.962\n","\n","Epoch: 10\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.65it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.121, Train Acc: 0.950, Test Loss: 0.092, Test Acc: 0.966\n","\n","Epoch: 11\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.69it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.104, Train Acc: 0.957, Test Loss: 0.066, Test Acc: 0.978\n","\n","Epoch: 12\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.73it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.105, Train Acc: 0.955, Test Loss: 0.060, Test Acc: 0.974\n","\n","Epoch: 13\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.68it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.087, Train Acc: 0.965, Test Loss: 0.050, Test Acc: 0.983\n","\n","Epoch: 14\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.72it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.078, Train Acc: 0.969, Test Loss: 0.057, Test Acc: 0.979\n","\n","Epoch: 15\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.70it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.080, Train Acc: 0.968, Test Loss: 0.043, Test Acc: 0.986\n","\n","Epoch: 16\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.71it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.072, Train Acc: 0.971, Test Loss: 0.055, Test Acc: 0.978\n","\n","Epoch: 17\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.72it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.069, Train Acc: 0.972, Test Loss: 0.071, Test Acc: 0.967\n","\n","Epoch: 18\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.72it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.060, Train Acc: 0.973, Test Loss: 0.039, Test Acc: 0.987\n","\n","Epoch: 19\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.72it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.053, Train Acc: 0.977, Test Loss: 0.037, Test Acc: 0.987\n","\n","Epoch: 20\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.74it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.062, Train Acc: 0.976, Test Loss: 0.049, Test Acc: 0.983\n","\n","Epoch: 21\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.70it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.048, Train Acc: 0.982, Test Loss: 0.037, Test Acc: 0.983\n","\n","Epoch: 22\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.43it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.047, Train Acc: 0.981, Test Loss: 0.022, Test Acc: 0.989\n","\n","Epoch: 23\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.74it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.038, Train Acc: 0.985, Test Loss: 0.023, Test Acc: 0.992\n","\n","Epoch: 24\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.72it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.038, Train Acc: 0.985, Test Loss: 0.031, Test Acc: 0.989\n","\n","Epoch: 25\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.72it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.037, Train Acc: 0.985, Test Loss: 0.022, Test Acc: 0.992\n","\n","Epoch: 26\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.73it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.037, Train Acc: 0.985, Test Loss: 0.028, Test Acc: 0.989\n","\n","Epoch: 27\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.66it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.034, Train Acc: 0.987, Test Loss: 0.020, Test Acc: 0.991\n","\n","Epoch: 28\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.73it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.028, Train Acc: 0.991, Test Loss: 0.023, Test Acc: 0.990\n","\n","Epoch: 29\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.72it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.031, Train Acc: 0.987, Test Loss: 0.018, Test Acc: 0.995\n","\n","Epoch: 30\n"]},{"name":"stderr","output_type":"stream","text":["Train\t: 100%|██████████| 141/141 [01:49<00:00, 1.29it/s]\n","Test\t: 100%|██████████| 16/16 [00:04<00:00, 3.69it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Train Loss: 0.027, Train Acc: 0.991, Test Loss: 0.021, Test Acc: 0.990\n","\n"]}],"source":["train_losses, train_acc, test_losses, test_acc, best_acc = train(model, epochs, optimizer, loss_function, scheduler, device, train_loader, test_loader)"]},{"cell_type":"code","execution_count":39,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T10:40:25.474303Z","iopub.status.busy":"2023-03-11T10:40:25.473880Z","iopub.status.idle":"2023-03-11T10:40:25.844491Z","shell.execute_reply":"2023-03-11T10:40:25.843460Z","shell.execute_reply.started":"2023-03-11T10:40:25.474259Z"},"trusted":true},"outputs":[{"data":{"text/plain":[""]},"execution_count":39,"metadata":{},"output_type":"execute_result"},{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["plt.figure(figsize=(10, 4))\n","plt.subplot(1, 2, 1)\n","plt.plot(train_losses, label='Training loss')\n","plt.plot(test_losses, label='Testing loss')\n","plt.legend(frameon=False)\n","\n","plt.subplot(1, 2, 2)\n","plt.plot(train_acc, label='Training accuracy')\n","plt.plot(test_acc, label='Testing accuracy')\n","plt.legend(frameon=False)"]},{"cell_type":"markdown","metadata":{},"source":["## Testing"]},{"cell_type":"markdown","metadata":{},"source":["The best model is loaded for testing and the results are plotted."]},{"cell_type":"code","execution_count":40,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T10:40:25.846638Z","iopub.status.busy":"2023-03-11T10:40:25.845810Z","iopub.status.idle":"2023-03-11T10:40:26.173559Z","shell.execute_reply":"2023-03-11T10:40:26.172371Z","shell.execute_reply.started":"2023-03-11T10:40:25.846592Z"},"trusted":true},"outputs":[{"data":{"text/plain":[""]},"execution_count":40,"metadata":{},"output_type":"execute_result"}],"source":["model = CvT(num_classes= 2, dropout= 0.1).to(device)\n","model.load_state_dict(torch.load('../models/test-5-cvt-model.pth'))"]},{"cell_type":"code","execution_count":41,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T10:40:26.175694Z","iopub.status.busy":"2023-03-11T10:40:26.175083Z","iopub.status.idle":"2023-03-11T10:40:30.491683Z","shell.execute_reply":"2023-03-11T10:40:30.489847Z","shell.execute_reply.started":"2023-03-11T10:40:26.175654Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["100%|██████████| 16/16 [00:04<00:00, 3.72it/s]\n"]}],"source":["y_true = []\n","y_pred_class = []\n","y_pred_prob = []\n","\n","with torch.no_grad():\n"," for images, targets in tqdm(test_loader):\n"," images, targets = images.to(device), targets.to(device)\n"," output = model(images)\n"," preds = torch.argmax(output, dim=1)\n"," y_true.extend(targets.cpu().numpy())\n"," y_pred_class.extend(preds.cpu().numpy())\n"," y_pred_prob.extend(output.cpu().numpy())"]},{"cell_type":"code","execution_count":42,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T10:40:30.494550Z","iopub.status.busy":"2023-03-11T10:40:30.493892Z","iopub.status.idle":"2023-03-11T10:40:30.510742Z","shell.execute_reply":"2023-03-11T10:40:30.509244Z","shell.execute_reply.started":"2023-03-11T10:40:30.494480Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":[" precision recall f1-score support\n","\n"," no_sub 0.99 1.00 0.99 500\n"," sub 1.00 0.99 0.99 500\n","\n"," accuracy 0.99 1000\n"," macro avg 0.99 0.99 0.99 1000\n","weighted avg 0.99 0.99 0.99 1000\n","\n"]}],"source":["from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve, classification_report, auc\n","\n","print(classification_report(y_true, y_pred_class, target_names=label_map.keys()))"]},{"cell_type":"code","execution_count":43,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T10:40:30.514018Z","iopub.status.busy":"2023-03-11T10:40:30.513193Z","iopub.status.idle":"2023-03-11T10:40:30.753110Z","shell.execute_reply":"2023-03-11T10:40:30.752187Z","shell.execute_reply.started":"2023-03-11T10:40:30.513974Z"},"trusted":true},"outputs":[{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["conf_matrix = confusion_matrix(y_true, y_pred_class)\n","plt.figure(figsize=(7, 5))\n","sns.heatmap(conf_matrix, annot=True, fmt='d', xticklabels=label_map.keys(), yticklabels=label_map.keys())\n","plt.ylabel('Actual')\n","plt.xlabel('Predicted')\n","plt.show()"]},{"cell_type":"code","execution_count":44,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T10:40:30.754937Z","iopub.status.busy":"2023-03-11T10:40:30.754592Z","iopub.status.idle":"2023-03-11T10:40:30.765508Z","shell.execute_reply":"2023-03-11T10:40:30.764597Z","shell.execute_reply.started":"2023-03-11T10:40:30.754900Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["ROC-AUC Score: 0.99992\n"]}],"source":["print(f\"ROC-AUC Score: {roc_auc_score(y_true, np.array(y_pred_prob)[:,1], multi_class='ovr')}\")"]},{"cell_type":"code","execution_count":45,"metadata":{"execution":{"iopub.execute_input":"2023-03-11T10:40:30.767380Z","iopub.status.busy":"2023-03-11T10:40:30.766873Z","iopub.status.idle":"2023-03-11T10:40:31.009843Z","shell.execute_reply":"2023-03-11T10:40:31.007745Z","shell.execute_reply.started":"2023-03-11T10:40:30.767343Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["no_sub ROC-AUC: 0.999888\n","sub ROC-AUC: 0.99992\n"]},{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["y_true_dummies = pd.get_dummies(y_true).values\n","\n","auc_scores = {}\n","for i in range(2):\n"," fpr_i, tpr_i, thresholds_i = roc_curve(y_true_dummies[:, i], np.array(y_pred_prob)[:, i])\n"," auc_score = auc(fpr_i, tpr_i)\n"," print(f\"{(val_label_map[i]).ljust(8)} ROC-AUC: {auc_score}\")\n"," auc_scores[val_label_map[i]] = auc_score\n"," plt.plot(fpr_i, tpr_i, label=f\"{(val_label_map[i]).ljust(6)} AUC: {auc_score:.4f}\")\n","\n","plt.plot([0, 1], [0, 1], 'k--')\n","plt.xlabel('False Positive Rate')\n","plt.ylabel('True Positive Rate')\n","plt.title('ROC Curve')\n","plt.legend()\n","plt.show()"]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.6"},"vscode":{"interpreter":{"hash":"916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"}}},"nbformat":4,"nbformat_minor":4}