TheComputerMan commited on
Commit
1e3e10b
1 Parent(s): baf679b

Upload DurationPredictor.py

Browse files
Files changed (1) hide show
  1. DurationPredictor.py +139 -0
DurationPredictor.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2021
4
+
5
+
6
+ import torch
7
+
8
+ from Layers.LayerNorm import LayerNorm
9
+
10
+
11
+ class DurationPredictor(torch.nn.Module):
12
+ """
13
+ Duration predictor module.
14
+
15
+ This is a module of duration predictor described
16
+ in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
17
+ The duration predictor predicts a duration of each frame in log domain
18
+ from the hidden embeddings of encoder.
19
+
20
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
21
+ https://arxiv.org/pdf/1905.09263.pdf
22
+
23
+ Note:
24
+ The calculation domain of outputs is different
25
+ between in `forward` and in `inference`. In `forward`,
26
+ the outputs are calculated in log domain but in `inference`,
27
+ those are calculated in linear domain.
28
+
29
+ """
30
+
31
+ def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0):
32
+ """
33
+ Initialize duration predictor module.
34
+
35
+ Args:
36
+ idim (int): Input dimension.
37
+ n_layers (int, optional): Number of convolutional layers.
38
+ n_chans (int, optional): Number of channels of convolutional layers.
39
+ kernel_size (int, optional): Kernel size of convolutional layers.
40
+ dropout_rate (float, optional): Dropout rate.
41
+ offset (float, optional): Offset value to avoid nan in log domain.
42
+
43
+ """
44
+ super(DurationPredictor, self).__init__()
45
+ self.offset = offset
46
+ self.conv = torch.nn.ModuleList()
47
+ for idx in range(n_layers):
48
+ in_chans = idim if idx == 0 else n_chans
49
+ self.conv += [torch.nn.Sequential(torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ), torch.nn.ReLU(),
50
+ LayerNorm(n_chans, dim=1), torch.nn.Dropout(dropout_rate), )]
51
+ self.linear = torch.nn.Linear(n_chans, 1)
52
+
53
+ def _forward(self, xs, x_masks=None, is_inference=False):
54
+ xs = xs.transpose(1, -1) # (B, idim, Tmax)
55
+ for f in self.conv:
56
+ xs = f(xs) # (B, C, Tmax)
57
+
58
+ # NOTE: calculate in log domain
59
+ xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax)
60
+
61
+ if is_inference:
62
+ # NOTE: calculate in linear domain
63
+ xs = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
64
+
65
+ if x_masks is not None:
66
+ xs = xs.masked_fill(x_masks, 0.0)
67
+
68
+ return xs
69
+
70
+ def forward(self, xs, x_masks=None):
71
+ """
72
+ Calculate forward propagation.
73
+
74
+ Args:
75
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
76
+ x_masks (ByteTensor, optional):
77
+ Batch of masks indicating padded part (B, Tmax).
78
+
79
+ Returns:
80
+ Tensor: Batch of predicted durations in log domain (B, Tmax).
81
+
82
+ """
83
+ return self._forward(xs, x_masks, False)
84
+
85
+ def inference(self, xs, x_masks=None):
86
+ """
87
+ Inference duration.
88
+
89
+ Args:
90
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
91
+ x_masks (ByteTensor, optional):
92
+ Batch of masks indicating padded part (B, Tmax).
93
+
94
+ Returns:
95
+ LongTensor: Batch of predicted durations in linear domain (B, Tmax).
96
+
97
+ """
98
+ return self._forward(xs, x_masks, True)
99
+
100
+
101
+ class DurationPredictorLoss(torch.nn.Module):
102
+ """
103
+ Loss function module for duration predictor.
104
+
105
+ The loss value is Calculated in log domain to make it Gaussian.
106
+
107
+ """
108
+
109
+ def __init__(self, offset=1.0, reduction="mean"):
110
+ """
111
+ Args:
112
+ offset (float, optional): Offset value to avoid nan in log domain.
113
+ reduction (str): Reduction type in loss calculation.
114
+
115
+ """
116
+ super(DurationPredictorLoss, self).__init__()
117
+ self.criterion = torch.nn.MSELoss(reduction=reduction)
118
+ self.offset = offset
119
+
120
+ def forward(self, outputs, targets):
121
+ """
122
+ Calculate forward propagation.
123
+
124
+ Args:
125
+ outputs (Tensor): Batch of prediction durations in log domain (B, T)
126
+ targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
127
+
128
+ Returns:
129
+ Tensor: Mean squared error loss value.
130
+
131
+ Note:
132
+ `outputs` is in log domain but `targets` is in linear domain.
133
+
134
+ """
135
+ # NOTE: outputs is in log domain while targets in linear
136
+ targets = torch.log(targets.float() + self.offset)
137
+ loss = self.criterion(outputs, targets)
138
+
139
+ return loss