Sifal commited on
Commit
c23a3d6
1 Parent(s): f637c62

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +81 -0
model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+ from torch.nn import Transformer
5
+
6
+ # helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
7
+ class PositionalEncoding(nn.Module):
8
+ def __init__(self,
9
+ emb_size: int,
10
+ dropout: float,
11
+ maxlen: int = 5000):
12
+ super(PositionalEncoding, self).__init__()
13
+ den = torch.exp(- torch.arange(0, emb_size, 2)* torch.log(10000) / emb_size)
14
+ pos = torch.arange(0, maxlen).reshape(maxlen, 1)
15
+ pos_embedding = torch.zeros((maxlen, emb_size))
16
+ pos_embedding[:, 0::2] = torch.sin(pos * den)
17
+ pos_embedding[:, 1::2] = torch.cos(pos * den)
18
+ pos_embedding = pos_embedding.unsqueeze(-2)
19
+
20
+ self.dropout = nn.Dropout(dropout)
21
+ self.register_buffer('pos_embedding', pos_embedding)
22
+
23
+ def forward(self, token_embedding: Tensor):
24
+ return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
25
+
26
+ # helper Module to convert tensor of input indices into corresponding tensor of token embeddings
27
+ class TokenEmbedding(nn.Module):
28
+ def __init__(self, vocab_size: int, emb_size):
29
+ super(TokenEmbedding, self).__init__()
30
+ self.embedding = nn.Embedding(vocab_size, emb_size)
31
+ self.emb_size = emb_size
32
+
33
+ def forward(self, tokens: Tensor):
34
+ return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
35
+
36
+ class Seq2SeqTransformer(nn.Module):
37
+ def __init__(self,
38
+ num_encoder_layers: int,
39
+ num_decoder_layers: int,
40
+ emb_size: int,
41
+ nhead: int,
42
+ src_vocab_size: int,
43
+ tgt_vocab_size: int,
44
+ dim_feedforward: int = 512,
45
+ dropout: float = 0.1):
46
+ super(Seq2SeqTransformer, self).__init__()
47
+ self.transformer = Transformer(d_model=emb_size,
48
+ nhead=nhead,
49
+ num_encoder_layers=num_encoder_layers,
50
+ num_decoder_layers=num_decoder_layers,
51
+ dim_feedforward=dim_feedforward,
52
+ dropout=dropout,
53
+ batch_first=True)
54
+ self.generator = nn.Linear(emb_size, tgt_vocab_size)
55
+ self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
56
+ self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
57
+ self.positional_encoding = PositionalEncoding(
58
+ emb_size, dropout=dropout)
59
+
60
+ def forward(self,
61
+ src: Tensor,
62
+ trg: Tensor,
63
+ src_mask: Tensor,
64
+ tgt_mask: Tensor,
65
+ src_padding_mask: Tensor,
66
+ tgt_padding_mask: Tensor,
67
+ memory_key_padding_mask: Tensor):
68
+ src_emb = self.positional_encoding(self.src_tok_emb(src))
69
+ tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
70
+ outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
71
+ src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
72
+ return self.generator(outs)
73
+
74
+ def encode(self, src: Tensor, src_mask: Tensor):
75
+ return self.transformer.encoder(self.positional_encoding(
76
+ self.src_tok_emb(src)), src_mask)
77
+
78
+ def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
79
+ return self.transformer.decoder(self.positional_encoding(
80
+ self.tgt_tok_emb(tgt)), memory,
81
+ tgt_mask)