crimeacs's picture
Fixed imports
2bbf18f
raw
history blame
No virus
9.53 kB
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torchmetrics import MeanAbsoluteError
from torch.optim.lr_scheduler import ReduceLROnPlateau
import lightning as pl
class BlurPool1D(nn.Module):
def __init__(self, channels, pad_type="reflect", filt_size=3, stride=2, pad_off=0):
super(BlurPool1D, self).__init__()
self.filt_size = filt_size
self.pad_off = pad_off
self.pad_sizes = [
int(1.0 * (filt_size - 1) / 2),
int(np.ceil(1.0 * (filt_size - 1) / 2)),
]
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
self.stride = stride
self.off = int((self.stride - 1) / 2.0)
self.channels = channels
# print('Filter size [%i]' % filt_size)
if self.filt_size == 1:
a = np.array(
[
1.0,
]
)
elif self.filt_size == 2:
a = np.array([1.0, 1.0])
elif self.filt_size == 3:
a = np.array([1.0, 2.0, 1.0])
elif self.filt_size == 4:
a = np.array([1.0, 3.0, 3.0, 1.0])
elif self.filt_size == 5:
a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
elif self.filt_size == 6:
a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
elif self.filt_size == 7:
a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
filt = torch.Tensor(a)
filt = filt / torch.sum(filt)
self.register_buffer("filt", filt[None, None, :].repeat((self.channels, 1, 1)))
self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
def forward(self, inp):
if self.filt_size == 1:
if self.pad_off == 0:
return inp[:, :, :: self.stride]
else:
return self.pad(inp)[:, :, :: self.stride]
else:
return F.conv1d(
self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]
)
def get_pad_layer_1d(pad_type):
if pad_type in ["refl", "reflect"]:
PadLayer = nn.ReflectionPad1d
elif pad_type in ["repl", "replicate"]:
PadLayer = nn.ReplicationPad1d
elif pad_type == "zero":
PadLayer = nn.ZeroPad1d
else:
print("Pad type [%s] not recognized" % pad_type)
return PadLayer
from masksembles import common
class Masksembles1D(nn.Module):
def __init__(self, channels: int, n: int, scale: float):
super().__init__()
self.channels = channels
self.n = n
self.scale = scale
masks = common.generation_wrapper(channels, n, scale)
masks = torch.from_numpy(masks)
self.masks = torch.nn.Parameter(masks, requires_grad=False)
def forward(self, inputs):
batch = inputs.shape[0]
x = torch.split(inputs.unsqueeze(1), batch // self.n, dim=0)
x = torch.cat(x, dim=1).permute([1, 0, 2, 3])
x = x * self.masks.unsqueeze(1).unsqueeze(-1)
x = torch.cat(torch.split(x, 1, dim=0), dim=1)
return x.squeeze(0).type(inputs.dtype)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, kernel_size=7, groups=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv1d(
in_planes,
planes,
kernel_size=kernel_size,
stride=stride,
padding="same",
bias=False,
)
self.bn1 = nn.BatchNorm1d(planes)
self.conv2 = nn.Conv1d(
planes,
planes,
kernel_size=kernel_size,
stride=1,
padding="same",
bias=False,
)
self.bn2 = nn.BatchNorm1d(planes)
self.shortcut = nn.Sequential(
nn.Conv1d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
padding="same",
bias=False,
),
nn.BatchNorm1d(self.expansion * planes),
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Updated_onset_picker(nn.Module):
def __init__(
self,
):
super().__init__()
# self.activation = nn.ReLU()
# self.maxpool = nn.MaxPool1d(2)
self.n_masks = 128
self.block1 = nn.Sequential(
BasicBlock(3, 8, kernel_size=7, groups=1),
nn.GELU(),
BlurPool1D(8, filt_size=3, stride=2),
nn.GroupNorm(2, 8),
)
self.block2 = nn.Sequential(
BasicBlock(8, 16, kernel_size=7, groups=8),
nn.GELU(),
BlurPool1D(16, filt_size=3, stride=2),
nn.GroupNorm(2, 16),
)
self.block3 = nn.Sequential(
BasicBlock(16, 32, kernel_size=7, groups=16),
nn.GELU(),
BlurPool1D(32, filt_size=3, stride=2),
nn.GroupNorm(2, 32),
)
self.block4 = nn.Sequential(
BasicBlock(32, 64, kernel_size=7, groups=32),
nn.GELU(),
BlurPool1D(64, filt_size=3, stride=2),
nn.GroupNorm(2, 64),
)
self.block5 = nn.Sequential(
BasicBlock(64, 128, kernel_size=7, groups=64),
nn.GELU(),
BlurPool1D(128, filt_size=3, stride=2),
nn.GroupNorm(2, 128),
)
self.block6 = nn.Sequential(
Masksembles1D(128, self.n_masks, 2.0),
BasicBlock(128, 256, kernel_size=7, groups=128),
nn.GELU(),
BlurPool1D(256, filt_size=3, stride=2),
nn.GroupNorm(2, 256),
)
self.block7 = nn.Sequential(
Masksembles1D(256, self.n_masks, 2.0),
BasicBlock(256, 512, kernel_size=7, groups=256),
BlurPool1D(512, filt_size=3, stride=2),
nn.GELU(),
nn.GroupNorm(2, 512),
)
self.block8 = nn.Sequential(
Masksembles1D(512, self.n_masks, 2.0),
BasicBlock(512, 1024, kernel_size=7, groups=512),
BlurPool1D(1024, filt_size=3, stride=2),
nn.GELU(),
nn.GroupNorm(2, 1024),
)
self.block9 = nn.Sequential(
Masksembles1D(1024, self.n_masks, 2.0),
BasicBlock(1024, 128, kernel_size=7, groups=128),
# BlurPool1D(512, filt_size=3, stride=2),
# nn.GELU(),
# nn.GroupNorm(2,512),
)
self.out = nn.Sequential(nn.Linear(3072, 2), nn.Sigmoid())
def forward(self, x):
# Feature extraction
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.block7(x)
x = self.block8(x)
x = self.block9(x)
# Regressor
x = x.flatten(start_dim=1)
x = self.out(x)
return x
class Onset_picker(pl.LightningModule):
def __init__(self, picker, learning_rate):
super().__init__()
self.picker = picker
self.learning_rate = learning_rate
self.save_hyperparameters(ignore=['picker'])
self.mae = MeanAbsoluteError()
def compute_loss(self, y, pick, mae_name=False):
y_filt = y[y != 0]
pick_filt = pick[y != 0]
if len(y_filt) > 0:
loss = F.l1_loss(y_filt, pick_filt.flatten())
if mae_name != False:
mae_phase = self.mae(y_filt, pick_filt.flatten())*60
self.log(f'MAE/{mae_name}_val', mae_phase, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
else:
loss = 0
return loss
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
x, y_p, y_s = batch
# x, y_p, y_s, y_pg, y_sg, y_pn, y_sn = batch
picks = self.picker(x)
p_pick = picks[:,0]
s_pick = picks[:,1]
p_loss = self.compute_loss(y_p, p_pick)
s_loss = self.compute_loss(y_s, s_pick)
loss = (p_loss+s_loss)/2
self.log('Loss/train', loss, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True)
return loss
def validation_step(self, batch, batch_idx):
x, y_p, y_s = batch
picks = self.picker(x)
p_pick = picks[:,0]
s_pick = picks[:,1]
p_loss = self.compute_loss(y_p, p_pick, mae_name='P')
s_loss = self.compute_loss(y_s, s_pick, mae_name='S')
loss = (p_loss+s_loss)/2
self.log('Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, cooldown=10, threshold=1e-3)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 3e-4, epochs=300, steps_per_epoch=len(train_loader))
monitor = 'Loss/train'
return {"optimizer": optimizer, "lr_scheduler": scheduler, 'monitor': monitor}
def forward(self, x):
picks = self.picker(x)
return picks