File size: 4,732 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from torch import nn
from torch.nn import functional as F


class Linear(nn.Module):
    """Linear layer with a specific initialization.

    Args:
        in_features (int): number of channels in the input tensor.
        out_features (int): number of channels in the output tensor.
        bias (bool, optional): enable/disable bias in the layer. Defaults to True.
        init_gain (str, optional): method to compute the gain in the weight initializtion based on the nonlinear activation used afterwards. Defaults to 'linear'.
    """

    def __init__(self, in_features, out_features, bias=True, init_gain="linear"):
        super().__init__()
        self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias)
        self._init_w(init_gain)

    def _init_w(self, init_gain):
        torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain))

    def forward(self, x):
        return self.linear_layer(x)


class LinearBN(nn.Module):
    """Linear layer with Batch Normalization.

    x -> linear -> BN -> o

    Args:
        in_features (int): number of channels in the input tensor.
        out_features (int ): number of channels in the output tensor.
        bias (bool, optional): enable/disable bias in the linear layer. Defaults to True.
        init_gain (str, optional): method to set the gain for weight initialization. Defaults to 'linear'.
    """

    def __init__(self, in_features, out_features, bias=True, init_gain="linear"):
        super().__init__()
        self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias)
        self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5)
        self._init_w(init_gain)

    def _init_w(self, init_gain):
        torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain))

    def forward(self, x):
        """
        Shapes:
            x: [T, B, C] or [B, C]
        """
        out = self.linear_layer(x)
        if len(out.shape) == 3:
            out = out.permute(1, 2, 0)
        out = self.batch_normalization(out)
        if len(out.shape) == 3:
            out = out.permute(2, 0, 1)
        return out


class Prenet(nn.Module):
    """Tacotron specific Prenet with an optional Batch Normalization.

    Note:
        Prenet with BN improves the model performance significantly especially
    if it is enabled after learning a diagonal attention alignment with the original
    prenet. However, if the target dataset is high quality then it also works from
    the start. It is also suggested to disable dropout if BN is in use.

        prenet_type == "original"
            x -> [linear -> ReLU -> Dropout]xN -> o

        prenet_type == "bn"
            x -> [linear -> BN -> ReLU -> Dropout]xN -> o

    Args:
        in_features (int): number of channels in the input tensor and the inner layers.
        prenet_type (str, optional): prenet type "original" or "bn". Defaults to "original".
        prenet_dropout (bool, optional): dropout rate. Defaults to True.
        dropout_at_inference (bool, optional): use dropout at inference. It leads to a better quality for some models.
        out_features (list, optional): List of output channels for each prenet block.
            It also defines number of the prenet blocks based on the length of argument list.
            Defaults to [256, 256].
        bias (bool, optional): enable/disable bias in prenet linear layers. Defaults to True.
    """

    # pylint: disable=dangerous-default-value
    def __init__(
        self,
        in_features,
        prenet_type="original",
        prenet_dropout=True,
        dropout_at_inference=False,
        out_features=[256, 256],
        bias=True,
    ):
        super().__init__()
        self.prenet_type = prenet_type
        self.prenet_dropout = prenet_dropout
        self.dropout_at_inference = dropout_at_inference
        in_features = [in_features] + out_features[:-1]
        if prenet_type == "bn":
            self.linear_layers = nn.ModuleList(
                [LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)]
            )
        elif prenet_type == "original":
            self.linear_layers = nn.ModuleList(
                [Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)]
            )

    def forward(self, x):
        for linear in self.linear_layers:
            if self.prenet_dropout:
                x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training or self.dropout_at_inference)
            else:
                x = F.relu(linear(x))
        return x