File size: 2,045 Bytes
cb2adb5
3a905e4
cb2adb5
 
 
3a905e4
 
 
cb2adb5
3a905e4
cb2adb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a905e4
 
 
cb2adb5
 
 
 
 
 
 
 
 
 
 
3a905e4
 
 
 
 
 
 
cb2adb5
 
3a905e4
cb2adb5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import Tuple

import torch
import torch.nn as nn

HIDDEN_SIZE = 64
VOCAB_SIZE = 196906
EMBEDDING_DIM = 64  # embedding_dim
SEQ_LEN = 100
BATCH_SIZE = 16


class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.W_q = nn.Linear(hidden_size, hidden_size)
        self.W_k = nn.Linear(hidden_size, hidden_size)
        self.W_v = nn.Linear(hidden_size, 1)

        self.tanh = nn.Tanh()

    def forward(
        self,
        lstm_outputs: torch.Tensor,  # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE
        final_hidden: torch.Tensor,  # BATCH_SIZE x HIDDEN_SIZE
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        keys = self.W_k(lstm_outputs)
        query = self.W_q(final_hidden)

        sum = query.unsqueeze(1) + keys

        tanhed = self.tanh(sum)

        vector = self.W_v(tanhed).squeeze(-1)

        att_weights = torch.softmax(vector, -1)

        context = torch.bmm(att_weights.unsqueeze(1), keys).squeeze()

        return context, att_weights


BahdanauAttention()(
    torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), torch.randn(BATCH_SIZE, HIDDEN_SIZE)
)[1].shape


class LSTMConcatAttentionEmbed(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        # self.embedding = embedding_layer
        self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True)
        self.attn = BahdanauAttention(HIDDEN_SIZE)
        self.clf = nn.Sequential(
            nn.Linear(HIDDEN_SIZE, 128),
            nn.Dropout(),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Dropout(),
            nn.Tanh(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        embeddings = self.embedding(x)
        outputs, (h_n, _) = self.lstm(embeddings)
        att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0))
        out = self.clf(att_hidden)
        return out, att_weights