File size: 7,905 Bytes
8c7a320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from typing import *

import torch
import torch.nn as nn

from attention import Attention
from attention import ConcatScore

Tensor = torch.Tensor


class Encoder(nn.Module):
    """Single layer recurrent bidirectional encoder."""""

    def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int):
        super().__init__()
        self.embedding = nn.Sequential(
            OrderedDict(
                embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx),
                dropout=nn.Dropout(p=0.33),
            )
        )
        self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(2*hidden_dim, hidden_dim)
        self.initialize_parameters()

    def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
        """Encode a sequence of tokens as a sequence of hidden states."""""
        B, T = input.shape
        embedded = self.embedding(input)                     # (B, T, D)
        output, hidden = self.gru(embedded)                  # (B, T, 2*D), (2, B, D)
        hidden = torch.cat((hidden[0], hidden[1]), dim=-1)   # (B, 2*D)
        hidden = torch.tanh(self.fc(hidden))                 # (B, D)
        return output, hidden.unsqueeze(0)                   # (B, T, 2*D), (1, B, D)

    @torch.no_grad()
    def initialize_parameters(self):
        """Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero."""""
        for name, parameters in self.named_parameters():
            if "embedding" in name:
                nn.init.xavier_uniform_(parameters)
            elif "weight_ih" in name:
                w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0)
                nn.init.xavier_uniform_(w_ir)
                nn.init.xavier_uniform_(w_iz)
                nn.init.xavier_uniform_(w_in)
            elif "weight_hh" in name:
                w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0)
                nn.init.orthogonal_(w_hr)
                nn.init.orthogonal_(w_hz)
                nn.init.orthogonal_(w_hn)
            elif "weight" in name:
                nn.init.xavier_uniform_(parameters)
            elif "bias" in name:
                nn.init.zeros_(parameters)


class Decoder(nn.Module):
    """Single layer recurrent decoder."""""

    def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int, temperature: float = 1.0):
        super().__init__()
        self.embedding = nn.Sequential(
            OrderedDict(
                embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx),
                dropout=nn.Dropout(p=0.33),
            )
        )
        self.attention = Attention(ConcatScore(hidden_dim), nn.Dropout(p=0.1))
        self.gru = nn.GRU(3*hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Sequential(
            OrderedDict(
                fc1=nn.Linear(4*hidden_dim, hidden_dim),
                layer_norm=nn.LayerNorm(hidden_dim),
                gelu=nn.GELU(),
                fc2=nn.Linear(hidden_dim, vocab_size, bias=False),
            )
        )
        self.fc.fc2.weight = self.embedding.embedding.weight
        self.temperature = temperature
        self.initialize_parameters()

    def forward(self, input: Tensor, hidden: Tensor, encoder_output: Tensor, source_mask: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]:
        """Predict the next token given an input token. Returns unnormalized predictions over the vocabulary."""""
        B, = input.shape                                                                           # L=1
        embedded = self.embedding(input.view(B, 1))                                                # (B, 1, D)
        context, weights = self.attention(hidden.view(B, 1, -1), encoder_output, source_mask)      # (B, 1, 2*D), (B, 1, T)
        output, hidden = self.gru(torch.cat((embedded, context), dim=-1), hidden)                  # (B, 1, D), (1, B, D)
        predictions = self.fc(torch.cat((embedded, context, output), dim=-1)) / self.temperature   # (B, 1, V)
        return predictions.view(B, -1), hidden, weights.view(B, -1)                                # (B, V), (1, B, D), (B, T)


    @torch.no_grad()
    def initialize_parameters(self):
        """Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero."""""
        for name, parameters in self.named_parameters():
            if "norm" in name:
                continue
            elif "embedding" in name:
                nn.init.xavier_uniform_(parameters)
            elif "weight_ih" in name:
                w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0)
                nn.init.xavier_uniform_(w_ir)
                nn.init.xavier_uniform_(w_iz)
                nn.init.xavier_uniform_(w_in)
            elif "weight_hh" in name:
                w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0)
                nn.init.orthogonal_(w_hr)
                nn.init.orthogonal_(w_hz)
                nn.init.orthogonal_(w_hn)
            elif "weight" in name:
                nn.init.xavier_uniform_(parameters)
            elif "bias" in name:
                nn.init.zeros_(parameters)


class Seq2Seq(nn.Module):
    """Seq2seq with attention."""""

    def __init__(self, vocab_size: int, hidden_dim: int, bos_idx: int, eos_idx: int, pad_idx: int, teacher_forcing: float = 0.5, temperature: float = 1.0):
        super().__init__()
        self.encoder = Encoder(vocab_size, hidden_dim, pad_idx)
        self.decoder = Decoder(vocab_size, hidden_dim, pad_idx, temperature=temperature)
        self.bos_idx = bos_idx
        self.eos_idx = eos_idx
        self.pad_idx = pad_idx
        self.teacher_forcing = teacher_forcing

    def forward(self, source: Tensor, target: Tensor) -> Tensor:
        """Forward pass at training time. Returns unnormalized predictions over the vocabulary."""""
        (B, T), (B, L) = source.shape, target.shape
        encoder_output, hidden = self.encoder(source)                          # (B, T, D), (1, B, D)
        decoder_input = torch.full((B,), self.bos_idx, device=source.device)   # (B,)
        source_mask = source == self.pad_idx                                   # (B, 1, T)

        output = []
        for i in range(L):
            predictions, hidden, _ = self.decoder(decoder_input, hidden, encoder_output, source_mask)   # (B, V), (1, B, D)
            output.append(predictions)
            if self.training and random.random() < self.teacher_forcing:
                decoder_input = target[:,i]                                    # (B,)
            else:
                decoder_input = predictions.argmax(dim=1)                      # (B,)
        return torch.stack(output, dim=1)                                      # (B, L, V)

    @torch.inference_mode()
    def decode(self, source: Tensor, max_decode_length: int) -> Tuple[Tensor, Tensor]:
        """Decode a single sequence at inference time.  Returns output sequence and attention weights."""""
        B, (T,) = 1, source.shape
        encoder_output, hidden = self.encoder(source.view(B, T))               # (B, T, D), (B, 1, D)
        decoder_input = torch.full((B,), self.bos_idx, device=source.device)   # (B,)

        output, attention = [], []
        for i in range(max_decode_length):
            predictions, hidden, weights = self.decoder(decoder_input, hidden, encoder_output)   # (B, V), (1, B, D), (B, T)
            output.append(predictions.argmax(dim=-1))                          # (B,)
            attention.append(weights)                                          # (B, T)
            if output[i] == self.eos_idx:
                break
            else:
                decoder_input = output[i]                                      # (B,)
        return torch.cat(output, dim=0), torch.cat(attention, dim=0)           # (L,), (L, T)