| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | from torch.nn.utils import weight_norm |
| |
|
| |
|
| | def WNConv1d(*args, **kwargs): |
| | return weight_norm(nn.Conv1d(*args, **kwargs)) |
| |
|
| |
|
| | def WNConvTranspose1d(*args, **kwargs): |
| | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
| |
|
| |
|
| | |
| | @torch.jit.script |
| | def snake(x, alpha): |
| | shape = x.shape |
| | x = x.reshape(shape[0], shape[1], -1) |
| | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) |
| | x = x.reshape(shape) |
| | return x |
| |
|
| |
|
| | class Snake1d(nn.Module): |
| | def __init__(self, channels): |
| | super().__init__() |
| | self.alpha = nn.Parameter(torch.ones(1, channels, 1)) |
| |
|
| | def forward(self, x): |
| | return snake(x, self.alpha) |
| |
|