Soutrik commited on
Commit
8d4131e
·
1 Parent(s): de7d21e

optuna added as base

Browse files
configs/experiment/catdog_experiment.yaml CHANGED
@@ -39,7 +39,7 @@ model:
39
 
40
  trainer:
41
  min_epochs: 1
42
- max_epochs: 6
43
 
44
  callbacks:
45
  model_checkpoint:
 
39
 
40
  trainer:
41
  min_epochs: 1
42
+ max_epochs: 5
43
 
44
  callbacks:
45
  model_checkpoint:
configs/trainer/default.yaml CHANGED
@@ -1,4 +1,3 @@
1
- _target_: lightning.Trainer
2
 
3
  default_root_dir: ${paths.output_dir}
4
  min_epochs: 1
@@ -10,8 +9,7 @@ devices: auto
10
  # mixed precision for extra speed-up
11
  # precision: 16
12
 
13
- # set True to to ensure deterministic results
14
- # makes training slower but gives more reproducibility than just setting seeds
15
  deterministic: True
16
 
17
  # Log every N steps in training and validation
@@ -19,3 +17,4 @@ log_every_n_steps: 10
19
  fast_dev_run: False
20
 
21
  gradient_clip_val: 1.0
 
 
 
1
 
2
  default_root_dir: ${paths.output_dir}
3
  min_epochs: 1
 
9
  # mixed precision for extra speed-up
10
  # precision: 16
11
 
12
+ # set True to to ensure deterministic results makes training slower but gives more reproducibility than just setting seeds
 
13
  deterministic: True
14
 
15
  # Log every N steps in training and validation
 
17
  fast_dev_run: False
18
 
19
  gradient_clip_val: 1.0
20
+ gradient_clip_algorithm: 'norm'
src/train_new.py CHANGED
@@ -1,7 +1,5 @@
1
  """
2
- Train and evaluate a model using PyTorch Lightning.
3
- Initializes the DataModule, Model, Trainer, and runs training and testing.
4
- Initializes loggers and callbacks from the configuration using Hydra and target paths from the configuration.
5
  """
6
 
7
  import os
@@ -17,51 +15,34 @@ from src.utils.logging_utils import setup_logger, task_wrapper
17
  from loguru import logger
18
  import rootutils
19
  from lightning.pytorch.loggers import Logger
20
- from lightning.pytorch.callbacks import Callback
 
21
 
22
  # Load environment variables
23
  load_dotenv(find_dotenv(".env"))
24
 
25
  # Setup root directory
26
-
27
  root = rootutils.setup_root(__file__, indicator=".project-root")
28
 
29
 
30
- def instantiate_callbacks(callback_cfg: DictConfig) -> List[Callback]:
31
- """Instantiate and return a list of callbacks from the configuration."""
32
- callbacks_ls: List[L.Callback] = []
33
-
34
- if not callback_cfg:
35
- logger.warning("No callback configs found! Skipping..")
36
- return None
37
-
38
- if not isinstance(callback_cfg, DictConfig):
39
- raise TypeError("Callbacks config must be a DictConfig!")
40
-
41
- for _, cb_conf in callback_cfg.items():
42
- if "_target_" in cb_conf:
43
- logger.info(f"Instantiating callback <{cb_conf._target_}>")
44
- callbacks_ls.append(hydra.utils.instantiate(cb_conf))
45
-
46
- return callbacks_ls
47
-
48
-
49
  def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
50
  """Instantiate and return a list of loggers from the configuration."""
51
  loggers_ls: List[Logger] = []
52
 
53
- if not logger_cfg:
54
- logger.warning("No logger configs found! Skipping..")
55
  return loggers_ls
56
 
57
  if not isinstance(logger_cfg, DictConfig):
58
  raise TypeError("Logger config must be a DictConfig!")
59
 
60
  for _, lg_conf in logger_cfg.items():
61
- if "_target_" in lg_conf:
62
  logger.info(f"Instantiating logger <{lg_conf._target_}>")
63
- loggers_ls.append(hydra.utils.instantiate(lg_conf))
64
-
 
 
65
  return loggers_ls
66
 
67
 
@@ -93,16 +74,19 @@ def clear_checkpoint_directory(ckpt_dir: str):
93
  def train_module(
94
  data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
95
  ):
96
- """Train the model and log metrics."""
97
- logger.info("Starting training")
 
98
  trainer.fit(model, data_module)
99
- train_metrics = trainer.callback_metrics
100
- train_acc = train_metrics.get("train_acc")
101
- val_acc = train_metrics.get("val_acc")
102
- logger.info(
103
- f"Training completed. Metrics - train_acc: {train_acc}, val_acc: {val_acc}"
104
- )
105
- return train_metrics
 
 
106
 
107
 
108
  @task_wrapper
@@ -122,77 +106,78 @@ def run_test_module(
122
  return test_metrics[0] if test_metrics else {}
123
 
124
 
125
- @hydra.main(config_path="../configs", config_name="train", version_base="1.3")
126
- def setup_run_trainer(cfg: DictConfig):
127
- """Set up and run the Trainer for training and testing."""
128
- # Display configuration
129
- logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
130
 
131
- # Initialize logger
132
- log_path = Path(cfg.paths.log_dir) / (
133
- "train.log" if cfg.task_name == "train" else "eval.log"
134
- )
135
- setup_logger(log_path)
136
-
137
- # Display key paths
138
- for path_name in [
139
- "root_dir",
140
- "data_dir",
141
- "log_dir",
142
- "ckpt_dir",
143
- "artifact_dir",
144
- "output_dir",
145
- ]:
146
- logger.info(
147
- f"{path_name.replace('_', ' ').capitalize()}: {cfg.paths[path_name]}"
148
- )
149
 
150
- # Initialize DataModule and Model
151
- logger.info(f"Instantiating datamodule <{cfg.data._target_}>")
152
- datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
153
- logger.info(f"Instantiating model <{cfg.model._target_}>")
154
  model: L.LightningModule = hydra.utils.instantiate(cfg.model)
155
 
156
- # Check GPU availability and set seed for reproducibility
157
- logger.info("GPU available" if torch.cuda.is_available() else "No GPU available")
158
- L.seed_everything(cfg.seed, workers=True)
159
-
160
- # Set up callbacks, loggers, and Trainer
161
- callbacks = instantiate_callbacks(cfg.callbacks)
162
- logger.info(f"Callbacks: {callbacks}")
163
  loggers = instantiate_loggers(cfg.logger)
164
- logger.info(f"Loggers: {loggers}")
165
- trainer: L.Trainer = hydra.utils.instantiate(
166
- cfg.trainer, callbacks=callbacks, logger=loggers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  )
168
 
169
- # Training phase
170
- train_metrics = {}
171
- if cfg.get("train"):
172
- clear_checkpoint_directory(cfg.paths.ckpt_dir)
173
- train_metrics = train_module(datamodule, model, trainer)
174
- (Path(cfg.paths.ckpt_dir) / "train_done.flag").write_text(
175
- "Training completed.\n"
176
  )
177
-
178
- # Testing phase
179
- test_metrics = {}
180
- if cfg.get("test"):
181
- test_metrics = run_test_module(cfg, datamodule, model, trainer)
182
-
183
- # Combine metrics and extract optimization metric
184
- all_metrics = {**train_metrics, **test_metrics}
185
- optimization_metric = all_metrics.get(cfg.get("optimization_metric"), 0.0)
186
- (
187
- logger.warning(
188
- f"Optimization metric '{cfg.get('optimization_metric')}' not found. Defaulting to 0."
189
  )
190
- if optimization_metric == 0.0
191
- else logger.info(f"Optimization metric: {optimization_metric}")
192
- )
193
 
194
- return optimization_metric
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
 
197
  if __name__ == "__main__":
198
- setup_run_trainer()
 
1
  """
2
+ Train and evaluate a model using PyTorch Lightning with Optuna for hyperparameter optimization.
 
 
3
  """
4
 
5
  import os
 
15
  from loguru import logger
16
  import rootutils
17
  from lightning.pytorch.loggers import Logger
18
+ import optuna
19
+ from lightning.pytorch import Trainer
20
 
21
  # Load environment variables
22
  load_dotenv(find_dotenv(".env"))
23
 
24
  # Setup root directory
 
25
  root = rootutils.setup_root(__file__, indicator=".project-root")
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
29
  """Instantiate and return a list of loggers from the configuration."""
30
  loggers_ls: List[Logger] = []
31
 
32
+ if not logger_cfg or isinstance(logger_cfg, bool):
33
+ logger.warning("No valid logger configs found! Skipping..")
34
  return loggers_ls
35
 
36
  if not isinstance(logger_cfg, DictConfig):
37
  raise TypeError("Logger config must be a DictConfig!")
38
 
39
  for _, lg_conf in logger_cfg.items():
40
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
41
  logger.info(f"Instantiating logger <{lg_conf._target_}>")
42
+ try:
43
+ loggers_ls.append(hydra.utils.instantiate(lg_conf))
44
+ except Exception as e:
45
+ logger.error(f"Failed to instantiate logger {lg_conf}: {e}")
46
  return loggers_ls
47
 
48
 
 
74
  def train_module(
75
  data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
76
  ):
77
+ """Train the model, return validation accuracy for each epoch."""
78
+ logger.info("Starting training with custom pruning")
79
+
80
  trainer.fit(model, data_module)
81
+ val_accuracies = []
82
+
83
+ for epoch in range(trainer.current_epoch):
84
+ val_acc = trainer.callback_metrics.get("val_acc")
85
+ if val_acc is not None:
86
+ val_accuracies.append(val_acc.item())
87
+ logger.info(f"Epoch {epoch}: val_acc={val_acc}")
88
+
89
+ return val_accuracies
90
 
91
 
92
  @task_wrapper
 
106
  return test_metrics[0] if test_metrics else {}
107
 
108
 
109
+ def objective(trial: optuna.trial.Trial, cfg: DictConfig):
110
+ """Objective function for Optuna hyperparameter tuning."""
 
 
 
111
 
112
+ # Sample hyperparameters for the model
113
+ cfg.model.embed_dim = trial.suggest_categorical("embed_dim", [64, 128, 256])
114
+ cfg.model.depth = trial.suggest_int("depth", 2, 6)
115
+ cfg.model.lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
116
+ cfg.model.mlp_ratio = trial.suggest_float("mlp_ratio", 1.0, 4.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ # Initialize data module and model
119
+ data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
 
 
120
  model: L.LightningModule = hydra.utils.instantiate(cfg.model)
121
 
122
+ # Set up logger
 
 
 
 
 
 
123
  loggers = instantiate_loggers(cfg.logger)
124
+
125
+ # Trainer configuration without pruning callback
126
+ trainer = Trainer(**cfg.trainer, logger=loggers)
127
+
128
+ # Clear checkpoint directory
129
+ clear_checkpoint_directory(cfg.paths.ckpt_dir)
130
+
131
+ # Train and get val_acc for each epoch
132
+ val_accuracies = train_module(data_module, model, trainer)
133
+
134
+ # Report validation accuracy and prune if necessary
135
+ for epoch, val_acc in enumerate(val_accuracies):
136
+ trial.report(val_acc, step=epoch)
137
+
138
+ # Check if the trial should be pruned at this epoch
139
+ if trial.should_prune():
140
+ logger.info(f"Pruning trial at epoch {epoch}")
141
+ raise optuna.TrialPruned()
142
+
143
+ # Return the final validation accuracy as the objective metric
144
+ return val_accuracies[-1] if val_accuracies else 0.0
145
+
146
+
147
+ @hydra.main(config_path="../configs", config_name="train", version_base="1.3")
148
+ def setup_trainer(cfg: DictConfig):
149
+ logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
150
+
151
+ setup_logger(
152
+ Path(cfg.paths.log_dir)
153
+ / ("train.log" if cfg.task_name == "train" else "eval.log")
154
  )
155
 
156
+ if cfg.get("train", False):
157
+ pruner = optuna.pruners.MedianPruner()
158
+ study = optuna.create_study(
159
+ direction="maximize", pruner=pruner, study_name="pytorch_lightning_optuna"
 
 
 
160
  )
161
+ study.optimize(
162
+ lambda trial: objective(trial, cfg), n_trials=3, show_progress_bar=True
 
 
 
 
 
 
 
 
 
 
163
  )
 
 
 
164
 
165
+ # Log best trial results
166
+ best_trial = study.best_trial
167
+ logger.info(f"Best trial number: {best_trial.number}")
168
+ logger.info(f"Best trial value (val_acc): {best_trial.value}")
169
+ for key, value in best_trial.params.items():
170
+ logger.info(f" {key}: {value}")
171
+
172
+ if cfg.get("test", False):
173
+ data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
174
+ model: L.LightningModule = hydra.utils.instantiate(cfg.model)
175
+ trainer = Trainer(**cfg.trainer, logger=instantiate_loggers(cfg.logger))
176
+ test_metrics = run_test_module(cfg, data_module, model, trainer)
177
+ logger.info(f"Test metrics: {test_metrics}")
178
+
179
+ return cfg.model if not cfg.get("test", False) else test_metrics
180
 
181
 
182
  if __name__ == "__main__":
183
+ setup_trainer()
src/train_old.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+ import torch
5
+ import lightning as L
6
+ from lightning.pytorch.loggers import Logger
7
+ from typing import List
8
+ from src.datamodules.dogbreed_datamodule import main_dataloader
9
+ from src.utils.logging_utils import setup_logger, task_wrapper
10
+ from loguru import logger
11
+ from dotenv import load_dotenv, find_dotenv
12
+ import rootutils
13
+ import hydra
14
+ from omegaconf import DictConfig, OmegaConf
15
+
16
+ # Load environment variables
17
+ load_dotenv(find_dotenv(".env"))
18
+
19
+ # Setup root directory
20
+ root = rootutils.setup_root(__file__, indicator=".project-root")
21
+
22
+
23
+ def instantiate_callbacks(callback_cfg: DictConfig) -> List[L.Callback]:
24
+ """Instantiate and return a list of callbacks from the configuration."""
25
+ callbacks: List[L.Callback] = []
26
+
27
+ if not callback_cfg:
28
+ logger.warning("No callback configs found! Skipping..")
29
+ return callbacks
30
+
31
+ if not isinstance(callback_cfg, DictConfig):
32
+ raise TypeError("Callbacks config must be a DictConfig!")
33
+
34
+ for _, cb_conf in callback_cfg.items():
35
+ if "_target_" in cb_conf:
36
+ logger.info(f"Instantiating callback <{cb_conf._target_}>")
37
+ callbacks.append(hydra.utils.instantiate(cb_conf))
38
+
39
+ return callbacks
40
+
41
+
42
+ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
43
+ """Instantiate and return a list of loggers from the configuration."""
44
+ loggers_ls: List[Logger] = []
45
+
46
+ if not logger_cfg:
47
+ logger.warning("No logger configs found! Skipping..")
48
+ return loggers_ls
49
+
50
+ if not isinstance(logger_cfg, DictConfig):
51
+ raise TypeError("Logger config must be a DictConfig!")
52
+
53
+ for _, lg_conf in logger_cfg.items():
54
+ if "_target_" in lg_conf:
55
+ logger.info(f"Instantiating logger <{lg_conf._target_}>")
56
+ loggers_ls.append(hydra.utils.instantiate(lg_conf))
57
+
58
+ return loggers_ls
59
+
60
+
61
+ def load_checkpoint_if_available(ckpt_path: str) -> str:
62
+ """Check if the specified checkpoint exists and return the valid checkpoint path."""
63
+ if ckpt_path and Path(ckpt_path).exists():
64
+ logger.info(f"Checkpoint found: {ckpt_path}")
65
+ return ckpt_path
66
+ else:
67
+ logger.warning(
68
+ f"No checkpoint found at {ckpt_path}. Using current model weights."
69
+ )
70
+ return None
71
+
72
+
73
+ def clear_checkpoint_directory(ckpt_dir: str):
74
+ """Clear all contents of the checkpoint directory without deleting the directory itself."""
75
+ ckpt_dir_path = Path(ckpt_dir)
76
+ if ckpt_dir_path.exists() and ckpt_dir_path.is_dir():
77
+ logger.info(f"Clearing checkpoint directory: {ckpt_dir}")
78
+ # Iterate over all files and directories in the checkpoint directory and remove them
79
+ for item in ckpt_dir_path.iterdir():
80
+ try:
81
+ if item.is_file() or item.is_symlink():
82
+ item.unlink() # Remove file or symlink
83
+ elif item.is_dir():
84
+ shutil.rmtree(item) # Remove directory
85
+ except Exception as e:
86
+ logger.error(f"Failed to delete {item}: {e}")
87
+ logger.info(f"Checkpoint directory cleared: {ckpt_dir}")
88
+ else:
89
+ logger.info(
90
+ f"Checkpoint directory does not exist. Creating directory: {ckpt_dir}"
91
+ )
92
+ os.makedirs(ckpt_dir_path, exist_ok=True)
93
+
94
+
95
+ @task_wrapper
96
+ def train_module(
97
+ cfg: DictConfig,
98
+ data_module: L.LightningDataModule,
99
+ model: L.LightningModule,
100
+ trainer: L.Trainer,
101
+ ):
102
+ """Train the model using the provided Trainer and DataModule."""
103
+ logger.info("Training the model")
104
+ trainer.fit(model, data_module)
105
+ train_metrics = trainer.callback_metrics
106
+ try:
107
+ logger.info(
108
+ f"Training completed with the following metrics- train_acc: {train_metrics['train_acc'].item()} and val_acc: {train_metrics['val_acc'].item()}"
109
+ )
110
+ except KeyError:
111
+ logger.info(f"Training completed with the following metrics:{train_metrics}")
112
+
113
+ return train_metrics
114
+
115
+
116
+ @task_wrapper
117
+ def run_test_module(
118
+ cfg: DictConfig,
119
+ datamodule: L.LightningDataModule,
120
+ model: L.LightningModule,
121
+ trainer: L.Trainer,
122
+ ):
123
+ """Test the model using the best checkpoint or the current model weights."""
124
+ logger.info("Testing the model")
125
+ datamodule.setup(stage="test")
126
+
127
+ ckpt_path = load_checkpoint_if_available(cfg.ckpt_path)
128
+
129
+ # If no checkpoint is available, Lightning will use current model weights
130
+ test_metrics = trainer.test(model, datamodule, ckpt_path=ckpt_path)
131
+ logger.info(f"Test metrics:\n{test_metrics}")
132
+
133
+ return test_metrics[0] if test_metrics else {}
134
+
135
+
136
+ @hydra.main(config_path="../configs", config_name="train", version_base="1.1")
137
+ def setup_run_trainer(cfg: DictConfig):
138
+ """Set up and run the Trainer for training and testing the model."""
139
+ # show me the entire config
140
+ logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
141
+ # Initialize logger
142
+ if cfg.task_name == "train":
143
+ log_path = Path(cfg.paths.log_dir) / "train.log"
144
+ else:
145
+ log_path = Path(cfg.paths.log_dir) / "eval.log"
146
+ setup_logger(log_path)
147
+
148
+ # the path to the checkpoint directory
149
+ root_dir = cfg.paths.root_dir
150
+ logger.info(f"Root directory: {root_dir}")
151
+
152
+ logger.info(f"Current working directory: {os.listdir(root_dir)}")
153
+
154
+ ckpt_dir = cfg.paths.ckpt_dir
155
+ logger.info(f"Checkpoint directory: {ckpt_dir}")
156
+
157
+ # the path to the data directory
158
+ data_dir = cfg.paths.data_dir
159
+ logger.info(f"Data directory: {data_dir}")
160
+
161
+ # the path to the log directory
162
+ log_dir = cfg.paths.log_dir
163
+ logger.info(f"Log directory: {log_dir}")
164
+
165
+ # the path to the artifact directory
166
+ artifact_dir = cfg.paths.artifact_dir
167
+ logger.info(f"Artifact directory: {artifact_dir}")
168
+
169
+ # output directory
170
+ output_dir = cfg.paths.output_dir
171
+ logger.info(f"Output directory: {output_dir}")
172
+
173
+ # name of the experiment
174
+ experiment_name = cfg.name
175
+ logger.info(f"Experiment name: {experiment_name}")
176
+
177
+ # Initialize DataModule
178
+ if experiment_name == "dogbreed_experiment":
179
+ logger.info("Setting up the DataModule")
180
+ dataset_df, datamodule = main_dataloader(cfg)
181
+ labels = dataset_df.label.nunique()
182
+ logger.info(f"Number of classes: {labels}")
183
+
184
+ os.makedirs(cfg.paths.artifact_dir, exist_ok=True)
185
+ dataset_df.to_csv(
186
+ Path(cfg.paths.artifact_dir) / "dogbreed_dataset.csv", index=False
187
+ )
188
+ elif (
189
+ experiment_name == "catdog_experiment"
190
+ or experiment_name == "catdog_experiment_convnext"
191
+ ):
192
+ # Initialize DataModule
193
+ logger.info(f"Instantiating datamodule <{cfg.data._target_}>")
194
+ datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
195
+
196
+ # Check for GPU availability
197
+ logger.info("GPU available" if torch.cuda.is_available() else "No GPU available")
198
+
199
+ # Set seed for reproducibility
200
+ L.seed_everything(cfg.seed, workers=True)
201
+
202
+ # Initialize model
203
+ logger.info(f"Instantiating model <{cfg.model._target_}>")
204
+ model: L.LightningModule = hydra.utils.instantiate(cfg.model)
205
+
206
+ logger.info(f"Model summary:\n{model}")
207
+
208
+ # Set up callbacks and loggers
209
+ logger.info("Setting up callbacks and loggers")
210
+ callbacks: List[L.Callback] = instantiate_callbacks(cfg.get("callbacks"))
211
+ logger.info(f"Callbacks: {callbacks}")
212
+ loggers: List[Logger] = instantiate_loggers(cfg.get("logger"))
213
+ logger.info(f"Loggers: {loggers}")
214
+
215
+ # Initialize Trainer
216
+ logger.info(f"Instantiating trainer <{cfg.trainer._target_}>")
217
+ trainer: L.Trainer = hydra.utils.instantiate(
218
+ cfg.trainer, callbacks=callbacks, logger=loggers
219
+ )
220
+
221
+ # Train and test the model based on config settings
222
+ train_metrics = {}
223
+ if cfg.get("train"):
224
+ # clear the checkpoint directory
225
+ clear_checkpoint_directory(cfg.paths.ckpt_dir)
226
+
227
+ logger.info("Training the model")
228
+ train_metrics = train_module(cfg, datamodule, model, trainer)
229
+
230
+ # Write training done flag using Hydra paths config
231
+ done_flag_path = Path(cfg.paths.ckpt_dir) / "train_done.flag"
232
+ with done_flag_path.open("w") as f:
233
+ f.write("Training completed.\n")
234
+ logger.info(f"Training completion flag written to: {done_flag_path}")
235
+
236
+ logger.info(
237
+ f"Training completed. Checkpoint directory: {os.listdir(cfg.paths.ckpt_dir)}"
238
+ )
239
+
240
+ test_metrics = {}
241
+ if cfg.get("test"):
242
+ logger.info(f"Checkpoint directory: {os.listdir(cfg.paths.ckpt_dir)}")
243
+ test_metrics = run_test_module(cfg, datamodule, model, trainer)
244
+
245
+ # Combine metrics
246
+ all_metrics = {**train_metrics, **test_metrics}
247
+
248
+ # Extract and return the optimization metric
249
+ optimization_metric = all_metrics.get(cfg.get("optimization_metric"))
250
+ if optimization_metric is None:
251
+ logger.warning(
252
+ f"Optimization metric '{cfg.get('optimization_metric')}' not found in metrics. Returning 0."
253
+ )
254
+ return 0.0
255
+
256
+ return optimization_metric
257
+
258
+
259
+ if __name__ == "__main__":
260
+ setup_run_trainer()
src/{train.py → train_optuna_callbacks.py} RENAMED
@@ -1,7 +1,5 @@
1
  """
2
- Train and evaluate a model using PyTorch Lightning.
3
- Initializes the DataModule, Model, Trainer, and runs training and testing.
4
- Initializes loggers and callbacks from the configuration using Hydra configuration but with a more modular approach without direct instantiation.
5
  """
6
 
7
  import os
@@ -10,47 +8,61 @@ from pathlib import Path
10
  from typing import List
11
  import torch
12
  import lightning as L
13
- from lightning.pytorch.loggers import Logger, TensorBoardLogger, CSVLogger
14
- from lightning.pytorch.callbacks import (
15
- ModelCheckpoint,
16
- EarlyStopping,
17
- RichModelSummary,
18
- RichProgressBar,
19
- )
20
  from dotenv import load_dotenv, find_dotenv
21
  import hydra
22
  from omegaconf import DictConfig, OmegaConf
23
- from src.datamodules.catdog_datamodule import CatDogImageDataModule
24
  from src.utils.logging_utils import setup_logger, task_wrapper
25
  from loguru import logger
26
  import rootutils
 
 
 
27
 
28
  # Load environment variables
29
  load_dotenv(find_dotenv(".env"))
30
 
31
  # Setup root directory
32
-
33
  root = rootutils.setup_root(__file__, indicator=".project-root")
34
 
35
 
36
- def initialize_callbacks(cfg: DictConfig) -> List[L.Callback]:
37
- """Initialize callbacks based on configuration."""
38
- callback_classes = {
39
- "model_checkpoint": ModelCheckpoint,
40
- "early_stopping": EarlyStopping,
41
- "rich_model_summary": RichModelSummary,
42
- "rich_progress_bar": RichProgressBar,
43
- }
44
- return [callback_classes[name](**params) for name, params in cfg.callbacks.items()]
 
 
 
 
 
 
 
 
45
 
46
 
47
- def initialize_loggers(cfg: DictConfig) -> List[Logger]:
48
- """Initialize loggers based on configuration."""
49
- logger_classes = {
50
- "tensorboard": TensorBoardLogger,
51
- "csv": CSVLogger,
52
- }
53
- return [logger_classes[name](**params) for name, params in cfg.logger.items()]
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  def load_checkpoint_if_available(ckpt_path: str) -> str:
@@ -81,16 +93,19 @@ def clear_checkpoint_directory(ckpt_dir: str):
81
  def train_module(
82
  data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
83
  ):
84
- """Train the model and log metrics."""
85
- logger.info("Starting training")
 
86
  trainer.fit(model, data_module)
87
- train_metrics = trainer.callback_metrics
88
- train_acc = train_metrics.get("train_acc")
89
- val_acc = train_metrics.get("val_acc")
90
- logger.info(
91
- f"Training completed. Metrics - train_acc: {train_acc}, val_acc: {val_acc}"
92
- )
93
- return train_metrics
 
 
94
 
95
 
96
  @task_wrapper
@@ -110,77 +125,84 @@ def run_test_module(
110
  return test_metrics[0] if test_metrics else {}
111
 
112
 
113
- @hydra.main(config_path="../configs", config_name="train", version_base="1.1")
114
- def setup_run_trainer(cfg: DictConfig):
115
- """Set up and run the Trainer for training and testing."""
116
- # Display configuration
117
- logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
118
 
119
- # Initialize logger
120
- log_path = Path(cfg.paths.log_dir) / (
121
- "train.log" if cfg.task_name == "train" else "eval.log"
122
- )
123
- setup_logger(log_path)
124
-
125
- # Display key paths
126
- for path_name in [
127
- "root_dir",
128
- "data_dir",
129
- "log_dir",
130
- "ckpt_dir",
131
- "artifact_dir",
132
- "output_dir",
133
- ]:
134
- logger.info(
135
- f"{path_name.replace('_', ' ').capitalize()}: {cfg.paths[path_name]}"
136
- )
137
 
138
- # Initialize DataModule and Model
139
- logger.info(f"Instantiating datamodule <{cfg.data._target_}>")
140
- datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
141
- logger.info(f"Instantiating model <{cfg.model._target_}>")
142
  model: L.LightningModule = hydra.utils.instantiate(cfg.model)
143
 
144
- # Check GPU availability and set seed for reproducibility
145
- logger.info("GPU available" if torch.cuda.is_available() else "No GPU available")
146
- L.seed_everything(cfg.seed, workers=True)
147
 
148
- # Set up callbacks, loggers, and Trainer
149
- callbacks = initialize_callbacks(cfg)
150
- logger.info(f"Callbacks: {callbacks}")
151
- loggers = initialize_loggers(cfg)
152
- logger.info(f"Loggers: {loggers}")
153
- trainer: L.Trainer = hydra.utils.instantiate(
154
- cfg.trainer, callbacks=callbacks, logger=loggers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  )
156
 
157
- # Training phase
158
- train_metrics = {}
159
- if cfg.get("train"):
160
- clear_checkpoint_directory(cfg.paths.ckpt_dir)
161
- train_metrics = train_module(datamodule, model, trainer)
162
- (Path(cfg.paths.ckpt_dir) / "train_done.flag").write_text(
163
- "Training completed.\n"
164
- )
165
 
166
- # Testing phase
167
- test_metrics = {}
168
- if cfg.get("test"):
169
- test_metrics = run_test_module(cfg, datamodule, model, trainer)
170
-
171
- # Combine metrics and extract optimization metric
172
- all_metrics = {**train_metrics, **test_metrics}
173
- optimization_metric = all_metrics.get(cfg.get("optimization_metric"), 0.0)
174
- (
175
- logger.warning(
176
- f"Optimization metric '{cfg.get('optimization_metric')}' not found. Defaulting to 0."
177
  )
178
- if optimization_metric == 0.0
179
- else logger.info(f"Optimization metric: {optimization_metric}")
180
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- return optimization_metric
183
 
184
 
185
  if __name__ == "__main__":
186
- setup_run_trainer()
 
1
  """
2
+ Train and evaluate a model using PyTorch Lightning with Optuna for hyperparameter optimization.
 
 
3
  """
4
 
5
  import os
 
8
  from typing import List
9
  import torch
10
  import lightning as L
 
 
 
 
 
 
 
11
  from dotenv import load_dotenv, find_dotenv
12
  import hydra
13
  from omegaconf import DictConfig, OmegaConf
 
14
  from src.utils.logging_utils import setup_logger, task_wrapper
15
  from loguru import logger
16
  import rootutils
17
+ from lightning.pytorch.loggers import Logger
18
+ import optuna
19
+ from lightning.pytorch import Trainer
20
 
21
  # Load environment variables
22
  load_dotenv(find_dotenv(".env"))
23
 
24
  # Setup root directory
 
25
  root = rootutils.setup_root(__file__, indicator=".project-root")
26
 
27
 
28
+ def instantiate_callbacks(callback_cfg: DictConfig) -> List[L.Callback]:
29
+ """Instantiate and return a list of callbacks from the configuration."""
30
+ callbacks: List[L.Callback] = []
31
+
32
+ if not callback_cfg:
33
+ logger.warning("No callback configs found! Skipping..")
34
+ return callbacks
35
+
36
+ if not isinstance(callback_cfg, DictConfig):
37
+ raise TypeError("Callbacks config must be a DictConfig!")
38
+
39
+ for _, cb_conf in callback_cfg.items():
40
+ if "_target_" in cb_conf:
41
+ logger.info(f"Instantiating callback <{cb_conf._target_}>")
42
+ callbacks.append(hydra.utils.instantiate(cb_conf))
43
+
44
+ return callbacks
45
 
46
 
47
+ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
48
+ """Instantiate and return a list of loggers from the configuration."""
49
+ loggers_ls: List[Logger] = []
50
+
51
+ if not logger_cfg or isinstance(logger_cfg, bool):
52
+ logger.warning("No valid logger configs found! Skipping..")
53
+ return loggers_ls
54
+
55
+ if not isinstance(logger_cfg, DictConfig):
56
+ raise TypeError("Logger config must be a DictConfig!")
57
+
58
+ for _, lg_conf in logger_cfg.items():
59
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
60
+ logger.info(f"Instantiating logger <{lg_conf._target_}>")
61
+ try:
62
+ loggers_ls.append(hydra.utils.instantiate(lg_conf))
63
+ except Exception as e:
64
+ logger.error(f"Failed to instantiate logger {lg_conf}: {e}")
65
+ return loggers_ls
66
 
67
 
68
  def load_checkpoint_if_available(ckpt_path: str) -> str:
 
93
  def train_module(
94
  data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
95
  ):
96
+ """Train the model, return validation accuracy for each epoch."""
97
+ logger.info("Starting training with custom pruning")
98
+
99
  trainer.fit(model, data_module)
100
+ val_accuracies = []
101
+
102
+ for epoch in range(trainer.current_epoch):
103
+ val_acc = trainer.callback_metrics.get("val_acc")
104
+ if val_acc is not None:
105
+ val_accuracies.append(val_acc.item())
106
+ logger.info(f"Epoch {epoch}: val_acc={val_acc}")
107
+
108
+ return val_accuracies
109
 
110
 
111
  @task_wrapper
 
125
  return test_metrics[0] if test_metrics else {}
126
 
127
 
128
+ def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[L.Callback]):
129
+ """Objective function for Optuna hyperparameter tuning."""
 
 
 
130
 
131
+ # Sample hyperparameters for the model
132
+ cfg.model.embed_dim = trial.suggest_categorical("embed_dim", [64, 128, 256])
133
+ cfg.model.depth = trial.suggest_int("depth", 2, 6)
134
+ cfg.model.lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
135
+ cfg.model.mlp_ratio = trial.suggest_float("mlp_ratio", 1.0, 4.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ # Initialize data module and model
138
+ data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
 
 
139
  model: L.LightningModule = hydra.utils.instantiate(cfg.model)
140
 
141
+ # Set up logger
142
+ loggers = instantiate_loggers(cfg.logger)
 
143
 
144
+ # Trainer configuration with passed callbacks
145
+ trainer = Trainer(**cfg.trainer, logger=loggers, callbacks=callbacks)
146
+
147
+ # Clear checkpoint directory
148
+ clear_checkpoint_directory(cfg.paths.ckpt_dir)
149
+
150
+ # Train and get val_acc for each epoch
151
+ val_accuracies = train_module(data_module, model, trainer)
152
+
153
+ # Report validation accuracy and prune if necessary
154
+ for epoch, val_acc in enumerate(val_accuracies):
155
+ trial.report(val_acc, step=epoch)
156
+
157
+ # Check if the trial should be pruned at this epoch
158
+ if trial.should_prune():
159
+ logger.info(f"Pruning trial at epoch {epoch}")
160
+ raise optuna.TrialPruned()
161
+
162
+ # Return the final validation accuracy as the objective metric
163
+ return val_accuracies[-1] if val_accuracies else 0.0
164
+
165
+
166
+ @hydra.main(config_path="../configs", config_name="train", version_base="1.3")
167
+ def setup_trainer(cfg: DictConfig):
168
+ logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
169
+
170
+ setup_logger(
171
+ Path(cfg.paths.log_dir)
172
+ / ("train.log" if cfg.task_name == "train" else "eval.log")
173
  )
174
 
175
+ # Instantiate callbacks
176
+ callbacks = instantiate_callbacks(cfg.callbacks)
177
+ logger.info(f"Callbacks: {callbacks}")
 
 
 
 
 
178
 
179
+ if cfg.get("train", False):
180
+ pruner = optuna.pruners.MedianPruner()
181
+ study = optuna.create_study(
182
+ direction="maximize", pruner=pruner, study_name="pytorch_lightning_optuna"
 
 
 
 
 
 
 
183
  )
184
+ study.optimize(
185
+ lambda trial: objective(trial, cfg, callbacks),
186
+ n_trials=5,
187
+ show_progress_bar=True,
188
+ )
189
+
190
+ # Log best trial results
191
+ best_trial = study.best_trial
192
+ logger.info(f"Best trial number: {best_trial.number}")
193
+ logger.info(f"Best trial value (val_acc): {best_trial.value}")
194
+ for key, value in best_trial.params.items():
195
+ logger.info(f" {key}: {value}")
196
+
197
+ if cfg.get("test", False):
198
+ data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
199
+ model: L.LightningModule = hydra.utils.instantiate(cfg.model)
200
+ trainer = Trainer(**cfg.trainer, logger=instantiate_loggers(cfg.logger))
201
+ test_metrics = run_test_module(cfg, data_module, model, trainer)
202
+ logger.info(f"Test metrics: {test_metrics}")
203
 
204
+ return cfg.model if not cfg.get("test", False) else test_metrics
205
 
206
 
207
  if __name__ == "__main__":
208
+ setup_trainer()