# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.py import torch from torch import nn from modules.base.base_module import LayerNorm class DurationPredictor(nn.Module): def __init__( self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 ): super().__init__() self.in_channels = in_channels self.filter_channels = filter_channels self.kernel_size = kernel_size self.p_dropout = p_dropout self.gin_channels = gin_channels self.drop = nn.Dropout(p_dropout) self.conv_1 = nn.Conv1d( in_channels, filter_channels, kernel_size, padding=kernel_size // 2 ) self.norm_1 = LayerNorm(filter_channels) self.conv_2 = nn.Conv1d( filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 ) self.norm_2 = LayerNorm(filter_channels) self.proj = nn.Conv1d(filter_channels, 1, 1) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, in_channels, 1) def forward(self, x, x_mask, g=None): x = torch.detach(x) if g is not None: g = torch.detach(g) x = x + self.cond(g) x = self.conv_1(x * x_mask) x = torch.relu(x) x = self.norm_1(x) x = self.drop(x) x = self.conv_2(x * x_mask) x = torch.relu(x) x = self.norm_2(x) x = self.drop(x) x = self.proj(x * x_mask) return x * x_mask