|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from einops import rearrange |
|
|
|
from src.utils import misc as utils |
|
|
|
def apply_gain(x, gain, fn=None): |
|
gain = fn(gain) if fn is not None else gain |
|
x_list = x.chunk(len(gain), -1) |
|
x_list = [gain[i] * x_i for i, x_i in enumerate(x_list)] |
|
return torch.cat(x_list, dim=-1) |
|
|
|
class FMBlock(nn.Module): |
|
def __init__(self, input_dim, embed_dim, num_features): |
|
super().__init__() |
|
concat_size = embed_dim * num_features + embed_dim |
|
feature_dim = embed_dim * num_features |
|
self.rff2 = RFF2(input_dim, embed_dim//2) |
|
self.tmlp = mlp(concat_size, feature_dim, 5) |
|
self.proj = nn.Linear(concat_size, 2*input_dim) |
|
self.activation = nn.GLU(dim=-1) |
|
|
|
gain_in = torch.randn(num_features) / 2 |
|
gain_out = torch.Tensor([0.1]) |
|
self.register_parameter('gain_in', nn.Parameter(gain_in, requires_grad=True)) |
|
self.register_parameter('gain_out', nn.Parameter(gain_out, requires_grad=True)) |
|
|
|
def forward(self, input, feature, slider, omega): |
|
''' input : (B T input_dim) |
|
feature: (B T feature_dim) |
|
slider : (B T 1) |
|
''' |
|
_input = input / (1.3*math.pi) - 1 |
|
_input = self.rff2(_input) |
|
feature = apply_gain(feature, self.gain_in, torch.tanh) |
|
|
|
x = torch.cat((_input, feature), dim=-1) |
|
x = torch.cat((self.tmlp(x), _input), dim=-1) |
|
x = self.activation(self.proj(x)) |
|
|
|
gate = torch.tanh((slider - 1) * self.gain_out) |
|
return input + omega * x * gate |
|
|
|
class AMBlock(nn.Module): |
|
def __init__(self, input_dim, embed_dim, num_features): |
|
super().__init__() |
|
concat_size = embed_dim * num_features + embed_dim |
|
feature_dim = embed_dim * num_features |
|
self.rff2 = RFF2(input_dim, embed_dim//2) |
|
self.tmlp = mlp(concat_size, feature_dim, 5) |
|
self.proj = nn.Linear(concat_size, 2*input_dim) |
|
self.activation = nn.GLU(dim=-1) |
|
|
|
gain_in = torch.randn(num_features) / 2 |
|
self.register_parameter('gain_in', nn.Parameter(gain_in, requires_grad=True)) |
|
|
|
def forward(self, input, feature, slider): |
|
''' input : (B T input_dim) |
|
feature: (B T feature_dim) |
|
slider : (B T 1) |
|
''' |
|
_input = input * 110 - 0.55 |
|
_input = self.rff2(_input) |
|
feature = apply_gain(feature, self.gain_in, torch.tanh) |
|
|
|
x = torch.cat((_input, feature), dim=-1) |
|
x = torch.cat((self.tmlp(x), _input), dim=-1) |
|
x = self.activation(self.proj(x)) |
|
|
|
return input * (1 + x) |
|
|
|
class ModBlock(nn.Module): |
|
def __init__(self, input_dim, feature_dim, embed_dim): |
|
super().__init__() |
|
cat_size = 1+feature_dim |
|
self.tmlp = mlp(cat_size, feature_dim, 2) |
|
self.proj = nn.Linear(cat_size, 2) |
|
self.activation = nn.GLU(dim=-1) |
|
|
|
def forward(self, input, feature, slider): |
|
''' input : (B T input_dim) |
|
feature: (B T feature_dim) |
|
slider : (B T 1) |
|
''' |
|
input = input.unsqueeze(-1) |
|
feature = feature.unsqueeze(-2).repeat(1,1,input.size(-2),1) |
|
x = torch.cat((input, feature), dim=-1) |
|
x = torch.cat((self.tmlp(x), input), dim=-1) |
|
x = self.activation(self.proj(x)) |
|
return (input * (1 + x)).squeeze(-1) |
|
|
|
def mlp(in_size, hidden_size, n_layers): |
|
channels = [in_size] + (n_layers) * [hidden_size] |
|
net = [] |
|
for i in range(n_layers): |
|
net.append(nn.Linear(channels[i], channels[i + 1])) |
|
|
|
net.append(nn.PReLU()) |
|
return nn.Sequential(*net) |
|
|
|
|
|
class RFF2(nn.Module): |
|
""" Random Fourier Features Module """ |
|
def __init__(self, input_dim, embed_dim, scale=1.): |
|
super().__init__() |
|
|
|
N = torch.ones((input_dim, embed_dim)) / input_dim / embed_dim |
|
N = nn.Parameter(N, requires_grad=False) |
|
e = torch.Tensor([scale]) |
|
e = nn.Parameter(e, requires_grad=True) |
|
self.register_buffer('N', N) |
|
self.register_parameter('e', e) |
|
|
|
def forward(self, x): |
|
''' x: (Bs, Nt, input_dim) |
|
-> (Bs, Nt, embed_dim) |
|
''' |
|
B = self.e * self.N |
|
x_embd = utils.fourier_feature(x, B) |
|
return x_embd |
|
|
|
|
|
class RFF(nn.Module): |
|
""" Random Fourier Features Module """ |
|
def __init__(self, scales, embed_dim): |
|
super().__init__() |
|
input_dim = len(scales) |
|
N = torch.randn(input_dim, embed_dim) |
|
N = nn.Parameter(N, requires_grad=False) |
|
e = torch.Tensor(scales).view(-1,1) |
|
e = nn.Parameter(e, requires_grad=True) |
|
self.register_buffer('N', N) |
|
self.register_parameter('e', e) |
|
|
|
def forward(self, x): |
|
''' x: (Bs, Nt, input_dim) |
|
-> (Bs, Nt, input_dim*embed_dim) |
|
''' |
|
xs = x.chunk(self.N.size(0), -1) |
|
Ns = self.N.chunk(self.N.size(0), 0) |
|
Bs = [torch.pow(10, self.e[i]) * N for i, N in enumerate(Ns)] |
|
x_embd = [utils.fourier_feature(xs[i], B) for i, B in enumerate(Bs)] |
|
return torch.cat(x_embd, dim=-1) |
|
|
|
|
|
class ModeEstimator(nn.Module): |
|
def __init__(self, n_modes, hidden_dim, kappa_scale=None, gamma_scale=None, inharmonic=True, sr=48000): |
|
super().__init__() |
|
self.sr = sr |
|
self.kappa_scale = kappa_scale |
|
self.gamma_scale = gamma_scale |
|
self.rff = RFF([1.]*5, hidden_dim//2) |
|
self.a_mlp = mlp(5*hidden_dim, hidden_dim, 2) |
|
self.a_proj = nn.Linear(hidden_dim, n_modes) |
|
self.tanh = nn.Tanh() |
|
if inharmonic: |
|
self.f_mlp = mlp(5*hidden_dim, hidden_dim, 2) |
|
self.f_proj = nn.Linear(hidden_dim, n_modes) |
|
self.sigmoid = nn.Sigmoid() |
|
else: |
|
self.f_mlp = None |
|
self.f_proj = None |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, u_0, x_p, kappa, gamma): |
|
''' u_0 : (b, 1, x) |
|
x_p : (b, 1, 1) |
|
kappa : (b, 1, 1) |
|
gamma : (b, 1, 1) |
|
''' |
|
p_x = torch.argmax(u_0, dim=-1, keepdim=True) / 255. |
|
p_a = torch.max(u_0, dim=-1, keepdim=True).values / 0.02 |
|
kappa = self.normalize_kappa(kappa) |
|
gamma = self.normalize_gamma(gamma) |
|
con = torch.cat((p_x, p_a, x_p, kappa, gamma), dim=-1) |
|
con = self.rff(con) |
|
|
|
mode_amps = self.a_mlp(con) |
|
mode_amps = self.tanh(1e-3 * self.a_proj(mode_amps)) |
|
|
|
if self.f_mlp is not None: |
|
mode_freq = self.f_mlp(con) |
|
mode_freq = 0.3 * self.sigmoid(self.f_proj(mode_freq)) |
|
mode_freq = mode_freq.cumsum(-1) |
|
else: |
|
int_mults = torch.ones_like(mode_amps).cumsum(-1) |
|
omega = gamma / self.sr * (2*math.pi) |
|
mode_freq = omega * int_mults |
|
|
|
return mode_amps, mode_freq |
|
|
|
def normalize_gamma(self, x): |
|
if self.gamma_scale is not None: |
|
minval = min(self.gamma_scale) |
|
denval = max(self.gamma_scale) - minval |
|
x = (x - minval) / denval |
|
return x |
|
|
|
def normalize_kappa(self, x): |
|
if self.kappa_scale is not None: |
|
minval = min(self.kappa_scale) |
|
denval = max(self.kappa_scale) - minval |
|
x = (x - minval) / denval |
|
return x |
|
|
|
|