msarmi9 commited on
Commit
8c7a320
1 Parent(s): cc430e7

initial commit

Browse files
Files changed (4) hide show
  1. app.py +75 -0
  2. attention.py +53 -0
  3. models.py +162 -0
  4. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import *
3
+
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import seaborn as sns
8
+ import sentencepiece as sp
9
+ import torch
10
+
11
+ from huggingface_hub import hf_hub_download
12
+ from torchtext.datasets import Multi30k
13
+
14
+ from models import Seq2Seq
15
+
16
+
17
+ # Load model
18
+ model_path = hf_hub_download("msarmi9/multi30k", "models/de-en/model.bin")
19
+ model = Seq2Seq(vocab_size=8000, hidden_dim=512, bos_idx=1, eos_idx=2, pad_idx=3, temperature=2)
20
+ model.load_state_dict(torch.load(model_path))
21
+ model.eval()
22
+
23
+ # Load sentencepiece tokenizers
24
+ source_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/de8000.model")
25
+ target_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/en8000.model")
26
+ source_spm = sp.SentencePieceProcessor(model_file=source_spm_path, add_eos=True)
27
+ target_spm = sp.SentencePieceProcessor(model_file=target_spm_path, add_eos=True)
28
+
29
+ # Load test set for example inputs
30
+ normalize = lambda sample: (sample[0].lower().strip(), sample[1].lower().strip())
31
+ test_source, _ = zip(*map(normalize, Multi30k(split="test", language_pair=("de", "en"))))
32
+
33
+
34
+ def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights: np.ndarray) -> plt.Figure:
35
+ figure = plt.figure(dpi=800, tight_layout=True)
36
+ axes = sns.heatmap(weights, cmap="gray", cbar=False)
37
+ axes.set_xticklabels(input_tokens, rotation=90)
38
+ axes.set_yticklabels(output_tokens, rotation=0)
39
+ axes.tick_params(axis="both", length=0)
40
+ axes.xaxis.tick_top()
41
+ plt.close()
42
+ return figure
43
+
44
+
45
+ @torch.inference_mode()
46
+ def run(input: str) -> Tuple[str, plt.Figure]:
47
+ """Run inference on a single sentence. Returns prediction and attention heatmap."""""
48
+ input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
49
+ output, weights = model.decode(input_tensor, max_decode_length=max(len(input_tensor), 80))
50
+ output = target_spm.decode(output.detach().tolist())
51
+ input_tokens = source_spm.encode(input, out_type=str)
52
+ output_tokens = target_spm.encode(output, out_type=str)
53
+ return output, attention_heatmap(input_tokens, output_tokens, weights.detach().numpy())
54
+
55
+
56
+ if __name__ == "__main__":
57
+ interface = gr.Interface(
58
+ run,
59
+ inputs=gr.inputs.Textbox(lines=4, label="German"),
60
+ outputs=[
61
+ gr.outputs.Textbox(label="English"),
62
+ gr.outputs.Image(type="plot", label="Attention Heatmap"),
63
+ ],
64
+ title = "Multi30k Translation Widget",
65
+ examples=random.sample(test_source, k=30),
66
+ examples_per_page=10,
67
+ allow_flagging="never",
68
+ theme="huggingface",
69
+ live=True,
70
+ )
71
+
72
+ interface.launch(
73
+ enable_queue=True,
74
+ cache_examples=True,
75
+ )
attention.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ Tensor = torch.Tensor
8
+
9
+
10
+ class Attention(nn.Module):
11
+ """Container for applying an attention scoring function."""""
12
+
13
+ def __init__(self, score: nn.Module, dropout: nn.Module = None):
14
+ super().__init__()
15
+ self.score = score
16
+ self.dropout = dropout
17
+
18
+ def forward(self, decoder_state: Tensor, encoder_state: Tensor, source_mask: Tensor = None) -> Tuple[Tensor, Tensor]:
19
+ """Return context and attention weights. Accepts a boolean mask indicating padding in the source sequence."""""
20
+ (B, L, D), (B, T, _) = decoder_state.shape, encoder_state.shape
21
+ scores = self.score(decoder_state, encoder_state) # (B, L, T)
22
+ if source_mask is not None: # (B, T)
23
+ scores.masked_fill_(source_mask.view(B, 1, T), -1e4)
24
+ weights = F.softmax(scores, dim=-1) # (B, L, T)
25
+ if self.dropout is not None:
26
+ weights = self.dropout(weights)
27
+ context = weights @ encoder_state # (B, L, _)
28
+ return context, weights # (B, L, _), (B, L, T)
29
+
30
+
31
+ class ConcatScore(nn.Module):
32
+ """A two layer network as an attention scoring function. Expects bidirectional encoder."""""
33
+
34
+ def __init__(self, d: int):
35
+ super().__init__()
36
+ self.w = nn.Linear(3*d, d)
37
+ self.v = nn.Linear(d, 1, bias=False)
38
+ self.initialize_parameters()
39
+
40
+ def forward(self, decoder_state: Tensor, encoder_state: Tensor) -> Tensor:
41
+ """Return attention scores."""""
42
+ (B, L, D), (B, T, _) = decoder_state.shape, encoder_state.shape # (B, L, D), (B, T, 2*D)
43
+ decoder_state = decoder_state.repeat_interleave(T, dim=1) # (B, L*T, D)
44
+ encoder_state = encoder_state.repeat(1, L, 1) # (B, L*T, 2*D)
45
+ concatenated = torch.cat((decoder_state, encoder_state), dim=-1) # (B, L*T, 3*D)
46
+ scores = self.v(torch.tanh(self.w(concatenated))) # (B, L*T, 1)
47
+ return scores.view(B, L, T) # (B, L, T)
48
+
49
+ @torch.no_grad()
50
+ def initialize_parameters(self):
51
+ nn.init.xavier_uniform_(self.w.weight)
52
+ nn.init.xavier_uniform_(self.v.weight, gain=nn.init.calculate_gain("tanh"))
53
+ nn.init.zeros_(self.w.bias)
models.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from attention import Attention
7
+ from attention import ConcatScore
8
+
9
+ Tensor = torch.Tensor
10
+
11
+
12
+ class Encoder(nn.Module):
13
+ """Single layer recurrent bidirectional encoder."""""
14
+
15
+ def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int):
16
+ super().__init__()
17
+ self.embedding = nn.Sequential(
18
+ OrderedDict(
19
+ embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx),
20
+ dropout=nn.Dropout(p=0.33),
21
+ )
22
+ )
23
+ self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
24
+ self.fc = nn.Linear(2*hidden_dim, hidden_dim)
25
+ self.initialize_parameters()
26
+
27
+ def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
28
+ """Encode a sequence of tokens as a sequence of hidden states."""""
29
+ B, T = input.shape
30
+ embedded = self.embedding(input) # (B, T, D)
31
+ output, hidden = self.gru(embedded) # (B, T, 2*D), (2, B, D)
32
+ hidden = torch.cat((hidden[0], hidden[1]), dim=-1) # (B, 2*D)
33
+ hidden = torch.tanh(self.fc(hidden)) # (B, D)
34
+ return output, hidden.unsqueeze(0) # (B, T, 2*D), (1, B, D)
35
+
36
+ @torch.no_grad()
37
+ def initialize_parameters(self):
38
+ """Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero."""""
39
+ for name, parameters in self.named_parameters():
40
+ if "embedding" in name:
41
+ nn.init.xavier_uniform_(parameters)
42
+ elif "weight_ih" in name:
43
+ w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0)
44
+ nn.init.xavier_uniform_(w_ir)
45
+ nn.init.xavier_uniform_(w_iz)
46
+ nn.init.xavier_uniform_(w_in)
47
+ elif "weight_hh" in name:
48
+ w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0)
49
+ nn.init.orthogonal_(w_hr)
50
+ nn.init.orthogonal_(w_hz)
51
+ nn.init.orthogonal_(w_hn)
52
+ elif "weight" in name:
53
+ nn.init.xavier_uniform_(parameters)
54
+ elif "bias" in name:
55
+ nn.init.zeros_(parameters)
56
+
57
+
58
+ class Decoder(nn.Module):
59
+ """Single layer recurrent decoder."""""
60
+
61
+ def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int, temperature: float = 1.0):
62
+ super().__init__()
63
+ self.embedding = nn.Sequential(
64
+ OrderedDict(
65
+ embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx),
66
+ dropout=nn.Dropout(p=0.33),
67
+ )
68
+ )
69
+ self.attention = Attention(ConcatScore(hidden_dim), nn.Dropout(p=0.1))
70
+ self.gru = nn.GRU(3*hidden_dim, hidden_dim, batch_first=True)
71
+ self.fc = nn.Sequential(
72
+ OrderedDict(
73
+ fc1=nn.Linear(4*hidden_dim, hidden_dim),
74
+ layer_norm=nn.LayerNorm(hidden_dim),
75
+ gelu=nn.GELU(),
76
+ fc2=nn.Linear(hidden_dim, vocab_size, bias=False),
77
+ )
78
+ )
79
+ self.fc.fc2.weight = self.embedding.embedding.weight
80
+ self.temperature = temperature
81
+ self.initialize_parameters()
82
+
83
+ def forward(self, input: Tensor, hidden: Tensor, encoder_output: Tensor, source_mask: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]:
84
+ """Predict the next token given an input token. Returns unnormalized predictions over the vocabulary."""""
85
+ B, = input.shape # L=1
86
+ embedded = self.embedding(input.view(B, 1)) # (B, 1, D)
87
+ context, weights = self.attention(hidden.view(B, 1, -1), encoder_output, source_mask) # (B, 1, 2*D), (B, 1, T)
88
+ output, hidden = self.gru(torch.cat((embedded, context), dim=-1), hidden) # (B, 1, D), (1, B, D)
89
+ predictions = self.fc(torch.cat((embedded, context, output), dim=-1)) / self.temperature # (B, 1, V)
90
+ return predictions.view(B, -1), hidden, weights.view(B, -1) # (B, V), (1, B, D), (B, T)
91
+
92
+
93
+ @torch.no_grad()
94
+ def initialize_parameters(self):
95
+ """Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero."""""
96
+ for name, parameters in self.named_parameters():
97
+ if "norm" in name:
98
+ continue
99
+ elif "embedding" in name:
100
+ nn.init.xavier_uniform_(parameters)
101
+ elif "weight_ih" in name:
102
+ w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0)
103
+ nn.init.xavier_uniform_(w_ir)
104
+ nn.init.xavier_uniform_(w_iz)
105
+ nn.init.xavier_uniform_(w_in)
106
+ elif "weight_hh" in name:
107
+ w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0)
108
+ nn.init.orthogonal_(w_hr)
109
+ nn.init.orthogonal_(w_hz)
110
+ nn.init.orthogonal_(w_hn)
111
+ elif "weight" in name:
112
+ nn.init.xavier_uniform_(parameters)
113
+ elif "bias" in name:
114
+ nn.init.zeros_(parameters)
115
+
116
+
117
+ class Seq2Seq(nn.Module):
118
+ """Seq2seq with attention."""""
119
+
120
+ def __init__(self, vocab_size: int, hidden_dim: int, bos_idx: int, eos_idx: int, pad_idx: int, teacher_forcing: float = 0.5, temperature: float = 1.0):
121
+ super().__init__()
122
+ self.encoder = Encoder(vocab_size, hidden_dim, pad_idx)
123
+ self.decoder = Decoder(vocab_size, hidden_dim, pad_idx, temperature=temperature)
124
+ self.bos_idx = bos_idx
125
+ self.eos_idx = eos_idx
126
+ self.pad_idx = pad_idx
127
+ self.teacher_forcing = teacher_forcing
128
+
129
+ def forward(self, source: Tensor, target: Tensor) -> Tensor:
130
+ """Forward pass at training time. Returns unnormalized predictions over the vocabulary."""""
131
+ (B, T), (B, L) = source.shape, target.shape
132
+ encoder_output, hidden = self.encoder(source) # (B, T, D), (1, B, D)
133
+ decoder_input = torch.full((B,), self.bos_idx, device=source.device) # (B,)
134
+ source_mask = source == self.pad_idx # (B, 1, T)
135
+
136
+ output = []
137
+ for i in range(L):
138
+ predictions, hidden, _ = self.decoder(decoder_input, hidden, encoder_output, source_mask) # (B, V), (1, B, D)
139
+ output.append(predictions)
140
+ if self.training and random.random() < self.teacher_forcing:
141
+ decoder_input = target[:,i] # (B,)
142
+ else:
143
+ decoder_input = predictions.argmax(dim=1) # (B,)
144
+ return torch.stack(output, dim=1) # (B, L, V)
145
+
146
+ @torch.inference_mode()
147
+ def decode(self, source: Tensor, max_decode_length: int) -> Tuple[Tensor, Tensor]:
148
+ """Decode a single sequence at inference time. Returns output sequence and attention weights."""""
149
+ B, (T,) = 1, source.shape
150
+ encoder_output, hidden = self.encoder(source.view(B, T)) # (B, T, D), (B, 1, D)
151
+ decoder_input = torch.full((B,), self.bos_idx, device=source.device) # (B,)
152
+
153
+ output, attention = [], []
154
+ for i in range(max_decode_length):
155
+ predictions, hidden, weights = self.decoder(decoder_input, hidden, encoder_output) # (B, V), (1, B, D), (B, T)
156
+ output.append(predictions.argmax(dim=-1)) # (B,)
157
+ attention.append(weights) # (B, T)
158
+ if output[i] == self.eos_idx:
159
+ break
160
+ else:
161
+ decoder_input = output[i] # (B,)
162
+ return torch.cat(output, dim=0), torch.cat(attention, dim=0) # (L,), (L, T)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ huggingface_hub
3
+ matplotlib
4
+ numpy
5
+ seaborn
6
+ sentencepiece
7
+ torch
8
+ torchtext