DHRUV SHEKHAWAT
commited on
Commit
·
1dd09ef
1
Parent(s):
52f9f0f
Upload 2 files
Browse files
models.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
+
|
8 |
+
|
9 |
+
class Embeddings(nn.Module):
|
10 |
+
"""
|
11 |
+
Implements embeddings of the words and adds their positional encodings.
|
12 |
+
"""
|
13 |
+
def __init__(self, vocab_size, d_model, max_len = 50):
|
14 |
+
super(Embeddings, self).__init__()
|
15 |
+
self.d_model = d_model
|
16 |
+
self.dropout = nn.Dropout(0.1)
|
17 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
18 |
+
self.pe = self.create_positinal_encoding(max_len, self.d_model)
|
19 |
+
self.dropout = nn.Dropout(0.1)
|
20 |
+
|
21 |
+
def create_positinal_encoding(self, max_len, d_model):
|
22 |
+
pe = torch.zeros(max_len, d_model).to(device)
|
23 |
+
for pos in range(max_len): # for each position of the word
|
24 |
+
for i in range(0, d_model, 2): # for each dimension of the each position
|
25 |
+
pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
|
26 |
+
pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
|
27 |
+
pe = pe.unsqueeze(0) # include the batch size
|
28 |
+
return pe
|
29 |
+
|
30 |
+
def forward(self, encoded_words):
|
31 |
+
embedding = self.embed(encoded_words) * math.sqrt(self.d_model)
|
32 |
+
embedding += self.pe[:, :embedding.size(1)] # pe will automatically be expanded with the same batch size as encoded_words
|
33 |
+
embedding = self.dropout(embedding)
|
34 |
+
return embedding
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
class MultiHeadAttention(nn.Module):
|
39 |
+
|
40 |
+
def __init__(self, heads, d_model):
|
41 |
+
|
42 |
+
super(MultiHeadAttention, self).__init__()
|
43 |
+
assert d_model % heads == 0
|
44 |
+
self.d_k = d_model // heads
|
45 |
+
self.heads = heads
|
46 |
+
self.dropout = nn.Dropout(0.1)
|
47 |
+
self.query = nn.Linear(d_model, d_model)
|
48 |
+
self.key = nn.Linear(d_model, d_model)
|
49 |
+
self.value = nn.Linear(d_model, d_model)
|
50 |
+
self.concat = nn.Linear(d_model, d_model)
|
51 |
+
|
52 |
+
def forward(self, query, key, value, mask):
|
53 |
+
"""
|
54 |
+
query, key, value of shape: (batch_size, max_len, 512)
|
55 |
+
mask of shape: (batch_size, 1, 1, max_words)
|
56 |
+
"""
|
57 |
+
# (batch_size, max_len, 512)
|
58 |
+
query = self.query(query)
|
59 |
+
key = self.key(key)
|
60 |
+
value = self.value(value)
|
61 |
+
|
62 |
+
# (batch_size, max_len, 512) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
|
63 |
+
query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
|
64 |
+
key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
|
65 |
+
value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
|
66 |
+
|
67 |
+
# (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
|
68 |
+
scores = torch.matmul(query, key.permute(0,1,3,2)) / math.sqrt(query.size(-1))
|
69 |
+
scores = scores.masked_fill(mask == 0, -1e9) # (batch_size, h, max_len, max_len)
|
70 |
+
weights = F.softmax(scores, dim = -1) # (batch_size, h, max_len, max_len)
|
71 |
+
weights = self.dropout(weights)
|
72 |
+
# (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
|
73 |
+
context = torch.matmul(weights, value)
|
74 |
+
# (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, h * d_k)
|
75 |
+
context = context.permute(0,2,1,3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)
|
76 |
+
# (batch_size, max_len, h * d_k)
|
77 |
+
interacted = self.concat(context)
|
78 |
+
return interacted
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
class FeedForward(nn.Module):
|
83 |
+
|
84 |
+
def __init__(self, d_model, middle_dim = 2048):
|
85 |
+
super(FeedForward, self).__init__()
|
86 |
+
|
87 |
+
self.fc1 = nn.Linear(d_model, middle_dim)
|
88 |
+
self.fc2 = nn.Linear(middle_dim, d_model)
|
89 |
+
self.dropout = nn.Dropout(0.1)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
out = F.relu(self.fc1(x))
|
93 |
+
out = self.fc2(self.dropout(out))
|
94 |
+
return out
|
95 |
+
|
96 |
+
|
97 |
+
class EncoderLayer(nn.Module):
|
98 |
+
|
99 |
+
def __init__(self, d_model, heads):
|
100 |
+
super(EncoderLayer, self).__init__()
|
101 |
+
self.layernorm = nn.LayerNorm(d_model)
|
102 |
+
self.self_multihead = MultiHeadAttention(heads, d_model)
|
103 |
+
self.feed_forward = FeedForward(d_model)
|
104 |
+
self.dropout = nn.Dropout(0.1)
|
105 |
+
|
106 |
+
def forward(self, embeddings, mask):
|
107 |
+
interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
|
108 |
+
interacted = self.layernorm(interacted + embeddings)
|
109 |
+
feed_forward_out = self.dropout(self.feed_forward(interacted))
|
110 |
+
encoded = self.layernorm(feed_forward_out + interacted)
|
111 |
+
return encoded
|
112 |
+
|
113 |
+
|
114 |
+
class DecoderLayer(nn.Module):
|
115 |
+
|
116 |
+
def __init__(self, d_model, heads):
|
117 |
+
super(DecoderLayer, self).__init__()
|
118 |
+
self.layernorm = nn.LayerNorm(d_model)
|
119 |
+
self.self_multihead = MultiHeadAttention(heads, d_model)
|
120 |
+
self.src_multihead = MultiHeadAttention(heads, d_model)
|
121 |
+
self.feed_forward = FeedForward(d_model)
|
122 |
+
self.dropout = nn.Dropout(0.1)
|
123 |
+
|
124 |
+
def forward(self, embeddings, encoded, src_mask, target_mask):
|
125 |
+
query = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, target_mask))
|
126 |
+
query = self.layernorm(query + embeddings)
|
127 |
+
interacted = self.dropout(self.src_multihead(query, encoded, encoded, src_mask))
|
128 |
+
interacted = self.layernorm(interacted + query)
|
129 |
+
feed_forward_out = self.dropout(self.feed_forward(interacted))
|
130 |
+
decoded = self.layernorm(feed_forward_out + interacted)
|
131 |
+
return decoded
|
132 |
+
|
133 |
+
|
134 |
+
class Transformer(nn.Module):
|
135 |
+
|
136 |
+
def __init__(self, d_model, heads, num_layers, word_map):
|
137 |
+
super(Transformer, self).__init__()
|
138 |
+
|
139 |
+
self.d_model = d_model
|
140 |
+
self.vocab_size = len(word_map)
|
141 |
+
self.embed = Embeddings(self.vocab_size, d_model)
|
142 |
+
self.encoder = nn.ModuleList([EncoderLayer(d_model, heads) for _ in range(num_layers)])
|
143 |
+
self.decoder = nn.ModuleList([DecoderLayer(d_model, heads) for _ in range(num_layers)])
|
144 |
+
self.logit = nn.Linear(d_model, self.vocab_size)
|
145 |
+
|
146 |
+
def encode(self, src_words, src_mask):
|
147 |
+
src_embeddings = self.embed(src_words)
|
148 |
+
for layer in self.encoder:
|
149 |
+
src_embeddings = layer(src_embeddings, src_mask)
|
150 |
+
return src_embeddings
|
151 |
+
|
152 |
+
def decode(self, target_words, target_mask, src_embeddings, src_mask):
|
153 |
+
tgt_embeddings = self.embed(target_words)
|
154 |
+
for layer in self.decoder:
|
155 |
+
tgt_embeddings = layer(tgt_embeddings, src_embeddings, src_mask, target_mask)
|
156 |
+
return tgt_embeddings
|
157 |
+
|
158 |
+
def forward(self, src_words, src_mask, target_words, target_mask):
|
159 |
+
encoded = self.encode(src_words, src_mask)
|
160 |
+
decoded = self.decode(target_words, target_mask, encoded, src_mask)
|
161 |
+
out = F.log_softmax(self.logit(decoded), dim = 2)
|
162 |
+
return out
|
utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
import torch.utils.data
|
5 |
+
import json
|
6 |
+
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
|
9 |
+
class Dataset(Dataset):
|
10 |
+
|
11 |
+
def __init__(self):
|
12 |
+
|
13 |
+
self.pairs = json.load(open('pairs_encoded.json'))
|
14 |
+
self.dataset_size = len(self.pairs)
|
15 |
+
|
16 |
+
def __getitem__(self, i):
|
17 |
+
|
18 |
+
question = torch.LongTensor(self.pairs[i][0])
|
19 |
+
reply = torch.LongTensor(self.pairs[i][1])
|
20 |
+
|
21 |
+
return question, reply
|
22 |
+
|
23 |
+
def __len__(self):
|
24 |
+
return self.dataset_size
|
25 |
+
|
26 |
+
|
27 |
+
def create_masks(question, reply_input, reply_target):
|
28 |
+
|
29 |
+
def subsequent_mask(size):
|
30 |
+
mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
|
31 |
+
return mask.unsqueeze(0)
|
32 |
+
|
33 |
+
question_mask = (question!=0).to(device)
|
34 |
+
question_mask = question_mask.unsqueeze(1).unsqueeze(1) # (batch_size, 1, 1, max_words)
|
35 |
+
|
36 |
+
reply_input_mask = reply_input!=0
|
37 |
+
reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words)
|
38 |
+
reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data)
|
39 |
+
reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words)
|
40 |
+
reply_target_mask = reply_target!=0 # (batch_size, max_words)
|
41 |
+
|
42 |
+
return question_mask, reply_input_mask, reply_target_mask
|
43 |
+
|
44 |
+
|
45 |
+
class AdamWarmup:
|
46 |
+
|
47 |
+
def __init__(self, model_size, warmup_steps, optimizer):
|
48 |
+
|
49 |
+
self.model_size = model_size
|
50 |
+
self.warmup_steps = warmup_steps
|
51 |
+
self.optimizer = optimizer
|
52 |
+
self.current_step = 0
|
53 |
+
self.lr = 0
|
54 |
+
|
55 |
+
def get_lr(self):
|
56 |
+
return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
|
57 |
+
|
58 |
+
def step(self):
|
59 |
+
# Increment the number of steps each time we call the step function
|
60 |
+
self.current_step += 1
|
61 |
+
lr = self.get_lr()
|
62 |
+
for param_group in self.optimizer.param_groups:
|
63 |
+
param_group['lr'] = lr
|
64 |
+
# update the learning rate
|
65 |
+
self.lr = lr
|
66 |
+
self.optimizer.step()
|
67 |
+
|
68 |
+
class LossWithLS(nn.Module):
|
69 |
+
|
70 |
+
def __init__(self, size, smooth):
|
71 |
+
super(LossWithLS, self).__init__()
|
72 |
+
self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
|
73 |
+
self.confidence = 1.0 - smooth
|
74 |
+
self.smooth = smooth
|
75 |
+
self.size = size
|
76 |
+
|
77 |
+
def forward(self, prediction, target, mask):
|
78 |
+
"""
|
79 |
+
prediction of shape: (batch_size, max_words, vocab_size)
|
80 |
+
target and mask of shape: (batch_size, max_words)
|
81 |
+
"""
|
82 |
+
prediction = prediction.view(-1, prediction.size(-1)) # (batch_size * max_words, vocab_size)
|
83 |
+
target = target.contiguous().view(-1) # (batch_size * max_words)
|
84 |
+
mask = mask.float()
|
85 |
+
mask = mask.view(-1) # (batch_size * max_words)
|
86 |
+
labels = prediction.data.clone()
|
87 |
+
labels.fill_(self.smooth / (self.size - 1))
|
88 |
+
labels.scatter_(1, target.data.unsqueeze(1), self.confidence)
|
89 |
+
loss = self.criterion(prediction, labels) # (batch_size * max_words, vocab_size)
|
90 |
+
loss = (loss.sum(1) * mask).sum() / mask.sum()
|
91 |
+
return loss
|