# -*- coding: utf-8 -*- # Copyright 2020 Tomoki Hayashi # MIT License (https://opensource.org/licenses/MIT) """Causal convolusion layer modules.""" import torch class CausalConv1d(torch.nn.Module): """CausalConv1d module with customized initialization.""" def __init__( self, in_channels, out_channels, kernel_size, dilation=1, bias=True, pad="ConstantPad1d", pad_params={"value": 0.0}, ): """Initialize CausalConv1d module.""" super(CausalConv1d, self).__init__() self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params) self.conv = torch.nn.Conv1d( in_channels, out_channels, kernel_size, dilation=dilation, bias=bias ) def forward(self, x): """Calculate forward propagation. Args: x (Tensor): Input tensor (B, in_channels, T). Returns: Tensor: Output tensor (B, out_channels, T). """ return self.conv(self.pad(x))[:, :, : x.size(2)] class CausalConvTranspose1d(torch.nn.Module): """CausalConvTranspose1d module with customized initialization.""" def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True): """Initialize CausalConvTranspose1d module.""" super(CausalConvTranspose1d, self).__init__() self.deconv = torch.nn.ConvTranspose1d( in_channels, out_channels, kernel_size, stride, bias=bias ) self.stride = stride def forward(self, x): """Calculate forward propagation. Args: x (Tensor): Input tensor (B, in_channels, T_in). Returns: Tensor: Output tensor (B, out_channels, T_out). """ return self.deconv(x)[:, :, : -self.stride]