jerryhai
Track binary files with Git LFS
90f7c1e
raw
history blame
4.69 kB
import torch
import torch.nn as nn
import math
from einops import rearrange
class LinearAttention(nn.Module):
def __init__(self, dim, heads=8, dim_head=32, q_norm=True):
super(LinearAttention, self).__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv1d(hidden_dim, dim, 1)
self.q_norm = q_norm
def forward(self, x):
# b, l, c = x.shape
x = x.permute(0, 2, 1)
# b, c, l = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) l -> qkv b heads c l',
heads=self.heads, qkv=3)
k = k.softmax(dim=-1)
if self.q_norm:
q = q.softmax(dim=-2)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c l -> b (heads c) l',
heads=self.heads)
return self.to_out(out).permute(0, 2, 1)
class TransformerBlock(nn.Module):
def __init__(self, dim, n_heads=4, layer_norm_first=True):
super(TransformerBlock, self).__init__()
dim_head = dim//n_heads
self.attention = LinearAttention(dim, heads=n_heads, dim_head=dim_head)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.feed_forward = nn.Sequential(nn.Linear(dim, dim*2),
nn.SiLU(),
nn.Linear(dim*2, dim))
self.dropout1 = nn.Dropout(0.2)
self.dropout2 = nn.Dropout(0.2)
self.layer_norm_first = layer_norm_first
def forward(self, x):
nx = self.norm1(x)
x = x + self.dropout1(self.attention(nx))
nx = self.norm2(x)
nx = x + self.dropout2(self.feed_forward(nx))
# attention_out = self.attention(x)
# attention_residual_out = attention_out + x
# # print(attention_residual_out.shape)
# norm1_out = self.dropout1(self.norm1(attention_residual_out))
#
# feed_fwd_out = self.feed_forward(norm1_out)
# feed_fwd_residual_out = feed_fwd_out + norm1_out
# norm2_out = self.dropout2(self.norm2(feed_fwd_residual_out))
return nx
class PitchFormer(nn.Module):
def __init__(self, n_mels, hidden_size, attn_layers=4):
super(PitchFormer, self).__init__()
self.sp_linear = nn.Sequential(nn.Conv1d(n_mels, hidden_size, kernel_size=1),
nn.SiLU(),
nn.Conv1d(hidden_size, hidden_size//2, kernel_size=1)
)
self.midi_linear = nn.Sequential(nn.Conv1d(1, hidden_size, kernel_size=1),
nn.SiLU(),
nn.Conv1d(hidden_size, hidden_size//2, kernel_size=1),
)
self.hidden_size = hidden_size
self.pos_conv = nn.Conv1d(hidden_size, hidden_size,
kernel_size=63,
padding=31,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (63 * hidden_size))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, nn.SiLU())
self.attn_block = nn.ModuleList([TransformerBlock(hidden_size, 4) for i in range(attn_layers)])
# self.silu = nn.SiLU()
self.linear = nn.Sequential(nn.Linear(hidden_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, 1))
def forward(self, midi, sp):
midi = midi.unsqueeze(1)
midi = self.midi_linear(midi)
sp = self.sp_linear(sp)
x = torch.cat([midi, sp], dim=1)
# position encoding
x_conv = self.pos_conv(x)
x = x + x_conv
# x = self.silu(x)
x = x.permute(0, 2, 1)
for layer in self.attn_block:
x = layer(x)
x = self.linear(x)
return x.squeeze(-1)
if __name__ == '__main__':
model = PitchFormer(100, 256)
x = torch.rand((4, 64))
sp = torch.rand((4, 100, 64))
midi = torch.rand((4, 64))
y = model(midi, sp)