bhimrazy commited on
Commit
9f31860
·
1 Parent(s): 23fa981

Add training script with data module setup, model initialization, logger, callbacks, and trainer configuration

Browse files
Files changed (2) hide show
  1. src/__init__.py +0 -0
  2. train.py +46 -0
src/__init__.py ADDED
File without changes
train.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ import torch
3
+ from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
4
+ from lightning.pytorch.loggers import TensorBoardLogger
5
+
6
+ from src.dataset import DRDataModule
7
+ from src.model import DRModel
8
+
9
+ # seed everything for reproducibility
10
+ SEED = 42
11
+ L.seed_everything(SEED, workers=True)
12
+ torch.set_float32_matmul_precision("high")
13
+
14
+
15
+ # Init DataModule
16
+ dm = DRDataModule(batch_size=128, num_workers=8)
17
+ dm.setup()
18
+
19
+ # Init model from datamodule's attributes
20
+ model = DRModel(
21
+ num_classes=dm.num_classes, learning_rate=3e-4, class_weights=dm.class_weights
22
+ )
23
+
24
+ # Init logger
25
+ logger = TensorBoardLogger("lightning_logs", name="dr_model")
26
+
27
+ # Init callbacks
28
+ checkpoint_callback = ModelCheckpoint(
29
+ monitor="val_loss",
30
+ mode="min",
31
+ save_top_k=3,
32
+ dirpath="checkpoints",
33
+ )
34
+
35
+ # Init trainer
36
+ trainer = L.Trainer(
37
+ max_epochs=20,
38
+ accelerator="auto",
39
+ devices="auto",
40
+ logger=logger,
41
+ callbacks=[checkpoint_callback],
42
+ enable_checkpointing=True
43
+ )
44
+
45
+ # Pass the datamodule as arg to trainer.fit to override model hooks :)
46
+ trainer.fit(model, dm)