Spaces:
Sleeping
Sleeping
from pytorch_lightning.callbacks import Callback | |
import pytorch_lightning as pl | |
from einops import rearrange | |
import torch | |
import wandb | |
from torch import Tensor | |
from remfx import effects | |
ALL_EFFECTS = effects.Pedalboard_Effects | |
class AudioCallback(Callback): | |
def __init__(self, sample_rate, log_audio, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.log_audio = log_audio | |
self.log_train_audio = True | |
self.sample_rate = sample_rate | |
if not self.log_audio: | |
self.log_train_audio = False | |
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): | |
# Log initial audio | |
if self.log_train_audio: | |
x, y, _, _ = batch | |
# Concat samples together for easier viewing in dashboard | |
input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0) | |
target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0) | |
log_wandb_audio_batch( | |
logger=trainer.logger, | |
id="input_effected_audio", | |
samples=input_samples.cpu(), | |
sampling_rate=self.sample_rate, | |
caption="Training Data", | |
) | |
log_wandb_audio_batch( | |
logger=trainer.logger, | |
id="target_audio", | |
samples=target_samples.cpu(), | |
sampling_rate=self.sample_rate, | |
caption="Target Data", | |
) | |
self.log_train_audio = False | |
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): | |
x, target, _, rem_fx_labels = batch | |
# Only run on first batch | |
if batch_idx == 0 and self.log_audio: | |
with torch.no_grad(): | |
# Avoids circular import | |
from remfx.models import RemFXChainInference | |
if isinstance(pl_module, RemFXChainInference): | |
y = pl_module.sample(batch) | |
effects_present_name = [ | |
[ | |
ALL_EFFECTS[i].__name__.replace("RandomPedalboard", "") | |
for i, effect in enumerate(effect_label) | |
if effect == 1.0 | |
] | |
for effect_label in rem_fx_labels | |
] | |
for i, label in enumerate(effects_present_name): | |
self.log(f"{'_'.join(label)}", 0.0) | |
else: | |
y = pl_module.model.sample(x) | |
# Concat samples together for easier viewing in dashboard | |
# 2 seconds of silence between each sample | |
silence = torch.zeros_like(x) | |
silence = silence[:, : self.sample_rate * 2] | |
concat_samples = torch.cat([y, silence, x, silence, target], dim=-1) | |
log_wandb_audio_batch( | |
logger=trainer.logger, | |
id="prediction_input_target", | |
samples=concat_samples.cpu(), | |
sampling_rate=self.sample_rate, | |
caption=f"Epoch {trainer.current_epoch}", | |
) | |
def on_test_batch_start(self, *args): | |
self.on_validation_batch_start(*args) | |
def log_wandb_audio_batch( | |
logger: pl.loggers.WandbLogger, | |
id: str, | |
samples: Tensor, | |
sampling_rate: int, | |
caption: str = "", | |
max_items: int = 10, | |
): | |
if type(logger) != pl.loggers.WandbLogger: | |
return | |
num_items = samples.shape[0] | |
samples = rearrange(samples, "b c t -> b t c") | |
for idx in range(num_items): | |
if idx >= max_items: | |
break | |
logger.experiment.log( | |
{ | |
f"{id}_{idx}": wandb.Audio( | |
samples[idx].cpu().numpy(), | |
caption=caption, | |
sample_rate=sampling_rate, | |
) | |
} | |
) | |