TheComputerMan commited on
Commit
3ae2f30
1 Parent(s): 83cd2da

Upload Attention.py

Browse files
Files changed (1) hide show
  1. Attention.py +324 -0
Attention.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Shigeki Karita, 2019
2
+ # Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux, 2021
4
+
5
+ """Multi-Head Attention layer definition."""
6
+
7
+ import math
8
+
9
+ import numpy
10
+ import torch
11
+ from torch import nn
12
+
13
+ from Utility.utils import make_non_pad_mask
14
+
15
+
16
+ class MultiHeadedAttention(nn.Module):
17
+ """
18
+ Multi-Head Attention layer.
19
+
20
+ Args:
21
+ n_head (int): The number of heads.
22
+ n_feat (int): The number of features.
23
+ dropout_rate (float): Dropout rate.
24
+ """
25
+
26
+ def __init__(self, n_head, n_feat, dropout_rate):
27
+ """
28
+ Construct an MultiHeadedAttention object.
29
+ """
30
+ super(MultiHeadedAttention, self).__init__()
31
+ assert n_feat % n_head == 0
32
+ # We assume d_v always equals d_k
33
+ self.d_k = n_feat // n_head
34
+ self.h = n_head
35
+ self.linear_q = nn.Linear(n_feat, n_feat)
36
+ self.linear_k = nn.Linear(n_feat, n_feat)
37
+ self.linear_v = nn.Linear(n_feat, n_feat)
38
+ self.linear_out = nn.Linear(n_feat, n_feat)
39
+ self.attn = None
40
+ self.dropout = nn.Dropout(p=dropout_rate)
41
+
42
+ def forward_qkv(self, query, key, value):
43
+ """
44
+ Transform query, key and value.
45
+
46
+ Args:
47
+ query (torch.Tensor): Query tensor (#batch, time1, size).
48
+ key (torch.Tensor): Key tensor (#batch, time2, size).
49
+ value (torch.Tensor): Value tensor (#batch, time2, size).
50
+
51
+ Returns:
52
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
53
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
54
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
55
+ """
56
+ n_batch = query.size(0)
57
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
58
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
59
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
60
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
61
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
62
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
63
+
64
+ return q, k, v
65
+
66
+ def forward_attention(self, value, scores, mask):
67
+ """
68
+ Compute attention context vector.
69
+
70
+ Args:
71
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
72
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
73
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
74
+
75
+ Returns:
76
+ torch.Tensor: Transformed value (#batch, time1, d_model)
77
+ weighted by the attention score (#batch, time1, time2).
78
+ """
79
+ n_batch = value.size(0)
80
+ if mask is not None:
81
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
82
+ min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
83
+ scores = scores.masked_fill(mask, min_value)
84
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
85
+ else:
86
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
87
+
88
+ p_attn = self.dropout(self.attn)
89
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
90
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)) # (batch, time1, d_model)
91
+
92
+ return self.linear_out(x) # (batch, time1, d_model)
93
+
94
+ def forward(self, query, key, value, mask):
95
+ """
96
+ Compute scaled dot product attention.
97
+
98
+ Args:
99
+ query (torch.Tensor): Query tensor (#batch, time1, size).
100
+ key (torch.Tensor): Key tensor (#batch, time2, size).
101
+ value (torch.Tensor): Value tensor (#batch, time2, size).
102
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
103
+ (#batch, time1, time2).
104
+
105
+ Returns:
106
+ torch.Tensor: Output tensor (#batch, time1, d_model).
107
+ """
108
+ q, k, v = self.forward_qkv(query, key, value)
109
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
110
+ return self.forward_attention(v, scores, mask)
111
+
112
+
113
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
114
+ """
115
+ Multi-Head Attention layer with relative position encoding.
116
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
117
+ Paper: https://arxiv.org/abs/1901.02860
118
+ Args:
119
+ n_head (int): The number of heads.
120
+ n_feat (int): The number of features.
121
+ dropout_rate (float): Dropout rate.
122
+ zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
123
+ """
124
+
125
+ def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
126
+ """Construct an RelPositionMultiHeadedAttention object."""
127
+ super().__init__(n_head, n_feat, dropout_rate)
128
+ self.zero_triu = zero_triu
129
+ # linear transformation for positional encoding
130
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
131
+ # these two learnable bias are used in matrix c and matrix d
132
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
133
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
134
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
135
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
136
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
137
+
138
+ def rel_shift(self, x):
139
+ """
140
+ Compute relative positional encoding.
141
+ Args:
142
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
143
+ time1 means the length of query vector.
144
+ Returns:
145
+ torch.Tensor: Output tensor.
146
+ """
147
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
148
+ x_padded = torch.cat([zero_pad, x], dim=-1)
149
+
150
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
151
+ x = x_padded[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1] # only keep the positions from 0 to time2
152
+
153
+ if self.zero_triu:
154
+ ones = torch.ones((x.size(2), x.size(3)), device=x.device)
155
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
156
+
157
+ return x
158
+
159
+ def forward(self, query, key, value, pos_emb, mask):
160
+ """
161
+ Compute 'Scaled Dot Product Attention' with rel. positional encoding.
162
+ Args:
163
+ query (torch.Tensor): Query tensor (#batch, time1, size).
164
+ key (torch.Tensor): Key tensor (#batch, time2, size).
165
+ value (torch.Tensor): Value tensor (#batch, time2, size).
166
+ pos_emb (torch.Tensor): Positional embedding tensor
167
+ (#batch, 2*time1-1, size).
168
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
169
+ (#batch, time1, time2).
170
+ Returns:
171
+ torch.Tensor: Output tensor (#batch, time1, d_model).
172
+ """
173
+ q, k, v = self.forward_qkv(query, key, value)
174
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
175
+
176
+ n_batch_pos = pos_emb.size(0)
177
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
178
+ p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
179
+
180
+ # (batch, head, time1, d_k)
181
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
182
+ # (batch, head, time1, d_k)
183
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
184
+
185
+ # compute attention score
186
+ # first compute matrix a and matrix c
187
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
188
+ # (batch, head, time1, time2)
189
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
190
+
191
+ # compute matrix b and matrix d
192
+ # (batch, head, time1, 2*time1-1)
193
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
194
+ matrix_bd = self.rel_shift(matrix_bd)
195
+
196
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
197
+
198
+ return self.forward_attention(v, scores, mask)
199
+
200
+
201
+ class GuidedAttentionLoss(torch.nn.Module):
202
+ """
203
+ Guided attention loss function module.
204
+
205
+ This module calculates the guided attention loss described
206
+ in `Efficiently Trainable Text-to-Speech System Based
207
+ on Deep Convolutional Networks with Guided Attention`_,
208
+ which forces the attention to be diagonal.
209
+
210
+ .. _`Efficiently Trainable Text-to-Speech System
211
+ Based on Deep Convolutional Networks with Guided Attention`:
212
+ https://arxiv.org/abs/1710.08969
213
+ """
214
+
215
+ def __init__(self, sigma=0.4, alpha=1.0):
216
+ """
217
+ Initialize guided attention loss module.
218
+
219
+ Args:
220
+ sigma (float, optional): Standard deviation to control
221
+ how close attention to a diagonal.
222
+ alpha (float, optional): Scaling coefficient (lambda).
223
+ reset_always (bool, optional): Whether to always reset masks.
224
+ """
225
+ super(GuidedAttentionLoss, self).__init__()
226
+ self.sigma = sigma
227
+ self.alpha = alpha
228
+ self.guided_attn_masks = None
229
+ self.masks = None
230
+
231
+ def _reset_masks(self):
232
+ self.guided_attn_masks = None
233
+ self.masks = None
234
+
235
+ def forward(self, att_ws, ilens, olens):
236
+ """
237
+ Calculate forward propagation.
238
+
239
+ Args:
240
+ att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in).
241
+ ilens (LongTensor): Batch of input lenghts (B,).
242
+ olens (LongTensor): Batch of output lenghts (B,).
243
+
244
+ Returns:
245
+ Tensor: Guided attention loss value.
246
+ """
247
+ self._reset_masks()
248
+ self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to(att_ws.device)
249
+ self.masks = self._make_masks(ilens, olens).to(att_ws.device)
250
+ losses = self.guided_attn_masks * att_ws
251
+ loss = torch.mean(losses.masked_select(self.masks))
252
+ self._reset_masks()
253
+ return self.alpha * loss
254
+
255
+ def _make_guided_attention_masks(self, ilens, olens):
256
+ n_batches = len(ilens)
257
+ max_ilen = max(ilens)
258
+ max_olen = max(olens)
259
+ guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=ilens.device)
260
+ for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
261
+ guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma)
262
+ return guided_attn_masks
263
+
264
+ @staticmethod
265
+ def _make_guided_attention_mask(ilen, olen, sigma):
266
+ """
267
+ Make guided attention mask.
268
+ """
269
+ grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device).float(), torch.arange(ilen, device=ilen.device).float())
270
+ return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2)))
271
+
272
+ @staticmethod
273
+ def _make_masks(ilens, olens):
274
+ """
275
+ Make masks indicating non-padded part.
276
+
277
+ Args:
278
+ ilens (LongTensor or List): Batch of lengths (B,).
279
+ olens (LongTensor or List): Batch of lengths (B,).
280
+
281
+ Returns:
282
+ Tensor: Mask tensor indicating non-padded part.
283
+ dtype=torch.uint8 in PyTorch 1.2-
284
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
285
+ """
286
+ in_masks = make_non_pad_mask(ilens, device=ilens.device) # (B, T_in)
287
+ out_masks = make_non_pad_mask(olens, device=olens.device) # (B, T_out)
288
+ return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
289
+
290
+
291
+ class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
292
+ """
293
+ Guided attention loss function module for multi head attention.
294
+
295
+ Args:
296
+ sigma (float, optional): Standard deviation to control
297
+ how close attention to a diagonal.
298
+ alpha (float, optional): Scaling coefficient (lambda).
299
+ reset_always (bool, optional): Whether to always reset masks.
300
+ """
301
+
302
+ def forward(self, att_ws, ilens, olens):
303
+ """
304
+ Calculate forward propagation.
305
+
306
+ Args:
307
+ att_ws (Tensor):
308
+ Batch of multi head attention weights (B, H, T_max_out, T_max_in).
309
+ ilens (LongTensor): Batch of input lenghts (B,).
310
+ olens (LongTensor): Batch of output lenghts (B,).
311
+
312
+ Returns:
313
+ Tensor: Guided attention loss value.
314
+ """
315
+ if self.guided_attn_masks is None:
316
+ self.guided_attn_masks = (self._make_guided_attention_masks(ilens, olens).to(att_ws.device).unsqueeze(1))
317
+ if self.masks is None:
318
+ self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1)
319
+ losses = self.guided_attn_masks * att_ws
320
+ loss = torch.mean(losses.masked_select(self.masks))
321
+ if self.reset_always:
322
+ self._reset_masks()
323
+
324
+ return self.alpha * loss