CDIApp / train_ood.py
sdutta28's picture
HF Changes
32cc554
raw
history blame
933 Bytes
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
from acfg.modelconfig import ModelConfig
from ml.app.anomaly import DiseaseOODModule
from ml.app.data import ImageDataModule
ckpt_callback = ModelCheckpoint(
filename="ood" + "_{epoch:02d}_{VL:.2f}",
save_top_k=1,
mode="min",
monitor=ModelConfig.VAL_LOSS,
)
tqdm_callback = TQDMProgressBar(refresh_rate=10)
datamodule = ImageDataModule(
train_path=ModelConfig.TRAIN_DATA_PATH,
val_path=ModelConfig.VAL_DATA_PATH,
test_path=ModelConfig.TEST_DATA_PATH,
batch_size=ModelConfig.BATCH_SIZE,
img_size=ModelConfig.IMG_SIZE,
)
l_module = DiseaseOODModule()
seed_everything(42)
trainer = Trainer(
max_epochs=100,
callbacks=[ckpt_callback, tqdm_callback],
num_sanity_val_steps=2,
)
if __name__ == "__main__":
trainer.fit(model=l_module, datamodule=datamodule)