CDIApp / train_classifier.py
sdutta28's picture
HF Changes
32cc554
raw
history blame
1.17 kB
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
from acfg.modelconfig import ModelConfig
from ml.app.data import ImageDataModule
from ml.app.lm import ClassificationModule
from ml.app.models.classification import DiseaseClassificationModel
ckpt_callback = ModelCheckpoint(
filename="classification" + "_{epoch:02d}_{VA:.2f}",
save_top_k=1,
mode="min",
monitor=ModelConfig.VAL_LOSS,
)
tqdm_callback = TQDMProgressBar(refresh_rate=10)
model = DiseaseClassificationModel(ModelConfig.PRETRAINED_MODEL_NAME)
datamodule = ImageDataModule(
train_path=ModelConfig.TRAIN_DATA_PATH,
val_path=ModelConfig.VAL_DATA_PATH,
test_path=ModelConfig.TEST_DATA_PATH,
batch_size=ModelConfig.BATCH_SIZE,
img_size=ModelConfig.IMG_SIZE,
)
l_module = ClassificationModule(
model=model,
num_classes=ModelConfig.NUM_OUTPUT_CLASSES,
)
seed_everything(42)
trainer = Trainer(
max_epochs=25,
callbacks=[ckpt_callback, tqdm_callback],
num_sanity_val_steps=2,
)
if __name__ == "__main__":
trainer.fit(
model=l_module,
datamodule=datamodule,
)