phase-hunter / phasehunter /training.py
crimeacs's picture
Fixed imports
2bbf18f
raw
history blame
3.11 kB
import torch
from data_preparation import augment, collation_fn, my_split_by_node
from model import Onset_picker, Updated_onset_picker
import webdataset as wds
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.strategies import DDPStrategy
from lightning import seed_everything
import lightning as pl
seed_everything(42, workers=False)
torch.set_float32_matmul_precision('medium')
batch_size = 256
num_workers = 16 #int(os.cpu_count())
n_iters_in_epoch = 5000
train_dataset = (
wds.WebDataset("data/sample/shard-00{0000..0001}.tar",
# splitter=my_split_by_worker,
nodesplitter=my_split_by_node)
.decode()
.map(augment)
.shuffle(5000)
.batched(batchsize=batch_size,
collation_fn=collation_fn,
partial=False
)
).with_epoch(n_iters_in_epoch//num_workers)
val_dataset = (
wds.WebDataset("data/sample/shard-00{0000..0000}.tar",
# splitter=my_split_by_worker,
nodesplitter=my_split_by_node)
.decode()
.map(augment)
.repeat()
.batched(batchsize=batch_size,
collation_fn=collation_fn,
partial=False
)
).with_epoch(100)
train_loader = wds.WebLoader(train_dataset,
num_workers=num_workers,
shuffle=False,
pin_memory=True,
batch_size=None)
val_loader = wds.WebLoader(val_dataset,
num_workers=0,
shuffle=False,
pin_memory=True,
batch_size=None)
# model
model = Onset_picker(picker=Updated_onset_picker(),
learning_rate=3e-4)
# model = torch.compile(model, mode="reduce-overhead")
logger = TensorBoardLogger("tensorboard_logdir", name="FAST")
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="Loss/val", filename="chkp-{epoch:02d}")
lr_callback = LearningRateMonitor(logging_interval='epoch')
# swa_callback = StochasticWeightAveraging(swa_lrs=0.05)
# # train model
trainer = pl.Trainer(
precision='16-mixed',
callbacks=[checkpoint_callback, lr_callback],
devices='auto',
accelerator='auto',
strategy=DDPStrategy(find_unused_parameters=False,
static_graph=True,
gradient_as_bucket_view=True),
benchmark=True,
gradient_clip_val=0.5,
# ckpt_path='path/to/saved/checkpoints/chkp.ckpt',
# fast_dev_run=True,
logger=logger,
log_every_n_steps=50,
enable_progress_bar=True,
max_epochs=300,
)
trainer.fit(model=model,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
)