mmenendezg commited on
Commit
c5c5181
1 Parent(s): 913eca1

Add files for the gradio app

Browse files
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import data, models, tools, visualization
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from tools.predict import single_prediction
4
+
5
+ KAGGLE_NOTEBOOK = "[![Static Badge](https://img.shields.io/badge/Open_Notebook_in_Kaggle-gray?logo=kaggle&logoColor=white&labelColor=20BEFF)](https://www.kaggle.com/code/mmenendezg/mobilevit-fluorescent-neuronal-cells/notebook)"
6
+ GITHUB_REPOSITORY = "[![Static Badge](https://img.shields.io/badge/Git_Repository-gray?logo=github&logoColor=white&labelColor=181717)](https://github.com/mmenendezg/mobilevit-fluorescent-cells)"
7
+ HF_SPACE = "[![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-md-dark.svg)](https://huggingface.co/spaces/mmenendezg/mobilevit-fluorescent-neuronal-cells)"
8
+
9
+ # Gradio interface
10
+ demo = gr.Blocks()
11
+ with demo:
12
+ gr.Markdown(
13
+ f"""
14
+ # Fluorescent Neuronal Cells Segmentation
15
+
16
+ This model extracts a segmentation mask of the neuronal cells on an image.
17
+
18
+ {KAGGLE_NOTEBOOK}
19
+
20
+ {GITHUB_REPOSITORY}
21
+
22
+ {HF_SPACE}
23
+ """
24
+ )
25
+ with gr.Tab("Image Segmentation"):
26
+ with gr.Row():
27
+ with gr.Column():
28
+ uploaded_image = gr.Image(
29
+ label="Neuronal Cells Image",
30
+ sources=["upload", "clipboard"],
31
+ type="pil",
32
+ height=550,
33
+ )
34
+ with gr.Column():
35
+ mask_image = gr.Image(label="Segmented Neurons", height=550)
36
+ with gr.Row():
37
+ classify_btn = gr.Button("Segment the image", variant="primary")
38
+ clear_btn = gr.ClearButton(components=[uploaded_image, mask_image])
39
+ classify_btn.click(
40
+ fn=single_prediction, inputs=uploaded_image, outputs=[mask_image]
41
+ )
42
+ gr.Examples(
43
+ examples=[
44
+ os.path.join(os.path.dirname(__file__), "examples/example_1.png"),
45
+ os.path.join(os.path.dirname(__file__), "examples/example_2.png"),
46
+ ],
47
+ inputs=uploaded_image,
48
+ )
49
+ demo.launch(show_error=True)
config/fluorescent_mobilevit_hps.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ learning_rate: 0.0005480015685663855
2
+ weight_decay: 1.544480236681167e-05
3
+ batch_size: 2
data/.gitkeep ADDED
File without changes
data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import data_preprocessing
data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (250 Bytes). View file
 
data/__pycache__/data_preprocessing.cpython-311.pyc ADDED
Binary file (8.85 kB). View file
 
data/data_preprocessing.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ import cv2
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ import albumentations as A
8
+ import pytorch_lightning as pl
9
+ from transformers import AutoImageProcessor
10
+ from datasets import Dataset, DatasetDict
11
+
12
+ # Checkpoint of the model used in the projec
13
+ MODEL_CHECKPOINT = "apple/deeplabv3-mobilevit-xx-small"
14
+ # Size of the image used to train the model
15
+ IMG_SIZE = [256, 256]
16
+
17
+
18
+ class FluorescentNeuronalDataModule(pl.LightningDataModule):
19
+ def __init__(self, batch_size, data_dir, dataset_size=1.0):
20
+ super().__init__()
21
+ self.data_dir = data_dir
22
+ self.batch_size = batch_size
23
+ self.image_processor = AutoImageProcessor.from_pretrained(
24
+ MODEL_CHECKPOINT, do_reduce_labels=False
25
+ )
26
+ self.image_resizer = A.Compose(
27
+ [
28
+ A.Resize(
29
+ width=IMG_SIZE[0],
30
+ height=IMG_SIZE[1],
31
+ interpolation=cv2.INTER_NEAREST,
32
+ )
33
+ ]
34
+ )
35
+ self.image_augmentator = A.Compose(
36
+ [
37
+ A.HorizontalFlip(p=0.6),
38
+ A.VerticalFlip(p=0.6),
39
+ A.RandomBrightnessContrast(p=0.6),
40
+ A.RandomGamma(p=0.6),
41
+ A.HueSaturationValue(p=0.6),
42
+ ]
43
+ )
44
+
45
+ # Percentage of the dataset
46
+ self.dataset_size = dataset_size
47
+
48
+ def _create_dataset(self):
49
+ images_path = os.path.join(self.data_dir, "all_images", "images")
50
+ masks_path = os.path.join(self.data_dir, "all_masks", "masks")
51
+ list_images = os.listdir(images_path)
52
+
53
+ # Determine the size of the dataset
54
+ if self.dataset_size < 1.0:
55
+ n_images = int(len(list_images) * self.dataset_size)
56
+ list_images = list_images[:n_images]
57
+
58
+ images = []
59
+ masks = []
60
+ for image_filename in list_images:
61
+ image_path = os.path.join(images_path, image_filename)
62
+ mask_path = os.path.join(masks_path, image_filename)
63
+
64
+ image = np.array(Image.open(image_path).convert("RGB"), dtype=np.uint8)
65
+ mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8)
66
+ mask = (mask / 255).astype(np.uint8)
67
+
68
+ images.append(image)
69
+ masks.append(mask)
70
+
71
+ dataset = Dataset.from_dict({"image": images, "mask": masks})
72
+
73
+ # Split the dataset into train, val, and test sets
74
+ dataset = dataset.train_test_split(test_size=0.1)
75
+ train_val = dataset["train"]
76
+ test_ds = dataset["test"]
77
+ del dataset
78
+
79
+ train_val = train_val.train_test_split(test_size=0.2)
80
+ train_ds = train_val["train"]
81
+ valid_ds = train_val["test"]
82
+ del train_val
83
+
84
+ dataset = DatasetDict(
85
+ {"train": train_ds, "validation": valid_ds, "test": test_ds}
86
+ )
87
+ del train_ds, valid_ds, test_ds
88
+ return dataset
89
+
90
+ def _transform_train_data(self, batch):
91
+ # Preprocess the images
92
+ images, masks = [], []
93
+ for i, m in zip(batch["image"], batch["mask"]):
94
+ img = np.asarray(i, dtype=np.uint8)
95
+ mask = np.asarray(m, dtype=np.uint8)
96
+ # First resize the images and masks
97
+ resized_outputs = self.image_resizer(image=img, mask=mask)
98
+ images.append(resized_outputs["image"])
99
+ masks.append(resized_outputs["mask"])
100
+
101
+ # Then augment the images
102
+ augmented_outputs = self.image_augmentator(
103
+ image=resized_outputs["image"], mask=resized_outputs["mask"]
104
+ )
105
+ images.append(augmented_outputs["image"])
106
+ masks.append(augmented_outputs["mask"])
107
+
108
+ inputs = self.image_processor(
109
+ images=images,
110
+ return_tensors="pt",
111
+ )
112
+ inputs["labels"] = torch.tensor(masks, dtype=torch.long)
113
+ return inputs
114
+
115
+ def _transform_data(self, batch):
116
+ # Preprocess the images
117
+ images, masks = [], []
118
+ for i, m in zip(batch["image"], batch["mask"]):
119
+ img = np.asarray(i, dtype=np.uint8)
120
+ mask = np.asarray(m, dtype=np.uint8)
121
+ # Resize the images and masks
122
+ resized_outputs = self.image_resizer(image=img, mask=mask)
123
+ images.append(resized_outputs["image"])
124
+ masks.append(resized_outputs["mask"])
125
+
126
+ inputs = self.image_processor(
127
+ images=images,
128
+ return_tensors="pt",
129
+ )
130
+ inputs["labels"] = inputs["labels"] = torch.tensor(masks, dtype=torch.long)
131
+ return inputs
132
+
133
+ def setup(self, stage=None):
134
+ dataset = self._create_dataset()
135
+ train_ds = dataset["train"]
136
+ valid_ds = dataset["validation"]
137
+ test_ds = dataset["test"]
138
+
139
+ if stage is None or stage == "fit":
140
+ self.train_ds = train_ds.with_transform(self._transform_train_data)
141
+ self.valid_ds = valid_ds.with_transform(self._transform_data)
142
+ if stage is None or stage == "test" or stage == "predict":
143
+ self.test_ds = test_ds.with_transform(self._transform_data)
144
+
145
+ def train_dataloader(self):
146
+ return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)
147
+
148
+ def val_dataloader(self):
149
+ return DataLoader(self.valid_ds, batch_size=self.batch_size)
150
+
151
+ def test_dataloader(self):
152
+ return DataLoader(self.test_ds, batch_size=self.batch_size)
153
+
154
+ def predict_dataloader(self):
155
+ return DataLoader(self.test_ds, batch_size=self.batch_size)
models/.gitkeep ADDED
File without changes
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import mobilevit
models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (243 Bytes). View file
 
models/__pycache__/hyperparameters_tuning.cpython-311.pyc ADDED
Binary file (5.68 kB). View file
 
models/__pycache__/mobilevit.cpython-311.pyc ADDED
Binary file (7.37 kB). View file
 
models/__pycache__/train_model.cpython-311.pyc ADDED
Binary file (3.21 kB). View file
 
models/mobilevit.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ import pytorch_lightning as pl
5
+ from transformers import MobileViTForSemanticSegmentation
6
+ import evaluate
7
+
8
+ MODEL_CHECKPOINT = "mmenendezg/mobilevit-fluorescent-neuronal-cells"
9
+ CLASSES = {0: "Background", 1: "Neuron"}
10
+
11
+
12
+ class MobileVIT(pl.LightningModule):
13
+ def __init__(self, learning_rate=None, weight_decay=None):
14
+ super().__init__()
15
+ self.id2label = CLASSES
16
+ self.label2id = {v: k for k, v in self.id2label.items()}
17
+ self.num_classes = len(self.id2label.keys())
18
+ self.model = MobileViTForSemanticSegmentation.from_pretrained(
19
+ MODEL_CHECKPOINT,
20
+ num_labels=self.num_classes,
21
+ id2label=self.id2label,
22
+ label2id=self.label2id,
23
+ ignore_mismatched_sizes=True,
24
+ )
25
+ self.metric = evaluate.load("mean_iou")
26
+ self.learning_rate = learning_rate
27
+ self.weight_decay = weight_decay
28
+
29
+ def forward(self, pixel_values, labels):
30
+ return self.model(pixel_values=pixel_values, labels=labels)
31
+
32
+ def common_step(self, batch, batch_idx):
33
+ pixel_values = batch["pixel_values"]
34
+ labels = batch["labels"]
35
+
36
+ outputs = self.model(pixel_values=pixel_values, labels=labels)
37
+
38
+ loss = outputs.loss
39
+ logits = outputs.logits
40
+ return loss, logits
41
+
42
+ def compute_metric(self, logits, labels):
43
+ logits_tensor = nn.functional.interpolate(
44
+ logits,
45
+ size=labels.shape[-2:],
46
+ mode="bilinear",
47
+ align_corners=False,
48
+ ).argmax(dim=1)
49
+ pred_labels = logits_tensor.detach().cpu().numpy()
50
+ metrics = self.metric.compute(
51
+ predictions=pred_labels,
52
+ references=labels,
53
+ num_labels=self.num_classes,
54
+ ignore_index=255,
55
+ reduce_labels=False,
56
+ )
57
+
58
+ return metrics
59
+
60
+ def training_step(self, batch, batch_idx):
61
+ labels = batch["labels"]
62
+
63
+ # Calculate and log the loss
64
+ loss, logits = self.common_step(batch, batch_idx)
65
+ self.log("train_loss", loss)
66
+
67
+ # Calculate and log the metrics
68
+ metrics = self.compute_metric(logits, labels)
69
+ metrics = {key: np.float32(value) for key, value in metrics.items()}
70
+
71
+ self.log("train_mean_iou", metrics["mean_iou"])
72
+ self.log("train_mean_accuracy", metrics["mean_accuracy"])
73
+ self.log("train_overall_accuracy", metrics["overall_accuracy"])
74
+
75
+ return loss
76
+
77
+ def validation_step(self, batch, batch_idx):
78
+ labels = batch["labels"]
79
+
80
+ # Calculate and log the loss
81
+ loss, logits = self.common_step(batch, batch_idx)
82
+ self.log("val_loss", loss)
83
+
84
+ # Calculate and log the metrics
85
+ metrics = self.compute_metric(logits, labels)
86
+ metrics = {key: np.float32(value) for key, value in metrics.items()}
87
+ self.log("val_mean_iou", metrics["mean_iou"])
88
+ self.log("val_mean_accuracy", metrics["mean_accuracy"])
89
+ self.log("val_overall_accuracy", metrics["overall_accuracy"])
90
+
91
+ return loss
92
+
93
+ def test_step(self, batch, batch_idx):
94
+ labels = batch["labels"]
95
+
96
+ # Calculate and log the loss
97
+ loss, logits = self.common_step(batch, batch_idx)
98
+ self.log("test_loss", loss)
99
+
100
+ # Calculate and log the metrics
101
+ metrics = self.compute_metric(logits, labels)
102
+ metrics = {key: np.float32(value) for key, value in metrics.items()}
103
+ # for k, v in metrics.items():
104
+ # self.log(f"val_{k}", v.item())
105
+ self.log("test_mean_iou", metrics["mean_iou"])
106
+ self.log("test_mean_accuracy", metrics["mean_accuracy"])
107
+ self.log("test_overall_accuracy", metrics["overall_accuracy"])
108
+
109
+ return loss
110
+
111
+ def configure_optimizers(self):
112
+ param_dicts = [
113
+ {
114
+ "params": [p for n, p in self.named_parameters()],
115
+ "lr": self.learning_rate,
116
+ }
117
+ ]
118
+ return torch.optim.AdamW(
119
+ param_dicts, lr=self.learning_rate, weight_decay=self.weight_decay
120
+ )
tools/.gitkeep ADDED
File without changes
tools/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import hyperparameters_tuning, train_model
tools/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (283 Bytes). View file
 
tools/__pycache__/hyperparameters_tuning.cpython-311.pyc ADDED
Binary file (5.67 kB). View file
 
tools/__pycache__/predict.cpython-311.pyc ADDED
Binary file (2.43 kB). View file
 
tools/__pycache__/train_model.cpython-311.pyc ADDED
Binary file (3.26 kB). View file
 
tools/hyperparameters_tuning.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import yaml
4
+ import torch
5
+ import optuna
6
+ import pytorch_lightning as pl
7
+ import click
8
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
9
+
10
+ from models.mobilevit import MobileVIT
11
+ from data.data_preprocessing import FluorescentNeuronalDataModule
12
+
13
+
14
+ MODEL_CHECKPOINT = "apple/deeplabv3-mobilevit-xx-small"
15
+
16
+ # Define the accelerator
17
+ if torch.backends.mps.is_available():
18
+ DEVICE = torch.device("mps:0")
19
+ ACCELERATOR = "mps"
20
+ elif torch.cuda.is_available():
21
+ DEVICE = torch.device("cuda")
22
+ ACCELERATOR = "gpu"
23
+ else:
24
+ DEVICE = torch.device("cpu")
25
+ ACCELERATOR = "cpu"
26
+
27
+ RAW_DATA_PATH = "./data/raw/"
28
+ DEFAULT_CONFIG_FILE = "./config/fluorescent_mobilevit_hps.yaml"
29
+
30
+ CLASSES = {0: "Background", 1: "Neuron"}
31
+
32
+ IMG_SIZE = [256, 256]
33
+
34
+
35
+ @click.command()
36
+ @click.option(
37
+ "--data_dir",
38
+ type=click.Path(exists=True, file_okay=True, path_type=Path),
39
+ default=RAW_DATA_PATH,
40
+ )
41
+ @click.option(
42
+ "--config_file",
43
+ type=click.Path(exists=True, file_okay=True, path_type=Path),
44
+ default=DEFAULT_CONFIG_FILE,
45
+ )
46
+ @click.option("--dataset_size", type=click.FLOAT, default=0.25)
47
+ @click.option("--force-tune/--no-force-tune", default=False)
48
+ def get_best_params(data_dir, config_file, dataset_size, force_tune) -> dict:
49
+ def objective(trial: optuna.Trial, dataset_size=dataset_size) -> float:
50
+ # Suggest values of the hyperparameters for the trials
51
+ learning_rate = trial.suggest_float("learning_rate", 1e-6, 1e-3, log=True)
52
+ weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
53
+ batch_size = trial.suggest_int("batch_size", 2, 4, log=True)
54
+
55
+ # Define the callbacks of the model
56
+ early_stopping_cb = EarlyStopping(monitor="val_loss", patience=2)
57
+
58
+ # Create the model
59
+ model = MobileVIT(learning_rate=learning_rate, weight_decay=weight_decay)
60
+
61
+ # Instantiate the data module
62
+ data_module = FluorescentNeuronalDataModule(
63
+ batch_size=batch_size, dataset_size=dataset_size, data_dir=data_dir
64
+ )
65
+ data_module.setup()
66
+
67
+ # Train the model
68
+ trainer = pl.Trainer(
69
+ devices=1,
70
+ accelerator=ACCELERATOR,
71
+ precision="16-mixed",
72
+ max_epochs=5,
73
+ log_every_n_steps=5,
74
+ callbacks=[early_stopping_cb],
75
+ )
76
+ trainer.fit(
77
+ model,
78
+ train_dataloaders=data_module.train_dataloader(),
79
+ val_dataloaders=data_module.val_dataloader(),
80
+ )
81
+ return trainer.callback_metrics["val_loss"].item()
82
+
83
+ if os.path.exists(config_file) and force_tune:
84
+ os.remove(config_file)
85
+ pruner = optuna.pruners.MedianPruner()
86
+ study = optuna.create_study(direction="maximize", pruner=pruner)
87
+
88
+ study.optimize(objective, n_trials=25)
89
+ best_params = study.best_params
90
+ with open(config_file, "w") as file:
91
+ yaml.dump(best_params, file)
92
+ elif os.path.exists(config_file):
93
+ with open(config_file, "r") as file:
94
+ best_params = yaml.safe_load(file)
95
+ else:
96
+ pruner = optuna.pruners.MedianPruner()
97
+ study = optuna.create_study(direction="minimize", pruner=pruner)
98
+
99
+ study.optimize(objective, n_trials=25)
100
+ best_params = study.best_params
101
+ with open(config_file, "w") as file:
102
+ yaml.dump(best_params, file)
103
+
104
+ click.echo(f"The best parameters are:\n{best_params}")
tools/predict.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoImageProcessor
5
+
6
+ from models.mobilevit import MobileVIT
7
+
8
+ # Checkpoint of the model used in the projec
9
+ MODEL_CHECKPOINT = "mmenendezg/mobilevit-fluorescent-neuronal-cells"
10
+
11
+ # Define the accelerator
12
+ if torch.backends.mps.is_available():
13
+ DEVICE = torch.device("mps:0")
14
+ ACCELERATOR = "mps"
15
+ elif torch.cuda.is_available():
16
+ DEVICE = torch.device("cuda")
17
+ ACCELERATOR = "gpu"
18
+ else:
19
+ DEVICE = torch.device("cpu")
20
+ ACCELERATOR = "cpu"
21
+
22
+
23
+ def single_prediction(image):
24
+ # Instantiate the model from the checkpoint and using the hparams file
25
+ mobilevit_model = MobileVIT()
26
+ mobilevit_model.to(DEVICE)
27
+ # Instantiate the image_processor
28
+ image_processor = AutoImageProcessor.from_pretrained(
29
+ MODEL_CHECKPOINT, do_reduce_labels=False
30
+ )
31
+ # Load the image
32
+ image = image.convert("RGB")
33
+ # Convert the image to numpy array
34
+ np_image = np.asarray(image, dtype=np.uint8)
35
+ # Preprocess the image and move the image to the GPU Device
36
+ processed_image = image_processor(images=np_image, return_tensors="pt")
37
+ processed_image.to(DEVICE)
38
+ # Make the prediction and resize the predicted mask
39
+ logits = mobilevit_model.model(pixel_values=processed_image["pixel_values"])
40
+ post_processed_image = image_processor.post_process_semantic_segmentation(
41
+ outputs=logits, target_sizes=[(np_image.shape[0], np_image.shape[1])]
42
+ )
43
+ # Process the mask
44
+ mask = post_processed_image[0].data.cpu().numpy().astype(np.uint8) * 255
45
+ mask = Image.fromarray(mask)
46
+
47
+ return mask
tools/train_model.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from pathlib import Path
3
+ import click
4
+ import torch
5
+ import pytorch_lightning as pl
6
+ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
7
+ from pytorch_lightning.loggers import TensorBoardLogger
8
+
9
+ from models.mobilevit import MobileVIT
10
+ from data.data_preprocessing import FluorescentNeuronalDataModule
11
+
12
+ CONFIG_FILE = "config/fluorescent_mobilevit_hps.yaml"
13
+ DATA_DIR = "data/raw/"
14
+ LOGS_DIR = "reports/logs/FluorescentMobileVIT"
15
+ MODEL_DIR = "models/FluorescentMobileVIT"
16
+
17
+ # Define the accelerator
18
+ if torch.backends.mps.is_available():
19
+ DEVICE = torch.device("mps:0")
20
+ ACCELERATOR = "mps"
21
+ elif torch.cuda.is_available():
22
+ DEVICE = torch.device("cuda")
23
+ ACCELERATOR = "gpu"
24
+ else:
25
+ DEVICE = torch.device("cpu")
26
+ ACCELERATOR = "cpu"
27
+
28
+
29
+ @click.command()
30
+ @click.option(
31
+ "--data_dir",
32
+ type=click.Path(exists=True, file_okay=True, path_type=Path),
33
+ default=DATA_DIR,
34
+ )
35
+ @click.option(
36
+ "--config_file",
37
+ type=click.Path(exists=True, file_okay=True, path_type=Path),
38
+ default=CONFIG_FILE,
39
+ )
40
+ def train_model(data_dir, config_file):
41
+ # Load the best parameters
42
+ with open(config_file, "r") as file:
43
+ best_params = yaml.safe_load(file)
44
+ # Instantiate the model
45
+ model = MobileVIT(
46
+ learning_rate=best_params["learning_rate"],
47
+ weight_decay=best_params["weight_decay"],
48
+ )
49
+ # Define the callbacks of the model
50
+ model_checkpoint_cb = ModelCheckpoint(
51
+ save_top_k=1, dirpath=MODEL_DIR, monitor="val_loss"
52
+ )
53
+ logger = TensorBoardLogger(save_dir=LOGS_DIR)
54
+
55
+ # Create the trainer with its parameters
56
+ trainer = pl.Trainer(
57
+ logger=logger,
58
+ devices=1,
59
+ accelerator=ACCELERATOR,
60
+ precision=16,
61
+ max_epochs=100,
62
+ log_every_n_steps=20,
63
+ callbacks=[model_checkpoint_cb],
64
+ )
65
+ data_module = FluorescentNeuronalDataModule(
66
+ data_dir=data_dir, batch_size=best_params["batch_size"]
67
+ )
68
+ trainer.fit(model=model, datamodule=data_module)
69
+ trainer.test(model=model, datamodule=data_module)
70
+ click.echo("\n\n==========The Training has Finished!==========")