{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from model import Transformer\n", "from config import get_config, get_weights_file_path\n", "from train import get_model, get_ds, greedy_decode\n", "import altair as alt\n", "import pandas as pd\n", "import numpy as np\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define the device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(\"Using device:\", device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config = get_config()\n", "train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)\n", "model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)\n", "\n", "# Load the pretrained weights\n", "model_filename = get_weights_file_path(config, f\"29\")\n", "state = torch.load(model_filename)\n", "model.load_state_dict(state['model_state_dict'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def load_next_batch():\n", " # Load a sample batch from the validation set\n", " batch = next(iter(val_dataloader))\n", " encoder_input = batch[\"encoder_input\"].to(device)\n", " encoder_mask = batch[\"encoder_mask\"].to(device)\n", " decoder_input = batch[\"decoder_input\"].to(device)\n", " decoder_mask = batch[\"decoder_mask\"].to(device)\n", "\n", " encoder_input_tokens = [vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]\n", " decoder_input_tokens = [vocab_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]\n", "\n", " # check that the batch size is 1\n", " assert encoder_input.size(\n", " 0) == 1, \"Batch size must be 1 for validation\"\n", "\n", " model_out = greedy_decode(\n", " model, encoder_input, encoder_mask, vocab_src, vocab_tgt, config['seq_len'], device)\n", " \n", " return batch, encoder_input_tokens, decoder_input_tokens" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def mtx2df(m, max_row, max_col, row_tokens, col_tokens):\n", " return pd.DataFrame(\n", " [\n", " (\n", " r,\n", " c,\n", " float(m[r, c]),\n", " \"%.3d %s\" % (r, row_tokens[r] if len(row_tokens) > r else \"\"),\n", " \"%.3d %s\" % (c, col_tokens[c] if len(col_tokens) > c else \"\"),\n", " )\n", " for r in range(m.shape[0])\n", " for c in range(m.shape[1])\n", " if r < max_row and c < max_col\n", " ],\n", " columns=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n", " )\n", "\n", "def get_attn_map(attn_type: str, layer: int, head: int):\n", " if attn_type == \"encoder\":\n", " attn = model.encoder.layers[layer].self_attention_block.attention_scores\n", " elif attn_type == \"decoder\":\n", " attn = model.decoder.layers[layer].self_attention_block.attention_scores\n", " elif attn_type == \"encoder-decoder\":\n", " attn = model.decoder.layers[layer].cross_attention_block.attention_scores\n", " return attn[0, head].data\n", "\n", "def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):\n", " df = mtx2df(\n", " get_attn_map(attn_type, layer, head),\n", " max_sentence_len,\n", " max_sentence_len,\n", " row_tokens,\n", " col_tokens,\n", " )\n", " return (\n", " alt.Chart(data=df)\n", " .mark_rect()\n", " .encode(\n", " x=alt.X(\"col_token\", axis=alt.Axis(title=\"\")),\n", " y=alt.Y(\"row_token\", axis=alt.Axis(title=\"\")),\n", " color=\"value\",\n", " tooltip=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n", " )\n", " #.title(f\"Layer {layer} Head {head}\")\n", " .properties(height=400, width=400, title=f\"Layer {layer} Head {head}\")\n", " .interactive()\n", " )\n", "\n", "def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):\n", " charts = []\n", " for layer in layers:\n", " rowCharts = []\n", " for head in heads:\n", " rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))\n", " charts.append(alt.hconcat(*rowCharts))\n", " return alt.vconcat(*charts)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()\n", "print(f'Source: {batch[\"src_text\"][0]}')\n", "print(f'Target: {batch[\"tgt_text\"][0]}')\n", "sentence_len = encoder_input_tokens.index(\"[PAD]\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "layers = [0, 1, 2]\n", "heads = [0, 1, 2, 3, 4, 5, 6, 7]\n", "\n", "# Encoder Self-Attention\n", "get_all_attention_maps(\"encoder\", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Encoder Self-Attention\n", "get_all_attention_maps(\"decoder\", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Encoder Self-Attention\n", "get_all_attention_maps(\"encoder-decoder\", layers, heads, encoder_input_tokens, decoder_input_tokens, min(20, sentence_len))" ] } ], "metadata": { "kernelspec": { "display_name": "transformer", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }