|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  | from torch.nn import functional as F | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Linear(nn.Module): | 
					
						
						|  | """Linear layer with a specific initialization. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | in_features (int): number of channels in the input tensor. | 
					
						
						|  | out_features (int): number of channels in the output tensor. | 
					
						
						|  | bias (bool, optional): enable/disable bias in the layer. Defaults to True. | 
					
						
						|  | init_gain (str, optional): method to compute the gain in the weight initializtion based on the nonlinear activation used afterwards. Defaults to 'linear'. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, in_features, out_features, bias=True, init_gain="linear"): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) | 
					
						
						|  | self._init_w(init_gain) | 
					
						
						|  |  | 
					
						
						|  | def _init_w(self, init_gain): | 
					
						
						|  | torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | return self.linear_layer(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LinearBN(nn.Module): | 
					
						
						|  | """Linear layer with Batch Normalization. | 
					
						
						|  |  | 
					
						
						|  | x -> linear -> BN -> o | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | in_features (int): number of channels in the input tensor. | 
					
						
						|  | out_features (int ): number of channels in the output tensor. | 
					
						
						|  | bias (bool, optional): enable/disable bias in the linear layer. Defaults to True. | 
					
						
						|  | init_gain (str, optional): method to set the gain for weight initialization. Defaults to 'linear'. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, in_features, out_features, bias=True, init_gain="linear"): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) | 
					
						
						|  | self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5) | 
					
						
						|  | self._init_w(init_gain) | 
					
						
						|  |  | 
					
						
						|  | def _init_w(self, init_gain): | 
					
						
						|  | torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ | 
					
						
						|  | Shapes: | 
					
						
						|  | x: [T, B, C] or [B, C] | 
					
						
						|  | """ | 
					
						
						|  | out = self.linear_layer(x) | 
					
						
						|  | if len(out.shape) == 3: | 
					
						
						|  | out = out.permute(1, 2, 0) | 
					
						
						|  | out = self.batch_normalization(out) | 
					
						
						|  | if len(out.shape) == 3: | 
					
						
						|  | out = out.permute(2, 0, 1) | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Prenet(nn.Module): | 
					
						
						|  | """Tacotron specific Prenet with an optional Batch Normalization. | 
					
						
						|  |  | 
					
						
						|  | Note: | 
					
						
						|  | Prenet with BN improves the model performance significantly especially | 
					
						
						|  | if it is enabled after learning a diagonal attention alignment with the original | 
					
						
						|  | prenet. However, if the target dataset is high quality then it also works from | 
					
						
						|  | the start. It is also suggested to disable dropout if BN is in use. | 
					
						
						|  |  | 
					
						
						|  | prenet_type == "original" | 
					
						
						|  | x -> [linear -> ReLU -> Dropout]xN -> o | 
					
						
						|  |  | 
					
						
						|  | prenet_type == "bn" | 
					
						
						|  | x -> [linear -> BN -> ReLU -> Dropout]xN -> o | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | in_features (int): number of channels in the input tensor and the inner layers. | 
					
						
						|  | prenet_type (str, optional): prenet type "original" or "bn". Defaults to "original". | 
					
						
						|  | prenet_dropout (bool, optional): dropout rate. Defaults to True. | 
					
						
						|  | dropout_at_inference (bool, optional): use dropout at inference. It leads to a better quality for some models. | 
					
						
						|  | out_features (list, optional): List of output channels for each prenet block. | 
					
						
						|  | It also defines number of the prenet blocks based on the length of argument list. | 
					
						
						|  | Defaults to [256, 256]. | 
					
						
						|  | bias (bool, optional): enable/disable bias in prenet linear layers. Defaults to True. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | in_features, | 
					
						
						|  | prenet_type="original", | 
					
						
						|  | prenet_dropout=True, | 
					
						
						|  | dropout_at_inference=False, | 
					
						
						|  | out_features=[256, 256], | 
					
						
						|  | bias=True, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.prenet_type = prenet_type | 
					
						
						|  | self.prenet_dropout = prenet_dropout | 
					
						
						|  | self.dropout_at_inference = dropout_at_inference | 
					
						
						|  | in_features = [in_features] + out_features[:-1] | 
					
						
						|  | if prenet_type == "bn": | 
					
						
						|  | self.linear_layers = nn.ModuleList( | 
					
						
						|  | [LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] | 
					
						
						|  | ) | 
					
						
						|  | elif prenet_type == "original": | 
					
						
						|  | self.linear_layers = nn.ModuleList( | 
					
						
						|  | [Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | for linear in self.linear_layers: | 
					
						
						|  | if self.prenet_dropout: | 
					
						
						|  | x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training or self.dropout_at_inference) | 
					
						
						|  | else: | 
					
						
						|  | x = F.relu(linear(x)) | 
					
						
						|  | return x | 
					
						
						|  |  |