|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn, sin, pow | 
					
						
						|  | from torch.nn import Parameter | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Snake(nn.Module): | 
					
						
						|  | ''' | 
					
						
						|  | Implementation of a sine-based periodic activation function | 
					
						
						|  | Shape: | 
					
						
						|  | - Input: (B, C, T) | 
					
						
						|  | - Output: (B, C, T), same shape as the input | 
					
						
						|  | Parameters: | 
					
						
						|  | - alpha - trainable parameter | 
					
						
						|  | References: | 
					
						
						|  | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: | 
					
						
						|  | https://arxiv.org/abs/2006.08195 | 
					
						
						|  | Examples: | 
					
						
						|  | >>> a1 = snake(256) | 
					
						
						|  | >>> x = torch.randn(256) | 
					
						
						|  | >>> x = a1(x) | 
					
						
						|  | ''' | 
					
						
						|  | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): | 
					
						
						|  | ''' | 
					
						
						|  | Initialization. | 
					
						
						|  | INPUT: | 
					
						
						|  | - in_features: shape of the input | 
					
						
						|  | - alpha: trainable parameter | 
					
						
						|  | alpha is initialized to 1 by default, higher values = higher-frequency. | 
					
						
						|  | alpha will be trained along with the rest of your model. | 
					
						
						|  | ''' | 
					
						
						|  | super(Snake, self).__init__() | 
					
						
						|  | self.in_features = in_features | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.alpha_logscale = alpha_logscale | 
					
						
						|  | if self.alpha_logscale: | 
					
						
						|  | self.alpha = Parameter(torch.zeros(in_features) * alpha) | 
					
						
						|  | else: | 
					
						
						|  | self.alpha = Parameter(torch.ones(in_features) * alpha) | 
					
						
						|  |  | 
					
						
						|  | self.alpha.requires_grad = alpha_trainable | 
					
						
						|  |  | 
					
						
						|  | self.no_div_by_zero = 0.000000001 | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | ''' | 
					
						
						|  | Forward pass of the function. | 
					
						
						|  | Applies the function to the input elementwise. | 
					
						
						|  | Snake ∶= x + 1/a * sin^2 (xa) | 
					
						
						|  | ''' | 
					
						
						|  | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) | 
					
						
						|  | if self.alpha_logscale: | 
					
						
						|  | alpha = torch.exp(alpha) | 
					
						
						|  | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) | 
					
						
						|  |  | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SnakeBeta(nn.Module): | 
					
						
						|  | ''' | 
					
						
						|  | A modified Snake function which uses separate parameters for the magnitude of the periodic components | 
					
						
						|  | Shape: | 
					
						
						|  | - Input: (B, C, T) | 
					
						
						|  | - Output: (B, C, T), same shape as the input | 
					
						
						|  | Parameters: | 
					
						
						|  | - alpha - trainable parameter that controls frequency | 
					
						
						|  | - beta - trainable parameter that controls magnitude | 
					
						
						|  | References: | 
					
						
						|  | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: | 
					
						
						|  | https://arxiv.org/abs/2006.08195 | 
					
						
						|  | Examples: | 
					
						
						|  | >>> a1 = snakebeta(256) | 
					
						
						|  | >>> x = torch.randn(256) | 
					
						
						|  | >>> x = a1(x) | 
					
						
						|  | ''' | 
					
						
						|  | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): | 
					
						
						|  | ''' | 
					
						
						|  | Initialization. | 
					
						
						|  | INPUT: | 
					
						
						|  | - in_features: shape of the input | 
					
						
						|  | - alpha - trainable parameter that controls frequency | 
					
						
						|  | - beta - trainable parameter that controls magnitude | 
					
						
						|  | alpha is initialized to 1 by default, higher values = higher-frequency. | 
					
						
						|  | beta is initialized to 1 by default, higher values = higher-magnitude. | 
					
						
						|  | alpha will be trained along with the rest of your model. | 
					
						
						|  | ''' | 
					
						
						|  | super(SnakeBeta, self).__init__() | 
					
						
						|  | self.in_features = in_features | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.alpha_logscale = alpha_logscale | 
					
						
						|  | if self.alpha_logscale: | 
					
						
						|  | self.alpha = Parameter(torch.zeros(in_features) * alpha) | 
					
						
						|  | self.beta = Parameter(torch.zeros(in_features) * alpha) | 
					
						
						|  | else: | 
					
						
						|  | self.alpha = Parameter(torch.ones(in_features) * alpha) | 
					
						
						|  | self.beta = Parameter(torch.ones(in_features) * alpha) | 
					
						
						|  |  | 
					
						
						|  | self.alpha.requires_grad = alpha_trainable | 
					
						
						|  | self.beta.requires_grad = alpha_trainable | 
					
						
						|  |  | 
					
						
						|  | self.no_div_by_zero = 0.000000001 | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | ''' | 
					
						
						|  | Forward pass of the function. | 
					
						
						|  | Applies the function to the input elementwise. | 
					
						
						|  | SnakeBeta ∶= x + 1/b * sin^2 (xa) | 
					
						
						|  | ''' | 
					
						
						|  | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) | 
					
						
						|  | beta = self.beta.unsqueeze(0).unsqueeze(-1) | 
					
						
						|  | if self.alpha_logscale: | 
					
						
						|  | alpha = torch.exp(alpha) | 
					
						
						|  | beta = torch.exp(beta) | 
					
						
						|  | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) | 
					
						
						|  |  | 
					
						
						|  | return x |