from typing import Tuple import torch import torch.nn as nn HIDDEN_SIZE = 64 VOCAB_SIZE = 196906 EMBEDDING_DIM = 64 # embedding_dim SEQ_LEN = 100 BATCH_SIZE = 16 class BahdanauAttention(nn.Module): def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None: super().__init__() self.hidden_size = hidden_size self.W_q = nn.Linear(hidden_size, hidden_size) self.W_k = nn.Linear(hidden_size, hidden_size) self.W_v = nn.Linear(hidden_size, 1) self.tanh = nn.Tanh() def forward( self, lstm_outputs: torch.Tensor, # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE final_hidden: torch.Tensor, # BATCH_SIZE x HIDDEN_SIZE ) -> Tuple[torch.Tensor, torch.Tensor]: keys = self.W_k(lstm_outputs) query = self.W_q(final_hidden) sum = query.unsqueeze(1) + keys tanhed = self.tanh(sum) vector = self.W_v(tanhed).squeeze(-1) att_weights = torch.softmax(vector, -1) context = torch.bmm(att_weights.unsqueeze(1), keys).squeeze() return context, att_weights BahdanauAttention()( torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), torch.randn(BATCH_SIZE, HIDDEN_SIZE) )[1].shape class LSTMConcatAttentionEmbed(nn.Module): def __init__(self) -> None: super().__init__() self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM) # self.embedding = embedding_layer self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True) self.attn = BahdanauAttention(HIDDEN_SIZE) self.clf = nn.Sequential( nn.Linear(HIDDEN_SIZE, 128), nn.Dropout(), nn.Tanh(), nn.Linear(128, 64), nn.Dropout(), nn.Tanh(), nn.Linear(64, 1), ) def forward(self, x): embeddings = self.embedding(x) outputs, (h_n, _) = self.lstm(embeddings) att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0)) out = self.clf(att_hidden) return out, att_weights