File size: 1,173 Bytes
32cc554 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
from acfg.modelconfig import ModelConfig
from ml.app.data import ImageDataModule
from ml.app.lm import ClassificationModule
from ml.app.models.classification import DiseaseClassificationModel
ckpt_callback = ModelCheckpoint(
filename="classification" + "_{epoch:02d}_{VA:.2f}",
save_top_k=1,
mode="min",
monitor=ModelConfig.VAL_LOSS,
)
tqdm_callback = TQDMProgressBar(refresh_rate=10)
model = DiseaseClassificationModel(ModelConfig.PRETRAINED_MODEL_NAME)
datamodule = ImageDataModule(
train_path=ModelConfig.TRAIN_DATA_PATH,
val_path=ModelConfig.VAL_DATA_PATH,
test_path=ModelConfig.TEST_DATA_PATH,
batch_size=ModelConfig.BATCH_SIZE,
img_size=ModelConfig.IMG_SIZE,
)
l_module = ClassificationModule(
model=model,
num_classes=ModelConfig.NUM_OUTPUT_CLASSES,
)
seed_everything(42)
trainer = Trainer(
max_epochs=25,
callbacks=[ckpt_callback, tqdm_callback],
num_sanity_val_steps=2,
)
if __name__ == "__main__":
trainer.fit(
model=l_module,
datamodule=datamodule,
)
|