achouffe commited on
Commit
bf6c3c2
·
verified ·
1 Parent(s): c9f62fe

feat: add crop and identification stages

Browse files
Files changed (3) hide show
  1. app.py +45 -27
  2. requirements.txt +7 -0
  3. utils.py +863 -7
app.py CHANGED
@@ -3,24 +3,22 @@ Gradio app to showcase the pyronear model for early forest fire detection.
3
  """
4
 
5
  from pathlib import Path
6
- from typing import Tuple
7
 
8
  import gradio as gr
9
  import numpy as np
10
  from PIL import Image
11
- import subprocess
12
- import shutil
13
- import logging
14
- import os
15
- import torch
16
- import pandas as pd
17
  from ultralytics import YOLO
18
 
19
  from utils import (
20
  bgr_to_rgb,
 
21
  get_best_device,
22
- load_segmentation_model,
23
- setup
 
 
 
24
  )
25
 
26
 
@@ -29,14 +27,12 @@ def prediction_to_str(yolo_prediction) -> str:
29
  Turn the yolo_prediction into a human friendly string.
30
  """
31
  boxes = yolo_prediction.boxes
32
- classes = boxes.cls.cpu().numpy().astype(np.int8)
33
- n_bear = len([c for c in classes if c == 0])
34
- n_soft_coral = len([c for c in classes if c == 1])
35
-
36
  return f"""{len(boxes.conf)} bear detected! Trigger the bear repellent 🐻"""
37
 
38
 
39
- def interface_fn(model_segmentation: YOLO, pil_image: Image.Image) -> Tuple[Image.Image, str]:
 
 
40
  """
41
  Main interface function that runs the model on the provided pil_image and
42
  returns the exepected tuple to populate the gradio interface.
@@ -50,12 +46,22 @@ def interface_fn(model_segmentation: YOLO, pil_image: Image.Image) -> Tuple[Imag
50
  raw_prediction_str (str): string representing the raw prediction from the
51
  model.
52
  """
53
- predictions = model_segmentation(pil_image)
54
- prediction = predictions[0]
55
- pil_image_with_prediction = Image.fromarray(bgr_to_rgb(prediction.plot()))
56
- raw_prediction_str = prediction_to_str(prediction)
 
 
 
 
 
 
 
57
 
58
- return (pil_image_with_prediction, raw_prediction_str)
 
 
 
59
 
60
 
61
  def examples(dir_examples: Path) -> list[Path]:
@@ -78,15 +84,23 @@ setup(
78
  )
79
 
80
  # Main Gradio interface
81
- METRIC_LEARNING_MODEL_FILEPATH = Path("./data/06_models/pipeline/metriclearning/bearidentification/model.pt")
82
- METRIC_LEARNING_KNN_INDEX_FILEPATH = Path("./data/06_models/pipeline/metriclearning/bearidentification/knn.index")
83
- INSTANCE_SEGMENTATION_WEIGHTS_FILEPATH = Path("./data/06_models/pipeline/metriclearning/bearfacesegmentation/model.pt")
 
 
 
 
 
 
84
  DIR_EXAMPLES = Path("data/images/")
85
  DEFAULT_IMAGE_INDEX = 0
86
 
87
  with gr.Blocks() as demo:
88
- model_segmentation = load_segmentation_model(INSTANCE_SEGMENTATION_WEIGHTS_FILEPATH)
89
- model_segmentation.info()
 
 
90
  image_filepaths = examples(dir_examples=DIR_EXAMPLES)
91
  default_value_input = Image.open(image_filepaths[DEFAULT_IMAGE_INDEX])
92
  input = gr.Image(
@@ -95,15 +109,19 @@ with gr.Blocks() as demo:
95
  label="input image",
96
  sources=["upload", "clipboard"],
97
  )
98
- output_image = gr.Image(type="pil", label="model prediction")
 
99
  output_raw = gr.Text(label="raw prediction")
100
 
101
- fn = lambda pil_image: interface_fn(model_segmentation=model_segmentation, pil_image=pil_image)
 
 
 
102
  gr.Interface(
103
  title="ML pipeline for identifying bears from their faces 🐻",
104
  fn=fn,
105
  inputs=input,
106
- outputs=[output_image, output_raw],
107
  examples=image_filepaths,
108
  flagging_mode="never",
109
  )
 
3
  """
4
 
5
  from pathlib import Path
6
+ from typing import Any, Tuple
7
 
8
  import gradio as gr
9
  import numpy as np
10
  from PIL import Image
 
 
 
 
 
 
11
  from ultralytics import YOLO
12
 
13
  from utils import (
14
  bgr_to_rgb,
15
+ crop_from_yolov8,
16
  get_best_device,
17
+ load_models,
18
+ resize,
19
+ run_pipeline,
20
+ setup,
21
+ square_pad,
22
  )
23
 
24
 
 
27
  Turn the yolo_prediction into a human friendly string.
28
  """
29
  boxes = yolo_prediction.boxes
 
 
 
 
30
  return f"""{len(boxes.conf)} bear detected! Trigger the bear repellent 🐻"""
31
 
32
 
33
+ def interface_fn(
34
+ loaded_models: dict[str, Any], pil_image: Image.Image
35
+ ) -> Tuple[Image.Image, Image.Image, str]:
36
  """
37
  Main interface function that runs the model on the provided pil_image and
38
  returns the exepected tuple to populate the gradio interface.
 
46
  raw_prediction_str (str): string representing the raw prediction from the
47
  model.
48
  """
49
+ PARAM_SQUARE_DIM = 300
50
+ result = run_pipeline(
51
+ loaded_models=loaded_models,
52
+ pil_image=pil_image,
53
+ param_square_dim=PARAM_SQUARE_DIM,
54
+ param_k=5,
55
+ param_n_samples_per_individual=4,
56
+ knn_index_filepath=METRIC_LEARNING_KNN_INDEX_FILEPATH,
57
+ )
58
+ pil_image_segmented_head = result["stages"]["segmentation"]["output"]["pil_image"]
59
+ pil_image_cropped_head = result["stages"]["crop"]["output"]["pil_images"]["resized"]
60
 
61
+ # raw_prediction_str = prediction_to_str(yolov8_segmentation_prediction)
62
+
63
+ return (pil_image_segmented_head, pil_image_cropped_head, str(result))
64
+ return (pil_image_segmented_head, raw_prediction_str)
65
 
66
 
67
  def examples(dir_examples: Path) -> list[Path]:
 
84
  )
85
 
86
  # Main Gradio interface
87
+ METRIC_LEARNING_MODEL_FILEPATH = Path(
88
+ "./data/06_models/pipeline/metriclearning/bearidentification/model.pt"
89
+ )
90
+ METRIC_LEARNING_KNN_INDEX_FILEPATH = Path(
91
+ "./data/06_models/pipeline/metriclearning/bearidentification/knn.index"
92
+ )
93
+ INSTANCE_SEGMENTATION_WEIGHTS_FILEPATH = Path(
94
+ "./data/06_models/pipeline/metriclearning/bearfacesegmentation/model.pt"
95
+ )
96
  DIR_EXAMPLES = Path("data/images/")
97
  DEFAULT_IMAGE_INDEX = 0
98
 
99
  with gr.Blocks() as demo:
100
+ loaded_models = load_models(
101
+ filepath_metric_learning_weights=METRIC_LEARNING_MODEL_FILEPATH,
102
+ filepath_segmentation_weights=INSTANCE_SEGMENTATION_WEIGHTS_FILEPATH,
103
+ )
104
  image_filepaths = examples(dir_examples=DIR_EXAMPLES)
105
  default_value_input = Image.open(image_filepaths[DEFAULT_IMAGE_INDEX])
106
  input = gr.Image(
 
109
  label="input image",
110
  sources=["upload", "clipboard"],
111
  )
112
+ output_segmentation_image = gr.Image(type="pil", label="model prediction")
113
+ output_cropped_image = gr.Image(type="pil", label="cropped bear face")
114
  output_raw = gr.Text(label="raw prediction")
115
 
116
+ fn = lambda pil_image: interface_fn(
117
+ loaded_models=loaded_models,
118
+ pil_image=pil_image,
119
+ )
120
  gr.Interface(
121
  title="ML pipeline for identifying bears from their faces 🐻",
122
  fn=fn,
123
  inputs=input,
124
+ outputs=[output_segmentation_image, output_cropped_image, output_raw],
125
  examples=image_filepaths,
126
  flagging_mode="never",
127
  )
requirements.txt CHANGED
@@ -1,4 +1,11 @@
 
 
 
1
  gradio==5.4.*
2
  pandas==2.2.*
 
 
3
  torch==2.5.*
 
4
  ultralytics==8.3.*
 
 
1
+ # faiss-cpu==1.7.4
2
+ faiss-cpu==1.9.*
3
+
4
  gradio==5.4.*
5
  pandas==2.2.*
6
+ # pytorch-metric-learning==2.4.1
7
+ pytorch-metric-learning==2.7.*
8
  torch==2.5.*
9
+ tqdm==4.66.1
10
  ultralytics==8.3.*
11
+ umap-learn==0.5.5
utils.py CHANGED
@@ -1,15 +1,642 @@
 
 
 
 
 
1
  from pathlib import Path
2
- from typing import Tuple
3
 
4
- import logging
5
  import numpy as np
6
- import os
7
  import pandas as pd
8
- import subprocess
9
- import shutil
10
  import torch
 
 
 
 
 
 
 
 
 
11
  from ultralytics import YOLO
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def get_best_device() -> torch.device:
15
  """Returns the best torch device depending on the hardware it is running
@@ -51,14 +678,14 @@ def _setup_ml_pipeline(input_packaged_pipeline: Path, install_path: Path) -> Non
51
  dirs_exist_ok=True,
52
  )
53
 
 
54
  def setup(input_packaged_pipeline: Path, install_path: Path) -> None:
55
  """
56
  Full setup of the project.
57
  """
58
  _setup_chips()
59
  _setup_ml_pipeline(
60
- input_packaged_pipeline=input_packaged_pipeline,
61
- install_path=install_path
62
  )
63
 
64
 
@@ -74,4 +701,233 @@ def load_segmentation_model(filepath_weights: Path) -> YOLO:
74
  """
75
  Load the YOLO model given the filepath_weights.
76
  """
 
77
  return YOLO(filepath_weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import shutil
4
+ import subprocess
5
+ from collections import Counter
6
  from pathlib import Path
7
+ from typing import Any, Optional, OrderedDict
8
 
9
+ import cv2
10
  import numpy as np
 
11
  import pandas as pd
 
 
12
  import torch
13
+ import torch.nn as nn
14
+ import torchvision
15
+ import torchvision.models as models
16
+ from PIL import Image
17
+ from pytorch_metric_learning.utils.common_functions import logging
18
+ from pytorch_metric_learning.utils.inference import InferenceModel
19
+ from torch.utils.data import DataLoader, Dataset
20
+ from torchvision import transforms
21
+ from torchvision.transforms import v2
22
  from ultralytics import YOLO
23
 
24
+ # TODO: move metric learning functions into their own namespace
25
+
26
+ def sample_chips_from_bearid(
27
+ bear_id: str,
28
+ df_split: pd.DataFrame,
29
+ n: int = 4,
30
+ ) -> list[Path]:
31
+ xs = df_split[df_split["bear_id"] == bear_id].sample(n=n)["path"].tolist()
32
+ return [Path(x) for x in xs]
33
+
34
+
35
+ def make_indexed_samples(
36
+ bear_ids: list[str],
37
+ df_split: pd.DataFrame,
38
+ n: int = 4,
39
+ ) -> dict[str, list[Path]]:
40
+ return {
41
+ bear_id: sample_chips_from_bearid(bear_id=bear_id, df_split=df_split, n=n)
42
+ for bear_id in bear_ids
43
+ }
44
+
45
+
46
+ def _aux_get_k_nearest_individuals(
47
+ model: InferenceModel,
48
+ k_neighbors: int,
49
+ k_individuals: int,
50
+ query,
51
+ id_to_label: dict,
52
+ dataset: Dataset,
53
+ ) -> dict:
54
+ """Auxiliary helper function to get k nearest individuals.
55
+
56
+ Returns a dict with the following keys:
57
+ - k_neighbors: int - number of neighbors the KNN search extends to in order to find at least k_individuals
58
+ - dataset_indices: list[int] - list of indices to call get_item on the dataset
59
+ - dataset_labels: list[int] - labels of the dataset for the given dataset_indices
60
+ - dataset_images: list[torch.tensor] - chips of the bears
61
+ - distances: list[float] - distances from the query
62
+
63
+ Note: it can return more than k_individuals as it extends progressively the
64
+ KNN search to find at least k_individuals.
65
+ """
66
+ assert k_individuals <= 20, f"Keep a small k_individuals: {k_individuals}"
67
+
68
+ distances, indices = model.get_nearest_neighbors(query=query, k=k_neighbors)
69
+ indices_on_cpu = indices.cpu()[0].tolist()
70
+ distances_on_cpu = distances.cpu()[0].tolist()
71
+ nearest_images, nearest_ids = list(zip(*[dataset[idx] for idx in indices_on_cpu]))
72
+ bearids = [id_to_label.get(nearest_id, "unknown") for nearest_id in nearest_ids]
73
+ counter = Counter(nearest_ids)
74
+ if len(counter.keys()) >= k_individuals:
75
+ return {
76
+ "k_neighbors": k_neighbors,
77
+ "dataset_indices": indices_on_cpu,
78
+ "dataset_labels": list(nearest_ids),
79
+ "dataset_images": list(nearest_images),
80
+ "bearids": bearids,
81
+ "distances": distances_on_cpu,
82
+ }
83
+ else:
84
+ new_k_neighbors = k_neighbors * 2
85
+ return _aux_get_k_nearest_individuals(
86
+ model,
87
+ k_neighbors=new_k_neighbors,
88
+ k_individuals=k_individuals,
89
+ query=query,
90
+ id_to_label=id_to_label,
91
+ dataset=dataset,
92
+ )
93
+
94
+
95
+ def _find_cutoff_index(k: int, dataset_labels: list[str]) -> Optional[int]:
96
+ """Returns the index for dataset_labels that retrieves exactly k
97
+ individuals."""
98
+ if not dataset_labels:
99
+ return None
100
+ else:
101
+ selected_labels = set()
102
+ cutoff_index = -1
103
+ for idx, label in enumerate(dataset_labels):
104
+ if len(selected_labels) == k:
105
+ break
106
+ else:
107
+ selected_labels.add(label)
108
+ cutoff_index = idx + 1
109
+ return cutoff_index
110
+
111
+
112
+ def get_k_nearest_individuals(
113
+ model: InferenceModel,
114
+ k: int,
115
+ query,
116
+ id_to_label: dict,
117
+ dataset: Dataset,
118
+ ) -> dict:
119
+ """Returns the k nearest individuals using the inference model and a query.
120
+
121
+ A dict is returned with the following keys:
122
+ - dataset_indices: list[int] - list of indices to call get_item on the dataset
123
+ - dataset_labels: list[int] - labels of the dataset for the given dataset_indices
124
+ - dataset_images: list[torch.tensor] - chips of the bears
125
+ - distances: list[float] - distances from the query
126
+ """
127
+ k_neighbors = k * 5
128
+ k_individuals = k
129
+ result = _aux_get_k_nearest_individuals(
130
+ model=model,
131
+ k_neighbors=k_neighbors,
132
+ k_individuals=k_individuals,
133
+ query=query,
134
+ id_to_label=id_to_label,
135
+ dataset=dataset,
136
+ )
137
+ cutoff_index = _find_cutoff_index(
138
+ k=k,
139
+ dataset_labels=result["dataset_labels"],
140
+ )
141
+ return {
142
+ "dataset_indices": result["dataset_indices"][:cutoff_index],
143
+ "dataset_labels": result["dataset_labels"][:cutoff_index],
144
+ "dataset_images": result["dataset_images"][:cutoff_index],
145
+ "bearids": result["bearids"][:cutoff_index],
146
+ "distances": result["distances"][:cutoff_index],
147
+ }
148
+
149
+
150
+ def index_by_bearid(k_nearest_individuals: dict) -> dict:
151
+ """Returns a dict where keys are bearid labels (eg. 'bf_480') and the
152
+ values are list of the following dict shapes:
153
+
154
+ - dataset_label: int
155
+ - dataset_image: torch.tensor
156
+ - distance: float
157
+ - dataset_index: int
158
+ """
159
+ result = {}
160
+ for dataset_label, dataset_image, distance, bearid, dataset_index in zip(
161
+ k_nearest_individuals["dataset_labels"],
162
+ k_nearest_individuals["dataset_images"],
163
+ k_nearest_individuals["distances"],
164
+ k_nearest_individuals["bearids"],
165
+ k_nearest_individuals["dataset_indices"],
166
+ ):
167
+ row = {
168
+ "dataset_label": dataset_label,
169
+ "dataset_image": dataset_image,
170
+ "distance": distance,
171
+ "dataset_index": dataset_index,
172
+ }
173
+ if bearid not in result:
174
+ result[bearid] = [row]
175
+ else:
176
+ result[bearid].append(row)
177
+ return result
178
+
179
+
180
+ def prefix_keys_with(weights: OrderedDict, prefix: str = "module.") -> OrderedDict:
181
+ """Returns the new weights where each key is prefixed with the provided
182
+ `prefix`.
183
+
184
+ Note: Useful when using DataParallel to account for the module. prefix key.
185
+ """
186
+ weights_copy = weights.copy()
187
+ for k, v in weights.items():
188
+ weights_copy[f"{prefix}{k}"] = v
189
+ del weights_copy[k]
190
+ return weights_copy
191
+
192
+
193
+ def load_weights(
194
+ network: torch.nn.Module,
195
+ weights_filepath: Optional[Path] = None,
196
+ weights: Optional[OrderedDict] = None,
197
+ prefix: str = "",
198
+ ) -> torch.nn.Module:
199
+ """Loads the network weights.
200
+
201
+ Returns the network.
202
+ """
203
+ if weights:
204
+ prefixed_weights = prefix_keys_with(weights, prefix=prefix)
205
+ network.load_state_dict(state_dict=prefixed_weights)
206
+ return network
207
+ elif weights_filepath:
208
+ assert weights_filepath.exists(), f"Invalid model_filepath {weights_filepath}"
209
+ weights = torch.load(weights_filepath)
210
+ prefixed_weights = prefix_keys_with(weights, prefix=prefix)
211
+ network.load_state_dict(state_dict=prefixed_weights)
212
+ return network
213
+ else:
214
+ raise Exception(f"Should provide at least weights or weights_filepath")
215
+
216
+
217
+ class MLP(nn.Module):
218
+ # layer_sizes[0] is the dimension of the input
219
+ # layer_sizes[-1] is the dimension of the output
220
+ def __init__(self, layer_sizes, final_relu=False):
221
+ super().__init__()
222
+ layer_list = []
223
+ layer_sizes = [int(x) for x in layer_sizes]
224
+ num_layers = len(layer_sizes) - 1
225
+ final_relu_layer = num_layers if final_relu else num_layers - 1
226
+ for i in range(len(layer_sizes) - 1):
227
+ input_size = layer_sizes[i]
228
+ curr_size = layer_sizes[i + 1]
229
+ if i <= final_relu_layer:
230
+ layer_list.append(nn.ReLU(inplace=False))
231
+ layer_list.append(nn.BatchNorm1d(input_size))
232
+ layer_list.append(nn.Linear(input_size, curr_size))
233
+ self.net = nn.Sequential(*layer_list)
234
+ self.last_linear = self.net[-1]
235
+
236
+ def forward(self, x):
237
+ return self.net(x)
238
+
239
+
240
+ def check_backbone(pretrained_backbone: str) -> None:
241
+ allowed_backbones = {
242
+ "resnet18",
243
+ "resnet50",
244
+ "convnext_tiny",
245
+ "convnext_base",
246
+ "convnext_large",
247
+ "efficientnet_v2_s",
248
+ # "squeezenet1_1",
249
+ "vit_b_16",
250
+ }
251
+ assert (
252
+ pretrained_backbone in allowed_backbones
253
+ ), f"pretrained_backbone {pretrained_backbone} is not implemented, only {allowed_backbones}"
254
+
255
+
256
+ def make_trunk(pretrained_backbone: str = "resnet18") -> nn.Module:
257
+ """Returns a nn.Module with pretrained weights using a given
258
+ pretrained_backbone.
259
+
260
+ Note: The currently available backbones are resnet18, resnet50,
261
+ convnext_tiny, convnext_bas, efficientnet_v2_s, squeezenet1_1, vit_b_16
262
+ """
263
+
264
+ check_backbone(pretrained_backbone)
265
+
266
+ if pretrained_backbone == "resnet18":
267
+ return torchvision.models.resnet18(
268
+ weights=models.ResNet18_Weights.IMAGENET1K_V1
269
+ )
270
+ elif pretrained_backbone == "resnet50":
271
+ return torchvision.models.resnet50(
272
+ weights=models.ResNet50_Weights.IMAGENET1K_V1
273
+ )
274
+ elif pretrained_backbone == "convnext_tiny":
275
+ return torchvision.models.convnext_tiny(
276
+ weights=models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
277
+ )
278
+ elif pretrained_backbone == "convnext_base":
279
+ return torchvision.models.convnext_base(
280
+ weights=models.ConvNeXt_Base_Weights.IMAGENET1K_V1
281
+ )
282
+ elif pretrained_backbone == "convnext_large":
283
+ return torchvision.models.convnext_large(
284
+ weights=models.ConvNeXt_Large_Weights.IMAGENET1K_V1
285
+ )
286
+ elif pretrained_backbone == "efficientnet_v2_s":
287
+ return torchvision.models.efficientnet_v2_s(
288
+ weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
289
+ )
290
+ elif pretrained_backbone == "squeezenet1_1":
291
+ return torchvision.models.squeezenet1_1(
292
+ weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1
293
+ )
294
+ elif pretrained_backbone == "vit_b_16":
295
+ return torchvision.models.vit_b_16(
296
+ weights=models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
297
+ )
298
+ else:
299
+ raise Exception(f"Cannot make trunk with backbone {pretrained_backbone}")
300
+
301
+
302
+ def make_embedder(
303
+ pretrained_backbone: str,
304
+ trunk: nn.Module,
305
+ embedding_size: int,
306
+ hidden_layer_sizes: list[int],
307
+ ) -> nn.Module:
308
+ check_backbone(pretrained_backbone)
309
+
310
+ if pretrained_backbone in ["resnet18", "resnet50"]:
311
+ trunk_output_size = trunk.fc.in_features
312
+ trunk.fc = nn.Identity()
313
+ return MLP([trunk_output_size, *hidden_layer_sizes, embedding_size])
314
+ if pretrained_backbone in ["convnext_tiny", "convnext_base", "convnext_large"]:
315
+ trunk_output_size = trunk.classifier[-1].in_features
316
+ trunk.classifier[-1] = nn.Identity()
317
+ return MLP([trunk_output_size, *hidden_layer_sizes, embedding_size])
318
+ elif pretrained_backbone == "efficientnet_v2_s":
319
+ trunk_output_size = trunk.classifier[-1].in_features
320
+ trunk.classifier[-1] = nn.Identity()
321
+ return MLP([trunk_output_size, *hidden_layer_sizes, embedding_size])
322
+ elif pretrained_backbone == "vit_b_16":
323
+ trunk_output_size = trunk.heads.head.in_features
324
+ trunk.heads.head = nn.Identity()
325
+ return MLP([trunk_output_size, *hidden_layer_sizes, embedding_size])
326
+ else:
327
+ raise Exception(f"{pretrained_backbone} embedder not implemented yet")
328
+
329
+
330
+ def make_model_dict(
331
+ device: torch.device,
332
+ pretrained_backbone: str = "resnet18",
333
+ embedding_size: int = 128,
334
+ hidden_layer_sizes: list[int] = [1024],
335
+ ) -> dict[str, nn.Module]:
336
+ """
337
+ Returns a dict with the following keys:
338
+ - embedder: nn.Module - embedder model, usually an MLP.
339
+ - trunk: nn.Module - the backbone model, usually a pretrained model (like a ResNet).
340
+ """
341
+
342
+ trunk = make_trunk(pretrained_backbone=pretrained_backbone)
343
+ embedder = make_embedder(
344
+ pretrained_backbone=pretrained_backbone,
345
+ embedding_size=embedding_size,
346
+ hidden_layer_sizes=hidden_layer_sizes,
347
+ trunk=trunk,
348
+ )
349
+
350
+ trunk = torch.nn.DataParallel(trunk.to(device))
351
+ embedder = torch.nn.DataParallel(embedder.to(device))
352
+
353
+ return {
354
+ "trunk": trunk,
355
+ "embedder": embedder,
356
+ }
357
+
358
+
359
+ class BearDataset(Dataset):
360
+ def __init__(self, dataframe, id_mapping, transform=None):
361
+ self.dataframe = dataframe
362
+ self.id_mapping = id_mapping
363
+ self.transform = transform
364
+
365
+ def __len__(self):
366
+ return len(self.dataframe)
367
+
368
+ def __getitem__(self, idx):
369
+ sample = self.dataframe.iloc[idx]
370
+ image_path = sample.path
371
+ bear_id = sample.bear_id
372
+
373
+ id_value = self.id_mapping.loc[self.id_mapping["label"] == bear_id, "id"].iloc[
374
+ 0
375
+ ]
376
+
377
+ image = Image.open(image_path)
378
+ if self.transform:
379
+ image = self.transform(image)
380
+
381
+ return image, id_value
382
+
383
+
384
+ def make_dataloaders(
385
+ batch_size: int,
386
+ df_split: pd.DataFrame,
387
+ transforms: dict,
388
+ ) -> dict:
389
+ """Returns a dict with top level keys in {dataset and loader}.
390
+
391
+ Each returns a dict with the train, val and test objects associated.
392
+ """
393
+
394
+ df_train = df_split[df_split["split"] == "train"]
395
+ df_val = df_split[df_split["split"] == "val"]
396
+ df_test = df_split[df_split["split"] == "test"]
397
+ id_mapping = make_id_mapping(df=df_split)
398
+
399
+ train_dataset = BearDataset(
400
+ df_train,
401
+ id_mapping,
402
+ transform=transforms["train"],
403
+ )
404
+ train_loader = DataLoader(
405
+ train_dataset,
406
+ batch_size=batch_size,
407
+ shuffle=True,
408
+ drop_last=True,
409
+ )
410
+
411
+ val_dataset = BearDataset(
412
+ df_val,
413
+ id_mapping,
414
+ transform=transforms["val"],
415
+ )
416
+ val_loader = DataLoader(
417
+ val_dataset,
418
+ batch_size=batch_size,
419
+ )
420
+
421
+ test_dataset = BearDataset(
422
+ df_test,
423
+ id_mapping,
424
+ transform=transforms["test"],
425
+ )
426
+ test_loader = DataLoader(
427
+ test_dataset,
428
+ batch_size=batch_size,
429
+ )
430
+
431
+ viz_dataset = BearDataset(
432
+ df_train,
433
+ id_mapping,
434
+ transform=transforms["viz"],
435
+ )
436
+ viz_loader = DataLoader(
437
+ viz_dataset,
438
+ batch_size=batch_size,
439
+ shuffle=True,
440
+ drop_last=True,
441
+ )
442
+ full_dataset = BearDataset(
443
+ df_split,
444
+ id_mapping,
445
+ transform=transforms["val"],
446
+ )
447
+
448
+ return {
449
+ "dataset": {
450
+ "viz": viz_dataset,
451
+ "train": train_dataset,
452
+ "val": val_dataset,
453
+ "test": test_dataset,
454
+ "full": full_dataset,
455
+ },
456
+ "loader": {
457
+ "viz": viz_loader,
458
+ "train": train_loader,
459
+ "val": val_loader,
460
+ "test": test_loader,
461
+ },
462
+ }
463
+
464
+
465
+ def make_id_mapping(df: pd.DataFrame) -> pd.DataFrame:
466
+ """Returns a dataframe that maps a bear label (eg.
467
+
468
+ bf_755) to a unique natural number (eg. 0). The dataFrame contains
469
+ two columns, namely id and label.
470
+ """
471
+ return pd.DataFrame(
472
+ list(enumerate(df["bear_id"].unique())), columns=["id", "label"]
473
+ )
474
+
475
+
476
+ def filter_none(xs: list) -> list:
477
+ return [x for x in xs if x is not None]
478
+
479
+
480
+ def get_dtype(dtype_str: str) -> torch.dtype:
481
+ if dtype_str == "float32":
482
+ return torch.float32
483
+ elif dtype_str == "int64":
484
+ return torch.int64
485
+ else:
486
+ logging.warning(
487
+ f"dtype_str {dtype_str} not implemented, returning default value"
488
+ )
489
+ return torch.float32
490
+
491
+
492
+ def get_transforms(
493
+ data_augmentation: dict = {},
494
+ trunk_preprocessing: dict = {},
495
+ ) -> dict:
496
+ """Returns a dict containing the transforms for the following splits:
497
+ train, val, test and viz (the latter is used for batch visualization).
498
+ """
499
+ logging.info(f"data_augmentation config: {data_augmentation}")
500
+ logging.info(f"trunk preprocessing config: {trunk_preprocessing}")
501
+
502
+ DEFAULT_CROP_SIZE = 224
503
+ crop_size = (
504
+ trunk_preprocessing.get("crop_size", DEFAULT_CROP_SIZE),
505
+ trunk_preprocessing.get("crop_size", DEFAULT_CROP_SIZE),
506
+ )
507
+
508
+ # transform to persist a batch of data as an artefact
509
+ transform_viz = transforms.Compose(
510
+ [
511
+ transforms.Resize(crop_size),
512
+ transforms.ToTensor(),
513
+ ]
514
+ )
515
+
516
+ mdtype: Optional[torch.dtype] = (
517
+ get_dtype(trunk_preprocessing["values"].get("dtype", None))
518
+ if trunk_preprocessing.get("values", None)
519
+ else None
520
+ )
521
+ mscale: Optional[bool] = (
522
+ trunk_preprocessing["values"].get("scale", None)
523
+ if trunk_preprocessing.get("values", None)
524
+ else None
525
+ )
526
+
527
+ mmean: Optional[list[float]] = (
528
+ trunk_preprocessing["normalization"].get("mean", None)
529
+ if trunk_preprocessing.get("normalization", None)
530
+ else None
531
+ )
532
+
533
+ mstd: Optional[list[float]] = (
534
+ trunk_preprocessing["normalization"].get("std", None)
535
+ if trunk_preprocessing.get("normalization", None)
536
+ else None
537
+ )
538
+
539
+ hue = (
540
+ data_augmentation["colorjitter"].get("hue", 0)
541
+ if data_augmentation.get("colorjitter", 0)
542
+ else 0
543
+ )
544
+ saturation = (
545
+ data_augmentation["colorjitter"].get("saturation", 0)
546
+ if data_augmentation.get("colorjitter", 0)
547
+ else 0
548
+ )
549
+ degrees = (
550
+ data_augmentation["rotation"].get("degrees", 0)
551
+ if data_augmentation.get("rotation", 0)
552
+ else 0
553
+ )
554
+
555
+ transformations_plain = [
556
+ transforms.Resize(crop_size),
557
+ transforms.ToTensor(),
558
+ v2.ToDtype(dtype=mdtype, scale=mscale) if mdtype and mscale else None,
559
+ transforms.Normalize(mean=mmean, std=mstd) if mmean and mstd else None,
560
+ ]
561
+
562
+ transformations_train = [
563
+ transforms.Resize(crop_size),
564
+ (
565
+ transforms.ColorJitter(
566
+ hue=hue,
567
+ saturation=saturation,
568
+ )
569
+ if data_augmentation.get("colorjitter", None)
570
+ else None
571
+ ), # Taken from Dolphin ID
572
+ (
573
+ v2.RandomRotation(degrees=degrees)
574
+ if data_augmentation.get("rotation", None)
575
+ else None
576
+ ), # Taken from Dolphin ID
577
+ transforms.ToTensor(),
578
+ v2.ToDtype(dtype=mdtype, scale=mscale) if mdtype and mscale else None,
579
+ transforms.Normalize(mean=mmean, std=mstd) if mmean and mstd else None,
580
+ ]
581
+
582
+ # Filtering out None transforms
583
+ transform_plain = transforms.Compose(filter_none(transformations_plain))
584
+ transform_train = transforms.Compose(filter_none(transformations_train))
585
+
586
+ return {
587
+ "viz": transform_viz,
588
+ "train": transform_train,
589
+ "val": transform_plain,
590
+ "test": transform_plain,
591
+ }
592
+
593
+
594
+ def resize(
595
+ mask: np.ndarray,
596
+ dim: tuple[int, int],
597
+ interpolation: int = cv2.INTER_LINEAR,
598
+ ):
599
+ """Resize the mask to the provided `dim` using the interpolation method.
600
+
601
+ `dim`: (W, H) format
602
+ """
603
+ return cv2.resize(mask, dsize=dim, interpolation=interpolation)
604
+
605
+
606
+ def crop_from_yolov8(prediction_yolov8) -> np.ndarray:
607
+ """Given a yolov8 prediction, returns an image containing the cropped bear
608
+ head."""
609
+ H, W = prediction_yolov8.orig_shape
610
+ predictions_masks = prediction_yolov8.masks.data.to("cpu").numpy()
611
+ idx = np.argmax(prediction_yolov8.boxes.conf.to("cpu").numpy())
612
+ predictions_mask = predictions_masks[idx]
613
+ prediction_resized = resize(predictions_mask, dim=(W, H))
614
+ masked_image = prediction_yolov8.orig_img.copy()
615
+ black_pixel = [0, 0, 0]
616
+ masked_image[~prediction_resized.astype(bool)] = black_pixel
617
+ x0, y0, x1, y1 = prediction_yolov8.boxes[idx].xyxy[0].to("cpu").numpy()
618
+ return masked_image[int(y0) : int(y1), int(x0) : int(x1)]
619
+
620
+
621
+ def square_pad(img: np.ndarray):
622
+ """Returns an image with dimension max(W, H) x max(W, H), padded with black
623
+ pixels."""
624
+ H, W, _ = img.shape
625
+ K = max(H, W)
626
+ top = (K - H) // 2
627
+ bottom = (K - H) // 2
628
+ left = (K - W) // 2
629
+ right = (K - W) // 2
630
+
631
+ return cv2.copyMakeBorder(
632
+ img.copy(),
633
+ top,
634
+ bottom,
635
+ left,
636
+ right,
637
+ cv2.BORDER_CONSTANT,
638
+ )
639
+
640
 
641
  def get_best_device() -> torch.device:
642
  """Returns the best torch device depending on the hardware it is running
 
678
  dirs_exist_ok=True,
679
  )
680
 
681
+
682
  def setup(input_packaged_pipeline: Path, install_path: Path) -> None:
683
  """
684
  Full setup of the project.
685
  """
686
  _setup_chips()
687
  _setup_ml_pipeline(
688
+ input_packaged_pipeline=input_packaged_pipeline, install_path=install_path
 
689
  )
690
 
691
 
 
701
  """
702
  Load the YOLO model given the filepath_weights.
703
  """
704
+ assert filepath_weights.exists()
705
  return YOLO(filepath_weights)
706
+
707
+
708
+ def load_metric_learning_model(device: torch.device, filepath_weights: Path) -> Any:
709
+ assert filepath_weights.exists()
710
+ return torch.load(filepath_weights, map_location=device)
711
+
712
+
713
+ def load_models(
714
+ filepath_segmentation_weights: Path,
715
+ filepath_metric_learning_weights: Path,
716
+ ) -> dict[str, Any]:
717
+ assert filepath_segmentation_weights.exists()
718
+ assert filepath_metric_learning_weights.exists()
719
+
720
+ device = get_best_device()
721
+ model_segmentation = load_segmentation_model(filepath_segmentation_weights)
722
+ model_metric_learning = load_metric_learning_model(
723
+ device=device,
724
+ filepath_weights=filepath_metric_learning_weights,
725
+ )
726
+
727
+ return {
728
+ "segmentation": model_segmentation,
729
+ "metric_learning": model_metric_learning,
730
+ }
731
+
732
+
733
+ def run_segmentation(model: YOLO, pil_image: Image.Image) -> dict[str, Any]:
734
+ predictions = model(pil_image)
735
+ if len(predictions) > 0:
736
+ prediction = predictions[0]
737
+ pil_image_with_prediction = Image.fromarray(bgr_to_rgb(prediction.plot()))
738
+ return {"pil_image": pil_image_with_prediction, "prediction": prediction}
739
+ else:
740
+ return {}
741
+
742
+
743
+ def run_crop(square_dim: int, yolo_prediction) -> dict[str, Any]:
744
+ """
745
+ Run the crop stage on the yolo_prediction.
746
+
747
+ It resizes a square bear face based on `square_dim`.
748
+ """
749
+ cropped_bear_head = crop_from_yolov8(prediction_yolov8=yolo_prediction)
750
+ padded_cropped_head = square_pad(cropped_bear_head)
751
+ resized_padded_cropped_head = resize(
752
+ padded_cropped_head, dim=(square_dim, square_dim)
753
+ )
754
+ pil_image_cropped_bear_head = Image.fromarray(bgr_to_rgb(cropped_bear_head))
755
+ pil_image_padded_cropped_head = Image.fromarray(
756
+ bgr_to_rgb(resized_padded_cropped_head)
757
+ )
758
+ pil_image_resized_padded_cropped_head = Image.fromarray(
759
+ bgr_to_rgb(resized_padded_cropped_head)
760
+ )
761
+ return {
762
+ "pil_images": {
763
+ "cropped": pil_image_cropped_bear_head,
764
+ "padded": pil_image_padded_cropped_head,
765
+ "resized": pil_image_resized_padded_cropped_head,
766
+ }
767
+ }
768
+
769
+
770
+ def make_id_to_label(id_mapping: pd.DataFrame) -> dict[int, str]:
771
+ return id_mapping.set_index("id")["label"].to_dict()
772
+
773
+
774
+ def run_identification(
775
+ loaded_model,
776
+ k: int,
777
+ knn_index_filepath: Path,
778
+ pil_image_chip: Image.Image,
779
+ n_samples_per_individual: int = 5,
780
+ ) -> dict[str, Any]:
781
+ """
782
+ Run the identification stage.
783
+ """
784
+ device = get_best_device()
785
+ args = loaded_model["args"]
786
+ config = args.copy()
787
+ del config["run"]
788
+
789
+ transforms = get_transforms(
790
+ data_augmentation=config.get("data_augmentation", {}),
791
+ trunk_preprocessing=config["model"]["trunk"].get("preprocessing", {}),
792
+ )
793
+
794
+ logging.info("loading the df_split")
795
+ df_split = pd.DataFrame(loaded_model["data_split"])
796
+ df_split.info()
797
+
798
+ id_mapping = make_id_mapping(df=df_split)
799
+
800
+ dataloaders = make_dataloaders(
801
+ batch_size=config["batch_size"],
802
+ df_split=df_split,
803
+ transforms=transforms,
804
+ )
805
+
806
+ model_dict = make_model_dict(
807
+ device=device,
808
+ pretrained_backbone=config["model"]["trunk"]["backbone"],
809
+ embedding_size=config["model"]["embedder"]["embedding_size"],
810
+ hidden_layer_sizes=config["model"]["embedder"]["hidden_layer_sizes"],
811
+ )
812
+
813
+ trunk_weights = loaded_model["trunk"]
814
+ trunk = model_dict["trunk"]
815
+ trunk = load_weights(
816
+ network=trunk,
817
+ weights=trunk_weights,
818
+ prefix="module.",
819
+ )
820
+
821
+ embedder_weights = loaded_model["embedder"]
822
+ embedder = model_dict["embedder"]
823
+ embedder = load_weights(
824
+ network=embedder,
825
+ weights=embedder_weights,
826
+ prefix="module.",
827
+ )
828
+
829
+ model = InferenceModel(
830
+ trunk=trunk,
831
+ embedder=embedder,
832
+ )
833
+
834
+ dataset_full = dataloaders["dataset"]["full"]
835
+
836
+ assert (
837
+ knn_index_filepath.exists()
838
+ ), f"knn_index_filepath invalid filepath: {knn_index_filepath}"
839
+ model.load_knn_func(filename=str(knn_index_filepath))
840
+
841
+ image = pil_image_chip
842
+ transform_test = transforms["test"]
843
+ model_input = transform_test(image)
844
+ query = model_input.unsqueeze(0)
845
+ id_to_label = make_id_to_label(id_mapping=id_mapping)
846
+
847
+ k_nearest_individuals = get_k_nearest_individuals(
848
+ model=model,
849
+ k=k,
850
+ query=query,
851
+ id_to_label=id_to_label,
852
+ dataset=dataset_full,
853
+ )
854
+ indexed_k_nearest_individuals = index_by_bearid(
855
+ k_nearest_individuals=k_nearest_individuals
856
+ )
857
+ bear_ids = list(indexed_k_nearest_individuals.keys())
858
+ indexed_samples = make_indexed_samples(
859
+ bear_ids=bear_ids,
860
+ df_split=df_split,
861
+ n=n_samples_per_individual,
862
+ )
863
+ return {
864
+ "bear_ids": bear_ids,
865
+ "k_nearest_individuals": k_nearest_individuals,
866
+ "indexed_k_nearest_individuals": indexed_k_nearest_individuals,
867
+ "indexed_samples": indexed_samples,
868
+ }
869
+
870
+
871
+ def run_pipeline(
872
+ loaded_models: dict[str, Any],
873
+ param_square_dim: int,
874
+ param_k: int,
875
+ param_n_samples_per_individual: int,
876
+ knn_index_filepath: Path,
877
+ pil_image: Image.Image,
878
+ ) -> dict[str, Any]:
879
+ """
880
+ Run the full pipeline on pil_image, using `pil_image` as an input.
881
+
882
+ Args:
883
+ loaded_models (dict[str, Any]): dict of all the loaded models needed to
884
+ run the pipeline. Usually loaded via the `load_model` function.
885
+ param_square_dim (int): size of the square chip.
886
+ param_k (int): how many closest individuals to query to compare it to
887
+ the chip
888
+ param_n_samples_per_individual (int): How many chips from each
889
+ individual do we want to compare it to?
890
+ knn_index_filepath (Path): filepath to the KNN index of the embedded
891
+ chips.
892
+ pil_image (PIL): Main input image of the pipeline
893
+ """
894
+ results_segmentation = run_segmentation(
895
+ model=loaded_models["segmentation"], pil_image=pil_image
896
+ )
897
+ results_crop = run_crop(
898
+ square_dim=param_square_dim,
899
+ yolo_prediction=results_segmentation["prediction"],
900
+ )
901
+ pil_image_chip = results_crop["pil_images"]["resized"]
902
+ results_identification = run_identification(
903
+ loaded_model=loaded_models["metric_learning"],
904
+ k=param_k,
905
+ knn_index_filepath=knn_index_filepath,
906
+ pil_image_chip=pil_image_chip,
907
+ n_samples_per_individual=5,
908
+ )
909
+ return {
910
+ "order": ["segmentation", "crop", "identification"],
911
+ "stages": {
912
+ "segmentation": {
913
+ "input": {"pil_image": pil_image},
914
+ "output": results_segmentation,
915
+ },
916
+ "crop": {
917
+ "input": {
918
+ "square_dim": param_square_dim,
919
+ "yolo_prediction": results_segmentation["prediction"],
920
+ },
921
+ "output": results_crop,
922
+ },
923
+ "identification": {
924
+ "input": {
925
+ "k": param_k,
926
+ "n_samples_per_individual": param_n_samples_per_individual,
927
+ "knn_index_filepath": knn_index_filepath,
928
+ "pil_image_chip": pil_image_chip,
929
+ },
930
+ "output": results_identification,
931
+ },
932
+ },
933
+ }