File size: 1,055 Bytes
14ae0ea
 
 
 
 
8949a8c
 
14ae0ea
 
 
 
 
 
8949a8c
 
14ae0ea
8949a8c
14ae0ea
 
 
 
 
 
 
 
 
 
8949a8c
 
 
 
14ae0ea
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from datasets import GuitarFXDataset
from models import DiffusionGenerationModel, OpenUnmixModel


SAMPLE_RATE = 22050
TRAIN_SPLIT = 0.8


def main():
    wandb_logger = WandbLogger(project="RemFX", save_dir="./")
    trainer = pl.Trainer(logger=wandb_logger, max_epochs=10)
    guitfx = GuitarFXDataset(
        root="/Users/matthewrice/Developer/remfx/data/egfx",
        sample_rate=SAMPLE_RATE,
        effect_type=["Phaser"],
    )
    train_size = int(TRAIN_SPLIT * len(guitfx))
    val_size = len(guitfx) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        guitfx, [train_size, val_size]
    )
    train = DataLoader(train_dataset, batch_size=2)
    val = DataLoader(val_dataset, batch_size=2)

    # model = DiffusionGenerationModel()
    model = OpenUnmixModel()

    trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val)


if __name__ == "__main__":
    main()