Spaces:
Runtime error
Runtime error
Soutrik
commited on
Commit
·
8d4131e
1
Parent(s):
de7d21e
optuna added as base
Browse files- configs/experiment/catdog_experiment.yaml +1 -1
- configs/trainer/default.yaml +2 -3
- src/train_new.py +84 -99
- src/train_old.py +260 -0
- src/{train.py → train_optuna_callbacks.py} +120 -98
configs/experiment/catdog_experiment.yaml
CHANGED
@@ -39,7 +39,7 @@ model:
|
|
39 |
|
40 |
trainer:
|
41 |
min_epochs: 1
|
42 |
-
max_epochs:
|
43 |
|
44 |
callbacks:
|
45 |
model_checkpoint:
|
|
|
39 |
|
40 |
trainer:
|
41 |
min_epochs: 1
|
42 |
+
max_epochs: 5
|
43 |
|
44 |
callbacks:
|
45 |
model_checkpoint:
|
configs/trainer/default.yaml
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
_target_: lightning.Trainer
|
2 |
|
3 |
default_root_dir: ${paths.output_dir}
|
4 |
min_epochs: 1
|
@@ -10,8 +9,7 @@ devices: auto
|
|
10 |
# mixed precision for extra speed-up
|
11 |
# precision: 16
|
12 |
|
13 |
-
# set True to to ensure deterministic results
|
14 |
-
# makes training slower but gives more reproducibility than just setting seeds
|
15 |
deterministic: True
|
16 |
|
17 |
# Log every N steps in training and validation
|
@@ -19,3 +17,4 @@ log_every_n_steps: 10
|
|
19 |
fast_dev_run: False
|
20 |
|
21 |
gradient_clip_val: 1.0
|
|
|
|
|
|
1 |
|
2 |
default_root_dir: ${paths.output_dir}
|
3 |
min_epochs: 1
|
|
|
9 |
# mixed precision for extra speed-up
|
10 |
# precision: 16
|
11 |
|
12 |
+
# set True to to ensure deterministic results makes training slower but gives more reproducibility than just setting seeds
|
|
|
13 |
deterministic: True
|
14 |
|
15 |
# Log every N steps in training and validation
|
|
|
17 |
fast_dev_run: False
|
18 |
|
19 |
gradient_clip_val: 1.0
|
20 |
+
gradient_clip_algorithm: 'norm'
|
src/train_new.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
"""
|
2 |
-
Train and evaluate a model using PyTorch Lightning.
|
3 |
-
Initializes the DataModule, Model, Trainer, and runs training and testing.
|
4 |
-
Initializes loggers and callbacks from the configuration using Hydra and target paths from the configuration.
|
5 |
"""
|
6 |
|
7 |
import os
|
@@ -17,51 +15,34 @@ from src.utils.logging_utils import setup_logger, task_wrapper
|
|
17 |
from loguru import logger
|
18 |
import rootutils
|
19 |
from lightning.pytorch.loggers import Logger
|
20 |
-
|
|
|
21 |
|
22 |
# Load environment variables
|
23 |
load_dotenv(find_dotenv(".env"))
|
24 |
|
25 |
# Setup root directory
|
26 |
-
|
27 |
root = rootutils.setup_root(__file__, indicator=".project-root")
|
28 |
|
29 |
|
30 |
-
def instantiate_callbacks(callback_cfg: DictConfig) -> List[Callback]:
|
31 |
-
"""Instantiate and return a list of callbacks from the configuration."""
|
32 |
-
callbacks_ls: List[L.Callback] = []
|
33 |
-
|
34 |
-
if not callback_cfg:
|
35 |
-
logger.warning("No callback configs found! Skipping..")
|
36 |
-
return None
|
37 |
-
|
38 |
-
if not isinstance(callback_cfg, DictConfig):
|
39 |
-
raise TypeError("Callbacks config must be a DictConfig!")
|
40 |
-
|
41 |
-
for _, cb_conf in callback_cfg.items():
|
42 |
-
if "_target_" in cb_conf:
|
43 |
-
logger.info(f"Instantiating callback <{cb_conf._target_}>")
|
44 |
-
callbacks_ls.append(hydra.utils.instantiate(cb_conf))
|
45 |
-
|
46 |
-
return callbacks_ls
|
47 |
-
|
48 |
-
|
49 |
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
50 |
"""Instantiate and return a list of loggers from the configuration."""
|
51 |
loggers_ls: List[Logger] = []
|
52 |
|
53 |
-
if not logger_cfg:
|
54 |
-
logger.warning("No logger configs found! Skipping..")
|
55 |
return loggers_ls
|
56 |
|
57 |
if not isinstance(logger_cfg, DictConfig):
|
58 |
raise TypeError("Logger config must be a DictConfig!")
|
59 |
|
60 |
for _, lg_conf in logger_cfg.items():
|
61 |
-
if "_target_" in lg_conf:
|
62 |
logger.info(f"Instantiating logger <{lg_conf._target_}>")
|
63 |
-
|
64 |
-
|
|
|
|
|
65 |
return loggers_ls
|
66 |
|
67 |
|
@@ -93,16 +74,19 @@ def clear_checkpoint_directory(ckpt_dir: str):
|
|
93 |
def train_module(
|
94 |
data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
|
95 |
):
|
96 |
-
"""Train the model
|
97 |
-
logger.info("Starting training")
|
|
|
98 |
trainer.fit(model, data_module)
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
106 |
|
107 |
|
108 |
@task_wrapper
|
@@ -122,77 +106,78 @@ def run_test_module(
|
|
122 |
return test_metrics[0] if test_metrics else {}
|
123 |
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
"""Set up and run the Trainer for training and testing."""
|
128 |
-
# Display configuration
|
129 |
-
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
|
130 |
|
131 |
-
#
|
132 |
-
|
133 |
-
|
134 |
-
)
|
135 |
-
|
136 |
-
|
137 |
-
# Display key paths
|
138 |
-
for path_name in [
|
139 |
-
"root_dir",
|
140 |
-
"data_dir",
|
141 |
-
"log_dir",
|
142 |
-
"ckpt_dir",
|
143 |
-
"artifact_dir",
|
144 |
-
"output_dir",
|
145 |
-
]:
|
146 |
-
logger.info(
|
147 |
-
f"{path_name.replace('_', ' ').capitalize()}: {cfg.paths[path_name]}"
|
148 |
-
)
|
149 |
|
150 |
-
# Initialize
|
151 |
-
|
152 |
-
datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
153 |
-
logger.info(f"Instantiating model <{cfg.model._target_}>")
|
154 |
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
|
155 |
|
156 |
-
#
|
157 |
-
logger.info("GPU available" if torch.cuda.is_available() else "No GPU available")
|
158 |
-
L.seed_everything(cfg.seed, workers=True)
|
159 |
-
|
160 |
-
# Set up callbacks, loggers, and Trainer
|
161 |
-
callbacks = instantiate_callbacks(cfg.callbacks)
|
162 |
-
logger.info(f"Callbacks: {callbacks}")
|
163 |
loggers = instantiate_loggers(cfg.logger)
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
)
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
train_metrics = train_module(datamodule, model, trainer)
|
174 |
-
(Path(cfg.paths.ckpt_dir) / "train_done.flag").write_text(
|
175 |
-
"Training completed.\n"
|
176 |
)
|
177 |
-
|
178 |
-
|
179 |
-
test_metrics = {}
|
180 |
-
if cfg.get("test"):
|
181 |
-
test_metrics = run_test_module(cfg, datamodule, model, trainer)
|
182 |
-
|
183 |
-
# Combine metrics and extract optimization metric
|
184 |
-
all_metrics = {**train_metrics, **test_metrics}
|
185 |
-
optimization_metric = all_metrics.get(cfg.get("optimization_metric"), 0.0)
|
186 |
-
(
|
187 |
-
logger.warning(
|
188 |
-
f"Optimization metric '{cfg.get('optimization_metric')}' not found. Defaulting to 0."
|
189 |
)
|
190 |
-
if optimization_metric == 0.0
|
191 |
-
else logger.info(f"Optimization metric: {optimization_metric}")
|
192 |
-
)
|
193 |
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
|
197 |
if __name__ == "__main__":
|
198 |
-
|
|
|
1 |
"""
|
2 |
+
Train and evaluate a model using PyTorch Lightning with Optuna for hyperparameter optimization.
|
|
|
|
|
3 |
"""
|
4 |
|
5 |
import os
|
|
|
15 |
from loguru import logger
|
16 |
import rootutils
|
17 |
from lightning.pytorch.loggers import Logger
|
18 |
+
import optuna
|
19 |
+
from lightning.pytorch import Trainer
|
20 |
|
21 |
# Load environment variables
|
22 |
load_dotenv(find_dotenv(".env"))
|
23 |
|
24 |
# Setup root directory
|
|
|
25 |
root = rootutils.setup_root(__file__, indicator=".project-root")
|
26 |
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
29 |
"""Instantiate and return a list of loggers from the configuration."""
|
30 |
loggers_ls: List[Logger] = []
|
31 |
|
32 |
+
if not logger_cfg or isinstance(logger_cfg, bool):
|
33 |
+
logger.warning("No valid logger configs found! Skipping..")
|
34 |
return loggers_ls
|
35 |
|
36 |
if not isinstance(logger_cfg, DictConfig):
|
37 |
raise TypeError("Logger config must be a DictConfig!")
|
38 |
|
39 |
for _, lg_conf in logger_cfg.items():
|
40 |
+
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
41 |
logger.info(f"Instantiating logger <{lg_conf._target_}>")
|
42 |
+
try:
|
43 |
+
loggers_ls.append(hydra.utils.instantiate(lg_conf))
|
44 |
+
except Exception as e:
|
45 |
+
logger.error(f"Failed to instantiate logger {lg_conf}: {e}")
|
46 |
return loggers_ls
|
47 |
|
48 |
|
|
|
74 |
def train_module(
|
75 |
data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
|
76 |
):
|
77 |
+
"""Train the model, return validation accuracy for each epoch."""
|
78 |
+
logger.info("Starting training with custom pruning")
|
79 |
+
|
80 |
trainer.fit(model, data_module)
|
81 |
+
val_accuracies = []
|
82 |
+
|
83 |
+
for epoch in range(trainer.current_epoch):
|
84 |
+
val_acc = trainer.callback_metrics.get("val_acc")
|
85 |
+
if val_acc is not None:
|
86 |
+
val_accuracies.append(val_acc.item())
|
87 |
+
logger.info(f"Epoch {epoch}: val_acc={val_acc}")
|
88 |
+
|
89 |
+
return val_accuracies
|
90 |
|
91 |
|
92 |
@task_wrapper
|
|
|
106 |
return test_metrics[0] if test_metrics else {}
|
107 |
|
108 |
|
109 |
+
def objective(trial: optuna.trial.Trial, cfg: DictConfig):
|
110 |
+
"""Objective function for Optuna hyperparameter tuning."""
|
|
|
|
|
|
|
111 |
|
112 |
+
# Sample hyperparameters for the model
|
113 |
+
cfg.model.embed_dim = trial.suggest_categorical("embed_dim", [64, 128, 256])
|
114 |
+
cfg.model.depth = trial.suggest_int("depth", 2, 6)
|
115 |
+
cfg.model.lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
|
116 |
+
cfg.model.mlp_ratio = trial.suggest_float("mlp_ratio", 1.0, 4.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
+
# Initialize data module and model
|
119 |
+
data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
|
|
|
|
120 |
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
|
121 |
|
122 |
+
# Set up logger
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
loggers = instantiate_loggers(cfg.logger)
|
124 |
+
|
125 |
+
# Trainer configuration without pruning callback
|
126 |
+
trainer = Trainer(**cfg.trainer, logger=loggers)
|
127 |
+
|
128 |
+
# Clear checkpoint directory
|
129 |
+
clear_checkpoint_directory(cfg.paths.ckpt_dir)
|
130 |
+
|
131 |
+
# Train and get val_acc for each epoch
|
132 |
+
val_accuracies = train_module(data_module, model, trainer)
|
133 |
+
|
134 |
+
# Report validation accuracy and prune if necessary
|
135 |
+
for epoch, val_acc in enumerate(val_accuracies):
|
136 |
+
trial.report(val_acc, step=epoch)
|
137 |
+
|
138 |
+
# Check if the trial should be pruned at this epoch
|
139 |
+
if trial.should_prune():
|
140 |
+
logger.info(f"Pruning trial at epoch {epoch}")
|
141 |
+
raise optuna.TrialPruned()
|
142 |
+
|
143 |
+
# Return the final validation accuracy as the objective metric
|
144 |
+
return val_accuracies[-1] if val_accuracies else 0.0
|
145 |
+
|
146 |
+
|
147 |
+
@hydra.main(config_path="../configs", config_name="train", version_base="1.3")
|
148 |
+
def setup_trainer(cfg: DictConfig):
|
149 |
+
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
|
150 |
+
|
151 |
+
setup_logger(
|
152 |
+
Path(cfg.paths.log_dir)
|
153 |
+
/ ("train.log" if cfg.task_name == "train" else "eval.log")
|
154 |
)
|
155 |
|
156 |
+
if cfg.get("train", False):
|
157 |
+
pruner = optuna.pruners.MedianPruner()
|
158 |
+
study = optuna.create_study(
|
159 |
+
direction="maximize", pruner=pruner, study_name="pytorch_lightning_optuna"
|
|
|
|
|
|
|
160 |
)
|
161 |
+
study.optimize(
|
162 |
+
lambda trial: objective(trial, cfg), n_trials=3, show_progress_bar=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
)
|
|
|
|
|
|
|
164 |
|
165 |
+
# Log best trial results
|
166 |
+
best_trial = study.best_trial
|
167 |
+
logger.info(f"Best trial number: {best_trial.number}")
|
168 |
+
logger.info(f"Best trial value (val_acc): {best_trial.value}")
|
169 |
+
for key, value in best_trial.params.items():
|
170 |
+
logger.info(f" {key}: {value}")
|
171 |
+
|
172 |
+
if cfg.get("test", False):
|
173 |
+
data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
174 |
+
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
|
175 |
+
trainer = Trainer(**cfg.trainer, logger=instantiate_loggers(cfg.logger))
|
176 |
+
test_metrics = run_test_module(cfg, data_module, model, trainer)
|
177 |
+
logger.info(f"Test metrics: {test_metrics}")
|
178 |
+
|
179 |
+
return cfg.model if not cfg.get("test", False) else test_metrics
|
180 |
|
181 |
|
182 |
if __name__ == "__main__":
|
183 |
+
setup_trainer()
|
src/train_old.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
from pathlib import Path
|
4 |
+
import torch
|
5 |
+
import lightning as L
|
6 |
+
from lightning.pytorch.loggers import Logger
|
7 |
+
from typing import List
|
8 |
+
from src.datamodules.dogbreed_datamodule import main_dataloader
|
9 |
+
from src.utils.logging_utils import setup_logger, task_wrapper
|
10 |
+
from loguru import logger
|
11 |
+
from dotenv import load_dotenv, find_dotenv
|
12 |
+
import rootutils
|
13 |
+
import hydra
|
14 |
+
from omegaconf import DictConfig, OmegaConf
|
15 |
+
|
16 |
+
# Load environment variables
|
17 |
+
load_dotenv(find_dotenv(".env"))
|
18 |
+
|
19 |
+
# Setup root directory
|
20 |
+
root = rootutils.setup_root(__file__, indicator=".project-root")
|
21 |
+
|
22 |
+
|
23 |
+
def instantiate_callbacks(callback_cfg: DictConfig) -> List[L.Callback]:
|
24 |
+
"""Instantiate and return a list of callbacks from the configuration."""
|
25 |
+
callbacks: List[L.Callback] = []
|
26 |
+
|
27 |
+
if not callback_cfg:
|
28 |
+
logger.warning("No callback configs found! Skipping..")
|
29 |
+
return callbacks
|
30 |
+
|
31 |
+
if not isinstance(callback_cfg, DictConfig):
|
32 |
+
raise TypeError("Callbacks config must be a DictConfig!")
|
33 |
+
|
34 |
+
for _, cb_conf in callback_cfg.items():
|
35 |
+
if "_target_" in cb_conf:
|
36 |
+
logger.info(f"Instantiating callback <{cb_conf._target_}>")
|
37 |
+
callbacks.append(hydra.utils.instantiate(cb_conf))
|
38 |
+
|
39 |
+
return callbacks
|
40 |
+
|
41 |
+
|
42 |
+
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
43 |
+
"""Instantiate and return a list of loggers from the configuration."""
|
44 |
+
loggers_ls: List[Logger] = []
|
45 |
+
|
46 |
+
if not logger_cfg:
|
47 |
+
logger.warning("No logger configs found! Skipping..")
|
48 |
+
return loggers_ls
|
49 |
+
|
50 |
+
if not isinstance(logger_cfg, DictConfig):
|
51 |
+
raise TypeError("Logger config must be a DictConfig!")
|
52 |
+
|
53 |
+
for _, lg_conf in logger_cfg.items():
|
54 |
+
if "_target_" in lg_conf:
|
55 |
+
logger.info(f"Instantiating logger <{lg_conf._target_}>")
|
56 |
+
loggers_ls.append(hydra.utils.instantiate(lg_conf))
|
57 |
+
|
58 |
+
return loggers_ls
|
59 |
+
|
60 |
+
|
61 |
+
def load_checkpoint_if_available(ckpt_path: str) -> str:
|
62 |
+
"""Check if the specified checkpoint exists and return the valid checkpoint path."""
|
63 |
+
if ckpt_path and Path(ckpt_path).exists():
|
64 |
+
logger.info(f"Checkpoint found: {ckpt_path}")
|
65 |
+
return ckpt_path
|
66 |
+
else:
|
67 |
+
logger.warning(
|
68 |
+
f"No checkpoint found at {ckpt_path}. Using current model weights."
|
69 |
+
)
|
70 |
+
return None
|
71 |
+
|
72 |
+
|
73 |
+
def clear_checkpoint_directory(ckpt_dir: str):
|
74 |
+
"""Clear all contents of the checkpoint directory without deleting the directory itself."""
|
75 |
+
ckpt_dir_path = Path(ckpt_dir)
|
76 |
+
if ckpt_dir_path.exists() and ckpt_dir_path.is_dir():
|
77 |
+
logger.info(f"Clearing checkpoint directory: {ckpt_dir}")
|
78 |
+
# Iterate over all files and directories in the checkpoint directory and remove them
|
79 |
+
for item in ckpt_dir_path.iterdir():
|
80 |
+
try:
|
81 |
+
if item.is_file() or item.is_symlink():
|
82 |
+
item.unlink() # Remove file or symlink
|
83 |
+
elif item.is_dir():
|
84 |
+
shutil.rmtree(item) # Remove directory
|
85 |
+
except Exception as e:
|
86 |
+
logger.error(f"Failed to delete {item}: {e}")
|
87 |
+
logger.info(f"Checkpoint directory cleared: {ckpt_dir}")
|
88 |
+
else:
|
89 |
+
logger.info(
|
90 |
+
f"Checkpoint directory does not exist. Creating directory: {ckpt_dir}"
|
91 |
+
)
|
92 |
+
os.makedirs(ckpt_dir_path, exist_ok=True)
|
93 |
+
|
94 |
+
|
95 |
+
@task_wrapper
|
96 |
+
def train_module(
|
97 |
+
cfg: DictConfig,
|
98 |
+
data_module: L.LightningDataModule,
|
99 |
+
model: L.LightningModule,
|
100 |
+
trainer: L.Trainer,
|
101 |
+
):
|
102 |
+
"""Train the model using the provided Trainer and DataModule."""
|
103 |
+
logger.info("Training the model")
|
104 |
+
trainer.fit(model, data_module)
|
105 |
+
train_metrics = trainer.callback_metrics
|
106 |
+
try:
|
107 |
+
logger.info(
|
108 |
+
f"Training completed with the following metrics- train_acc: {train_metrics['train_acc'].item()} and val_acc: {train_metrics['val_acc'].item()}"
|
109 |
+
)
|
110 |
+
except KeyError:
|
111 |
+
logger.info(f"Training completed with the following metrics:{train_metrics}")
|
112 |
+
|
113 |
+
return train_metrics
|
114 |
+
|
115 |
+
|
116 |
+
@task_wrapper
|
117 |
+
def run_test_module(
|
118 |
+
cfg: DictConfig,
|
119 |
+
datamodule: L.LightningDataModule,
|
120 |
+
model: L.LightningModule,
|
121 |
+
trainer: L.Trainer,
|
122 |
+
):
|
123 |
+
"""Test the model using the best checkpoint or the current model weights."""
|
124 |
+
logger.info("Testing the model")
|
125 |
+
datamodule.setup(stage="test")
|
126 |
+
|
127 |
+
ckpt_path = load_checkpoint_if_available(cfg.ckpt_path)
|
128 |
+
|
129 |
+
# If no checkpoint is available, Lightning will use current model weights
|
130 |
+
test_metrics = trainer.test(model, datamodule, ckpt_path=ckpt_path)
|
131 |
+
logger.info(f"Test metrics:\n{test_metrics}")
|
132 |
+
|
133 |
+
return test_metrics[0] if test_metrics else {}
|
134 |
+
|
135 |
+
|
136 |
+
@hydra.main(config_path="../configs", config_name="train", version_base="1.1")
|
137 |
+
def setup_run_trainer(cfg: DictConfig):
|
138 |
+
"""Set up and run the Trainer for training and testing the model."""
|
139 |
+
# show me the entire config
|
140 |
+
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
|
141 |
+
# Initialize logger
|
142 |
+
if cfg.task_name == "train":
|
143 |
+
log_path = Path(cfg.paths.log_dir) / "train.log"
|
144 |
+
else:
|
145 |
+
log_path = Path(cfg.paths.log_dir) / "eval.log"
|
146 |
+
setup_logger(log_path)
|
147 |
+
|
148 |
+
# the path to the checkpoint directory
|
149 |
+
root_dir = cfg.paths.root_dir
|
150 |
+
logger.info(f"Root directory: {root_dir}")
|
151 |
+
|
152 |
+
logger.info(f"Current working directory: {os.listdir(root_dir)}")
|
153 |
+
|
154 |
+
ckpt_dir = cfg.paths.ckpt_dir
|
155 |
+
logger.info(f"Checkpoint directory: {ckpt_dir}")
|
156 |
+
|
157 |
+
# the path to the data directory
|
158 |
+
data_dir = cfg.paths.data_dir
|
159 |
+
logger.info(f"Data directory: {data_dir}")
|
160 |
+
|
161 |
+
# the path to the log directory
|
162 |
+
log_dir = cfg.paths.log_dir
|
163 |
+
logger.info(f"Log directory: {log_dir}")
|
164 |
+
|
165 |
+
# the path to the artifact directory
|
166 |
+
artifact_dir = cfg.paths.artifact_dir
|
167 |
+
logger.info(f"Artifact directory: {artifact_dir}")
|
168 |
+
|
169 |
+
# output directory
|
170 |
+
output_dir = cfg.paths.output_dir
|
171 |
+
logger.info(f"Output directory: {output_dir}")
|
172 |
+
|
173 |
+
# name of the experiment
|
174 |
+
experiment_name = cfg.name
|
175 |
+
logger.info(f"Experiment name: {experiment_name}")
|
176 |
+
|
177 |
+
# Initialize DataModule
|
178 |
+
if experiment_name == "dogbreed_experiment":
|
179 |
+
logger.info("Setting up the DataModule")
|
180 |
+
dataset_df, datamodule = main_dataloader(cfg)
|
181 |
+
labels = dataset_df.label.nunique()
|
182 |
+
logger.info(f"Number of classes: {labels}")
|
183 |
+
|
184 |
+
os.makedirs(cfg.paths.artifact_dir, exist_ok=True)
|
185 |
+
dataset_df.to_csv(
|
186 |
+
Path(cfg.paths.artifact_dir) / "dogbreed_dataset.csv", index=False
|
187 |
+
)
|
188 |
+
elif (
|
189 |
+
experiment_name == "catdog_experiment"
|
190 |
+
or experiment_name == "catdog_experiment_convnext"
|
191 |
+
):
|
192 |
+
# Initialize DataModule
|
193 |
+
logger.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
194 |
+
datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
195 |
+
|
196 |
+
# Check for GPU availability
|
197 |
+
logger.info("GPU available" if torch.cuda.is_available() else "No GPU available")
|
198 |
+
|
199 |
+
# Set seed for reproducibility
|
200 |
+
L.seed_everything(cfg.seed, workers=True)
|
201 |
+
|
202 |
+
# Initialize model
|
203 |
+
logger.info(f"Instantiating model <{cfg.model._target_}>")
|
204 |
+
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
|
205 |
+
|
206 |
+
logger.info(f"Model summary:\n{model}")
|
207 |
+
|
208 |
+
# Set up callbacks and loggers
|
209 |
+
logger.info("Setting up callbacks and loggers")
|
210 |
+
callbacks: List[L.Callback] = instantiate_callbacks(cfg.get("callbacks"))
|
211 |
+
logger.info(f"Callbacks: {callbacks}")
|
212 |
+
loggers: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
213 |
+
logger.info(f"Loggers: {loggers}")
|
214 |
+
|
215 |
+
# Initialize Trainer
|
216 |
+
logger.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
217 |
+
trainer: L.Trainer = hydra.utils.instantiate(
|
218 |
+
cfg.trainer, callbacks=callbacks, logger=loggers
|
219 |
+
)
|
220 |
+
|
221 |
+
# Train and test the model based on config settings
|
222 |
+
train_metrics = {}
|
223 |
+
if cfg.get("train"):
|
224 |
+
# clear the checkpoint directory
|
225 |
+
clear_checkpoint_directory(cfg.paths.ckpt_dir)
|
226 |
+
|
227 |
+
logger.info("Training the model")
|
228 |
+
train_metrics = train_module(cfg, datamodule, model, trainer)
|
229 |
+
|
230 |
+
# Write training done flag using Hydra paths config
|
231 |
+
done_flag_path = Path(cfg.paths.ckpt_dir) / "train_done.flag"
|
232 |
+
with done_flag_path.open("w") as f:
|
233 |
+
f.write("Training completed.\n")
|
234 |
+
logger.info(f"Training completion flag written to: {done_flag_path}")
|
235 |
+
|
236 |
+
logger.info(
|
237 |
+
f"Training completed. Checkpoint directory: {os.listdir(cfg.paths.ckpt_dir)}"
|
238 |
+
)
|
239 |
+
|
240 |
+
test_metrics = {}
|
241 |
+
if cfg.get("test"):
|
242 |
+
logger.info(f"Checkpoint directory: {os.listdir(cfg.paths.ckpt_dir)}")
|
243 |
+
test_metrics = run_test_module(cfg, datamodule, model, trainer)
|
244 |
+
|
245 |
+
# Combine metrics
|
246 |
+
all_metrics = {**train_metrics, **test_metrics}
|
247 |
+
|
248 |
+
# Extract and return the optimization metric
|
249 |
+
optimization_metric = all_metrics.get(cfg.get("optimization_metric"))
|
250 |
+
if optimization_metric is None:
|
251 |
+
logger.warning(
|
252 |
+
f"Optimization metric '{cfg.get('optimization_metric')}' not found in metrics. Returning 0."
|
253 |
+
)
|
254 |
+
return 0.0
|
255 |
+
|
256 |
+
return optimization_metric
|
257 |
+
|
258 |
+
|
259 |
+
if __name__ == "__main__":
|
260 |
+
setup_run_trainer()
|
src/{train.py → train_optuna_callbacks.py}
RENAMED
@@ -1,7 +1,5 @@
|
|
1 |
"""
|
2 |
-
Train and evaluate a model using PyTorch Lightning.
|
3 |
-
Initializes the DataModule, Model, Trainer, and runs training and testing.
|
4 |
-
Initializes loggers and callbacks from the configuration using Hydra configuration but with a more modular approach without direct instantiation.
|
5 |
"""
|
6 |
|
7 |
import os
|
@@ -10,47 +8,61 @@ from pathlib import Path
|
|
10 |
from typing import List
|
11 |
import torch
|
12 |
import lightning as L
|
13 |
-
from lightning.pytorch.loggers import Logger, TensorBoardLogger, CSVLogger
|
14 |
-
from lightning.pytorch.callbacks import (
|
15 |
-
ModelCheckpoint,
|
16 |
-
EarlyStopping,
|
17 |
-
RichModelSummary,
|
18 |
-
RichProgressBar,
|
19 |
-
)
|
20 |
from dotenv import load_dotenv, find_dotenv
|
21 |
import hydra
|
22 |
from omegaconf import DictConfig, OmegaConf
|
23 |
-
from src.datamodules.catdog_datamodule import CatDogImageDataModule
|
24 |
from src.utils.logging_utils import setup_logger, task_wrapper
|
25 |
from loguru import logger
|
26 |
import rootutils
|
|
|
|
|
|
|
27 |
|
28 |
# Load environment variables
|
29 |
load_dotenv(find_dotenv(".env"))
|
30 |
|
31 |
# Setup root directory
|
32 |
-
|
33 |
root = rootutils.setup_root(__file__, indicator=".project-root")
|
34 |
|
35 |
|
36 |
-
def
|
37 |
-
"""
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
"
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
|
47 |
-
def
|
48 |
-
"""
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
|
56 |
def load_checkpoint_if_available(ckpt_path: str) -> str:
|
@@ -81,16 +93,19 @@ def clear_checkpoint_directory(ckpt_dir: str):
|
|
81 |
def train_module(
|
82 |
data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
|
83 |
):
|
84 |
-
"""Train the model
|
85 |
-
logger.info("Starting training")
|
|
|
86 |
trainer.fit(model, data_module)
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
94 |
|
95 |
|
96 |
@task_wrapper
|
@@ -110,77 +125,84 @@ def run_test_module(
|
|
110 |
return test_metrics[0] if test_metrics else {}
|
111 |
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
"""Set up and run the Trainer for training and testing."""
|
116 |
-
# Display configuration
|
117 |
-
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
|
118 |
|
119 |
-
#
|
120 |
-
|
121 |
-
|
122 |
-
)
|
123 |
-
|
124 |
-
|
125 |
-
# Display key paths
|
126 |
-
for path_name in [
|
127 |
-
"root_dir",
|
128 |
-
"data_dir",
|
129 |
-
"log_dir",
|
130 |
-
"ckpt_dir",
|
131 |
-
"artifact_dir",
|
132 |
-
"output_dir",
|
133 |
-
]:
|
134 |
-
logger.info(
|
135 |
-
f"{path_name.replace('_', ' ').capitalize()}: {cfg.paths[path_name]}"
|
136 |
-
)
|
137 |
|
138 |
-
# Initialize
|
139 |
-
|
140 |
-
datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
141 |
-
logger.info(f"Instantiating model <{cfg.model._target_}>")
|
142 |
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
|
143 |
|
144 |
-
#
|
145 |
-
|
146 |
-
L.seed_everything(cfg.seed, workers=True)
|
147 |
|
148 |
-
#
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
)
|
156 |
|
157 |
-
#
|
158 |
-
|
159 |
-
|
160 |
-
clear_checkpoint_directory(cfg.paths.ckpt_dir)
|
161 |
-
train_metrics = train_module(datamodule, model, trainer)
|
162 |
-
(Path(cfg.paths.ckpt_dir) / "train_done.flag").write_text(
|
163 |
-
"Training completed.\n"
|
164 |
-
)
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
# Combine metrics and extract optimization metric
|
172 |
-
all_metrics = {**train_metrics, **test_metrics}
|
173 |
-
optimization_metric = all_metrics.get(cfg.get("optimization_metric"), 0.0)
|
174 |
-
(
|
175 |
-
logger.warning(
|
176 |
-
f"Optimization metric '{cfg.get('optimization_metric')}' not found. Defaulting to 0."
|
177 |
)
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
-
return
|
183 |
|
184 |
|
185 |
if __name__ == "__main__":
|
186 |
-
|
|
|
1 |
"""
|
2 |
+
Train and evaluate a model using PyTorch Lightning with Optuna for hyperparameter optimization.
|
|
|
|
|
3 |
"""
|
4 |
|
5 |
import os
|
|
|
8 |
from typing import List
|
9 |
import torch
|
10 |
import lightning as L
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
from dotenv import load_dotenv, find_dotenv
|
12 |
import hydra
|
13 |
from omegaconf import DictConfig, OmegaConf
|
|
|
14 |
from src.utils.logging_utils import setup_logger, task_wrapper
|
15 |
from loguru import logger
|
16 |
import rootutils
|
17 |
+
from lightning.pytorch.loggers import Logger
|
18 |
+
import optuna
|
19 |
+
from lightning.pytorch import Trainer
|
20 |
|
21 |
# Load environment variables
|
22 |
load_dotenv(find_dotenv(".env"))
|
23 |
|
24 |
# Setup root directory
|
|
|
25 |
root = rootutils.setup_root(__file__, indicator=".project-root")
|
26 |
|
27 |
|
28 |
+
def instantiate_callbacks(callback_cfg: DictConfig) -> List[L.Callback]:
|
29 |
+
"""Instantiate and return a list of callbacks from the configuration."""
|
30 |
+
callbacks: List[L.Callback] = []
|
31 |
+
|
32 |
+
if not callback_cfg:
|
33 |
+
logger.warning("No callback configs found! Skipping..")
|
34 |
+
return callbacks
|
35 |
+
|
36 |
+
if not isinstance(callback_cfg, DictConfig):
|
37 |
+
raise TypeError("Callbacks config must be a DictConfig!")
|
38 |
+
|
39 |
+
for _, cb_conf in callback_cfg.items():
|
40 |
+
if "_target_" in cb_conf:
|
41 |
+
logger.info(f"Instantiating callback <{cb_conf._target_}>")
|
42 |
+
callbacks.append(hydra.utils.instantiate(cb_conf))
|
43 |
+
|
44 |
+
return callbacks
|
45 |
|
46 |
|
47 |
+
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
48 |
+
"""Instantiate and return a list of loggers from the configuration."""
|
49 |
+
loggers_ls: List[Logger] = []
|
50 |
+
|
51 |
+
if not logger_cfg or isinstance(logger_cfg, bool):
|
52 |
+
logger.warning("No valid logger configs found! Skipping..")
|
53 |
+
return loggers_ls
|
54 |
+
|
55 |
+
if not isinstance(logger_cfg, DictConfig):
|
56 |
+
raise TypeError("Logger config must be a DictConfig!")
|
57 |
+
|
58 |
+
for _, lg_conf in logger_cfg.items():
|
59 |
+
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
60 |
+
logger.info(f"Instantiating logger <{lg_conf._target_}>")
|
61 |
+
try:
|
62 |
+
loggers_ls.append(hydra.utils.instantiate(lg_conf))
|
63 |
+
except Exception as e:
|
64 |
+
logger.error(f"Failed to instantiate logger {lg_conf}: {e}")
|
65 |
+
return loggers_ls
|
66 |
|
67 |
|
68 |
def load_checkpoint_if_available(ckpt_path: str) -> str:
|
|
|
93 |
def train_module(
|
94 |
data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
|
95 |
):
|
96 |
+
"""Train the model, return validation accuracy for each epoch."""
|
97 |
+
logger.info("Starting training with custom pruning")
|
98 |
+
|
99 |
trainer.fit(model, data_module)
|
100 |
+
val_accuracies = []
|
101 |
+
|
102 |
+
for epoch in range(trainer.current_epoch):
|
103 |
+
val_acc = trainer.callback_metrics.get("val_acc")
|
104 |
+
if val_acc is not None:
|
105 |
+
val_accuracies.append(val_acc.item())
|
106 |
+
logger.info(f"Epoch {epoch}: val_acc={val_acc}")
|
107 |
+
|
108 |
+
return val_accuracies
|
109 |
|
110 |
|
111 |
@task_wrapper
|
|
|
125 |
return test_metrics[0] if test_metrics else {}
|
126 |
|
127 |
|
128 |
+
def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[L.Callback]):
|
129 |
+
"""Objective function for Optuna hyperparameter tuning."""
|
|
|
|
|
|
|
130 |
|
131 |
+
# Sample hyperparameters for the model
|
132 |
+
cfg.model.embed_dim = trial.suggest_categorical("embed_dim", [64, 128, 256])
|
133 |
+
cfg.model.depth = trial.suggest_int("depth", 2, 6)
|
134 |
+
cfg.model.lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
|
135 |
+
cfg.model.mlp_ratio = trial.suggest_float("mlp_ratio", 1.0, 4.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
+
# Initialize data module and model
|
138 |
+
data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
|
|
|
|
139 |
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
|
140 |
|
141 |
+
# Set up logger
|
142 |
+
loggers = instantiate_loggers(cfg.logger)
|
|
|
143 |
|
144 |
+
# Trainer configuration with passed callbacks
|
145 |
+
trainer = Trainer(**cfg.trainer, logger=loggers, callbacks=callbacks)
|
146 |
+
|
147 |
+
# Clear checkpoint directory
|
148 |
+
clear_checkpoint_directory(cfg.paths.ckpt_dir)
|
149 |
+
|
150 |
+
# Train and get val_acc for each epoch
|
151 |
+
val_accuracies = train_module(data_module, model, trainer)
|
152 |
+
|
153 |
+
# Report validation accuracy and prune if necessary
|
154 |
+
for epoch, val_acc in enumerate(val_accuracies):
|
155 |
+
trial.report(val_acc, step=epoch)
|
156 |
+
|
157 |
+
# Check if the trial should be pruned at this epoch
|
158 |
+
if trial.should_prune():
|
159 |
+
logger.info(f"Pruning trial at epoch {epoch}")
|
160 |
+
raise optuna.TrialPruned()
|
161 |
+
|
162 |
+
# Return the final validation accuracy as the objective metric
|
163 |
+
return val_accuracies[-1] if val_accuracies else 0.0
|
164 |
+
|
165 |
+
|
166 |
+
@hydra.main(config_path="../configs", config_name="train", version_base="1.3")
|
167 |
+
def setup_trainer(cfg: DictConfig):
|
168 |
+
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
|
169 |
+
|
170 |
+
setup_logger(
|
171 |
+
Path(cfg.paths.log_dir)
|
172 |
+
/ ("train.log" if cfg.task_name == "train" else "eval.log")
|
173 |
)
|
174 |
|
175 |
+
# Instantiate callbacks
|
176 |
+
callbacks = instantiate_callbacks(cfg.callbacks)
|
177 |
+
logger.info(f"Callbacks: {callbacks}")
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
if cfg.get("train", False):
|
180 |
+
pruner = optuna.pruners.MedianPruner()
|
181 |
+
study = optuna.create_study(
|
182 |
+
direction="maximize", pruner=pruner, study_name="pytorch_lightning_optuna"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
)
|
184 |
+
study.optimize(
|
185 |
+
lambda trial: objective(trial, cfg, callbacks),
|
186 |
+
n_trials=5,
|
187 |
+
show_progress_bar=True,
|
188 |
+
)
|
189 |
+
|
190 |
+
# Log best trial results
|
191 |
+
best_trial = study.best_trial
|
192 |
+
logger.info(f"Best trial number: {best_trial.number}")
|
193 |
+
logger.info(f"Best trial value (val_acc): {best_trial.value}")
|
194 |
+
for key, value in best_trial.params.items():
|
195 |
+
logger.info(f" {key}: {value}")
|
196 |
+
|
197 |
+
if cfg.get("test", False):
|
198 |
+
data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
199 |
+
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
|
200 |
+
trainer = Trainer(**cfg.trainer, logger=instantiate_loggers(cfg.logger))
|
201 |
+
test_metrics = run_test_module(cfg, data_module, model, trainer)
|
202 |
+
logger.info(f"Test metrics: {test_metrics}")
|
203 |
|
204 |
+
return cfg.model if not cfg.get("test", False) else test_metrics
|
205 |
|
206 |
|
207 |
if __name__ == "__main__":
|
208 |
+
setup_trainer()
|