Spaces:
Runtime error
Runtime error
import auraloss | |
import gin | |
import pytorch_lightning as pl | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import wandb | |
from .modules.dynamic import TimeDistributedMLP | |
from .modules.generators import FIRNoiseSynth, HarmonicOscillator | |
from .modules.shaping import NEWT, Reverb | |
gin.external_configurable(nn.GRU, module="torch.nn") | |
gin.external_configurable(nn.Conv1d, module="torch.nn") | |
class ControlModule(nn.Module): | |
def __init__(self, control_size: int, hidden_size: int, embedding_size: int): | |
super().__init__() | |
self.gru = nn.GRU(control_size, hidden_size, batch_first=True) | |
self.proj = nn.Conv1d(hidden_size, embedding_size, 1) | |
def forward(self, x): | |
x, _ = self.gru(x.transpose(1, 2)) | |
return self.proj(x.transpose(1, 2)) | |
class NeuralWaveshaping(pl.LightningModule): | |
def __init__( | |
self, | |
n_waveshapers: int, | |
control_hop: int, | |
sample_rate: float = 16000, | |
learning_rate: float = 1e-3, | |
lr_decay: float = 0.9, | |
lr_decay_interval: int = 10000, | |
log_audio: bool = False, | |
): | |
super().__init__() | |
self.save_hyperparameters() | |
self.learning_rate = learning_rate | |
self.lr_decay = lr_decay | |
self.lr_decay_interval = lr_decay_interval | |
self.control_hop = control_hop | |
self.log_audio = log_audio | |
self.sample_rate = sample_rate | |
self.embedding = ControlModule() | |
self.osc = HarmonicOscillator() | |
self.harmonic_mixer = nn.Conv1d(self.osc.n_harmonics, n_waveshapers, 1) | |
self.newt = NEWT() | |
with gin.config_scope("noise_synth"): | |
self.h_generator = TimeDistributedMLP() | |
self.noise_synth = FIRNoiseSynth() | |
self.reverb = Reverb() | |
def render_exciter(self, f0): | |
sig = self.osc(f0[:, 0]) | |
sig = self.harmonic_mixer(sig) | |
return sig | |
def get_embedding(self, control): | |
f0, other = control[:, 0:1], control[:, 1:2] | |
control = torch.cat((f0, other), dim=1) | |
return self.embedding(control) | |
def forward(self, f0, control): | |
f0_upsampled = F.upsample(f0, f0.shape[-1] * self.control_hop, mode="linear") | |
x = self.render_exciter(f0_upsampled) | |
control_embedding = self.get_embedding(control) | |
x = self.newt(x, control_embedding) | |
H = self.h_generator(control_embedding) | |
noise = self.noise_synth(H) | |
x = torch.cat((x, noise), dim=1) | |
x = x.sum(1) | |
x = self.reverb(x) | |
return x | |
def configure_optimizers(self): | |
self.stft_loss = auraloss.freq.MultiResolutionSTFTLoss() | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) | |
scheduler = torch.optim.lr_scheduler.StepLR( | |
optimizer, self.lr_decay_interval, self.lr_decay | |
) | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": {"scheduler": scheduler, "interval": "step"}, | |
} | |
def _run_step(self, batch): | |
audio = batch["audio"].float() | |
f0 = batch["f0"].float() | |
control = batch["control"].float() | |
recon = self(f0, control) | |
loss = self.stft_loss(recon, audio) | |
return loss, recon, audio | |
def _log_audio(self, name, audio): | |
wandb.log( | |
{ | |
"audio/%s" | |
% name: wandb.Audio(audio, sample_rate=self.sample_rate, caption=name) | |
}, | |
commit=False, | |
) | |
def training_step(self, batch, batch_idx): | |
loss, _, _ = self._run_step(batch) | |
self.log( | |
"train/loss", | |
loss.item(), | |
on_step=False, | |
on_epoch=True, | |
prog_bar=True, | |
logger=True, | |
sync_dist=True, | |
) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
loss, recon, audio = self._run_step(batch) | |
self.log( | |
"val/loss", | |
loss.item(), | |
on_step=False, | |
on_epoch=True, | |
prog_bar=True, | |
logger=True, | |
sync_dist=True, | |
) | |
if batch_idx == 0 and self.log_audio: | |
self._log_audio("original", audio[0].detach().cpu().squeeze()) | |
self._log_audio("recon", recon[0].detach().cpu().squeeze()) | |
return loss | |
def test_step(self, batch, batch_idx): | |
loss, recon, audio = self._run_step(batch) | |
self.log( | |
"test/loss", | |
loss.item(), | |
on_step=False, | |
on_epoch=True, | |
prog_bar=True, | |
logger=True, | |
sync_dist=True, | |
) | |
if batch_idx == 0: | |
self._log_audio("original", audio[0].detach().cpu().squeeze()) | |
self._log_audio("recon", recon[0].detach().cpu().squeeze()) | |