File size: 2,195 Bytes
607ecc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import click
import gin
import pytorch_lightning as pl

from neural_waveshaping_synthesis.data.general import GeneralDataModule
from neural_waveshaping_synthesis.data.urmp import URMPDataModule
from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping


@gin.configurable
def get_model(model, with_wandb):
    return model(log_audio=with_wandb)


@gin.configurable
def trainer_kwargs(**kwargs):
    return kwargs


@click.command()
@click.option("--gin-file", prompt="Gin config file")
@click.option("--dataset-path", prompt="Dataset root")
@click.option("--urmp", is_flag=True)
@click.option("--device", default="0")
@click.option("--instrument", default="vn")
@click.option("--load-data-to-memory", is_flag=True)
@click.option("--with-wandb", is_flag=True)
@click.option("--restore-checkpoint", default="")
def main(
    gin_file,
    dataset_path,
    urmp,
    device,
    instrument,
    load_data_to_memory,
    with_wandb,
    restore_checkpoint,
):
    gin.parse_config_file(gin_file)
    model = get_model(with_wandb=with_wandb)

    if urmp:
        data = URMPDataModule(
            dataset_path,
            instrument,
            load_to_memory=load_data_to_memory,
            num_workers=16,
            shuffle=True,
        )
    else:
        data = GeneralDataModule(
            dataset_path,
            load_to_memory=load_data_to_memory,
            num_workers=16,
            shuffle=True,
        )

    checkpointing = pl.callbacks.ModelCheckpoint(
        monitor="val/loss", save_top_k=1, save_last=True
    )
    callbacks = [checkpointing]
    if with_wandb:
        lr_logger = pl.callbacks.LearningRateMonitor(logging_interval="epoch")
        callbacks.append(lr_logger)
        logger = pl.loggers.WandbLogger(project="neural-waveshaping-synthesis")
        logger.watch(model, log="parameters")


    kwargs = trainer_kwargs()
    trainer = pl.Trainer(
        logger=logger if with_wandb else None,
        callbacks=callbacks,
        gpus=device,
        resume_from_checkpoint=restore_checkpoint if restore_checkpoint != "" else None,
        **kwargs
    )
    trainer.fit(model, data)


if __name__ == "__main__":
    main()