import torch import torch.nn as nn import math class MultiHeadAttention(nn.Module): def __init__(self, d_model = 256, num_heads = 8): super().__init__() self.d_model = d_model self.num_heads = num_heads assert d_model % num_heads == 0, "Number of dimensions should be divisible by heads" self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model, bias=False) self.W_k = nn.Linear(d_model, d_model, bias=False) self.W_v = nn.Linear(d_model, d_model, bias=False) self.projection = nn.Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(0.1) def forward(self, x, attention_mask=None): batch_size, seq_length, d_model = x.shape Q = self.W_q(x) #(batch_size, seq_len, d_model) K = self.W_k(x) V = self.W_v(x) Q = Q.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # (batch_size, num_heads, seq_length, d_k) K = K.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) V = V.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) attention_scores = Q @ K.transpose(2, 3) if attention_mask is not None: mask = attention_mask.unsqueeze(1).unsqueeze(2) # (batch_dim, 1, 1, seq_length) mask = mask.to(attention_scores.device) # making mask to prevent model attending to PAD tokens attention_scores = attention_scores.masked_fill(mask == 0, float("-inf")) attention_weights = torch.softmax(attention_scores / math.sqrt(self.d_k), dim=-1) attention_weights = self.dropout(attention_weights) final_weights = attention_weights @ V # (batch_size, num_heads, seq_length, d_k) final_weights = final_weights.transpose(1,2).contiguous().view(batch_size, seq_length, d_model) out_projection = self.projection(final_weights) return out_projection class FeedForward(nn.Module): def __init__(self, d_model = 256): super().__init__() self.projection = nn.Sequential( nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(0.1), nn.Linear(d_model * 4, d_model) ) def forward(self, x): return self.projection(x) class TransformerBlock(nn.Module): def __init__(self, d_model = 256): super().__init__() self.attn = MultiHeadAttention() self.ffn = FeedForward() self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward(self, x, attn_mask): residual = x x = self.norm1(x) x = self.attn(x, attn_mask) x += residual residual = x x = self.norm2(x) x = self.ffn(x) x += residual return x