|
import torch
|
|
import torch.nn as nn
|
|
|
|
import math
|
|
from einops import rearrange
|
|
|
|
|
|
class LinearAttention(nn.Module):
|
|
def __init__(self, dim, heads=8, dim_head=32, q_norm=True):
|
|
super(LinearAttention, self).__init__()
|
|
self.heads = heads
|
|
hidden_dim = dim_head * heads
|
|
self.to_qkv = torch.nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
|
|
self.to_out = torch.nn.Conv1d(hidden_dim, dim, 1)
|
|
self.q_norm = q_norm
|
|
|
|
def forward(self, x):
|
|
|
|
x = x.permute(0, 2, 1)
|
|
|
|
|
|
qkv = self.to_qkv(x)
|
|
q, k, v = rearrange(qkv, 'b (qkv heads c) l -> qkv b heads c l',
|
|
heads=self.heads, qkv=3)
|
|
k = k.softmax(dim=-1)
|
|
if self.q_norm:
|
|
q = q.softmax(dim=-2)
|
|
|
|
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
|
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
|
out = rearrange(out, 'b heads c l -> b (heads c) l',
|
|
heads=self.heads)
|
|
return self.to_out(out).permute(0, 2, 1)
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(self, dim, n_heads=4, layer_norm_first=True):
|
|
super(TransformerBlock, self).__init__()
|
|
dim_head = dim//n_heads
|
|
self.attention = LinearAttention(dim, heads=n_heads, dim_head=dim_head)
|
|
|
|
self.norm1 = nn.LayerNorm(dim)
|
|
self.norm2 = nn.LayerNorm(dim)
|
|
|
|
self.feed_forward = nn.Sequential(nn.Linear(dim, dim*2),
|
|
nn.SiLU(),
|
|
nn.Linear(dim*2, dim))
|
|
|
|
self.dropout1 = nn.Dropout(0.2)
|
|
self.dropout2 = nn.Dropout(0.2)
|
|
|
|
self.layer_norm_first = layer_norm_first
|
|
|
|
def forward(self, x):
|
|
nx = self.norm1(x)
|
|
x = x + self.dropout1(self.attention(nx))
|
|
nx = self.norm2(x)
|
|
nx = x + self.dropout2(self.feed_forward(nx))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return nx
|
|
|
|
|
|
class PitchFormer(nn.Module):
|
|
def __init__(self, n_mels, hidden_size, attn_layers=4):
|
|
super(PitchFormer, self).__init__()
|
|
|
|
self.sp_linear = nn.Sequential(nn.Conv1d(n_mels, hidden_size, kernel_size=1),
|
|
nn.SiLU(),
|
|
nn.Conv1d(hidden_size, hidden_size//2, kernel_size=1)
|
|
)
|
|
|
|
self.midi_linear = nn.Sequential(nn.Conv1d(1, hidden_size, kernel_size=1),
|
|
nn.SiLU(),
|
|
nn.Conv1d(hidden_size, hidden_size//2, kernel_size=1),
|
|
)
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.pos_conv = nn.Conv1d(hidden_size, hidden_size,
|
|
kernel_size=63,
|
|
padding=31,
|
|
)
|
|
dropout = 0
|
|
std = math.sqrt((4 * (1.0 - dropout)) / (63 * hidden_size))
|
|
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
|
nn.init.constant_(self.pos_conv.bias, 0)
|
|
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
|
self.pos_conv = nn.Sequential(self.pos_conv, nn.SiLU())
|
|
|
|
self.attn_block = nn.ModuleList([TransformerBlock(hidden_size, 4) for i in range(attn_layers)])
|
|
|
|
|
|
|
|
self.linear = nn.Sequential(nn.Linear(hidden_size, hidden_size),
|
|
nn.SiLU(),
|
|
nn.Linear(hidden_size, 1))
|
|
|
|
def forward(self, midi, sp):
|
|
midi = midi.unsqueeze(1)
|
|
midi = self.midi_linear(midi)
|
|
sp = self.sp_linear(sp)
|
|
|
|
x = torch.cat([midi, sp], dim=1)
|
|
|
|
|
|
x_conv = self.pos_conv(x)
|
|
x = x + x_conv
|
|
|
|
|
|
x = x.permute(0, 2, 1)
|
|
for layer in self.attn_block:
|
|
x = layer(x)
|
|
|
|
x = self.linear(x)
|
|
|
|
return x.squeeze(-1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
model = PitchFormer(100, 256)
|
|
|
|
x = torch.rand((4, 64))
|
|
sp = torch.rand((4, 100, 64))
|
|
midi = torch.rand((4, 64))
|
|
|
|
y = model(midi, sp) |