File size: 2,302 Bytes
e5d6e03 9f31860 2c1d9fe 2bb6467 e5d6e03 2bb6467 9f31860 e5d6e03 9f31860 e5d6e03 9f31860 e5d6e03 9f31860 e5d6e03 2c1d9fe 9f31860 e5d6e03 9f31860 e5d6e03 9f31860 e5d6e03 30df46a e5d6e03 2bb6467 9f31860 e5d6e03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from os.path import join
import hydra
import lightning as L
import torch
from lightning.pytorch.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from lightning.pytorch.loggers import TensorBoardLogger
from omegaconf import DictConfig
from src.data_module import DRDataModule
from src.model import DRModel
from src.utils import generate_run_id
@hydra.main(version_base=None, config_path="conf", config_name="config")
def train(cfg: DictConfig) -> None:
# generate unique run id based on current date & time
run_id = generate_run_id()
# Seed everything for reproducibility
L.seed_everything(cfg.seed, workers=True)
torch.set_float32_matmul_precision("high")
# Initialize DataModule
dm = DRDataModule(
train_csv_path=cfg.train_csv_path,
val_csv_path=cfg.val_csv_path,
image_size=cfg.image_size,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
use_class_weighting=cfg.use_class_weighting,
use_weighted_sampler=cfg.use_weighted_sampler,
)
dm.setup()
# Init model from datamodule's attributes
model = DRModel(
num_classes=dm.num_classes,
model_name=cfg.model_name,
learning_rate=cfg.learning_rate,
class_weights=dm.class_weights,
use_scheduler=cfg.use_scheduler,
)
# Init logger
logger = TensorBoardLogger(save_dir=cfg.logs_dir, name="", version=run_id)
# Init callbacks
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
mode="min",
save_top_k=2,
dirpath=join(cfg.checkpoint_dirpath, run_id),
filename="{epoch}-{step}-{val_loss:.2f}-{val_acc:.2f}-{val_kappa:.2f}",
)
# Init LearningRateMonitor
lr_monitor = LearningRateMonitor(logging_interval="step")
# early stopping
early_stopping = EarlyStopping(
monitor="val_loss",
patience=10,
verbose=True,
mode="min",
)
# Initialize Trainer
trainer = L.Trainer(
max_epochs=cfg.max_epochs,
accelerator="auto",
devices="auto",
logger=logger,
callbacks=[checkpoint_callback, lr_monitor, early_stopping],
)
# Train the model
trainer.fit(model, dm)
if __name__ == "__main__":
train()
|