TheComputerMan commited on
Commit
8aa300f
1 Parent(s): 5e25351

Upload PositionalEncoding.py

Browse files
Files changed (1) hide show
  1. PositionalEncoding.py +166 -0
PositionalEncoding.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from ESPNet
3
+ """
4
+
5
+ import math
6
+
7
+ import torch
8
+
9
+
10
+ class PositionalEncoding(torch.nn.Module):
11
+ """
12
+ Positional encoding.
13
+
14
+ Args:
15
+ d_model (int): Embedding dimension.
16
+ dropout_rate (float): Dropout rate.
17
+ max_len (int): Maximum input length.
18
+ reverse (bool): Whether to reverse the input position.
19
+ """
20
+
21
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
22
+ """
23
+ Construct an PositionalEncoding object.
24
+ """
25
+ super(PositionalEncoding, self).__init__()
26
+ self.d_model = d_model
27
+ self.reverse = reverse
28
+ self.xscale = math.sqrt(self.d_model)
29
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
30
+ self.pe = None
31
+ self.extend_pe(torch.tensor(0.0, device=d_model.device).expand(1, max_len))
32
+
33
+ def extend_pe(self, x):
34
+ """
35
+ Reset the positional encodings.
36
+ """
37
+ if self.pe is not None:
38
+ if self.pe.size(1) >= x.size(1):
39
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
40
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
41
+ return
42
+ pe = torch.zeros(x.size(1), self.d_model)
43
+ if self.reverse:
44
+ position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
45
+ else:
46
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
47
+ div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model))
48
+ pe[:, 0::2] = torch.sin(position * div_term)
49
+ pe[:, 1::2] = torch.cos(position * div_term)
50
+ pe = pe.unsqueeze(0)
51
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
52
+
53
+ def forward(self, x):
54
+ """
55
+ Add positional encoding.
56
+
57
+ Args:
58
+ x (torch.Tensor): Input tensor (batch, time, `*`).
59
+
60
+ Returns:
61
+ torch.Tensor: Encoded tensor (batch, time, `*`).
62
+ """
63
+ self.extend_pe(x)
64
+ x = x * self.xscale + self.pe[:, : x.size(1)]
65
+ return self.dropout(x)
66
+
67
+
68
+ class RelPositionalEncoding(torch.nn.Module):
69
+ """
70
+ Relative positional encoding module (new implementation).
71
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
72
+ See : Appendix B in https://arxiv.org/abs/1901.02860
73
+ Args:
74
+ d_model (int): Embedding dimension.
75
+ dropout_rate (float): Dropout rate.
76
+ max_len (int): Maximum input length.
77
+ """
78
+
79
+ def __init__(self, d_model, dropout_rate, max_len=5000):
80
+ """
81
+ Construct an PositionalEncoding object.
82
+ """
83
+ super(RelPositionalEncoding, self).__init__()
84
+ self.d_model = d_model
85
+ self.xscale = math.sqrt(self.d_model)
86
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
87
+ self.pe = None
88
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
89
+
90
+ def extend_pe(self, x):
91
+ """Reset the positional encodings."""
92
+ if self.pe is not None:
93
+ # self.pe contains both positive and negative parts
94
+ # the length of self.pe is 2 * input_len - 1
95
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
96
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
97
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
98
+ return
99
+ # Suppose `i` means to the position of query vecotr and `j` means the
100
+ # position of key vector. We use position relative positions when keys
101
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
102
+ pe_positive = torch.zeros(x.size(1), self.d_model, device=x.device)
103
+ pe_negative = torch.zeros(x.size(1), self.d_model, device=x.device)
104
+ position = torch.arange(0, x.size(1), dtype=torch.float32, device=x.device).unsqueeze(1)
105
+ div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32, device=x.device) * -(math.log(10000.0) / self.d_model))
106
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
107
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
108
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
109
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
110
+
111
+ # Reserve the order of positive indices and concat both positive and
112
+ # negative indices. This is used to support the shifting trick
113
+ # as in https://arxiv.org/abs/1901.02860
114
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
115
+ pe_negative = pe_negative[1:].unsqueeze(0)
116
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
117
+ self.pe = pe.to(dtype=x.dtype)
118
+
119
+ def forward(self, x):
120
+ """
121
+ Add positional encoding.
122
+ Args:
123
+ x (torch.Tensor): Input tensor (batch, time, `*`).
124
+ Returns:
125
+ torch.Tensor: Encoded tensor (batch, time, `*`).
126
+ """
127
+ self.extend_pe(x)
128
+ x = x * self.xscale
129
+ pos_emb = self.pe[:, self.pe.size(1) // 2 - x.size(1) + 1: self.pe.size(1) // 2 + x.size(1), ]
130
+ return self.dropout(x), self.dropout(pos_emb)
131
+
132
+
133
+ class ScaledPositionalEncoding(PositionalEncoding):
134
+ """
135
+ Scaled positional encoding module.
136
+
137
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
138
+
139
+ Args:
140
+ d_model (int): Embedding dimension.
141
+ dropout_rate (float): Dropout rate.
142
+ max_len (int): Maximum input length.
143
+
144
+ """
145
+
146
+ def __init__(self, d_model, dropout_rate, max_len=5000):
147
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
148
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
149
+
150
+ def reset_parameters(self):
151
+ self.alpha.data = torch.tensor(1.0)
152
+
153
+ def forward(self, x):
154
+ """
155
+ Add positional encoding.
156
+
157
+ Args:
158
+ x (torch.Tensor): Input tensor (batch, time, `*`).
159
+
160
+ Returns:
161
+ torch.Tensor: Encoded tensor (batch, time, `*`).
162
+
163
+ """
164
+ self.extend_pe(x)
165
+ x = x + self.alpha * self.pe[:, : x.size(1)]
166
+ return self.dropout(x)