TheComputerMan commited on
Commit
d09d67d
1 Parent(s): daa42a1

Upload PostNet.py

Browse files
Files changed (1) hide show
  1. PostNet.py +74 -0
PostNet.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from ESPNet
3
+ """
4
+
5
+ import torch
6
+
7
+
8
+ class PostNet(torch.nn.Module):
9
+ """
10
+ From Tacotron2
11
+
12
+ Postnet module for Spectrogram prediction network.
13
+
14
+ This is a module of Postnet in Spectrogram prediction network,
15
+ which described in `Natural TTS Synthesis by
16
+ Conditioning WaveNet on Mel Spectrogram Predictions`_.
17
+ The Postnet refines the predicted
18
+ Mel-filterbank of the decoder,
19
+ which helps to compensate the detail sturcture of spectrogram.
20
+
21
+ .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
22
+ https://arxiv.org/abs/1712.05884
23
+ """
24
+
25
+ def __init__(self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True):
26
+ """
27
+ Initialize postnet module.
28
+
29
+ Args:
30
+ idim (int): Dimension of the inputs.
31
+ odim (int): Dimension of the outputs.
32
+ n_layers (int, optional): The number of layers.
33
+ n_filts (int, optional): The number of filter size.
34
+ n_units (int, optional): The number of filter channels.
35
+ use_batch_norm (bool, optional): Whether to use batch normalization..
36
+ dropout_rate (float, optional): Dropout rate..
37
+ """
38
+ super(PostNet, self).__init__()
39
+ self.postnet = torch.nn.ModuleList()
40
+ for layer in range(n_layers - 1):
41
+ ichans = odim if layer == 0 else n_chans
42
+ ochans = odim if layer == n_layers - 1 else n_chans
43
+ if use_batch_norm:
44
+ self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ),
45
+ torch.nn.GroupNorm(num_groups=32, num_channels=ochans), torch.nn.Tanh(),
46
+ torch.nn.Dropout(dropout_rate), )]
47
+
48
+ else:
49
+ self.postnet += [
50
+ torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.Tanh(),
51
+ torch.nn.Dropout(dropout_rate), )]
52
+ ichans = n_chans if n_layers != 1 else odim
53
+ if use_batch_norm:
54
+ self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ),
55
+ torch.nn.GroupNorm(num_groups=20, num_channels=odim),
56
+ torch.nn.Dropout(dropout_rate), )]
57
+
58
+ else:
59
+ self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ),
60
+ torch.nn.Dropout(dropout_rate), )]
61
+
62
+ def forward(self, xs):
63
+ """
64
+ Calculate forward propagation.
65
+
66
+ Args:
67
+ xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax).
68
+
69
+ Returns:
70
+ Tensor: Batch of padded output tensor. (B, odim, Tmax).
71
+ """
72
+ for i in range(len(self.postnet)):
73
+ xs = self.postnet[i](xs)
74
+ return xs