{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import math\n", "import torch \n", "import torch.nn as nn\n", "\n", "class InputEmbeddingsLayer(nn.Module):\n", " def __init__(self, d_model: int, vocab_size: int) -> None:\n", " super().__init__()\n", " self.d_model = d_model\n", " self.vocab_size = vocab_size\n", " self.embedding = nn.Embedding(vocab_size, d_model)\n", " def forward(self, x):\n", " return self.embedding(x) * math.sqrt(self.d_model)\n", "\n", "class PositionalEncodingLayer(nn.Module):\n", " def __init__(self, d_model: int, sequence_length: int, dropout: float) -> None:\n", " super().__init__()\n", " self.d_model = d_model\n", " self.sequence_length = sequence_length\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " PE = torch.zeros(sequence_length, d_model)\n", " Position = torch.arange(0, sequence_length, dtype=torch.float).unsqueeze(1)\n", " deviation_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n", " \n", " PE[:, 0::2] = torch.sin(Position * deviation_term)\n", " PE[:, 1::2] = torch.cos(Position * deviation_term)\n", " PE = PE.unsqueeze(0)\n", " self.register_buffer(\"PE\", PE)\n", " def forward(self, x):\n", " x = x + (self.PE[:, :x.shape[1], :]).requires_grad_(False)\n", " return self.dropout(x)\n", "\n", "class NormalizationLayer(nn.Module):\n", " def __init__(self, Epslone: float = 10**-6) -> None:\n", " super().__init__()\n", " self.Epslone = Epslone\n", " self.Alpha = nn.Parameter(torch.ones(1))\n", " self.Bias = nn.Parameter(torch.ones(1))\n", " def forward(self, x):\n", " mean = x.mean(dim = -1, keepdim = True)\n", " std = x.std(dim = -1, keepdim = True)\n", " return self.Alpha * (x - mean) / (std + self.Epslone) + self.Bias\n", "\n", "class FeedForwardBlock(nn.Module):\n", " def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:\n", " super().__init__()\n", " self.Linear_1 = nn.Linear(d_model, d_ff)\n", " self.dropout = nn.Dropout(dropout)\n", " self.Linear_2 = nn.Linear(d_ff, d_model)\n", " def forward(self, x):\n", " return self.Linear_2(self.dropout(torch.relu(self.Linear_1(x))))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MultiHeadAttentionBlock(nn.Module):\n", " def __init__(self, d_model: int, heads: int, dropout: float) -> None:\n", " super().__init__()\n", " self.d_model = d_model\n", " self.heads = heads \n", " assert d_model % heads == 0, \"d_model is not divisable by heads\"\n", "\n", " self.d_k = d_model // heads\n", "\n", " self.W_Q = nn.Linear(d_model, d_model)\n", " self.W_K = nn.Linear(d_model, d_model)\n", " self.W_V = nn.Linear(d_model, d_model)\n", "\n", " self.W_O = nn.Linear(d_model, d_model)\n", " self.dropout = nn.Dropout(dropout)\n", " \n", " @staticmethod\n", " def Attention(Query, Key, Value, mask, dropout: nn.Module):\n", " d_k = Query.shape[-1]\n", "\n", " self_attention_score = (Query @ Key.transpose(-2,-1)) / math.sqrt(d_k)\n", " if mask is not None:\n", " self_attention_score.masked_fill_(mask == 0, -1e9)\n", " self_attention_score = self_attention_score.softmax(dim = -1)\n", "\n", " if dropout is not None:\n", " self_attention_score = dropout(self_attention_score)\n", " return self_attention_score @ Value\n", " def forward(self, query, key, value, mask):\n", " Query = self.W_Q(query)\n", " Key = self.W_K(key)\n", " Value = self.W_V(value)\n", "\n", " Query = Query.view(Query.shape[0], Query.shape[1], self.heads, self.d_k).transpose(1,2)\n", " Key = Key.view(Key.shape[0], Key.shape[1], self.heads, self.d_k).transpose(1,2)\n", " Value = Value.view(Value.shape[0], Value.shape[1], self.heads, self.d_k).transpose(1,2)\n", "\n", " x, self.self_attention_score = MultiHeadAttentionBlock.Attention(Query, Key, Value, mask, self.dropout)\n", " x = x.transpose(1,2).contiguous().view(x.shape[0], -1, self.heads * self.d_k)\n", " return self.W_O(x)\n", "\n", "class ResidualConnection(nn.Module):\n", " def __init__(self, dropout: float) -> None:\n", " super().__init__()\n", " self.dropout = nn.Dropout(dropout)\n", " self.normalization = NormalizationLayer()\n", " def forward(self, x, subLayer):\n", " return x + self.dropout(subLayer(self.normalization(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Building the decoder block \n", "class DecoderBlock(nn.Module):\n", " def __init__(self, decoder_self_attention_block: MultiHeadAttentionBlock, decoder_cross_attention_block: MultiHeadAttentionBlock, decoder_feed_forward_block: FeedForwardBlock, dropout: float) -> None:\n", " super().__init__()\n", " self.decoder_self_attention_block = decoder_self_attention_block\n", " self.decoder_cross_attention_block = decoder_cross_attention_block\n", " self.decoder_feed_forward_block = decoder_feed_forward_block\n", " self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])\n", " def forward(self, x, Encoder_output, source_mask, target_mask):\n", " x = self.residual_connection[0](x, lambda x: self.decoder_self_attention_block(x, x, x, source_mask))\n", " x = self.residual_connection[1](x, lambda x: self.decoder_cross_attention_block(x, Encoder_output, Encoder_output, target_mask))\n", " x = self.residual_connection[2](x, self.decoder_feed_forward_block)\n", " return x\n", "\n", "class Decoder(nn.Module):\n", " def __init__(self, Layers: nn.ModuleList) -> None:\n", " super().__init__()\n", " self.Layers = Layers\n", " self.normalization = NormalizationLayer()\n", " def forward(self, x, Encoder_output, source_mask, target_mask):\n", " for layer in self.Layers:\n", " x = layer(x, Encoder_output, source_mask, target_mask)\n", " return self.normalization(x)\n", "\n", "class LinearLayer(nn.Module):\n", " def __init__(self, d_model: int, vocab_size: int) -> None:\n", " super().__init__()\n", " self.Linear = nn.Linear(d_model, vocab_size)\n", " def forward(self, x):\n", " return self.Linear(x)" ] } ], "metadata": { "language_info": { "name": "python" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }