File size: 4,495 Bytes
02e1d16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import librosa
import pytorch_lightning as pl
import torch
from einops.layers.torch import Rearrange
from torch import nn
class Aff(nn.Module):
def __init__(self, dim):
super().__init__()
self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
def forward(self, x):
x = x * self.alpha + self.beta
return x
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class MLPBlock(nn.Module):
def __init__(self, dim, mlp_dim, dropout=0., init_values=1e-4):
super().__init__()
self.pre_affine = Aff(dim)
self.inter = nn.LSTM(input_size=dim, hidden_size=dim, num_layers=1,
bidirectional=False, batch_first=True)
self.ff = nn.Sequential(
FeedForward(dim, mlp_dim, dropout),
)
self.post_affine = Aff(dim)
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True)
def forward(self, x, state=None):
x = self.pre_affine(x)
if state is None:
inter, _ = self.inter(x)
else:
inter, state = self.inter(x, (state[0], state[1]))
x = x + self.gamma_1 * inter
x = self.post_affine(x)
x = x + self.gamma_2 * self.ff(x)
if state is None:
return x
state = torch.stack(state, 0)
return x, state
class Encoder(nn.Module):
def __init__(self, in_dim, dim, depth, mlp_dim):
super().__init__()
self.in_dim = in_dim
self.dim = dim
self.depth = depth
self.mlp_dim = mlp_dim
self.to_patch_embedding = nn.Sequential(
Rearrange('b c f t -> b t (c f)'),
nn.Linear(in_dim, dim),
nn.GELU()
)
self.mlp_blocks = nn.ModuleList([])
for _ in range(depth):
self.mlp_blocks.append(MLPBlock(self.dim, mlp_dim, dropout=0.15))
self.affine = nn.Sequential(
Aff(self.dim),
nn.Linear(dim, in_dim),
Rearrange('b t (c f) -> b c f t', c=2),
)
def forward(self, x_in, states=None):
x = self.to_patch_embedding(x_in)
if states is not None:
out_states = []
for i, mlp_block in enumerate(self.mlp_blocks):
if states is None:
x = mlp_block(x)
else:
x, state = mlp_block(x, states[i])
out_states.append(state)
x = self.affine(x)
x = x + x_in
if states is None:
return x
else:
return x, torch.stack(out_states, 0)
class Predictor(pl.LightningModule): # mel
def __init__(self, window_size=1536, sr=48000, lstm_dim=256, lstm_layers=3, n_mels=64):
super(Predictor, self).__init__()
self.window_size = window_size
self.hop_size = window_size // 2
self.lstm_dim = lstm_dim
self.n_mels = n_mels
self.lstm_layers = lstm_layers
fb = librosa.filters.mel(sr=sr, n_fft=self.window_size, n_mels=self.n_mels)[:, 1:]
self.fb = torch.from_numpy(fb).unsqueeze(0).unsqueeze(0)
self.lstm = nn.LSTM(input_size=self.n_mels, hidden_size=self.lstm_dim, bidirectional=False,
num_layers=self.lstm_layers, batch_first=True)
self.expand_dim = nn.Linear(self.lstm_dim, self.n_mels)
self.inv_mel = nn.Linear(self.n_mels, self.hop_size)
def forward(self, x, state=None): # B, 2, F, T
self.fb = self.fb.to(x.device)
x = torch.log(torch.matmul(self.fb, x) + 1e-8)
B, C, F, T = x.shape
x = x.reshape(B, F * C, T)
x = x.permute(0, 2, 1)
if state is None:
x, _ = self.lstm(x)
else:
x, state = self.lstm(x, (state[0], state[1]))
x = self.expand_dim(x)
x = torch.abs(self.inv_mel(torch.exp(x)))
x = x.permute(0, 2, 1)
x = x.reshape(B, C, -1, T)
if state is None:
return x
else:
return x, torch.stack(state, 0)
|