RemFx / train.py
mattricesound's picture
Initial ptl model and training script for umx
8949a8c
raw
history blame
1.06 kB
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from datasets import GuitarFXDataset
from models import DiffusionGenerationModel, OpenUnmixModel
SAMPLE_RATE = 22050
TRAIN_SPLIT = 0.8
def main():
wandb_logger = WandbLogger(project="RemFX", save_dir="./")
trainer = pl.Trainer(logger=wandb_logger, max_epochs=10)
guitfx = GuitarFXDataset(
root="/Users/matthewrice/Developer/remfx/data/egfx",
sample_rate=SAMPLE_RATE,
effect_type=["Phaser"],
)
train_size = int(TRAIN_SPLIT * len(guitfx))
val_size = len(guitfx) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
guitfx, [train_size, val_size]
)
train = DataLoader(train_dataset, batch_size=2)
val = DataLoader(val_dataset, batch_size=2)
# model = DiffusionGenerationModel()
model = OpenUnmixModel()
trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val)
if __name__ == "__main__":
main()