dmsp / src /model /nn /blocks.py
szin94's picture
first commit
bc3e180
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from src.utils import misc as utils
def apply_gain(x, gain, fn=None):
gain = fn(gain) if fn is not None else gain
x_list = x.chunk(len(gain), -1)
x_list = [gain[i] * x_i for i, x_i in enumerate(x_list)]
return torch.cat(x_list, dim=-1)
class FMBlock(nn.Module):
def __init__(self, input_dim, embed_dim, num_features):
super().__init__()
concat_size = embed_dim * num_features + embed_dim
feature_dim = embed_dim * num_features
self.rff2 = RFF2(input_dim, embed_dim//2)
self.tmlp = mlp(concat_size, feature_dim, 5)
self.proj = nn.Linear(concat_size, 2*input_dim)
self.activation = nn.GLU(dim=-1)
gain_in = torch.randn(num_features) / 2
gain_out = torch.Tensor([0.1])
self.register_parameter('gain_in', nn.Parameter(gain_in, requires_grad=True))
self.register_parameter('gain_out', nn.Parameter(gain_out, requires_grad=True))
def forward(self, input, feature, slider, omega):
''' input : (B T input_dim)
feature: (B T feature_dim)
slider : (B T 1)
'''
_input = input / (1.3*math.pi) - 1
_input = self.rff2(_input)
feature = apply_gain(feature, self.gain_in, torch.tanh)
x = torch.cat((_input, feature), dim=-1)
x = torch.cat((self.tmlp(x), _input), dim=-1)
x = self.activation(self.proj(x))
gate = torch.tanh((slider - 1) * self.gain_out)
return input + omega * x * gate
class AMBlock(nn.Module):
def __init__(self, input_dim, embed_dim, num_features):
super().__init__()
concat_size = embed_dim * num_features + embed_dim
feature_dim = embed_dim * num_features
self.rff2 = RFF2(input_dim, embed_dim//2)
self.tmlp = mlp(concat_size, feature_dim, 5)
self.proj = nn.Linear(concat_size, 2*input_dim)
self.activation = nn.GLU(dim=-1)
gain_in = torch.randn(num_features) / 2
self.register_parameter('gain_in', nn.Parameter(gain_in, requires_grad=True))
def forward(self, input, feature, slider):
''' input : (B T input_dim)
feature: (B T feature_dim)
slider : (B T 1)
'''
_input = input * 110 - 0.55
_input = self.rff2(_input)
feature = apply_gain(feature, self.gain_in, torch.tanh)
x = torch.cat((_input, feature), dim=-1)
x = torch.cat((self.tmlp(x), _input), dim=-1)
x = self.activation(self.proj(x))
return input * (1 + x)
class ModBlock(nn.Module):
def __init__(self, input_dim, feature_dim, embed_dim):
super().__init__()
cat_size = 1+feature_dim
self.tmlp = mlp(cat_size, feature_dim, 2)
self.proj = nn.Linear(cat_size, 2)
self.activation = nn.GLU(dim=-1)
def forward(self, input, feature, slider):
''' input : (B T input_dim)
feature: (B T feature_dim)
slider : (B T 1)
'''
input = input.unsqueeze(-1) # (B T input_dim 1)
feature = feature.unsqueeze(-2).repeat(1,1,input.size(-2),1)
x = torch.cat((input, feature), dim=-1)
x = torch.cat((self.tmlp(x), input), dim=-1)
x = self.activation(self.proj(x))
return (input * (1 + x)).squeeze(-1)
def mlp(in_size, hidden_size, n_layers):
channels = [in_size] + (n_layers) * [hidden_size]
net = []
for i in range(n_layers):
net.append(nn.Linear(channels[i], channels[i + 1]))
#net.append(nn.LayerNorm(channels[i + 1]))
net.append(nn.PReLU())
return nn.Sequential(*net)
class RFF2(nn.Module):
""" Random Fourier Features Module """
def __init__(self, input_dim, embed_dim, scale=1.):
super().__init__()
#N = torch.randn(input_dim, embed_dim)
N = torch.ones((input_dim, embed_dim)) / input_dim / embed_dim
N = nn.Parameter(N, requires_grad=False)
e = torch.Tensor([scale])
e = nn.Parameter(e, requires_grad=True)
self.register_buffer('N', N)
self.register_parameter('e', e)
def forward(self, x):
''' x: (Bs, Nt, input_dim)
-> (Bs, Nt, embed_dim)
'''
B = self.e * self.N
x_embd = utils.fourier_feature(x, B)
return x_embd
class RFF(nn.Module):
""" Random Fourier Features Module """
def __init__(self, scales, embed_dim):
super().__init__()
input_dim = len(scales)
N = torch.randn(input_dim, embed_dim)
N = nn.Parameter(N, requires_grad=False)
e = torch.Tensor(scales).view(-1,1)
e = nn.Parameter(e, requires_grad=True)
self.register_buffer('N', N)
self.register_parameter('e', e)
def forward(self, x):
''' x: (Bs, Nt, input_dim)
-> (Bs, Nt, input_dim*embed_dim)
'''
xs = x.chunk(self.N.size(0), -1) # (Bs, Nt, 1) * input_dim
Ns = self.N.chunk(self.N.size(0), 0) # (1, embed_dim) * input_dim
Bs = [torch.pow(10, self.e[i]) * N for i, N in enumerate(Ns)]
x_embd = [utils.fourier_feature(xs[i], B) for i, B in enumerate(Bs)]
return torch.cat(x_embd, dim=-1)
class ModeEstimator(nn.Module):
def __init__(self, n_modes, hidden_dim, kappa_scale=None, gamma_scale=None, inharmonic=True, sr=48000):
super().__init__()
self.sr = sr
self.kappa_scale = kappa_scale
self.gamma_scale = gamma_scale
self.rff = RFF([1.]*5, hidden_dim//2)
self.a_mlp = mlp(5*hidden_dim, hidden_dim, 2)
self.a_proj = nn.Linear(hidden_dim, n_modes)
self.tanh = nn.Tanh()
if inharmonic:
self.f_mlp = mlp(5*hidden_dim, hidden_dim, 2)
self.f_proj = nn.Linear(hidden_dim, n_modes)
self.sigmoid = nn.Sigmoid()
else:
self.f_mlp = None
self.f_proj = None
self.sigmoid = nn.Sigmoid()
def forward(self, u_0, x_p, kappa, gamma):
''' u_0 : (b, 1, x)
x_p : (b, 1, 1)
kappa : (b, 1, 1)
gamma : (b, 1, 1)
'''
p_x = torch.argmax(u_0, dim=-1, keepdim=True) / 255. # (b, 1, 1)
p_a = torch.max(u_0, dim=-1, keepdim=True).values / 0.02 # (b, 1, 1)
kappa = self.normalize_kappa(kappa)
gamma = self.normalize_gamma(gamma)
con = torch.cat((p_x, p_a, x_p, kappa, gamma), dim=-1) # (b, 1, 5)
con = self.rff(con) # (b, 1, 3*hidden_dim)
mode_amps = self.a_mlp(con) # (b, 1, k)
mode_amps = self.tanh(1e-3 * self.a_proj(mode_amps)) # (b, 1, m)
if self.f_mlp is not None:
mode_freq = self.f_mlp(con) # (b, 1, k)
mode_freq = 0.3 * self.sigmoid(self.f_proj(mode_freq)) # (b, 1, m)
mode_freq = mode_freq.cumsum(-1)
else:
int_mults = torch.ones_like(mode_amps).cumsum(-1) # (b, 1, k)
omega = gamma / self.sr * (2*math.pi)
mode_freq = omega * int_mults
return mode_amps, mode_freq
def normalize_gamma(self, x):
if self.gamma_scale is not None:
minval = min(self.gamma_scale)
denval = max(self.gamma_scale) - minval
x = (x - minval) / denval
return x
def normalize_kappa(self, x):
if self.kappa_scale is not None:
minval = min(self.kappa_scale)
denval = max(self.kappa_scale) - minval
x = (x - minval) / denval
return x