mmenendezg's picture
Add files for the gradio app
c5c5181
raw history blame
No virus
2.12 kB
import yaml
from pathlib import Path
import click
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from models.mobilevit import MobileVIT
from data.data_preprocessing import FluorescentNeuronalDataModule
CONFIG_FILE = "config/fluorescent_mobilevit_hps.yaml"
DATA_DIR = "data/raw/"
LOGS_DIR = "reports/logs/FluorescentMobileVIT"
MODEL_DIR = "models/FluorescentMobileVIT"
# Define the accelerator
if torch.backends.mps.is_available():
DEVICE = torch.device("mps:0")
ACCELERATOR = "mps"
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
ACCELERATOR = "gpu"
else:
DEVICE = torch.device("cpu")
ACCELERATOR = "cpu"
@click.command()
@click.option(
"--data_dir",
type=click.Path(exists=True, file_okay=True, path_type=Path),
default=DATA_DIR,
)
@click.option(
"--config_file",
type=click.Path(exists=True, file_okay=True, path_type=Path),
default=CONFIG_FILE,
)
def train_model(data_dir, config_file):
# Load the best parameters
with open(config_file, "r") as file:
best_params = yaml.safe_load(file)
# Instantiate the model
model = MobileVIT(
learning_rate=best_params["learning_rate"],
weight_decay=best_params["weight_decay"],
)
# Define the callbacks of the model
model_checkpoint_cb = ModelCheckpoint(
save_top_k=1, dirpath=MODEL_DIR, monitor="val_loss"
)
logger = TensorBoardLogger(save_dir=LOGS_DIR)
# Create the trainer with its parameters
trainer = pl.Trainer(
logger=logger,
devices=1,
accelerator=ACCELERATOR,
precision=16,
max_epochs=100,
log_every_n_steps=20,
callbacks=[model_checkpoint_cb],
)
data_module = FluorescentNeuronalDataModule(
data_dir=data_dir, batch_size=best_params["batch_size"]
)
trainer.fit(model=model, datamodule=data_module)
trainer.test(model=model, datamodule=data_module)
click.echo("\n\n==========The Training has Finished!==========")