Spaces:
Runtime error
Runtime error
import torch | |
from data_preparation import augment, collation_fn, my_split_by_node | |
from model import Onset_picker, Updated_onset_picker | |
import webdataset as wds | |
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint | |
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger | |
from lightning.pytorch.strategies import DDPStrategy | |
from lightning import seed_everything | |
import lightning as pl | |
seed_everything(42, workers=False) | |
torch.set_float32_matmul_precision('medium') | |
batch_size = 256 | |
num_workers = 16 #int(os.cpu_count()) | |
n_iters_in_epoch = 5000 | |
train_dataset = ( | |
wds.WebDataset("data/sample/shard-00{0000..0001}.tar", | |
# splitter=my_split_by_worker, | |
nodesplitter=my_split_by_node) | |
.decode() | |
.map(augment) | |
.shuffle(5000) | |
.batched(batchsize=batch_size, | |
collation_fn=collation_fn, | |
partial=False | |
) | |
).with_epoch(n_iters_in_epoch//num_workers) | |
val_dataset = ( | |
wds.WebDataset("data/sample/shard-00{0000..0000}.tar", | |
# splitter=my_split_by_worker, | |
nodesplitter=my_split_by_node) | |
.decode() | |
.map(augment) | |
.repeat() | |
.batched(batchsize=batch_size, | |
collation_fn=collation_fn, | |
partial=False | |
) | |
).with_epoch(100) | |
train_loader = wds.WebLoader(train_dataset, | |
num_workers=num_workers, | |
shuffle=False, | |
pin_memory=True, | |
batch_size=None) | |
val_loader = wds.WebLoader(val_dataset, | |
num_workers=0, | |
shuffle=False, | |
pin_memory=True, | |
batch_size=None) | |
# model | |
model = Onset_picker(picker=Updated_onset_picker(), | |
learning_rate=3e-4) | |
# model = torch.compile(model, mode="reduce-overhead") | |
logger = TensorBoardLogger("tensorboard_logdir", name="FAST") | |
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="Loss/val", filename="chkp-{epoch:02d}") | |
lr_callback = LearningRateMonitor(logging_interval='epoch') | |
# swa_callback = StochasticWeightAveraging(swa_lrs=0.05) | |
# # train model | |
trainer = pl.Trainer( | |
precision='16-mixed', | |
callbacks=[checkpoint_callback, lr_callback], | |
devices='auto', | |
accelerator='auto', | |
strategy=DDPStrategy(find_unused_parameters=False, | |
static_graph=True, | |
gradient_as_bucket_view=True), | |
benchmark=True, | |
gradient_clip_val=0.5, | |
# ckpt_path='path/to/saved/checkpoints/chkp.ckpt', | |
# fast_dev_run=True, | |
logger=logger, | |
log_every_n_steps=50, | |
enable_progress_bar=True, | |
max_epochs=300, | |
) | |
trainer.fit(model=model, | |
train_dataloaders=train_loader, | |
val_dataloaders=val_loader, | |
) |