Spaces:
Sleeping
Sleeping
File size: 2,045 Bytes
cb2adb5 3a905e4 cb2adb5 3a905e4 cb2adb5 3a905e4 cb2adb5 3a905e4 cb2adb5 3a905e4 cb2adb5 3a905e4 cb2adb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from typing import Tuple
import torch
import torch.nn as nn
HIDDEN_SIZE = 64
VOCAB_SIZE = 196906
EMBEDDING_DIM = 64 # embedding_dim
SEQ_LEN = 100
BATCH_SIZE = 16
class BahdanauAttention(nn.Module):
def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None:
super().__init__()
self.hidden_size = hidden_size
self.W_q = nn.Linear(hidden_size, hidden_size)
self.W_k = nn.Linear(hidden_size, hidden_size)
self.W_v = nn.Linear(hidden_size, 1)
self.tanh = nn.Tanh()
def forward(
self,
lstm_outputs: torch.Tensor, # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE
final_hidden: torch.Tensor, # BATCH_SIZE x HIDDEN_SIZE
) -> Tuple[torch.Tensor, torch.Tensor]:
keys = self.W_k(lstm_outputs)
query = self.W_q(final_hidden)
sum = query.unsqueeze(1) + keys
tanhed = self.tanh(sum)
vector = self.W_v(tanhed).squeeze(-1)
att_weights = torch.softmax(vector, -1)
context = torch.bmm(att_weights.unsqueeze(1), keys).squeeze()
return context, att_weights
BahdanauAttention()(
torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), torch.randn(BATCH_SIZE, HIDDEN_SIZE)
)[1].shape
class LSTMConcatAttentionEmbed(nn.Module):
def __init__(self) -> None:
super().__init__()
self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
# self.embedding = embedding_layer
self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True)
self.attn = BahdanauAttention(HIDDEN_SIZE)
self.clf = nn.Sequential(
nn.Linear(HIDDEN_SIZE, 128),
nn.Dropout(),
nn.Tanh(),
nn.Linear(128, 64),
nn.Dropout(),
nn.Tanh(),
nn.Linear(64, 1),
)
def forward(self, x):
embeddings = self.embedding(x)
outputs, (h_n, _) = self.lstm(embeddings)
att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0))
out = self.clf(att_hidden)
return out, att_weights
|