bhimrazy commited on
Commit
2bb6467
1 Parent(s): 356b6f2

Adds callbacks fo rearly stopping and updates other params

Browse files
Files changed (1) hide show
  1. train.py +21 -9
train.py CHANGED
@@ -1,6 +1,10 @@
1
  import lightning as L
2
  import torch
3
- from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
 
 
 
 
4
  from lightning.pytorch.loggers import TensorBoardLogger
5
 
6
  from src.dataset import DRDataModule
@@ -13,36 +17,44 @@ torch.set_float32_matmul_precision("high")
13
 
14
 
15
  # Init DataModule
16
- dm = DRDataModule(batch_size=96, num_workers=8)
17
  dm.setup()
18
 
19
  # Init model from datamodule's attributes
20
  model = DRModel(
21
- num_classes=dm.num_classes, learning_rate=3e-5, class_weights=dm.class_weights
22
  )
23
 
24
  # Init logger
25
- logger = TensorBoardLogger("lightning_logs", name="dr_model")
26
-
27
  # Init callbacks
28
  checkpoint_callback = ModelCheckpoint(
29
  monitor="val_loss",
30
  mode="min",
31
- save_top_k=3,
32
- dirpath="checkpoints",
 
33
  )
34
 
35
  # Init LearningRateMonitor
36
  lr_monitor = LearningRateMonitor(logging_interval="step")
37
 
 
 
 
 
 
 
 
 
38
  # Init trainer
39
  trainer = L.Trainer(
40
  max_epochs=20,
41
  accelerator="auto",
42
  devices="auto",
43
  logger=logger,
44
- callbacks=[checkpoint_callback, lr_monitor],
45
- enable_checkpointing=True,
46
  )
47
 
48
  # Pass the datamodule as arg to trainer.fit to override model hooks :)
 
1
  import lightning as L
2
  import torch
3
+ from lightning.pytorch.callbacks import (
4
+ ModelCheckpoint,
5
+ LearningRateMonitor,
6
+ EarlyStopping,
7
+ )
8
  from lightning.pytorch.loggers import TensorBoardLogger
9
 
10
  from src.dataset import DRDataModule
 
17
 
18
 
19
  # Init DataModule
20
+ dm = DRDataModule(batch_size=128, num_workers=24)
21
  dm.setup()
22
 
23
  # Init model from datamodule's attributes
24
  model = DRModel(
25
+ num_classes=dm.num_classes, learning_rate=3e-4, class_weights=dm.class_weights
26
  )
27
 
28
  # Init logger
29
+ logger = TensorBoardLogger(save_dir="artifacts")
 
30
  # Init callbacks
31
  checkpoint_callback = ModelCheckpoint(
32
  monitor="val_loss",
33
  mode="min",
34
+ save_top_k=2,
35
+ dirpath="artifacts/checkpoints",
36
+ filename="{epoch}-{step}-{val_loss:.2f}-{val_acc:.2f}-{val_kappa:.2f}",
37
  )
38
 
39
  # Init LearningRateMonitor
40
  lr_monitor = LearningRateMonitor(logging_interval="step")
41
 
42
+ # early stopping
43
+ early_stopping = EarlyStopping(
44
+ monitor="val_loss",
45
+ patience=5,
46
+ verbose=True,
47
+ mode="min",
48
+ )
49
+
50
  # Init trainer
51
  trainer = L.Trainer(
52
  max_epochs=20,
53
  accelerator="auto",
54
  devices="auto",
55
  logger=logger,
56
+ callbacks=[checkpoint_callback, lr_monitor, early_stopping],
57
+ # check_val_every_n_epoch=4,
58
  )
59
 
60
  # Pass the datamodule as arg to trainer.fit to override model hooks :)