gradio_demo_CatDogClassifier / src /train_optuna_callbacks.py
Soutrik
optuna added as base
8d4131e
raw
history blame
7.35 kB
"""
Train and evaluate a model using PyTorch Lightning with Optuna for hyperparameter optimization.
"""
import os
import shutil
from pathlib import Path
from typing import List
import torch
import lightning as L
from dotenv import load_dotenv, find_dotenv
import hydra
from omegaconf import DictConfig, OmegaConf
from src.utils.logging_utils import setup_logger, task_wrapper
from loguru import logger
import rootutils
from lightning.pytorch.loggers import Logger
import optuna
from lightning.pytorch import Trainer
# Load environment variables
load_dotenv(find_dotenv(".env"))
# Setup root directory
root = rootutils.setup_root(__file__, indicator=".project-root")
def instantiate_callbacks(callback_cfg: DictConfig) -> List[L.Callback]:
"""Instantiate and return a list of callbacks from the configuration."""
callbacks: List[L.Callback] = []
if not callback_cfg:
logger.warning("No callback configs found! Skipping..")
return callbacks
if not isinstance(callback_cfg, DictConfig):
raise TypeError("Callbacks config must be a DictConfig!")
for _, cb_conf in callback_cfg.items():
if "_target_" in cb_conf:
logger.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
return callbacks
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
"""Instantiate and return a list of loggers from the configuration."""
loggers_ls: List[Logger] = []
if not logger_cfg or isinstance(logger_cfg, bool):
logger.warning("No valid logger configs found! Skipping..")
return loggers_ls
if not isinstance(logger_cfg, DictConfig):
raise TypeError("Logger config must be a DictConfig!")
for _, lg_conf in logger_cfg.items():
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
logger.info(f"Instantiating logger <{lg_conf._target_}>")
try:
loggers_ls.append(hydra.utils.instantiate(lg_conf))
except Exception as e:
logger.error(f"Failed to instantiate logger {lg_conf}: {e}")
return loggers_ls
def load_checkpoint_if_available(ckpt_path: str) -> str:
"""Return the checkpoint path if available, else None."""
if ckpt_path and Path(ckpt_path).exists():
logger.info(f"Using checkpoint: {ckpt_path}")
return ckpt_path
logger.warning(f"Checkpoint not found at {ckpt_path}. Using current model weights.")
return None
def clear_checkpoint_directory(ckpt_dir: str):
"""Clear checkpoint directory contents without removing the directory."""
ckpt_dir_path = Path(ckpt_dir)
if not ckpt_dir_path.exists():
logger.info(f"Creating checkpoint directory: {ckpt_dir}")
ckpt_dir_path.mkdir(parents=True, exist_ok=True)
else:
logger.info(f"Clearing checkpoint directory: {ckpt_dir}")
for item in ckpt_dir_path.iterdir():
try:
item.unlink() if item.is_file() else shutil.rmtree(item)
except Exception as e:
logger.error(f"Failed to delete {item}: {e}")
@task_wrapper
def train_module(
data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
):
"""Train the model, return validation accuracy for each epoch."""
logger.info("Starting training with custom pruning")
trainer.fit(model, data_module)
val_accuracies = []
for epoch in range(trainer.current_epoch):
val_acc = trainer.callback_metrics.get("val_acc")
if val_acc is not None:
val_accuracies.append(val_acc.item())
logger.info(f"Epoch {epoch}: val_acc={val_acc}")
return val_accuracies
@task_wrapper
def run_test_module(
cfg: DictConfig,
datamodule: L.LightningDataModule,
model: L.LightningModule,
trainer: L.Trainer,
):
"""Test the model using the best checkpoint or current model weights."""
logger.info("Starting testing")
datamodule.setup(stage="test")
test_metrics = trainer.test(
model, datamodule, ckpt_path=load_checkpoint_if_available(cfg.ckpt_path)
)
logger.info(f"Test metrics: {test_metrics}")
return test_metrics[0] if test_metrics else {}
def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[L.Callback]):
"""Objective function for Optuna hyperparameter tuning."""
# Sample hyperparameters for the model
cfg.model.embed_dim = trial.suggest_categorical("embed_dim", [64, 128, 256])
cfg.model.depth = trial.suggest_int("depth", 2, 6)
cfg.model.lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
cfg.model.mlp_ratio = trial.suggest_float("mlp_ratio", 1.0, 4.0)
# Initialize data module and model
data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
# Set up logger
loggers = instantiate_loggers(cfg.logger)
# Trainer configuration with passed callbacks
trainer = Trainer(**cfg.trainer, logger=loggers, callbacks=callbacks)
# Clear checkpoint directory
clear_checkpoint_directory(cfg.paths.ckpt_dir)
# Train and get val_acc for each epoch
val_accuracies = train_module(data_module, model, trainer)
# Report validation accuracy and prune if necessary
for epoch, val_acc in enumerate(val_accuracies):
trial.report(val_acc, step=epoch)
# Check if the trial should be pruned at this epoch
if trial.should_prune():
logger.info(f"Pruning trial at epoch {epoch}")
raise optuna.TrialPruned()
# Return the final validation accuracy as the objective metric
return val_accuracies[-1] if val_accuracies else 0.0
@hydra.main(config_path="../configs", config_name="train", version_base="1.3")
def setup_trainer(cfg: DictConfig):
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
setup_logger(
Path(cfg.paths.log_dir)
/ ("train.log" if cfg.task_name == "train" else "eval.log")
)
# Instantiate callbacks
callbacks = instantiate_callbacks(cfg.callbacks)
logger.info(f"Callbacks: {callbacks}")
if cfg.get("train", False):
pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(
direction="maximize", pruner=pruner, study_name="pytorch_lightning_optuna"
)
study.optimize(
lambda trial: objective(trial, cfg, callbacks),
n_trials=5,
show_progress_bar=True,
)
# Log best trial results
best_trial = study.best_trial
logger.info(f"Best trial number: {best_trial.number}")
logger.info(f"Best trial value (val_acc): {best_trial.value}")
for key, value in best_trial.params.items():
logger.info(f" {key}: {value}")
if cfg.get("test", False):
data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
trainer = Trainer(**cfg.trainer, logger=instantiate_loggers(cfg.logger))
test_metrics = run_test_module(cfg, data_module, model, trainer)
logger.info(f"Test metrics: {test_metrics}")
return cfg.model if not cfg.get("test", False) else test_metrics
if __name__ == "__main__":
setup_trainer()