akhaliq3
spaces demo
607ecc1
import click
import gin
import pytorch_lightning as pl
from neural_waveshaping_synthesis.data.general import GeneralDataModule
from neural_waveshaping_synthesis.data.urmp import URMPDataModule
from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping
@gin.configurable
def get_model(model, with_wandb):
return model(log_audio=with_wandb)
@gin.configurable
def trainer_kwargs(**kwargs):
return kwargs
@click.command()
@click.option("--gin-file", prompt="Gin config file")
@click.option("--dataset-path", prompt="Dataset root")
@click.option("--urmp", is_flag=True)
@click.option("--device", default="0")
@click.option("--instrument", default="vn")
@click.option("--load-data-to-memory", is_flag=True)
@click.option("--with-wandb", is_flag=True)
@click.option("--restore-checkpoint", default="")
def main(
gin_file,
dataset_path,
urmp,
device,
instrument,
load_data_to_memory,
with_wandb,
restore_checkpoint,
):
gin.parse_config_file(gin_file)
model = get_model(with_wandb=with_wandb)
if urmp:
data = URMPDataModule(
dataset_path,
instrument,
load_to_memory=load_data_to_memory,
num_workers=16,
shuffle=True,
)
else:
data = GeneralDataModule(
dataset_path,
load_to_memory=load_data_to_memory,
num_workers=16,
shuffle=True,
)
checkpointing = pl.callbacks.ModelCheckpoint(
monitor="val/loss", save_top_k=1, save_last=True
)
callbacks = [checkpointing]
if with_wandb:
lr_logger = pl.callbacks.LearningRateMonitor(logging_interval="epoch")
callbacks.append(lr_logger)
logger = pl.loggers.WandbLogger(project="neural-waveshaping-synthesis")
logger.watch(model, log="parameters")
kwargs = trainer_kwargs()
trainer = pl.Trainer(
logger=logger if with_wandb else None,
callbacks=callbacks,
gpus=device,
resume_from_checkpoint=restore_checkpoint if restore_checkpoint != "" else None,
**kwargs
)
trainer.fit(model, data)
if __name__ == "__main__":
main()