Hugo Flores commited on
Commit
04c5b94
1 Parent(s): 534a89c

remove wavenet, readability

Browse files
vampnet/modules/layers.py CHANGED
@@ -8,6 +8,24 @@ import torch.nn.functional as F
8
  from einops import rearrange
9
  from torch.nn.utils import weight_norm
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def num_params(model):
13
  return sum(p.numel() for p in model.parameters() if p.requires_grad)
 
8
  from einops import rearrange
9
  from torch.nn.utils import weight_norm
10
 
11
+ # Scripting this brings model speed up 1.4x
12
+ @torch.jit.script
13
+ def snake(x, alpha):
14
+ shape = x.shape
15
+ x = x.reshape(shape[0], shape[1], -1)
16
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
17
+ x = x.reshape(shape)
18
+ return x
19
+
20
+
21
+ class Snake1d(nn.Module):
22
+ def __init__(self, channels):
23
+ super().__init__()
24
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
25
+
26
+ def forward(self, x):
27
+ return snake(x, self.alpha)
28
+
29
 
30
  def num_params(model):
31
  return sum(p.numel() for p in model.parameters() if p.requires_grad)
vampnet/modules/transformer.py CHANGED
@@ -377,7 +377,7 @@ class TransformerStack(nn.Module):
377
  n_heads,
378
  bidirectional,
379
  is_decoder,
380
- has_relative_attention_bias=(i == 0),
381
  flash_attn=flash_attn,
382
  dropout=dropout,
383
  )
 
377
  n_heads,
378
  bidirectional,
379
  is_decoder,
380
+ has_relative_attention_bias=True if (i == 0) else False,
381
  flash_attn=flash_attn,
382
  dropout=dropout,
383
  )
vampnet/modules/wavenet.py DELETED
@@ -1,90 +0,0 @@
1
- import torch.nn as nn
2
- from einops import rearrange
3
-
4
- from voicegpt.nn import WaveNet
5
-
6
- class AutoregMLP(nn.Module):
7
- """Implements an autoregressive ConvNet decoder
8
- Refer to SampleRNN (https://arxiv.org/abs/1612.07837) for motivation
9
- """
10
-
11
- def __init__(
12
- self,
13
- vocab_size: int,
14
- d_model: int,
15
- n_layers: int,
16
- n_fine_tokens: int = 6,
17
- n_tokens: int = 9,
18
- dropout: float = 0.1,
19
- activation: str = "gelu",
20
- causal: bool = True,
21
- ):
22
- super().__init__()
23
- self.n_fine = n_fine_tokens
24
- self.n_layers = n_layers
25
- self.upsampler = nn.Linear(d_model, d_model * n_fine_tokens)
26
-
27
- self.wavenet = WaveNet(
28
- d_model,
29
- d_model,
30
- d_model,
31
- n_layers,
32
- n_fine_tokens,
33
- dropout=dropout,
34
- activation=activation,
35
- causal=causal,
36
- )
37
- self.ff_output = nn.Linear(d_model, vocab_size * n_tokens, bias=False)
38
-
39
- def time_upsample(self, h_t_coarse):
40
- """Upsamples the conditioning hidden states to match the time resolution
41
- of output tokens
42
- Parameters
43
- ----------
44
- h_t_coarse : Tensor[B x T_coarse x D]
45
- Conditioning hidden states in coarse time-scale
46
- Returns
47
- -------
48
- Tensor[B x T_fine x D]
49
- Conditioning hidden states in fine time-scale
50
- """
51
- # Upsample the transformer hidden states to fine scale
52
- h_t_fine = rearrange(
53
- self.upsampler(h_t_coarse), "b t (n d) -> b (t n) d", n=self.n_fine
54
- )
55
- return h_t_fine
56
-
57
- def decode_logits(self, x_tm1, h_t_fine):
58
- """Decodes output logits conditioned on previous output
59
- tokens (upto timestep t-1) and conditioning hidden states
60
- using an autoregressive WaveNet
61
- Parameters
62
- ----------
63
- x_tm1 : Tensor[B x T x D]
64
- h_t_fine : Tensor[B x T x D]
65
- Returns
66
- -------
67
- Tensor[B x T x vocab_size]
68
- Predicted logits
69
- """
70
-
71
- # Compute wavenet layers and predict logits
72
- o_t = self.wavenet(x_tm1, h_t_fine)
73
- return self.ff_output(o_t)
74
-
75
- def forward(self, x_tm1, h_t_coarse):
76
- """Computes autoregressive conditional probability distribution
77
- using a WaveNet decoder
78
- Parameters
79
- ----------
80
- x_tm1 : Tensor[B x T_fine x D]
81
- Embeddings of tokens at fine time-scale
82
- h_t_coarse : Tensor[B x T_coarse x D]
83
- Hidden states at coarse time scale
84
- Returns
85
- -------
86
- Tensor[B x T_fine x vocab_size]
87
- Predicted logits at fine time-scale
88
- """
89
- h_t_fine = self.time_upsample(h_t_coarse)
90
- return self.decode_logits(x_tm1, h_t_fine)