DHRUV SHEKHAWAT commited on
Commit
1dd09ef
·
1 Parent(s): 52f9f0f

Upload 2 files

Browse files
Files changed (2) hide show
  1. models.py +162 -0
  2. utils.py +91 -0
models.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import torch.nn.functional as F
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+
9
+ class Embeddings(nn.Module):
10
+ """
11
+ Implements embeddings of the words and adds their positional encodings.
12
+ """
13
+ def __init__(self, vocab_size, d_model, max_len = 50):
14
+ super(Embeddings, self).__init__()
15
+ self.d_model = d_model
16
+ self.dropout = nn.Dropout(0.1)
17
+ self.embed = nn.Embedding(vocab_size, d_model)
18
+ self.pe = self.create_positinal_encoding(max_len, self.d_model)
19
+ self.dropout = nn.Dropout(0.1)
20
+
21
+ def create_positinal_encoding(self, max_len, d_model):
22
+ pe = torch.zeros(max_len, d_model).to(device)
23
+ for pos in range(max_len): # for each position of the word
24
+ for i in range(0, d_model, 2): # for each dimension of the each position
25
+ pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
26
+ pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
27
+ pe = pe.unsqueeze(0) # include the batch size
28
+ return pe
29
+
30
+ def forward(self, encoded_words):
31
+ embedding = self.embed(encoded_words) * math.sqrt(self.d_model)
32
+ embedding += self.pe[:, :embedding.size(1)] # pe will automatically be expanded with the same batch size as encoded_words
33
+ embedding = self.dropout(embedding)
34
+ return embedding
35
+
36
+
37
+
38
+ class MultiHeadAttention(nn.Module):
39
+
40
+ def __init__(self, heads, d_model):
41
+
42
+ super(MultiHeadAttention, self).__init__()
43
+ assert d_model % heads == 0
44
+ self.d_k = d_model // heads
45
+ self.heads = heads
46
+ self.dropout = nn.Dropout(0.1)
47
+ self.query = nn.Linear(d_model, d_model)
48
+ self.key = nn.Linear(d_model, d_model)
49
+ self.value = nn.Linear(d_model, d_model)
50
+ self.concat = nn.Linear(d_model, d_model)
51
+
52
+ def forward(self, query, key, value, mask):
53
+ """
54
+ query, key, value of shape: (batch_size, max_len, 512)
55
+ mask of shape: (batch_size, 1, 1, max_words)
56
+ """
57
+ # (batch_size, max_len, 512)
58
+ query = self.query(query)
59
+ key = self.key(key)
60
+ value = self.value(value)
61
+
62
+ # (batch_size, max_len, 512) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
63
+ query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
64
+ key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
65
+ value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
66
+
67
+ # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
68
+ scores = torch.matmul(query, key.permute(0,1,3,2)) / math.sqrt(query.size(-1))
69
+ scores = scores.masked_fill(mask == 0, -1e9) # (batch_size, h, max_len, max_len)
70
+ weights = F.softmax(scores, dim = -1) # (batch_size, h, max_len, max_len)
71
+ weights = self.dropout(weights)
72
+ # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
73
+ context = torch.matmul(weights, value)
74
+ # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, h * d_k)
75
+ context = context.permute(0,2,1,3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)
76
+ # (batch_size, max_len, h * d_k)
77
+ interacted = self.concat(context)
78
+ return interacted
79
+
80
+
81
+
82
+ class FeedForward(nn.Module):
83
+
84
+ def __init__(self, d_model, middle_dim = 2048):
85
+ super(FeedForward, self).__init__()
86
+
87
+ self.fc1 = nn.Linear(d_model, middle_dim)
88
+ self.fc2 = nn.Linear(middle_dim, d_model)
89
+ self.dropout = nn.Dropout(0.1)
90
+
91
+ def forward(self, x):
92
+ out = F.relu(self.fc1(x))
93
+ out = self.fc2(self.dropout(out))
94
+ return out
95
+
96
+
97
+ class EncoderLayer(nn.Module):
98
+
99
+ def __init__(self, d_model, heads):
100
+ super(EncoderLayer, self).__init__()
101
+ self.layernorm = nn.LayerNorm(d_model)
102
+ self.self_multihead = MultiHeadAttention(heads, d_model)
103
+ self.feed_forward = FeedForward(d_model)
104
+ self.dropout = nn.Dropout(0.1)
105
+
106
+ def forward(self, embeddings, mask):
107
+ interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
108
+ interacted = self.layernorm(interacted + embeddings)
109
+ feed_forward_out = self.dropout(self.feed_forward(interacted))
110
+ encoded = self.layernorm(feed_forward_out + interacted)
111
+ return encoded
112
+
113
+
114
+ class DecoderLayer(nn.Module):
115
+
116
+ def __init__(self, d_model, heads):
117
+ super(DecoderLayer, self).__init__()
118
+ self.layernorm = nn.LayerNorm(d_model)
119
+ self.self_multihead = MultiHeadAttention(heads, d_model)
120
+ self.src_multihead = MultiHeadAttention(heads, d_model)
121
+ self.feed_forward = FeedForward(d_model)
122
+ self.dropout = nn.Dropout(0.1)
123
+
124
+ def forward(self, embeddings, encoded, src_mask, target_mask):
125
+ query = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, target_mask))
126
+ query = self.layernorm(query + embeddings)
127
+ interacted = self.dropout(self.src_multihead(query, encoded, encoded, src_mask))
128
+ interacted = self.layernorm(interacted + query)
129
+ feed_forward_out = self.dropout(self.feed_forward(interacted))
130
+ decoded = self.layernorm(feed_forward_out + interacted)
131
+ return decoded
132
+
133
+
134
+ class Transformer(nn.Module):
135
+
136
+ def __init__(self, d_model, heads, num_layers, word_map):
137
+ super(Transformer, self).__init__()
138
+
139
+ self.d_model = d_model
140
+ self.vocab_size = len(word_map)
141
+ self.embed = Embeddings(self.vocab_size, d_model)
142
+ self.encoder = nn.ModuleList([EncoderLayer(d_model, heads) for _ in range(num_layers)])
143
+ self.decoder = nn.ModuleList([DecoderLayer(d_model, heads) for _ in range(num_layers)])
144
+ self.logit = nn.Linear(d_model, self.vocab_size)
145
+
146
+ def encode(self, src_words, src_mask):
147
+ src_embeddings = self.embed(src_words)
148
+ for layer in self.encoder:
149
+ src_embeddings = layer(src_embeddings, src_mask)
150
+ return src_embeddings
151
+
152
+ def decode(self, target_words, target_mask, src_embeddings, src_mask):
153
+ tgt_embeddings = self.embed(target_words)
154
+ for layer in self.decoder:
155
+ tgt_embeddings = layer(tgt_embeddings, src_embeddings, src_mask, target_mask)
156
+ return tgt_embeddings
157
+
158
+ def forward(self, src_words, src_mask, target_words, target_mask):
159
+ encoded = self.encode(src_words, src_mask)
160
+ decoded = self.decode(target_words, target_mask, encoded, src_mask)
161
+ out = F.log_softmax(self.logit(decoded), dim = 2)
162
+ return out
utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset
4
+ import torch.utils.data
5
+ import json
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ class Dataset(Dataset):
10
+
11
+ def __init__(self):
12
+
13
+ self.pairs = json.load(open('pairs_encoded.json'))
14
+ self.dataset_size = len(self.pairs)
15
+
16
+ def __getitem__(self, i):
17
+
18
+ question = torch.LongTensor(self.pairs[i][0])
19
+ reply = torch.LongTensor(self.pairs[i][1])
20
+
21
+ return question, reply
22
+
23
+ def __len__(self):
24
+ return self.dataset_size
25
+
26
+
27
+ def create_masks(question, reply_input, reply_target):
28
+
29
+ def subsequent_mask(size):
30
+ mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
31
+ return mask.unsqueeze(0)
32
+
33
+ question_mask = (question!=0).to(device)
34
+ question_mask = question_mask.unsqueeze(1).unsqueeze(1) # (batch_size, 1, 1, max_words)
35
+
36
+ reply_input_mask = reply_input!=0
37
+ reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words)
38
+ reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data)
39
+ reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words)
40
+ reply_target_mask = reply_target!=0 # (batch_size, max_words)
41
+
42
+ return question_mask, reply_input_mask, reply_target_mask
43
+
44
+
45
+ class AdamWarmup:
46
+
47
+ def __init__(self, model_size, warmup_steps, optimizer):
48
+
49
+ self.model_size = model_size
50
+ self.warmup_steps = warmup_steps
51
+ self.optimizer = optimizer
52
+ self.current_step = 0
53
+ self.lr = 0
54
+
55
+ def get_lr(self):
56
+ return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
57
+
58
+ def step(self):
59
+ # Increment the number of steps each time we call the step function
60
+ self.current_step += 1
61
+ lr = self.get_lr()
62
+ for param_group in self.optimizer.param_groups:
63
+ param_group['lr'] = lr
64
+ # update the learning rate
65
+ self.lr = lr
66
+ self.optimizer.step()
67
+
68
+ class LossWithLS(nn.Module):
69
+
70
+ def __init__(self, size, smooth):
71
+ super(LossWithLS, self).__init__()
72
+ self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
73
+ self.confidence = 1.0 - smooth
74
+ self.smooth = smooth
75
+ self.size = size
76
+
77
+ def forward(self, prediction, target, mask):
78
+ """
79
+ prediction of shape: (batch_size, max_words, vocab_size)
80
+ target and mask of shape: (batch_size, max_words)
81
+ """
82
+ prediction = prediction.view(-1, prediction.size(-1)) # (batch_size * max_words, vocab_size)
83
+ target = target.contiguous().view(-1) # (batch_size * max_words)
84
+ mask = mask.float()
85
+ mask = mask.view(-1) # (batch_size * max_words)
86
+ labels = prediction.data.clone()
87
+ labels.fill_(self.smooth / (self.size - 1))
88
+ labels.scatter_(1, target.data.unsqueeze(1), self.confidence)
89
+ loss = self.criterion(prediction, labels) # (batch_size * max_words, vocab_size)
90
+ loss = (loss.sum(1) * mask).sum() / mask.sum()
91
+ return loss