from lightning import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar from acfg.modelconfig import ModelConfig from ml.app.data import ImageDataModule from ml.app.lm import ClassificationModule from ml.app.models.classification import DiseaseClassificationModel ckpt_callback = ModelCheckpoint( filename="classification" + "_{epoch:02d}_{VA:.2f}", save_top_k=1, mode="min", monitor=ModelConfig.VAL_LOSS, ) tqdm_callback = TQDMProgressBar(refresh_rate=10) model = DiseaseClassificationModel(ModelConfig.PRETRAINED_MODEL_NAME) datamodule = ImageDataModule( train_path=ModelConfig.TRAIN_DATA_PATH, val_path=ModelConfig.VAL_DATA_PATH, test_path=ModelConfig.TEST_DATA_PATH, batch_size=ModelConfig.BATCH_SIZE, img_size=ModelConfig.IMG_SIZE, ) l_module = ClassificationModule( model=model, num_classes=ModelConfig.NUM_OUTPUT_CLASSES, ) seed_everything(42) trainer = Trainer( max_epochs=25, callbacks=[ckpt_callback, tqdm_callback], num_sanity_val_steps=2, ) if __name__ == "__main__": trainer.fit( model=l_module, datamodule=datamodule, )