deepafx-st / deepafx_st /probes /probe_system.py
yourusername's picture
:beers: cheers
66a6dc0
import torch
import julius
import torchopenl3
import torchmetrics
import pytorch_lightning as pl
from typing import Tuple, List, Dict
from argparse import ArgumentParser
from deepafx_st.probes.cdpam_encoder import CDPAMEncoder
from deepafx_st.probes.random_mel import RandomMelProjection
import deepafx_st.utils as utils
from deepafx_st.utils import DSPMode
from deepafx_st.system import System
from deepafx_st.data.style import StyleDataset
class ProbeSystem(pl.LightningModule):
def __init__(
self,
audio_dir=None,
num_classes=5,
task="style",
encoder_type="deepafx_st_autodiff",
deepafx_st_autodiff_ckpt=None,
deepafx_st_spsa_ckpt=None,
deepafx_st_proxy0_ckpt=None,
probe_type="linear",
batch_size=32,
lr=3e-4,
lr_patience=20,
patience=10,
preload=False,
sample_rate=24000,
shuffle=True,
num_workers=16,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
if "deepafx_st" in self.hparams.encoder_type:
if "autodiff" in self.hparams.encoder_type:
self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_autodiff_ckpt
elif "spsa" in self.hparams.encoder_type:
self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_spsa_ckpt
elif "proxy0" in self.hparams.encoder_type:
self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_proxy0_ckpt
else:
raise RuntimeError(f"Invalid encoder_type: {self.hparams.encoder_type}")
if self.hparams.deepafx_st_ckpt is None:
raise RuntimeError(
f"Must supply {self.hparams.encoder_type}_ckpt checkpoint."
)
use_dsp = DSPMode.NONE
system = System.load_from_checkpoint(
self.hparams.deepafx_st_ckpt,
use_dsp=use_dsp,
batch_size=self.hparams.batch_size,
spsa_parallel=False,
proxy_ckpts=[],
strict=False,
)
system.eval()
self.encoder = system.encoder
self.hparams.embed_dim = self.encoder.embed_dim
# freeze weights
for name, param in self.encoder.named_parameters():
param.requires_grad = False
elif self.hparams.encoder_type == "openl3":
self.encoder = torchopenl3.models.load_audio_embedding_model(
input_repr=self.hparams.openl3_input_repr,
embedding_size=self.hparams.openl3_embedding_size,
content_type=self.hparams.openl3_content_type,
)
self.hparams.embed_dim = 6144
elif self.hparams.encoder_type == "random_mel":
self.encoder = RandomMelProjection(
self.hparams.sample_rate,
self.hparams.random_mel_embedding_size,
self.hparams.random_mel_n_mels,
self.hparams.random_mel_n_fft,
self.hparams.random_mel_hop_size,
)
self.hparams.embed_dim = self.hparams.random_mel_embedding_size
elif self.hparams.encoder_type == "cdpam":
self.encoder = CDPAMEncoder(self.hparams.cdpam_ckpt)
self.encoder.eval()
self.hparams.embed_dim = self.encoder.embed_dim
else:
raise ValueError(f"Invalid encoder_type: {self.hparams.encoder_type}")
if self.hparams.probe_type == "linear":
if self.hparams.task == "style":
self.probe = torch.nn.Sequential(
torch.nn.Linear(self.hparams.embed_dim, self.hparams.num_classes),
# torch.nn.Softmax(-1),
)
elif self.hparams.probe_type == "mlp":
if self.hparams.task == "style":
self.probe = torch.nn.Sequential(
torch.nn.Linear(self.hparams.embed_dim, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, self.hparams.num_classes),
)
self.accuracy = torchmetrics.Accuracy()
self.f1_score = torchmetrics.F1Score(self.hparams.num_classes)
def forward(self, x):
bs, chs, samp = x.size()
with torch.no_grad():
if "deepafx_st" in self.hparams.encoder_type:
x /= x.abs().max()
x *= 10 ** (-12.0 / 20) # with min 12 dBFS headroom
e = self.encoder(x)
norm = torch.norm(e, p=2, dim=-1, keepdim=True)
e = e / norm
elif self.hparams.encoder_type == "openl3":
# x = julius.resample_frac(x, self.hparams.sample_rate, 48000)
e, ts = torchopenl3.get_audio_embedding(
x,
48000,
model=self.encoder,
input_repr="mel128",
content_type="music",
)
e = e.permute(0, 2, 1)
e = e.mean(dim=-1)
# normalize by L2 norm
norm = torch.norm(e, p=2, dim=-1, keepdim=True)
e = e / norm
elif self.hparams.encoder_type == "random_mel":
e = self.encoder(x)
norm = torch.norm(e, p=2, dim=-1, keepdim=True)
e = e / norm
elif self.hparams.encoder_type == "cdpam":
# x = julius.resample_frac(x, self.hparams.sample_rate, 22050)
x = torch.round(x * 32768)
e = self.encoder(x)
return self.probe(e)
def common_step(
self,
batch: Tuple,
batch_idx: int,
optimizer_idx: int = 0,
train: bool = True,
):
loss = 0
x, y = batch
y_hat = self(x)
# compute CE
if self.hparams.task == "style":
loss = torch.nn.functional.cross_entropy(y_hat, y)
if not train:
# store audio data
data_dict = {"x": x.float().cpu()}
else:
data_dict = {}
self.log(
"train_loss" if train else "val_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=False,
logger=True,
sync_dist=True,
)
if not train and self.hparams.task == "style":
self.log("val_acc_step", self.accuracy(y_hat, y))
self.log("val_f1_step", self.f1_score(y_hat, y))
return loss, data_dict
def training_step(self, batch, batch_idx, optimizer_idx=0):
loss, _ = self.common_step(batch, batch_idx)
return loss
def validation_step(self, batch, batch_idx):
loss, data_dict = self.common_step(batch, batch_idx, train=False)
if batch_idx == 0:
return data_dict
def validation_epoch_end(self, outputs) -> None:
if self.hparams.task == "style":
self.log("val_acc_epoch", self.accuracy.compute())
self.log("val_f1_epoch", self.f1_score.compute())
return super().validation_epoch_end(outputs)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.probe.parameters(),
lr=self.hparams.lr,
betas=(0.9, 0.999),
)
ms1 = int(self.hparams.max_epochs * 0.8)
ms2 = int(self.hparams.max_epochs * 0.95)
print(
"Learning rate schedule:",
f"0 {self.hparams.lr:0.2e} -> ",
f"{ms1} {self.hparams.lr*0.1:0.2e} -> ",
f"{ms2} {self.hparams.lr*0.01:0.2e}",
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[ms1, ms2],
gamma=0.1,
)
return [optimizer], {"scheduler": scheduler, "monitor": "val_loss"}
def train_dataloader(self):
if self.hparams.task == "style":
train_dataset = StyleDataset(
self.hparams.audio_dir,
"train",
sample_rate=self.hparams.encoder_sample_rate,
)
g = torch.Generator()
g.manual_seed(0)
return torch.utils.data.DataLoader(
train_dataset,
num_workers=self.hparams.num_workers,
batch_size=self.hparams.batch_size,
shuffle=True,
worker_init_fn=utils.seed_worker,
generator=g,
pin_memory=True,
)
def val_dataloader(self):
if self.hparams.task == "style":
val_dataset = StyleDataset(
self.hparams.audio_dir,
subset="val",
sample_rate=self.hparams.encoder_sample_rate,
)
g = torch.Generator()
g.manual_seed(0)
return torch.utils.data.DataLoader(
val_dataset,
num_workers=self.hparams.num_workers,
batch_size=self.hparams.batch_size,
worker_init_fn=utils.seed_worker,
generator=g,
pin_memory=True,
)
# add any model hyperparameters here
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
# --- Model ---
parser.add_argument("--encoder_type", type=str, default="deeapfx2")
parser.add_argument("--probe_type", type=str, default="linear")
parser.add_argument("--task", type=str, default="style")
parser.add_argument("--encoder_sample_rate", type=int, default=24000)
# --- deeapfx2 ---
parser.add_argument("--deepafx_st_autodiff_ckpt", type=str)
parser.add_argument("--deepafx_st_spsa_ckpt", type=str)
parser.add_argument("--deepafx_st_proxy0_ckpt", type=str)
# --- cdpam ---
parser.add_argument("--cdpam_ckpt", type=str)
# --- openl3 ---
parser.add_argument("--openl3_input_repr", type=str, default="mel128")
parser.add_argument("--openl3_content_type", type=str, default="env")
parser.add_argument("--openl3_embedding_size", type=int, default=6144)
# --- random_mel ---
parser.add_argument("--random_mel_embedding_size", type=str, default=4096)
parser.add_argument("--random_mel_n_fft", type=str, default=4096)
parser.add_argument("--random_mel_hop_size", type=str, default=1024)
parser.add_argument("--random_mel_n_mels", type=str, default=128)
# --- Training ---
parser.add_argument("--audio_dir", type=str)
parser.add_argument("--num_classes", type=int, default=5)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--lr_patience", type=int, default=20)
parser.add_argument("--patience", type=int, default=10)
parser.add_argument("--preload", action="store_true")
parser.add_argument("--sample_rate", type=int, default=24000)
parser.add_argument("--num_workers", type=int, default=8)
return parser