{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "yvBTTFubgJ0i" }, "source": [ "# From attention to transformers\n" ] }, { "cell_type": "markdown", "metadata": { "id": "LEyAG_Chgq4B" }, "source": [ "In this tutorial, our focus is on delving into the intricacies of the attention mechanism. If you're keen on it, you'll be able to create a self-attention layer and construct your own transformer model from skatch.\n", "\n", "In many well-established libraries like **torch**, the code tends to be somewhat challenging to decipher due to efficiency optimizations and the inclusion of various conditional paths using **if** and **else**. Here, we will craft **a more intelligible yet functionally equivalent model** and verify its performance against the official implementation.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "COoPVuuDrCel" }, "source": [ "### General note for GPU training (in colab)\n", "\n", "* First, please use the GPU runtime. If so the `!nvidia-smi` will return no error.\n", " 1. Click on \"Runtime\" in the top menu bar.\n", " 2. Select \"Change runtime type\" from the drop-down menu.\n", " 3. In the \"Runtime type\" section, select \"GPU\" as the hardware accelerator.\n", " 4. Click \"Save\" to apply the changes.\n", "\n", "\n", "* What should I do with **Cuda out of memory error.**? (this is THE most common error in DL)\n", "![](https://miro.medium.com/v2/resize:fit:828/format:webp/1*enMsxkgJ1eb9XvtWju5V8Q.png)\n", " 1. In colab notebook, **unfortunately, you need to restart the kernel after OOM happened**. Or it will keep happening no matter what.\n", " 2. Change the model to save memory, usually includes, decrease batch size, decrease the number of layers, decrease the max sequence length, decrease the hidden / embedding dimension\n", " 3. If you know mixed precision training, you can switch to low precision `fp16` numbers for weights and inputs.\n", "\n", "* What should I do for the **Device siee assert triggered** error\n", " > RuntimeError: CUDA error: device-side assert triggered\n", "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\n", "For debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n", " \n", " * Usually it's because the embedding layer receive an index (token id or position id) not stored in it.\n", " * Could be sth. else, which will be harder to debug..." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 284, "status": "ok", "timestamp": 1706009824050, "user": { "displayName": "hu jian", "userId": "11648317062194590196" }, "user_tz": 0 }, "id": "wymzVx-nrWCx", "outputId": "d5d63f04-03f1-4177-faff-7426866436a6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tue Jan 23 11:37:03 2024 \n", "+---------------------------------------------------------------------------------------+\n", "| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |\n", "|-----------------------------------------+----------------------+----------------------+\n", "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|=========================================+======================+======================|\n", "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", "| N/A 63C P8 11W / 70W | 0MiB / 15360MiB | 0% Default |\n", "| | | N/A |\n", "+-----------------------------------------+----------------------+----------------------+\n", " \n", "+---------------------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=======================================================================================|\n", "| No running processes found |\n", "+---------------------------------------------------------------------------------------+\n" ] } ], "source": [ "# import locale\n", "# locale.getpreferredencoding = lambda: \"UTF-8\" # to fix a potential locale bug\n", "!nvidia-smi" ] }, { "cell_type": "markdown", "metadata": { "id": "1FvTRFr_M-Rx" }, "source": [ "### Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VCs18-JIMx2A" }, "outputs": [], "source": [ "!pip install torch\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import math\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x1754xrzOU4B" }, "outputs": [], "source": [ "seed = 42\n", "np.random.seed(seed)\n", "torch.manual_seed(seed)\n", "torch.cuda.manual_seed(seed)" ] }, { "cell_type": "markdown", "metadata": { "id": "YyAokROySZWa" }, "source": [ "## Self-Attention Mechanism: Single Head" ] }, { "cell_type": "markdown", "metadata": { "id": "2ioS3Nrun62q" }, "source": [ "![](https://raw.githubusercontent.com/Animadversio/TransformerFromScratch/main/media/AttentionSchematics_white-01.png)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EOsGVI0tNH9m" }, "outputs": [], "source": [ "embdim = 256\n", "headdim = 64\n", "tokens = torch.randn(1, 5, embdim) # batch, tokens, embedding\n", "Wq = torch.randn(embdim, headdim) / math.sqrt(embdim)\n", "Wk = torch.randn(embdim, headdim) / math.sqrt(embdim)\n", "Wv = torch.randn(embdim, embdim) / math.sqrt(embdim)" ] }, { "cell_type": "markdown", "metadata": { "id": "Q3GOB0ld4TNN" }, "source": [ "Fill in the score matrix computation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sGz9Zeh-Orkw" }, "outputs": [], "source": [ "qis = torch.einsum(\"BSE,EH->BSH\", tokens, Wq) # batch x seqlen x headdim\n", "kis = torch.einsum(\"BTE,EH->BTH\", tokens, Wk) # batch x seqlen x headdim\n", "vis = torch.einsum(\"BTE,EF->BTF\", tokens, Wv) # batch x seqlen x embeddim\n", "#### ------ Add your code here:compute query-key similarities. ------ ####\n", "scoremat = # output: batch x seqlen (Query) x seqlen (Key)\n", "#### ------ End ------ ####\n", "attmat = F.softmax(scoremat / math.sqrt(headdim), dim=2)" ] }, { "cell_type": "markdown", "metadata": { "id": "l4csHSFhUKCy" }, "source": [ "Some checks to make sure the score correspond to the product of the right pair." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "myg-C-doPKF6" }, "outputs": [], "source": [ "assert(torch.isclose(scoremat[0,1,2], qis[0,1,:]@kis[0,2,:]))\n", "assert(torch.isclose(scoremat[0,3,4], qis[0,3,:]@kis[0,4,:]))\n", "assert(torch.isclose(scoremat[0,2,2], qis[0,2,:]@kis[0,2,:]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "usrpNpDSQZz5" }, "outputs": [], "source": [ "zis = torch.einsum(\"BST,BTF->BSF\", attmat, vis)" ] }, { "cell_type": "markdown", "metadata": { "id": "GxiBZ97ZQt1D" }, "source": [ "In pytorch, these operations are packed int the function `F.scaled_dot_product_attention`. So let's test our implementation of the single head attention against it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K4FvFf8vPg_0" }, "outputs": [], "source": [ "attn_torch = F.scaled_dot_product_attention(qis,kis,vis)\n", "assert(torch.allclose(attn_torch, zis, atol=1E-6,rtol=1E-6))" ] }, { "cell_type": "markdown", "metadata": { "id": "cdRNYM5FNIOC" }, "source": [ "## Multi-head attention" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ripFxfj-hjuB" }, "outputs": [], "source": [ "embdim = 768\n", "headcnt = 12\n", "headdim = embdim // headcnt\n", "assert headdim * headcnt == embdim\n", "tokens = torch.randn(1, 5, embdim) # batch, tokens, embedding\n", "Wq = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim\n", "Wk = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim\n", "Wv = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qN4P4TBZiRHx" }, "outputs": [], "source": [ "batch, token_num, _ = tokens.shape\n", "qis = torch.einsum(\"BSE,EH->BSH\", tokens, Wq)\n", "kis = torch.einsum(\"BTE,EH->BTH\", tokens, Wk)\n", "vis = torch.einsum(\"BTE,EH->BTH\", tokens, Wv)\n", "# split the single hidden dim into the heads\n", "qis_mh = qis.view(batch, token_num, headcnt, headdim)\n", "kis_mh = kis.view(batch, token_num, headcnt, headdim)\n", "vis_mh = vis.view(batch, token_num, headcnt, headdim)" ] }, { "cell_type": "markdown", "metadata": { "id": "HXgazGbY4vf7" }, "source": [ "Now your challenge is to compute multihead attention using `einsum`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vVGExSq3i-rh" }, "outputs": [], "source": [ "#### ------ Add your code here: compute query-key similarities. ------ ####\n", "scoremat_mh = # Output: batch x headcnt x seqlen (query) x seqlen (key)\n", "#### ------ End ------ ####\n", "attmat_mh = F.softmax(scoremat_mh / math.sqrt(headdim), dim=-1)\n", "zis_mh = torch.einsum(\"BCST,BTCH->BSCH\", attmat_mh, vis_mh) # batch x seqlen (query) x headcnt x headdim\n", "zis = zis_mh.reshape(batch, token_num, headcnt * headdim)" ] }, { "cell_type": "markdown", "metadata": { "id": "kLJo-3CL3BWQ" }, "source": [ "Let's validate the tensor multiplication is correct" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5yCZ0BI6zLRH" }, "outputs": [], "source": [ "# raw attention score of the 1st attention head\n", "assert (torch.allclose(scoremat_mh[0, 1], qis_mh[0,:,1] @ kis_mh[0,:,1,:].T))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oy5EBnnetHoU" }, "outputs": [], "source": [ "print(tokens.shape)\n", "print(qis_mh.shape)\n", "print(kis_mh.shape)\n", "print(vis_mh.shape)\n", "print(attmat_mh.shape)\n", "print(zis_mh.shape)\n", "print(zis.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "gr5UOaYo1Rtf" }, "source": [ "In `torch` this operation is packed in `nn.MultiheadAttention`, including the input projection, attention and out projection. So, note the input the the `mha.forward` function are the *token_embeddings* not the Q,K,Vs as we put it in `F.scaled_dot_product_attention`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "apQhwC6nU5Uy" }, "outputs": [], "source": [ "mha = nn.MultiheadAttention(embdim, headcnt, batch_first=True,)\n", "print(mha.in_proj_weight.shape) # 3 * embdim x embdim\n", "mha.in_proj_weight.data = torch.cat([Wq, Wk, Wv], dim=1).T" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YQcH-49V0Pyw" }, "outputs": [], "source": [ "attn_out, attn_weights = mha(tokens, tokens, tokens, average_attn_weights=False,)\n", "assert torch.allclose(attmat_mh, attn_weights, atol=1e-6, rtol=1e-6)" ] }, { "cell_type": "markdown", "metadata": { "id": "ukm63wFj0WC3" }, "source": [ "In `nn.MultiheadAttention` , there is a output projection `out_proj`, projecting the values. It is a linear layer with bias. We can validate that going through this projection our outputs `zis` is the same as the output of `mha`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ci12a8np0VQH" }, "outputs": [], "source": [ "print(mha.out_proj)\n", "assert torch.allclose(attn_out, mha.out_proj(zis), atol=1e-6, rtol=1e-6)" ] }, { "cell_type": "markdown", "metadata": { "id": "Da0lDydB3zoP" }, "source": [ "### Causal attention mask\n", "\n", "For models such as GPT, each token can only attend to tokens before it, thus the attention score needs to be modified before entering softmax.\n", "\n", "The common way of masking is to add a large negative number to the locations that you'd not want the model to attend to." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8s_xVQiN4TNf" }, "outputs": [], "source": [ "attn_mask = torch.ones(token_num,token_num,)\n", "attn_mask = -1E4 * torch.triu(attn_mask,1)\n", "attn_mask" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IXbr6nVQ4L2e" }, "outputs": [], "source": [ "scoremat_mh_msk = torch.einsum(\"BSCH,BTCH->BCST\", qis_mh, kis_mh) # batch x headcnt x seqlen (query) x seqlen (key)\n", "scoremat_mh_msk += attn_mask # add the attn mask to the scores before SoftMax normalization\n", "attmat_mh_msk = F.softmax(scoremat_mh_msk / math.sqrt(headdim), dim=-1)\n", "zis_mh_msk = torch.einsum(\"BCST,BTCH->BSCH\", attmat_mh_msk, vis_mh) # batch x seqlen (query) x headcnt x headdim\n", "zis_msk = zis_mh_msk.reshape(batch, token_num, headcnt * headdim)" ] }, { "cell_type": "markdown", "metadata": { "id": "Es5ABKzQ5phg" }, "source": [ "**Note** `is_causal` parameter should work and create a causal mask automatically. But in a recent pytorch bug, it doesn't work. So beware~\n", "https://github.com/pytorch/pytorch/issues/99282" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ySFCHtmE46QA" }, "outputs": [], "source": [ "attn_out_causal, attn_weights_causal = mha(tokens, tokens, tokens, average_attn_weights=False, attn_mask=attn_mask)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Fcmb5yX16buH" }, "outputs": [], "source": [ "assert torch.allclose(attn_weights_causal, attmat_mh_msk, atol=1e-6, rtol=1e-6)\n", "assert torch.allclose(attn_out_causal, mha.out_proj(zis_msk), atol=1e-6, rtol=1e-6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kHUZgGm74_qu" }, "outputs": [], "source": [ "plt.figure()\n", "for head in range(headcnt):\n", " plt.subplot(3, 4, head + 1)\n", " plt.imshow(attn_weights_causal[0, head].detach().numpy())\n", " plt.title(f\"head {head}\")\n", " plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "3JuaP78yWMva" }, "source": [ "## Transformer Block" ] }, { "cell_type": "markdown", "metadata": { "id": "BqscmIG11NDn" }, "source": [ "Having gaining some intuition about attention layer, let's build it into a transformer. An vanilla transformer block usually looks like this. Note there are slight difference between the transformer blocks in GPT2, BERT and other models, but they generally has the following components\n", "\n", "* Transformer Block\n", " * Layernorm\n", " * Skip connections\n", " * Multi-head attention\n", " * MLP, Feedforward net\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zZBi1l6-WMTy" }, "outputs": [], "source": [ "class TransformerBlock_simple(nn.Module):\n", "\n", " def __init__(self, embdim, headcnt, *args, dropout=0.0, **kwargs) -> None:\n", " super().__init__(*args, **kwargs)\n", " self.ln1 = nn.LayerNorm(embdim)\n", " self.ln2 = nn.LayerNorm(embdim)\n", " self.attn = nn.MultiheadAttention(embdim, headcnt, batch_first=True,)\n", " self.ffn = nn.Sequential(\n", " nn.Linear(embdim, 4 * embdim),\n", " nn.GELU(),\n", " nn.Linear(4 * embdim, embdim),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x, is_causal=True):\n", " batch, token_num, hidden_dim = x.shape\n", " if is_causal:\n", " attn_mask = torch.ones(token_num, token_num,)\n", " attn_mask = -1E4 * torch.triu(attn_mask,1)\n", " else:\n", " attn_mask = None\n", "\n", " residue = x\n", " x = self.ln1(x)\n", " #### ------ Add your code here: multihead attention ------ ####\n", " attn_output, attn_weights = # first output is the output latent states\n", " #### ------ End ------ ####\n", " x = residue + attn_output\n", "\n", " residue = x\n", " x = self.ln2(x)\n", " ffn_output = self.ffn(x)\n", " output = residue + ffn_output\n", " return output" ] }, { "cell_type": "markdown", "metadata": { "id": "xbR05_AZUl78" }, "source": [ "Compare the implmentation with the schematics and see if it makes more sense!\n", "\n", "\n", "*Attention Block*\n", "\n", "\n", "![BERT (Transformer encoder)](https://iq.opengenus.org/content/images/2020/06/encoder-1.png)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "LOPq1yyzrf8t" }, "source": [ "# Image Classification" ] }, { "cell_type": "markdown", "metadata": { "id": "sVc7yB5Dslys" }, "source": [ "Now we employ Transformer structure to conduct image classification." ] }, { "cell_type": "markdown", "metadata": { "id": "xeRlyRpM7Dg9" }, "source": [ "### Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cWosexjnIu29" }, "outputs": [], "source": [ "!pip install transformers\n", "!pip install torchvision\n", "\n", "## Import transformers\n", "from transformers import get_linear_schedule_with_warmup\n", "from transformers import BertForSequenceClassification\n", "from transformers import BertModel, BertTokenizer, BertConfig\n", "\n", "import os\n", "from os.path import join\n", "from tqdm.notebook import tqdm, trange\n", "import math\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.optim import AdamW, Adam\n", "from torch.utils.data import Dataset, DataLoader\n", "from torchvision.utils import make_grid, save_image\n", "import matplotlib.pyplot as plt\n", "from torchvision.datasets import MNIST, CIFAR10\n", "from torchvision import datasets, transforms\n" ] }, { "cell_type": "markdown", "metadata": { "id": "BJWHJvqHKSvj" }, "source": [ "### Preparing Image Dataset\n", "Load the dataset, note, the augmentations are necessary. If no augmentation, Transformer will overfit very soon." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EhhFrgSeI647" }, "outputs": [], "source": [ "!mkdir data\n", "dataset = CIFAR10(root='./data/', train=True, download=True, transform=\n", "transforms.Compose([\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomCrop(32, padding=4),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", "]))\n", "# augmentations are super important for CNN trainings, or it will overfit very fast without achieving good generalization accuracy\n", "val_dataset = CIFAR10(root='./data/', train=False, download=True, transform=transforms.Compose(\n", " [transforms.ToTensor(),\n", " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]))\n", "#%%" ] }, { "cell_type": "markdown", "metadata": { "id": "3mz8eJD4JJWQ" }, "source": [ "Citing https://openreview.net/pdf?id=SCN8UaetXx,\n", "\n", "> \"Visual Transformers. Despite some previous work in which attention is used inside the convolutional layers of a CNN [57, 26], the first fully-transformer architectures for vision are iGPT [8] and ViT [17]. The former is trained using a \"masked-pixel\" self-supervised approach, similar in spirit to the common masked-word task used, for instance, in BERT [15] and in GPT [45] (see below). On the other hand, ViT is trained in a supervised way, using a special \"class token\" and a classification head attached to the final embedding of this token. Both methods are computationally expensive and, despite their very good results when trained on huge datasets, they underperform ResNet architectures when trained from scratch using only ImageNet-1K [17, 8]. VideoBERT [51] is conceptually similar to iGPT, but, rather than using pixels as tokens, each frame of a video is holistically represented by a feature vector, which is quantized using an off-the-shelf pretrained video classification model. DeiT [53] trains ViT using distillation information provided by a pretrained CNN.\"" ] }, { "cell_type": "markdown", "metadata": { "id": "QYITiFu1KsCy" }, "source": [ "### Transformer model for images" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GTq_7HqyKREb" }, "outputs": [], "source": [ "config = BertConfig(hidden_size=256, intermediate_size=1024, num_hidden_layers=12,\n", " num_attention_heads=8, max_position_embeddings=256,\n", " vocab_size=100, bos_token_id=101, eos_token_id=102,\n", " cls_token_id=103, )\n", "model = BertModel(config).cuda()\n", "patch_embed = nn.Conv2d(3, config.hidden_size, kernel_size=4, stride=4).cuda()\n", "CLS_token = nn.Parameter(torch.randn(1, 1, config.hidden_size, device=\"cuda\") / math.sqrt(config.hidden_size))\n", "readout = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),\n", " nn.GELU(),\n", " nn.Linear(config.hidden_size, 10)\n", " ).cuda()\n", "for module in [patch_embed, readout, model, CLS_token]:\n", " module.cuda()\n", "\n", "optimizer = AdamW([*model.parameters(),\n", " *patch_embed.parameters(),\n", " *readout.parameters(),\n", " CLS_token], lr=5e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "n5AccS_YKprw" }, "outputs": [], "source": [ "batch_size = 192 # 96\n", "train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)\n", "model.train()\n", "loss_list = []\n", "acc_list = []\n", "correct_cnt = 0\n", "total_loss = 0\n", "for epoch in trange(10, leave=False):\n", " pbar = tqdm(train_loader, leave=False)\n", " for i, (imgs, labels) in enumerate(pbar):\n", " patch_embs = patch_embed(imgs.cuda())\n", " #### ------ Add your code here: replace the None with the correct order of the embedding dimension. ------ ####\n", " patch_embs = patch_embs.flatten(2).permute(None, None, None) # hint: (batch_size, HW, hidden)\n", " #### ------ End ------ ####\n", " # print(patch_embs.shape)\n", " input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1)\n", " # print(input_embs.shape)\n", " output = model(inputs_embeds=input_embs)\n", " logit = readout(output.last_hidden_state[:, 0, :])\n", " loss = F.cross_entropy(logit, labels.cuda())\n", " # print(loss)\n", " loss.backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", " pbar.set_description(f\"loss: {loss.item():.4f}\")\n", " total_loss += loss.item() * imgs.shape[0]\n", " correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item()\n", "\n", " loss_list.append(round(total_loss / len(dataset), 4))\n", " acc_list.append(round(correct_cnt / len(dataset), 4))\n", " # test on validation set\n", " model.eval()\n", " correct_cnt = 0\n", " total_loss = 0\n", "\n", " for i, (imgs, labels) in enumerate(val_loader):\n", " patch_embs = patch_embed(imgs.cuda())\n", " #### ------ Add your code here: replace the None with the correct order of the embedding dimension. ------ ####\n", " patch_embs = patch_embs.flatten(2).permute(None, None, None) # hint: (batch_size, HW, hidden)\n", " #### ------ End ------ ####\n", " input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1)\n", " output = model(inputs_embeds=input_embs)\n", " logit = readout(output.last_hidden_state[:, 0, :])\n", " loss = F.cross_entropy(logit, labels.cuda())\n", " total_loss += loss.item() * imgs.shape[0]\n", " correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item()\n", "\n", " print(f\"val loss: {total_loss / len(val_dataset):.4f}, val acc: {correct_cnt / len(val_dataset):.4f}\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NJYWb7AB3TGw" }, "outputs": [], "source": [ "#### ------ Add your code here: plot the training loss curve to show its variation with the epoch. ------ ####\n", "# hints: use the data in list 'loss_list' and 'acc_list' to plot the curve via plt.plot()\n", "#### ------ End ------ ####" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ftBcE8hC-64T" }, "outputs": [], "source": [ "#### ------ Add your code here: plot the accuracy score curve to show its variation with the epoch. ------ ####\n", "# hints: use the data in list 'loss_list' and 'acc_list' to plot the curve via plt.plot()\n", "#### ------ End ------ ####" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xdbiPh7hK7IU" }, "outputs": [], "source": [ "torch.save(model.state_dict(),\"bert.pth\")\n", "!du -sh bert.pth" ] }, { "cell_type": "markdown", "metadata": { "id": "KrmYqzYCl6iE" }, "source": [ "**Reference:**\n", "Tutorial for Harvard Medical School ML from Scratch Series: Transformer from Scratch (https://github.com/Animadversio/TransformerFromScratch?tab=readme-ov-file)." ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [ { "file_id": "1EblhYF9_aBCKOLEjQiBDRqvDnAaecQ_d", "timestamp": 1700846708592 }, { "file_id": "1ZuhA6khlWm57WGZ8i38JH-gc5aJrvpvs", "timestamp": 1700834170533 } ] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }