|
|
|
from typing import Tuple |
|
import torch |
|
import torch.nn as nn |
|
|
|
HIDDEN_SIZE = 32 |
|
VOCAB_SIZE =196906 |
|
EMBEDDING_DIM = 64 |
|
SEQ_LEN = 100 |
|
BATCH_SIZE = 64 |
|
|
|
|
|
class BahdanauAttention(nn.Module): |
|
def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None: |
|
|
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.W_q = nn.Linear(hidden_size, hidden_size) |
|
self.W_k = nn.Linear(hidden_size, hidden_size) |
|
self.W_v = nn.Linear(hidden_size, 1) |
|
|
|
self.tanh = nn.Tanh() |
|
|
|
def forward( |
|
self, |
|
lstm_outputs: torch.Tensor, |
|
final_hidden: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
"""Bahdanau Attention module |
|
|
|
Args: |
|
keys (torch.Tensor): lstm hidden states (BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE) |
|
query (torch.Tensor): lstm final hidden state (BATCH_SIZE, HIDDEN_SIZE) |
|
|
|
Returns: |
|
Tuple[torch.Tensor]: |
|
context_matrix (BATCH_SIZE, HIDDEN_SIZE) |
|
attention scores (BATCH_SIZE, SEQ_LEN) |
|
""" |
|
|
|
|
|
|
|
|
|
keys = self.W_k(lstm_outputs) |
|
|
|
|
|
query = self.W_q(final_hidden) |
|
|
|
|
|
|
|
|
|
sum = query.unsqueeze(1) + keys |
|
|
|
|
|
tanhed = self.tanh(sum) |
|
|
|
|
|
vector = self.W_v(tanhed).squeeze(-1) |
|
|
|
|
|
att_weights = torch.softmax(vector, -1) |
|
|
|
|
|
context = torch.bmm(att_weights.unsqueeze(1), keys).squeeze() |
|
|
|
|
|
return context, att_weights |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BahdanauAttention()(torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), torch.randn(BATCH_SIZE, HIDDEN_SIZE))[1].shape |
|
|
|
|
|
class LSTMConcatAttentionEmbed(nn.Module): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM) |
|
|
|
self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True) |
|
self.attn = BahdanauAttention(HIDDEN_SIZE) |
|
self.clf = nn.Sequential( |
|
nn.Linear(HIDDEN_SIZE, 128), |
|
nn.Dropout(), |
|
nn.Tanh(), |
|
nn.Linear(128, 1) |
|
) |
|
|
|
def forward(self, x): |
|
embeddings = self.embedding(x) |
|
outputs, (h_n, _) = self.lstm(embeddings) |
|
att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0)) |
|
out = self.clf(att_hidden) |
|
return out, att_weights |
|
|
|
|