from lightning import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar from acfg.modelconfig import ModelConfig from ml.app.anomaly import DiseaseOODModule from ml.app.data import ImageDataModule ckpt_callback = ModelCheckpoint( filename="ood" + "_{epoch:02d}_{VL:.2f}", save_top_k=1, mode="min", monitor=ModelConfig.VAL_LOSS, ) tqdm_callback = TQDMProgressBar(refresh_rate=10) datamodule = ImageDataModule( train_path=ModelConfig.TRAIN_DATA_PATH, val_path=ModelConfig.VAL_DATA_PATH, test_path=ModelConfig.TEST_DATA_PATH, batch_size=ModelConfig.BATCH_SIZE, img_size=ModelConfig.IMG_SIZE, ) l_module = DiseaseOODModule() seed_everything(42) trainer = Trainer( max_epochs=100, callbacks=[ckpt_callback, tqdm_callback], num_sanity_val_steps=2, ) if __name__ == "__main__": trainer.fit(model=l_module, datamodule=datamodule)