File size: 3,950 Bytes
cb2adb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

from typing import Tuple
import torch
import torch.nn as nn

HIDDEN_SIZE = 32
VOCAB_SIZE =196906
EMBEDDING_DIM = 64 # embedding_dim 
SEQ_LEN = 100
BATCH_SIZE = 64


class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None:

        super().__init__()
        self.hidden_size = hidden_size
        self.W_q = nn.Linear(hidden_size, hidden_size)
        self.W_k = nn.Linear(hidden_size, hidden_size)
        self.W_v = nn.Linear(hidden_size, 1)

        self.tanh = nn.Tanh()

    def forward(
        self,
        lstm_outputs: torch.Tensor,  # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE
        final_hidden: torch.Tensor,  # BATCH_SIZE x HIDDEN_SIZE
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        """Bahdanau Attention module

        Args:
            keys (torch.Tensor): lstm hidden states (BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE)
            query (torch.Tensor): lstm final hidden state (BATCH_SIZE, HIDDEN_SIZE)

        Returns:
            Tuple[torch.Tensor]:
                context_matrix (BATCH_SIZE, HIDDEN_SIZE)
                attention scores (BATCH_SIZE, SEQ_LEN)
        """
        # input:
        # keys – lstm hidden states (BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE)
        # query - lstm final hidden state (BATCH_SIZE, HIDDEN_SIZE)

        keys = self.W_k(lstm_outputs)
        # print(f'After linear keys: {keys.shape}')

        query = self.W_q(final_hidden)
        # print(f"After linear query: {query.shape}")

        # print(f"query.unsqueeze(1) {query.unsqueeze(1).shape}")

        sum = query.unsqueeze(1) + keys
        # print(f"After sum: {sum.shape}")

        tanhed = self.tanh(sum)
        # print(f"After tanhed: {tanhed.shape}")

        vector = self.W_v(tanhed).squeeze(-1)
        # print(f"After linear vector: {vector.shape}")

        att_weights = torch.softmax(vector, -1)
        # print(f"After softmax att_weights: {att_weights.shape}")

        context = torch.bmm(att_weights.unsqueeze(1), keys).squeeze()
        # print(f"After bmm context: {context.shape}")

        return context, att_weights

        # att_weights = self.linear(lstm_outputs)
        # # print(f'After linear: {att_weights.shape, final_hidden.unsqueeze(2).shape}')

        # att_weights = self.linear(lstm_outputs)
        # # print(f'After linear: {att_weights.shape, final_hidden.unsqueeze(2).shape}')
        # att_weights = torch.bmm(att_weights, final_hidden.unsqueeze(2))
        # # print(f'After bmm: {att_weights.shape}')
        # att_weights = F.softmax(att_weights.squeeze(2), dim=1)
        # # print(f'After softmax: {att_weights.shape}')
        # cntxt = torch.bmm(lstm_outputs.transpose(1, 2), att_weights.unsqueeze(2))
        # # print(f'Context: {cntxt.shape}')
        # concatted = torch.cat((cntxt, final_hidden.unsqueeze(2)), dim=1)
        # # print(f'Concatted: {concatted.shape}')
        # att_hidden = self.tanh(self.align(concatted.squeeze(-1)))
        # # print(f'Att Hidden: {att_hidden.shape}')
        # return att_hidden, att_weights

# Test on random numbers
BahdanauAttention()(torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), torch.randn(BATCH_SIZE, HIDDEN_SIZE))[1].shape


class LSTMConcatAttentionEmbed(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        # self.embedding = embedding_layer
        self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True)
        self.attn = BahdanauAttention(HIDDEN_SIZE)
        self.clf = nn.Sequential(
            nn.Linear(HIDDEN_SIZE, 128), 
            nn.Dropout(), 
            nn.Tanh(), 
            nn.Linear(128, 1)
        )

    def forward(self, x):        
        embeddings = self.embedding(x)
        outputs, (h_n, _) = self.lstm(embeddings)
        att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0))
        out = self.clf(att_hidden)
        return out, att_weights