Spaces:
Runtime error
Runtime error
import gin | |
import torch | |
import torch.fft | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .dynamic import FiLM, TimeDistributedMLP | |
class Sine(nn.Module): | |
def forward(self, x: torch.Tensor): | |
return torch.sin(x) | |
class TrainableNonlinearity(nn.Module): | |
def __init__( | |
self, channels, width, nonlinearity=nn.ReLU, final_nonlinearity=Sine, depth=3 | |
): | |
super().__init__() | |
self.input_scale = nn.Parameter(torch.randn(1, channels, 1) * 10) | |
layers = [] | |
for i in range(depth): | |
layers.append( | |
nn.Conv1d( | |
channels if i == 0 else channels * width, | |
channels * width if i < depth - 1 else channels, | |
1, | |
groups=channels, | |
) | |
) | |
layers.append(nonlinearity() if i < depth - 1 else final_nonlinearity()) | |
self.net = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.net(self.input_scale * x) | |
class NEWT(nn.Module): | |
def __init__( | |
self, | |
n_waveshapers: int, | |
control_embedding_size: int, | |
shaping_fn_size: int = 16, | |
out_channels: int = 1, | |
): | |
super().__init__() | |
self.n_waveshapers = n_waveshapers | |
self.mlp = TimeDistributedMLP( | |
control_embedding_size, control_embedding_size, n_waveshapers * 4, depth=4 | |
) | |
self.waveshaping_index = FiLM() | |
self.shaping_fn = TrainableNonlinearity( | |
n_waveshapers, shaping_fn_size, nonlinearity=Sine | |
) | |
self.normalising_coeff = FiLM() | |
self.mixer = nn.Sequential( | |
nn.Conv1d(n_waveshapers, out_channels, 1), | |
) | |
def forward(self, exciter, control_embedding): | |
film_params = self.mlp(control_embedding) | |
film_params = F.upsample(film_params, exciter.shape[-1], mode="linear") | |
gamma_index, beta_index, gamma_norm, beta_norm = torch.split( | |
film_params, self.n_waveshapers, 1 | |
) | |
x = self.waveshaping_index(exciter, gamma_index, beta_index) | |
x = self.shaping_fn(x) | |
x = self.normalising_coeff(x, gamma_norm, beta_norm) | |
# return x | |
return self.mixer(x) | |
class FastNEWT(NEWT): | |
def __init__( | |
self, | |
newt: NEWT, | |
table_size: int = 4096, | |
table_min: float = -3.0, | |
table_max: float = 3.0, | |
): | |
super().__init__() | |
self.table_size = table_size | |
self.table_min = table_min | |
self.table_max = table_max | |
self.n_waveshapers = newt.n_waveshapers | |
self.mlp = newt.mlp | |
self.waveshaping_index = newt.waveshaping_index | |
self.normalising_coeff = newt.normalising_coeff | |
self.mixer = newt.mixer | |
self.lookup_table = self._init_lookup_table( | |
newt, table_size, self.n_waveshapers, table_min, table_max | |
) | |
self.to(next(iter(newt.parameters())).device) | |
def _init_lookup_table( | |
self, | |
newt: NEWT, | |
table_size: int, | |
n_waveshapers: int, | |
table_min: float, | |
table_max: float, | |
): | |
sample_values = torch.linspace(table_min, table_max, table_size, device=next(iter(newt.parameters())).device).expand( | |
1, n_waveshapers, table_size | |
) | |
lookup_table = newt.shaping_fn(sample_values)[0] | |
return nn.Parameter(lookup_table) | |
def _lookup(self, idx): | |
return torch.stack( | |
[ | |
torch.stack( | |
[ | |
self.lookup_table[shaper, idx[batch, shaper]] | |
for shaper in range(idx.shape[1]) | |
], | |
dim=0, | |
) | |
for batch in range(idx.shape[0]) | |
], | |
dim=0, | |
) | |
def shaping_fn(self, x): | |
idx = self.table_size * (x - self.table_min) / (self.table_max - self.table_min) | |
lower = torch.floor(idx).long() | |
lower[lower < 0] = 0 | |
lower[lower >= self.table_size] = self.table_size - 1 | |
upper = lower + 1 | |
upper[upper >= self.table_size] = self.table_size - 1 | |
fract = idx - lower | |
lower_v = self._lookup(lower) | |
upper_v = self._lookup(upper) | |
output = (upper_v - lower_v) * fract + lower_v | |
return output | |
class Reverb(nn.Module): | |
def __init__(self, length_in_seconds, sr): | |
super().__init__() | |
self.ir = nn.Parameter(torch.randn(1, sr * length_in_seconds - 1) * 1e-6) | |
self.register_buffer("initial_zero", torch.zeros(1, 1)) | |
def forward(self, x): | |
ir_ = torch.cat((self.initial_zero, self.ir), dim=-1) | |
if x.shape[-1] > ir_.shape[-1]: | |
ir_ = F.pad(ir_, (0, x.shape[-1] - ir_.shape[-1])) | |
x_ = x | |
else: | |
x_ = F.pad(x, (0, ir_.shape[-1] - x.shape[-1])) | |
return ( | |
x | |
+ torch.fft.irfft(torch.fft.rfft(x_) * torch.fft.rfft(ir_))[ | |
..., : x.shape[-1] | |
] | |
) | |