Spaces:
Sleeping
Sleeping
from typing import Tuple | |
import torch | |
import torch.nn as nn | |
HIDDEN_SIZE = 64 | |
VOCAB_SIZE = 196906 | |
EMBEDDING_DIM = 64 # embedding_dim | |
SEQ_LEN = 100 | |
BATCH_SIZE = 16 | |
class BahdanauAttention(nn.Module): | |
def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None: | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.W_q = nn.Linear(hidden_size, hidden_size) | |
self.W_k = nn.Linear(hidden_size, hidden_size) | |
self.W_v = nn.Linear(hidden_size, 1) | |
self.tanh = nn.Tanh() | |
def forward( | |
self, | |
lstm_outputs: torch.Tensor, # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE | |
final_hidden: torch.Tensor, # BATCH_SIZE x HIDDEN_SIZE | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
keys = self.W_k(lstm_outputs) | |
query = self.W_q(final_hidden) | |
sum = query.unsqueeze(1) + keys | |
tanhed = self.tanh(sum) | |
vector = self.W_v(tanhed).squeeze(-1) | |
att_weights = torch.softmax(vector, -1) | |
context = torch.bmm(att_weights.unsqueeze(1), keys).squeeze() | |
return context, att_weights | |
BahdanauAttention()( | |
torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), torch.randn(BATCH_SIZE, HIDDEN_SIZE) | |
)[1].shape | |
class LSTMConcatAttentionEmbed(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM) | |
# self.embedding = embedding_layer | |
self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True) | |
self.attn = BahdanauAttention(HIDDEN_SIZE) | |
self.clf = nn.Sequential( | |
nn.Linear(HIDDEN_SIZE, 128), | |
nn.Dropout(), | |
nn.Tanh(), | |
nn.Linear(128, 64), | |
nn.Dropout(), | |
nn.Tanh(), | |
nn.Linear(64, 1), | |
) | |
def forward(self, x): | |
embeddings = self.embedding(x) | |
outputs, (h_n, _) = self.lstm(embeddings) | |
att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0)) | |
out = self.clf(att_hidden) | |
return out, att_weights | |