import torch from data_preparation import augment, collation_fn, my_split_by_node from model import Onset_picker, Updated_onset_picker import webdataset as wds from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from lightning.pytorch.loggers.tensorboard import TensorBoardLogger from lightning.pytorch.strategies import DDPStrategy from lightning import seed_everything import lightning as pl seed_everything(42, workers=False) torch.set_float32_matmul_precision('medium') batch_size = 256 num_workers = 16 #int(os.cpu_count()) n_iters_in_epoch = 5000 train_dataset = ( wds.WebDataset("data/sample/shard-00{0000..0001}.tar", # splitter=my_split_by_worker, nodesplitter=my_split_by_node) .decode() .map(augment) .shuffle(5000) .batched(batchsize=batch_size, collation_fn=collation_fn, partial=False ) ).with_epoch(n_iters_in_epoch//num_workers) val_dataset = ( wds.WebDataset("data/sample/shard-00{0000..0000}.tar", # splitter=my_split_by_worker, nodesplitter=my_split_by_node) .decode() .map(augment) .repeat() .batched(batchsize=batch_size, collation_fn=collation_fn, partial=False ) ).with_epoch(100) train_loader = wds.WebLoader(train_dataset, num_workers=num_workers, shuffle=False, pin_memory=True, batch_size=None) val_loader = wds.WebLoader(val_dataset, num_workers=0, shuffle=False, pin_memory=True, batch_size=None) # model model = Onset_picker(picker=Updated_onset_picker(), learning_rate=3e-4) # model = torch.compile(model, mode="reduce-overhead") logger = TensorBoardLogger("tensorboard_logdir", name="FAST") checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="Loss/val", filename="chkp-{epoch:02d}") lr_callback = LearningRateMonitor(logging_interval='epoch') # swa_callback = StochasticWeightAveraging(swa_lrs=0.05) # # train model trainer = pl.Trainer( precision='16-mixed', callbacks=[checkpoint_callback, lr_callback], devices='auto', accelerator='auto', strategy=DDPStrategy(find_unused_parameters=False, static_graph=True, gradient_as_bucket_view=True), benchmark=True, gradient_clip_val=0.5, # ckpt_path='path/to/saved/checkpoints/chkp.ckpt', # fast_dev_run=True, logger=logger, log_every_n_steps=50, enable_progress_bar=True, max_epochs=300, ) trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, )