FrankZxShen commited on
Commit
650b192
1 Parent(s): f29e685

Update attentions.py

Browse files
Files changed (1) hide show
  1. attentions.py +0 -240
attentions.py CHANGED
@@ -9,8 +9,6 @@ import commons
9
  import modules
10
  from modules import LayerNorm
11
 
12
- from loralib_tmp import layers
13
- import loralib as lora
14
 
15
 
16
  class Encoder(nn.Module):
@@ -49,41 +47,6 @@ class Encoder(nn.Module):
49
  x = x * x_mask
50
  return x
51
 
52
- class Encoder_lora(nn.Module):
53
- def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
54
- super().__init__()
55
- self.hidden_channels = hidden_channels
56
- self.filter_channels = filter_channels
57
- self.n_heads = n_heads
58
- self.n_layers = n_layers
59
- self.kernel_size = kernel_size
60
- self.p_dropout = p_dropout
61
- self.window_size = window_size
62
-
63
- self.drop = nn.Dropout(p_dropout)
64
- self.attn_layers = nn.ModuleList()
65
- self.norm_layers_1 = nn.ModuleList()
66
- self.ffn_layers = nn.ModuleList()
67
- self.norm_layers_2 = nn.ModuleList()
68
- for i in range(self.n_layers):
69
- self.attn_layers.append(MultiHeadAttention_lora(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
70
- self.norm_layers_1.append(LayerNorm(hidden_channels))
71
- self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
72
- self.norm_layers_2.append(LayerNorm(hidden_channels))
73
-
74
- def forward(self, x, x_mask):
75
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
76
- x = x * x_mask
77
- for i in range(self.n_layers):
78
- y = self.attn_layers[i](x, x, attn_mask)
79
- y = self.drop(y)
80
- x = self.norm_layers_1[i](x + y)
81
-
82
- y = self.ffn_layers[i](x, x_mask)
83
- y = self.drop(y)
84
- x = self.norm_layers_2[i](x + y)
85
- x = x * x_mask
86
- return x
87
 
88
  class Decoder(nn.Module):
89
  def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
@@ -135,211 +98,8 @@ class Decoder(nn.Module):
135
  x = x * x_mask
136
  return x
137
 
138
- # class Decoder_lora(nn.Module):
139
- # def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
140
- # super().__init__()
141
- # self.hidden_channels = hidden_channels
142
- # self.filter_channels = filter_channels
143
- # self.n_heads = n_heads
144
- # self.n_layers = n_layers
145
- # self.kernel_size = kernel_size
146
- # self.p_dropout = p_dropout
147
- # self.proximal_bias = proximal_bias
148
- # self.proximal_init = proximal_init
149
-
150
- # self.drop = nn.Dropout(p_dropout)
151
- # self.self_attn_layers = nn.ModuleList()
152
- # self.norm_layers_0 = nn.ModuleList()
153
- # self.encdec_attn_layers = nn.ModuleList()
154
- # self.norm_layers_1 = nn.ModuleList()
155
- # self.ffn_layers = nn.ModuleList()
156
- # self.norm_layers_2 = nn.ModuleList()
157
- # for i in range(self.n_layers):
158
- # self.self_attn_layers.append(MultiHeadAttention_lora(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
159
- # self.norm_layers_0.append(LayerNorm(hidden_channels))
160
- # self.encdec_attn_layers.append(MultiHeadAttention_lora(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
161
- # self.norm_layers_1.append(LayerNorm(hidden_channels))
162
- # self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
163
- # self.norm_layers_2.append(LayerNorm(hidden_channels))
164
-
165
- # def forward(self, x, x_mask, h, h_mask):
166
- # """
167
- # x: decoder input
168
- # h: encoder output
169
- # """
170
- # self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
171
- # encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
172
- # x = x * x_mask
173
- # for i in range(self.n_layers):
174
- # y = self.self_attn_layers[i](x, x, self_attn_mask)
175
- # y = self.drop(y)
176
- # x = self.norm_layers_0[i](x + y)
177
-
178
- # y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
179
- # y = self.drop(y)
180
- # x = self.norm_layers_1[i](x + y)
181
-
182
- # y = self.ffn_layers[i](x, x_mask)
183
- # y = self.drop(y)
184
- # x = self.norm_layers_2[i](x + y)
185
- # x = x * x_mask
186
- # return x
187
-
188
- class MultiHeadAttention_lora(nn.Module):
189
- def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
190
- super().__init__()
191
- assert channels % n_heads == 0
192
-
193
- self.channels = channels
194
- self.out_channels = out_channels
195
- self.n_heads = n_heads
196
- self.p_dropout = p_dropout
197
- self.window_size = window_size
198
- self.heads_share = heads_share
199
- self.block_length = block_length
200
- self.proximal_bias = proximal_bias
201
- self.proximal_init = proximal_init
202
- self.attn = None
203
-
204
- self.k_channels = channels // n_heads
205
- self.conv_q = nn.Conv1d(channels, channels, 1)
206
- self.conv_k = nn.Conv1d(channels, channels, 1)
207
- self.conv_v = nn.Conv1d(channels, channels, 1)
208
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
209
- self.drop = nn.Dropout(p_dropout)
210
-
211
- if window_size is not None:
212
- n_heads_rel = 1 if heads_share else n_heads
213
- rel_stddev = self.k_channels**-0.5
214
- self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
215
- self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
216
-
217
- nn.init.xavier_uniform_(self.conv_q.weight)
218
- nn.init.xavier_uniform_(self.conv_k.weight)
219
- nn.init.xavier_uniform_(self.conv_v.weight)
220
- if proximal_init:
221
- with torch.no_grad():
222
- self.conv_k.weight.copy_(self.conv_q.weight)
223
- self.conv_k.bias.copy_(self.conv_q.bias)
224
-
225
- def forward(self, x, c, attn_mask=None):
226
- q = self.conv_q(x)
227
- k = self.conv_k(c)
228
- v = self.conv_v(c)
229
-
230
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
231
- x = self.conv_o(x)
232
- return x
233
-
234
- def attention(self, query, key, value, mask=None):
235
- # reshape [b, d, t] -> [b, n_h, t, d_k]
236
- b, d, t_s, t_t = (*key.size(), query.size(2))
237
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
238
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
239
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
240
-
241
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
242
- if self.window_size is not None:
243
- assert t_s == t_t, "Relative attention is only available for self-attention."
244
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
245
- rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
246
- scores_local = self._relative_position_to_absolute_position(rel_logits)
247
- scores = scores + scores_local
248
- if self.proximal_bias:
249
- assert t_s == t_t, "Proximal bias is only available for self-attention."
250
- scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
251
- if mask is not None:
252
- scores = scores.masked_fill(mask == 0, -1e4)
253
- if self.block_length is not None:
254
- assert t_s == t_t, "Local attention is only available for self-attention."
255
- block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
256
- scores = scores.masked_fill(block_mask == 0, -1e4)
257
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
258
- p_attn = self.drop(p_attn)
259
- output = torch.matmul(p_attn, value)
260
- if self.window_size is not None:
261
- relative_weights = self._absolute_position_to_relative_position(p_attn)
262
- value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
263
- output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
264
- output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
265
- return output, p_attn
266
-
267
- def _matmul_with_relative_values(self, x, y):
268
- """
269
- x: [b, h, l, m]
270
- y: [h or 1, m, d]
271
- ret: [b, h, l, d]
272
- """
273
- ret = torch.matmul(x, y.unsqueeze(0))
274
- return ret
275
-
276
- def _matmul_with_relative_keys(self, x, y):
277
- """
278
- x: [b, h, l, d]
279
- y: [h or 1, m, d]
280
- ret: [b, h, l, m]
281
- """
282
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
283
- return ret
284
 
285
- def _get_relative_embeddings(self, relative_embeddings, length):
286
- max_relative_position = 2 * self.window_size + 1
287
- # Pad first before slice to avoid using cond ops.
288
- pad_length = max(length - (self.window_size + 1), 0)
289
- slice_start_position = max((self.window_size + 1) - length, 0)
290
- slice_end_position = slice_start_position + 2 * length - 1
291
- if pad_length > 0:
292
- padded_relative_embeddings = F.pad(
293
- relative_embeddings,
294
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
295
- else:
296
- padded_relative_embeddings = relative_embeddings
297
- used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
298
- return used_relative_embeddings
299
-
300
- def _relative_position_to_absolute_position(self, x):
301
- """
302
- x: [b, h, l, 2*l-1]
303
- ret: [b, h, l, l]
304
- """
305
- batch, heads, length, _ = x.size()
306
- # Concat columns of pad to shift from relative to absolute indexing.
307
- x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
308
-
309
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
310
- x_flat = x.view([batch, heads, length * 2 * length])
311
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
312
-
313
- # Reshape and slice out the padded elements.
314
- x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
315
- return x_final
316
-
317
- def _absolute_position_to_relative_position(self, x):
318
- """
319
- x: [b, h, l, l]
320
- ret: [b, h, l, 2*l-1]
321
- """
322
- batch, heads, length, _ = x.size()
323
- # padd along column
324
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
325
- x_flat = x.view([batch, heads, length**2 + length*(length -1)])
326
- # add 0's in the beginning that will skew the elements after reshape
327
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
328
- x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
329
- return x_final
330
-
331
- def _attention_bias_proximal(self, length):
332
- """Bias for self-attention to encourage attention to close positions.
333
- Args:
334
- length: an integer scalar.
335
- Returns:
336
- a Tensor with shape [1, 1, length, length]
337
- """
338
- r = torch.arange(length, dtype=torch.float32)
339
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
340
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
341
 
342
- # 注意改回去
343
  class MultiHeadAttention(nn.Module):
344
  def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
345
  super().__init__()
 
9
  import modules
10
  from modules import LayerNorm
11
 
 
 
12
 
13
 
14
  class Encoder(nn.Module):
 
47
  x = x * x_mask
48
  return x
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  class Decoder(nn.Module):
52
  def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
 
98
  x = x * x_mask
99
  return x
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
 
103
  class MultiHeadAttention(nn.Module):
104
  def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
105
  super().__init__()