File size: 2,121 Bytes
c5c5181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import yaml
from pathlib import Path
import click
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from models.mobilevit import MobileVIT
from data.data_preprocessing import FluorescentNeuronalDataModule

CONFIG_FILE = "config/fluorescent_mobilevit_hps.yaml"
DATA_DIR = "data/raw/"
LOGS_DIR = "reports/logs/FluorescentMobileVIT"
MODEL_DIR = "models/FluorescentMobileVIT"

# Define the accelerator
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps:0")
    ACCELERATOR = "mps"
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    ACCELERATOR = "gpu"
else:
    DEVICE = torch.device("cpu")
    ACCELERATOR = "cpu"


@click.command()
@click.option(
    "--data_dir",
    type=click.Path(exists=True, file_okay=True, path_type=Path),
    default=DATA_DIR,
)
@click.option(
    "--config_file",
    type=click.Path(exists=True, file_okay=True, path_type=Path),
    default=CONFIG_FILE,
)
def train_model(data_dir, config_file):
    # Load the best parameters
    with open(config_file, "r") as file:
        best_params = yaml.safe_load(file)
    # Instantiate the model
    model = MobileVIT(
        learning_rate=best_params["learning_rate"],
        weight_decay=best_params["weight_decay"],
    )
    # Define the callbacks of the model
    model_checkpoint_cb = ModelCheckpoint(
        save_top_k=1, dirpath=MODEL_DIR, monitor="val_loss"
    )
    logger = TensorBoardLogger(save_dir=LOGS_DIR)

    # Create the trainer with its parameters
    trainer = pl.Trainer(
        logger=logger,
        devices=1,
        accelerator=ACCELERATOR,
        precision=16,
        max_epochs=100,
        log_every_n_steps=20,
        callbacks=[model_checkpoint_cb],
    )
    data_module = FluorescentNeuronalDataModule(
        data_dir=data_dir, batch_size=best_params["batch_size"]
    )
    trainer.fit(model=model, datamodule=data_module)
    trainer.test(model=model, datamodule=data_module)
    click.echo("\n\n==========The Training has Finished!==========")