Spaces:
Runtime error
Runtime error
File size: 2,195 Bytes
607ecc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import click
import gin
import pytorch_lightning as pl
from neural_waveshaping_synthesis.data.general import GeneralDataModule
from neural_waveshaping_synthesis.data.urmp import URMPDataModule
from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping
@gin.configurable
def get_model(model, with_wandb):
return model(log_audio=with_wandb)
@gin.configurable
def trainer_kwargs(**kwargs):
return kwargs
@click.command()
@click.option("--gin-file", prompt="Gin config file")
@click.option("--dataset-path", prompt="Dataset root")
@click.option("--urmp", is_flag=True)
@click.option("--device", default="0")
@click.option("--instrument", default="vn")
@click.option("--load-data-to-memory", is_flag=True)
@click.option("--with-wandb", is_flag=True)
@click.option("--restore-checkpoint", default="")
def main(
gin_file,
dataset_path,
urmp,
device,
instrument,
load_data_to_memory,
with_wandb,
restore_checkpoint,
):
gin.parse_config_file(gin_file)
model = get_model(with_wandb=with_wandb)
if urmp:
data = URMPDataModule(
dataset_path,
instrument,
load_to_memory=load_data_to_memory,
num_workers=16,
shuffle=True,
)
else:
data = GeneralDataModule(
dataset_path,
load_to_memory=load_data_to_memory,
num_workers=16,
shuffle=True,
)
checkpointing = pl.callbacks.ModelCheckpoint(
monitor="val/loss", save_top_k=1, save_last=True
)
callbacks = [checkpointing]
if with_wandb:
lr_logger = pl.callbacks.LearningRateMonitor(logging_interval="epoch")
callbacks.append(lr_logger)
logger = pl.loggers.WandbLogger(project="neural-waveshaping-synthesis")
logger.watch(model, log="parameters")
kwargs = trainer_kwargs()
trainer = pl.Trainer(
logger=logger if with_wandb else None,
callbacks=callbacks,
gpus=device,
resume_from_checkpoint=restore_checkpoint if restore_checkpoint != "" else None,
**kwargs
)
trainer.fit(model, data)
if __name__ == "__main__":
main() |