File size: 1,173 Bytes
32cc554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar

from acfg.modelconfig import ModelConfig
from ml.app.data import ImageDataModule
from ml.app.lm import ClassificationModule
from ml.app.models.classification import DiseaseClassificationModel


ckpt_callback = ModelCheckpoint(
    filename="classification" + "_{epoch:02d}_{VA:.2f}",
    save_top_k=1,
    mode="min",
    monitor=ModelConfig.VAL_LOSS,
)

tqdm_callback = TQDMProgressBar(refresh_rate=10)


model = DiseaseClassificationModel(ModelConfig.PRETRAINED_MODEL_NAME)

datamodule = ImageDataModule(
    train_path=ModelConfig.TRAIN_DATA_PATH,
    val_path=ModelConfig.VAL_DATA_PATH,
    test_path=ModelConfig.TEST_DATA_PATH,
    batch_size=ModelConfig.BATCH_SIZE,
    img_size=ModelConfig.IMG_SIZE,
)

l_module = ClassificationModule(
    model=model,
    num_classes=ModelConfig.NUM_OUTPUT_CLASSES,
)

seed_everything(42)
trainer = Trainer(
    max_epochs=25,
    callbacks=[ckpt_callback, tqdm_callback],
    num_sanity_val_steps=2,
)


if __name__ == "__main__":
    trainer.fit(
        model=l_module,
        datamodule=datamodule,
    )