initial commit
Browse files- app.py +75 -0
- attention.py +53 -0
- models.py +162 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import *
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
import seaborn as sns
|
8 |
+
import sentencepiece as sp
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
from torchtext.datasets import Multi30k
|
13 |
+
|
14 |
+
from models import Seq2Seq
|
15 |
+
|
16 |
+
|
17 |
+
# Load model
|
18 |
+
model_path = hf_hub_download("msarmi9/multi30k", "models/de-en/model.bin")
|
19 |
+
model = Seq2Seq(vocab_size=8000, hidden_dim=512, bos_idx=1, eos_idx=2, pad_idx=3, temperature=2)
|
20 |
+
model.load_state_dict(torch.load(model_path))
|
21 |
+
model.eval()
|
22 |
+
|
23 |
+
# Load sentencepiece tokenizers
|
24 |
+
source_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/de8000.model")
|
25 |
+
target_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/en8000.model")
|
26 |
+
source_spm = sp.SentencePieceProcessor(model_file=source_spm_path, add_eos=True)
|
27 |
+
target_spm = sp.SentencePieceProcessor(model_file=target_spm_path, add_eos=True)
|
28 |
+
|
29 |
+
# Load test set for example inputs
|
30 |
+
normalize = lambda sample: (sample[0].lower().strip(), sample[1].lower().strip())
|
31 |
+
test_source, _ = zip(*map(normalize, Multi30k(split="test", language_pair=("de", "en"))))
|
32 |
+
|
33 |
+
|
34 |
+
def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights: np.ndarray) -> plt.Figure:
|
35 |
+
figure = plt.figure(dpi=800, tight_layout=True)
|
36 |
+
axes = sns.heatmap(weights, cmap="gray", cbar=False)
|
37 |
+
axes.set_xticklabels(input_tokens, rotation=90)
|
38 |
+
axes.set_yticklabels(output_tokens, rotation=0)
|
39 |
+
axes.tick_params(axis="both", length=0)
|
40 |
+
axes.xaxis.tick_top()
|
41 |
+
plt.close()
|
42 |
+
return figure
|
43 |
+
|
44 |
+
|
45 |
+
@torch.inference_mode()
|
46 |
+
def run(input: str) -> Tuple[str, plt.Figure]:
|
47 |
+
"""Run inference on a single sentence. Returns prediction and attention heatmap."""""
|
48 |
+
input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
|
49 |
+
output, weights = model.decode(input_tensor, max_decode_length=max(len(input_tensor), 80))
|
50 |
+
output = target_spm.decode(output.detach().tolist())
|
51 |
+
input_tokens = source_spm.encode(input, out_type=str)
|
52 |
+
output_tokens = target_spm.encode(output, out_type=str)
|
53 |
+
return output, attention_heatmap(input_tokens, output_tokens, weights.detach().numpy())
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
interface = gr.Interface(
|
58 |
+
run,
|
59 |
+
inputs=gr.inputs.Textbox(lines=4, label="German"),
|
60 |
+
outputs=[
|
61 |
+
gr.outputs.Textbox(label="English"),
|
62 |
+
gr.outputs.Image(type="plot", label="Attention Heatmap"),
|
63 |
+
],
|
64 |
+
title = "Multi30k Translation Widget",
|
65 |
+
examples=random.sample(test_source, k=30),
|
66 |
+
examples_per_page=10,
|
67 |
+
allow_flagging="never",
|
68 |
+
theme="huggingface",
|
69 |
+
live=True,
|
70 |
+
)
|
71 |
+
|
72 |
+
interface.launch(
|
73 |
+
enable_queue=True,
|
74 |
+
cache_examples=True,
|
75 |
+
)
|
attention.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
Tensor = torch.Tensor
|
8 |
+
|
9 |
+
|
10 |
+
class Attention(nn.Module):
|
11 |
+
"""Container for applying an attention scoring function."""""
|
12 |
+
|
13 |
+
def __init__(self, score: nn.Module, dropout: nn.Module = None):
|
14 |
+
super().__init__()
|
15 |
+
self.score = score
|
16 |
+
self.dropout = dropout
|
17 |
+
|
18 |
+
def forward(self, decoder_state: Tensor, encoder_state: Tensor, source_mask: Tensor = None) -> Tuple[Tensor, Tensor]:
|
19 |
+
"""Return context and attention weights. Accepts a boolean mask indicating padding in the source sequence."""""
|
20 |
+
(B, L, D), (B, T, _) = decoder_state.shape, encoder_state.shape
|
21 |
+
scores = self.score(decoder_state, encoder_state) # (B, L, T)
|
22 |
+
if source_mask is not None: # (B, T)
|
23 |
+
scores.masked_fill_(source_mask.view(B, 1, T), -1e4)
|
24 |
+
weights = F.softmax(scores, dim=-1) # (B, L, T)
|
25 |
+
if self.dropout is not None:
|
26 |
+
weights = self.dropout(weights)
|
27 |
+
context = weights @ encoder_state # (B, L, _)
|
28 |
+
return context, weights # (B, L, _), (B, L, T)
|
29 |
+
|
30 |
+
|
31 |
+
class ConcatScore(nn.Module):
|
32 |
+
"""A two layer network as an attention scoring function. Expects bidirectional encoder."""""
|
33 |
+
|
34 |
+
def __init__(self, d: int):
|
35 |
+
super().__init__()
|
36 |
+
self.w = nn.Linear(3*d, d)
|
37 |
+
self.v = nn.Linear(d, 1, bias=False)
|
38 |
+
self.initialize_parameters()
|
39 |
+
|
40 |
+
def forward(self, decoder_state: Tensor, encoder_state: Tensor) -> Tensor:
|
41 |
+
"""Return attention scores."""""
|
42 |
+
(B, L, D), (B, T, _) = decoder_state.shape, encoder_state.shape # (B, L, D), (B, T, 2*D)
|
43 |
+
decoder_state = decoder_state.repeat_interleave(T, dim=1) # (B, L*T, D)
|
44 |
+
encoder_state = encoder_state.repeat(1, L, 1) # (B, L*T, 2*D)
|
45 |
+
concatenated = torch.cat((decoder_state, encoder_state), dim=-1) # (B, L*T, 3*D)
|
46 |
+
scores = self.v(torch.tanh(self.w(concatenated))) # (B, L*T, 1)
|
47 |
+
return scores.view(B, L, T) # (B, L, T)
|
48 |
+
|
49 |
+
@torch.no_grad()
|
50 |
+
def initialize_parameters(self):
|
51 |
+
nn.init.xavier_uniform_(self.w.weight)
|
52 |
+
nn.init.xavier_uniform_(self.v.weight, gain=nn.init.calculate_gain("tanh"))
|
53 |
+
nn.init.zeros_(self.w.bias)
|
models.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from attention import Attention
|
7 |
+
from attention import ConcatScore
|
8 |
+
|
9 |
+
Tensor = torch.Tensor
|
10 |
+
|
11 |
+
|
12 |
+
class Encoder(nn.Module):
|
13 |
+
"""Single layer recurrent bidirectional encoder."""""
|
14 |
+
|
15 |
+
def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int):
|
16 |
+
super().__init__()
|
17 |
+
self.embedding = nn.Sequential(
|
18 |
+
OrderedDict(
|
19 |
+
embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx),
|
20 |
+
dropout=nn.Dropout(p=0.33),
|
21 |
+
)
|
22 |
+
)
|
23 |
+
self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
|
24 |
+
self.fc = nn.Linear(2*hidden_dim, hidden_dim)
|
25 |
+
self.initialize_parameters()
|
26 |
+
|
27 |
+
def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
|
28 |
+
"""Encode a sequence of tokens as a sequence of hidden states."""""
|
29 |
+
B, T = input.shape
|
30 |
+
embedded = self.embedding(input) # (B, T, D)
|
31 |
+
output, hidden = self.gru(embedded) # (B, T, 2*D), (2, B, D)
|
32 |
+
hidden = torch.cat((hidden[0], hidden[1]), dim=-1) # (B, 2*D)
|
33 |
+
hidden = torch.tanh(self.fc(hidden)) # (B, D)
|
34 |
+
return output, hidden.unsqueeze(0) # (B, T, 2*D), (1, B, D)
|
35 |
+
|
36 |
+
@torch.no_grad()
|
37 |
+
def initialize_parameters(self):
|
38 |
+
"""Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero."""""
|
39 |
+
for name, parameters in self.named_parameters():
|
40 |
+
if "embedding" in name:
|
41 |
+
nn.init.xavier_uniform_(parameters)
|
42 |
+
elif "weight_ih" in name:
|
43 |
+
w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0)
|
44 |
+
nn.init.xavier_uniform_(w_ir)
|
45 |
+
nn.init.xavier_uniform_(w_iz)
|
46 |
+
nn.init.xavier_uniform_(w_in)
|
47 |
+
elif "weight_hh" in name:
|
48 |
+
w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0)
|
49 |
+
nn.init.orthogonal_(w_hr)
|
50 |
+
nn.init.orthogonal_(w_hz)
|
51 |
+
nn.init.orthogonal_(w_hn)
|
52 |
+
elif "weight" in name:
|
53 |
+
nn.init.xavier_uniform_(parameters)
|
54 |
+
elif "bias" in name:
|
55 |
+
nn.init.zeros_(parameters)
|
56 |
+
|
57 |
+
|
58 |
+
class Decoder(nn.Module):
|
59 |
+
"""Single layer recurrent decoder."""""
|
60 |
+
|
61 |
+
def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int, temperature: float = 1.0):
|
62 |
+
super().__init__()
|
63 |
+
self.embedding = nn.Sequential(
|
64 |
+
OrderedDict(
|
65 |
+
embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx),
|
66 |
+
dropout=nn.Dropout(p=0.33),
|
67 |
+
)
|
68 |
+
)
|
69 |
+
self.attention = Attention(ConcatScore(hidden_dim), nn.Dropout(p=0.1))
|
70 |
+
self.gru = nn.GRU(3*hidden_dim, hidden_dim, batch_first=True)
|
71 |
+
self.fc = nn.Sequential(
|
72 |
+
OrderedDict(
|
73 |
+
fc1=nn.Linear(4*hidden_dim, hidden_dim),
|
74 |
+
layer_norm=nn.LayerNorm(hidden_dim),
|
75 |
+
gelu=nn.GELU(),
|
76 |
+
fc2=nn.Linear(hidden_dim, vocab_size, bias=False),
|
77 |
+
)
|
78 |
+
)
|
79 |
+
self.fc.fc2.weight = self.embedding.embedding.weight
|
80 |
+
self.temperature = temperature
|
81 |
+
self.initialize_parameters()
|
82 |
+
|
83 |
+
def forward(self, input: Tensor, hidden: Tensor, encoder_output: Tensor, source_mask: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]:
|
84 |
+
"""Predict the next token given an input token. Returns unnormalized predictions over the vocabulary."""""
|
85 |
+
B, = input.shape # L=1
|
86 |
+
embedded = self.embedding(input.view(B, 1)) # (B, 1, D)
|
87 |
+
context, weights = self.attention(hidden.view(B, 1, -1), encoder_output, source_mask) # (B, 1, 2*D), (B, 1, T)
|
88 |
+
output, hidden = self.gru(torch.cat((embedded, context), dim=-1), hidden) # (B, 1, D), (1, B, D)
|
89 |
+
predictions = self.fc(torch.cat((embedded, context, output), dim=-1)) / self.temperature # (B, 1, V)
|
90 |
+
return predictions.view(B, -1), hidden, weights.view(B, -1) # (B, V), (1, B, D), (B, T)
|
91 |
+
|
92 |
+
|
93 |
+
@torch.no_grad()
|
94 |
+
def initialize_parameters(self):
|
95 |
+
"""Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero."""""
|
96 |
+
for name, parameters in self.named_parameters():
|
97 |
+
if "norm" in name:
|
98 |
+
continue
|
99 |
+
elif "embedding" in name:
|
100 |
+
nn.init.xavier_uniform_(parameters)
|
101 |
+
elif "weight_ih" in name:
|
102 |
+
w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0)
|
103 |
+
nn.init.xavier_uniform_(w_ir)
|
104 |
+
nn.init.xavier_uniform_(w_iz)
|
105 |
+
nn.init.xavier_uniform_(w_in)
|
106 |
+
elif "weight_hh" in name:
|
107 |
+
w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0)
|
108 |
+
nn.init.orthogonal_(w_hr)
|
109 |
+
nn.init.orthogonal_(w_hz)
|
110 |
+
nn.init.orthogonal_(w_hn)
|
111 |
+
elif "weight" in name:
|
112 |
+
nn.init.xavier_uniform_(parameters)
|
113 |
+
elif "bias" in name:
|
114 |
+
nn.init.zeros_(parameters)
|
115 |
+
|
116 |
+
|
117 |
+
class Seq2Seq(nn.Module):
|
118 |
+
"""Seq2seq with attention."""""
|
119 |
+
|
120 |
+
def __init__(self, vocab_size: int, hidden_dim: int, bos_idx: int, eos_idx: int, pad_idx: int, teacher_forcing: float = 0.5, temperature: float = 1.0):
|
121 |
+
super().__init__()
|
122 |
+
self.encoder = Encoder(vocab_size, hidden_dim, pad_idx)
|
123 |
+
self.decoder = Decoder(vocab_size, hidden_dim, pad_idx, temperature=temperature)
|
124 |
+
self.bos_idx = bos_idx
|
125 |
+
self.eos_idx = eos_idx
|
126 |
+
self.pad_idx = pad_idx
|
127 |
+
self.teacher_forcing = teacher_forcing
|
128 |
+
|
129 |
+
def forward(self, source: Tensor, target: Tensor) -> Tensor:
|
130 |
+
"""Forward pass at training time. Returns unnormalized predictions over the vocabulary."""""
|
131 |
+
(B, T), (B, L) = source.shape, target.shape
|
132 |
+
encoder_output, hidden = self.encoder(source) # (B, T, D), (1, B, D)
|
133 |
+
decoder_input = torch.full((B,), self.bos_idx, device=source.device) # (B,)
|
134 |
+
source_mask = source == self.pad_idx # (B, 1, T)
|
135 |
+
|
136 |
+
output = []
|
137 |
+
for i in range(L):
|
138 |
+
predictions, hidden, _ = self.decoder(decoder_input, hidden, encoder_output, source_mask) # (B, V), (1, B, D)
|
139 |
+
output.append(predictions)
|
140 |
+
if self.training and random.random() < self.teacher_forcing:
|
141 |
+
decoder_input = target[:,i] # (B,)
|
142 |
+
else:
|
143 |
+
decoder_input = predictions.argmax(dim=1) # (B,)
|
144 |
+
return torch.stack(output, dim=1) # (B, L, V)
|
145 |
+
|
146 |
+
@torch.inference_mode()
|
147 |
+
def decode(self, source: Tensor, max_decode_length: int) -> Tuple[Tensor, Tensor]:
|
148 |
+
"""Decode a single sequence at inference time. Returns output sequence and attention weights."""""
|
149 |
+
B, (T,) = 1, source.shape
|
150 |
+
encoder_output, hidden = self.encoder(source.view(B, T)) # (B, T, D), (B, 1, D)
|
151 |
+
decoder_input = torch.full((B,), self.bos_idx, device=source.device) # (B,)
|
152 |
+
|
153 |
+
output, attention = [], []
|
154 |
+
for i in range(max_decode_length):
|
155 |
+
predictions, hidden, weights = self.decoder(decoder_input, hidden, encoder_output) # (B, V), (1, B, D), (B, T)
|
156 |
+
output.append(predictions.argmax(dim=-1)) # (B,)
|
157 |
+
attention.append(weights) # (B, T)
|
158 |
+
if output[i] == self.eos_idx:
|
159 |
+
break
|
160 |
+
else:
|
161 |
+
decoder_input = output[i] # (B,)
|
162 |
+
return torch.cat(output, dim=0), torch.cat(attention, dim=0) # (L,), (L, T)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
huggingface_hub
|
3 |
+
matplotlib
|
4 |
+
numpy
|
5 |
+
seaborn
|
6 |
+
sentencepiece
|
7 |
+
torch
|
8 |
+
torchtext
|