mobilevit-fluorescent-neuronal-cells / tools /hyperparameters_tuning.py
mmenendezg's picture
Add files for the gradio app
c5c5181
raw history blame
No virus
3.45 kB
import os
from pathlib import Path
import yaml
import torch
import optuna
import pytorch_lightning as pl
import click
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from models.mobilevit import MobileVIT
from data.data_preprocessing import FluorescentNeuronalDataModule
MODEL_CHECKPOINT = "apple/deeplabv3-mobilevit-xx-small"
# Define the accelerator
if torch.backends.mps.is_available():
DEVICE = torch.device("mps:0")
ACCELERATOR = "mps"
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
ACCELERATOR = "gpu"
else:
DEVICE = torch.device("cpu")
ACCELERATOR = "cpu"
RAW_DATA_PATH = "./data/raw/"
DEFAULT_CONFIG_FILE = "./config/fluorescent_mobilevit_hps.yaml"
CLASSES = {0: "Background", 1: "Neuron"}
IMG_SIZE = [256, 256]
@click.command()
@click.option(
"--data_dir",
type=click.Path(exists=True, file_okay=True, path_type=Path),
default=RAW_DATA_PATH,
)
@click.option(
"--config_file",
type=click.Path(exists=True, file_okay=True, path_type=Path),
default=DEFAULT_CONFIG_FILE,
)
@click.option("--dataset_size", type=click.FLOAT, default=0.25)
@click.option("--force-tune/--no-force-tune", default=False)
def get_best_params(data_dir, config_file, dataset_size, force_tune) -> dict:
def objective(trial: optuna.Trial, dataset_size=dataset_size) -> float:
# Suggest values of the hyperparameters for the trials
learning_rate = trial.suggest_float("learning_rate", 1e-6, 1e-3, log=True)
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
batch_size = trial.suggest_int("batch_size", 2, 4, log=True)
# Define the callbacks of the model
early_stopping_cb = EarlyStopping(monitor="val_loss", patience=2)
# Create the model
model = MobileVIT(learning_rate=learning_rate, weight_decay=weight_decay)
# Instantiate the data module
data_module = FluorescentNeuronalDataModule(
batch_size=batch_size, dataset_size=dataset_size, data_dir=data_dir
)
data_module.setup()
# Train the model
trainer = pl.Trainer(
devices=1,
accelerator=ACCELERATOR,
precision="16-mixed",
max_epochs=5,
log_every_n_steps=5,
callbacks=[early_stopping_cb],
)
trainer.fit(
model,
train_dataloaders=data_module.train_dataloader(),
val_dataloaders=data_module.val_dataloader(),
)
return trainer.callback_metrics["val_loss"].item()
if os.path.exists(config_file) and force_tune:
os.remove(config_file)
pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=25)
best_params = study.best_params
with open(config_file, "w") as file:
yaml.dump(best_params, file)
elif os.path.exists(config_file):
with open(config_file, "r") as file:
best_params = yaml.safe_load(file)
else:
pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(direction="minimize", pruner=pruner)
study.optimize(objective, n_trials=25)
best_params = study.best_params
with open(config_file, "w") as file:
yaml.dump(best_params, file)
click.echo(f"The best parameters are:\n{best_params}")