yourusername's picture
:beers: cheers
66a6dc0
from re import X
import torch
import auraloss
import pytorch_lightning as pl
from typing import Tuple, List, Dict
from argparse import ArgumentParser
import deepafx_st.utils as utils
from deepafx_st.data.proxy import DSPProxyDataset
from deepafx_st.processors.proxy.tcn import ConditionalTCN
from deepafx_st.processors.spsa.channel import SPSAChannel
from deepafx_st.processors.dsp.peq import ParametricEQ
from deepafx_st.processors.dsp.compressor import Compressor
class ProxySystem(pl.LightningModule):
def __init__(
self,
causal=True,
nblocks=4,
dilation_growth=8,
kernel_size=13,
channel_width=64,
input_dir=None,
processor="channel",
batch_size=32,
lr=3e-4,
lr_patience=20,
patience=10,
preload=False,
sample_rate=24000,
shuffle=True,
train_length=65536,
train_examples_per_epoch=10000,
val_length=131072,
val_examples_per_epoch=1000,
num_workers=16,
output_gain=False,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
#print(f"Proxy Processor: {processor} @ fs={sample_rate} Hz")
# construct both the true DSP...
if self.hparams.processor == "peq":
self.processor = ParametricEQ(self.hparams.sample_rate)
elif self.hparams.processor == "comp":
self.processor = Compressor(self.hparams.sample_rate)
elif self.hparams.processor == "channel":
self.processor = SPSAChannel(self.hparams.sample_rate)
# and the neural network proxy
self.proxy = ConditionalTCN(
self.hparams.sample_rate,
num_control_params=self.processor.num_control_params,
causal=self.hparams.causal,
nblocks=self.hparams.nblocks,
channel_width=self.hparams.channel_width,
kernel_size=self.hparams.kernel_size,
dilation_growth=self.hparams.dilation_growth,
)
self.receptive_field = self.proxy.compute_receptive_field()
self.recon_losses = {}
self.recon_loss_weights = {}
self.recon_losses["mrstft"] = auraloss.freq.MultiResolutionSTFTLoss(
fft_sizes=[32, 128, 512, 2048, 8192, 32768],
hop_sizes=[16, 64, 256, 1024, 4096, 16384],
win_lengths=[32, 128, 512, 2048, 8192, 32768],
w_sc=0.0,
w_phs=0.0,
w_lin_mag=1.0,
w_log_mag=1.0,
)
self.recon_loss_weights["mrstft"] = 1.0
self.recon_losses["l1"] = torch.nn.L1Loss()
self.recon_loss_weights["l1"] = 100.0
def forward(self, x, p, use_dsp=False, sample_rate=24000, **kwargs):
"""Use the pre-trained neural network proxy effect."""
bs, chs, samp = x.size()
if not use_dsp:
y = self.proxy(x, p)
# manually apply the makeup gain parameter
if self.hparams.output_gain and not self.hparams.processor == "peq":
gain_db = (p[..., -1] * 96) - 48
gain_ln = 10 ** (gain_db / 20.0)
y *= gain_ln.view(bs, chs, 1)
else:
with torch.no_grad():
bs, chs, s = x.shape
if self.hparams.output_gain and not self.hparams.processor == "peq":
# override makeup gain
gain_db = (p[..., -1] * 96) - 48
gain_ln = 10 ** (gain_db / 20.0)
p[..., -1] = 0.5
if self.hparams.processor == "channel":
y_temp = self.processor(x.cpu(), p.cpu())
y_temp = y_temp.view(bs, chs, s).type_as(x)
else:
y_temp = self.processor(
x.cpu().numpy(),
p.cpu().numpy(),
sample_rate,
)
y_temp = torch.tensor(y_temp).view(bs, chs, s).type_as(x)
y = y_temp.type_as(x).view(bs, 1, -1)
if self.hparams.output_gain and not self.hparams.processor == "peq":
y *= gain_ln.view(bs, chs, 1)
return y
def common_step(
self,
batch: Tuple,
batch_idx: int,
optimizer_idx: int = 0,
train: bool = True,
):
loss = 0
x, y, p = batch
y_hat = self(x, p)
# compute loss
for loss_idx, (loss_name, loss_fn) in enumerate(self.recon_losses.items()):
tmp_loss = loss_fn(y_hat.float(), y.float())
loss += self.recon_loss_weights[loss_name] * tmp_loss
self.log(
f"train_loss/{loss_name}" if train else f"val_loss/{loss_name}",
tmp_loss,
on_step=True,
on_epoch=True,
prog_bar=False,
logger=True,
sync_dist=True,
)
if not train:
# store audio data
data_dict = {
"x": x.float().cpu(),
"y": y.float().cpu(),
"p": p.float().cpu(),
"y_hat": y_hat.float().cpu(),
}
else:
data_dict = {}
self.log(
"train_loss" if train else "val_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=False,
logger=True,
sync_dist=True,
)
return loss, data_dict
def training_step(self, batch, batch_idx, optimizer_idx=0):
loss, _ = self.common_step(batch, batch_idx)
return loss
def validation_step(self, batch, batch_idx):
loss, data_dict = self.common_step(batch, batch_idx, train=False)
if batch_idx == 0:
return data_dict
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.proxy.parameters(),
lr=self.hparams.lr,
betas=(0.9, 0.999),
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
patience=self.hparams.lr_patience,
verbose=True,
)
return [optimizer], {"scheduler": scheduler, "monitor": "val_loss"}
def train_dataloader(self):
train_dataset = DSPProxyDataset(
self.hparams.input_dir,
self.processor,
self.hparams.processor, # name
subset="train",
length=self.hparams.train_length,
num_examples_per_epoch=self.hparams.train_examples_per_epoch,
half=True if self.hparams.precision == 16 else False,
buffer_size_gb=self.hparams.buffer_size_gb,
buffer_reload_rate=self.hparams.buffer_reload_rate,
)
g = torch.Generator()
g.manual_seed(0)
return torch.utils.data.DataLoader(
train_dataset,
num_workers=self.hparams.num_workers,
batch_size=self.hparams.batch_size,
worker_init_fn=utils.seed_worker,
generator=g,
pin_memory=True,
)
def val_dataloader(self):
val_dataset = DSPProxyDataset(
self.hparams.input_dir,
self.processor,
self.hparams.processor, # name
subset="val",
length=self.hparams.val_length,
num_examples_per_epoch=self.hparams.val_examples_per_epoch,
half=True if self.hparams.precision == 16 else False,
buffer_size_gb=self.hparams.buffer_size_gb,
buffer_reload_rate=self.hparams.buffer_reload_rate,
)
g = torch.Generator()
g.manual_seed(0)
return torch.utils.data.DataLoader(
val_dataset,
num_workers=self.hparams.num_workers,
batch_size=self.hparams.batch_size,
worker_init_fn=utils.seed_worker,
generator=g,
pin_memory=True,
)
@staticmethod
def count_control_params(plugin_config):
num_control_params = 0
for plugin in plugin_config["plugins"]:
for port in plugin["ports"]:
if port["optim"]:
num_control_params += 1
return num_control_params
# add any model hyperparameters here
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
# --- Model ---
parser.add_argument("--causal", action="store_true")
parser.add_argument("--output_gain", action="store_true")
parser.add_argument("--dilation_growth", type=int, default=8)
parser.add_argument("--nblocks", type=int, default=4)
parser.add_argument("--kernel_size", type=int, default=13)
parser.add_argument("--channel_width", type=int, default=13)
# --- Training ---
parser.add_argument("--input_dir", type=str)
parser.add_argument("--processor", type=str)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--lr_patience", type=int, default=20)
parser.add_argument("--patience", type=int, default=10)
parser.add_argument("--preload", action="store_true")
parser.add_argument("--sample_rate", type=int, default=24000)
parser.add_argument("--shuffle", type=bool, default=True)
parser.add_argument("--train_length", type=int, default=65536)
parser.add_argument("--train_examples_per_epoch", type=int, default=10000)
parser.add_argument("--val_length", type=int, default=131072)
parser.add_argument("--val_examples_per_epoch", type=int, default=1000)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--buffer_reload_rate", type=int, default=1000)
parser.add_argument("--buffer_size_gb", type=float, default=1.0)
return parser