# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # LICENSE is in incl_licenses directory. import torch import torch.nn as nn import torch.nn.functional as F from torch import pow, sin from torch.nn import Parameter from .resample import DownSample1d, UpSample1d class Activation1d(nn.Module): def __init__(self, activation, up_ratio: int = 2, down_ratio: int = 2, up_kernel_size: int = 12, down_kernel_size: int = 12): super().__init__() self.up_ratio = up_ratio self.down_ratio = down_ratio self.act = activation self.upsample = UpSample1d(up_ratio, up_kernel_size) self.downsample = DownSample1d(down_ratio, down_kernel_size) # x: [B,C,T] def forward(self, x): x = self.upsample(x) x = self.act(x) x = self.downsample(x) return x class SnakeBeta(nn.Module): ''' A modified Snake function which uses separate parameters for the magnitude of the periodic components Shape: - Input: (B, C, T) - Output: (B, C, T), same shape as the input Parameters: - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude References: - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195 Examples: >>> a1 = snakebeta(256) >>> x = torch.randn(256) >>> x = a1(x) ''' def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): ''' Initialization. INPUT: - in_features: shape of the input - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude alpha is initialized to 1 by default, higher values = higher-frequency. beta is initialized to 1 by default, higher values = higher-magnitude. alpha will be trained along with the rest of your model. ''' super(SnakeBeta, self).__init__() self.in_features = in_features # initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = Parameter(torch.zeros(in_features) * alpha) self.beta = Parameter(torch.zeros(in_features) * alpha) else: # linear scale alphas initialized to ones self.alpha = Parameter(torch.ones(in_features) * alpha) self.beta = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.beta.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x): ''' Forward pass of the function. Applies the function to the input elementwise. SnakeBeta = x + 1/b * sin^2 (xa) ''' alpha = self.alpha.unsqueeze( 0).unsqueeze(-1) # line up with x to [B, C, T] beta = self.beta.unsqueeze(0).unsqueeze(-1) if self.alpha_logscale: alpha = torch.exp(alpha) beta = torch.exp(beta) x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) return x class Mish(nn.Module): """ Mish activation function is proposed in "Mish: A Self Regularized Non-Monotonic Neural Activation Function" paper, https://arxiv.org/abs/1908.08681. """ def __init__(self): super().__init__() def forward(self, x): return x * torch.tanh(F.softplus(x)) class SnakeAlias(nn.Module): def __init__(self, channels, up_ratio: int = 2, down_ratio: int = 2, up_kernel_size: int = 12, down_kernel_size: int = 12, C = None): super().__init__() self.up_ratio = up_ratio self.down_ratio = down_ratio self.act = SnakeBeta(channels, alpha_logscale=True) self.upsample = UpSample1d(up_ratio, up_kernel_size, C) self.downsample = DownSample1d(down_ratio, down_kernel_size, C) # x: [B,C,T] def forward(self, x, C=None): x = self.upsample(x, C) x = self.act(x) x = self.downsample(x) return x