import logging import os import shutil import subprocess from collections import Counter from pathlib import Path from typing import Any, Optional, OrderedDict import cv2 import numpy as np import pandas as pd import torch import torch.nn as nn import torchvision import torchvision.models as models from PIL import Image from pytorch_metric_learning.utils.common_functions import logging from pytorch_metric_learning.utils.inference import InferenceModel from torch.utils.data import DataLoader, Dataset from torchvision import transforms from torchvision.transforms import v2 from ultralytics import YOLO # TODO: move metric learning functions into their own namespace def sample_chips_from_bearid( bear_id: str, df_split: pd.DataFrame, n: int = 4, ) -> list[Path]: xs = df_split[df_split["bear_id"] == bear_id].sample(n=n)["path"].tolist() return [Path(x) for x in xs] def make_indexed_samples( bear_ids: list[str], df_split: pd.DataFrame, n: int = 4, ) -> dict[str, list[Path]]: return { bear_id: sample_chips_from_bearid(bear_id=bear_id, df_split=df_split, n=n) for bear_id in bear_ids } def _aux_get_k_nearest_individuals( model: InferenceModel, k_neighbors: int, k_individuals: int, query, id_to_label: dict, dataset: Dataset, ) -> dict: """Auxiliary helper function to get k nearest individuals. Returns a dict with the following keys: - k_neighbors: int - number of neighbors the KNN search extends to in order to find at least k_individuals - dataset_indices: list[int] - list of indices to call get_item on the dataset - dataset_labels: list[int] - labels of the dataset for the given dataset_indices - dataset_images: list[torch.tensor] - chips of the bears - distances: list[float] - distances from the query Note: it can return more than k_individuals as it extends progressively the KNN search to find at least k_individuals. """ assert k_individuals <= 20, f"Keep a small k_individuals: {k_individuals}" distances, indices = model.get_nearest_neighbors(query=query, k=k_neighbors) indices_on_cpu = indices.cpu()[0].tolist() distances_on_cpu = distances.cpu()[0].tolist() nearest_images, nearest_ids = list(zip(*[dataset[idx] for idx in indices_on_cpu])) bearids = [id_to_label.get(nearest_id, "unknown") for nearest_id in nearest_ids] counter = Counter(nearest_ids) if len(counter.keys()) >= k_individuals: return { "k_neighbors": k_neighbors, "dataset_indices": indices_on_cpu, "dataset_labels": list(nearest_ids), "dataset_images": list(nearest_images), "bearids": bearids, "distances": distances_on_cpu, } else: new_k_neighbors = k_neighbors * 2 return _aux_get_k_nearest_individuals( model, k_neighbors=new_k_neighbors, k_individuals=k_individuals, query=query, id_to_label=id_to_label, dataset=dataset, ) def _find_cutoff_index(k: int, dataset_labels: list[str]) -> Optional[int]: """Returns the index for dataset_labels that retrieves exactly k individuals.""" if not dataset_labels: return None else: selected_labels = set() cutoff_index = -1 for idx, label in enumerate(dataset_labels): if len(selected_labels) == k: break else: selected_labels.add(label) cutoff_index = idx + 1 return cutoff_index def get_k_nearest_individuals( model: InferenceModel, k: int, query, id_to_label: dict, dataset: Dataset, ) -> dict: """Returns the k nearest individuals using the inference model and a query. A dict is returned with the following keys: - dataset_indices: list[int] - list of indices to call get_item on the dataset - dataset_labels: list[int] - labels of the dataset for the given dataset_indices - dataset_images: list[torch.tensor] - chips of the bears - distances: list[float] - distances from the query """ k_neighbors = k * 5 k_individuals = k result = _aux_get_k_nearest_individuals( model=model, k_neighbors=k_neighbors, k_individuals=k_individuals, query=query, id_to_label=id_to_label, dataset=dataset, ) cutoff_index = _find_cutoff_index( k=k, dataset_labels=result["dataset_labels"], ) return { "dataset_indices": result["dataset_indices"][:cutoff_index], "dataset_labels": result["dataset_labels"][:cutoff_index], "dataset_images": result["dataset_images"][:cutoff_index], "bearids": result["bearids"][:cutoff_index], "distances": result["distances"][:cutoff_index], } def index_by_bearid(k_nearest_individuals: dict) -> dict: """Returns a dict where keys are bearid labels (eg. 'bf_480') and the values are list of the following dict shapes: - dataset_label: int - dataset_image: torch.tensor - distance: float - dataset_index: int """ result = {} for dataset_label, dataset_image, distance, bearid, dataset_index in zip( k_nearest_individuals["dataset_labels"], k_nearest_individuals["dataset_images"], k_nearest_individuals["distances"], k_nearest_individuals["bearids"], k_nearest_individuals["dataset_indices"], ): row = { "dataset_label": dataset_label, "dataset_image": dataset_image, "distance": distance, "dataset_index": dataset_index, } if bearid not in result: result[bearid] = [row] else: result[bearid].append(row) return result def prefix_keys_with(weights: OrderedDict, prefix: str = "module.") -> OrderedDict: """Returns the new weights where each key is prefixed with the provided `prefix`. Note: Useful when using DataParallel to account for the module. prefix key. """ weights_copy = weights.copy() for k, v in weights.items(): weights_copy[f"{prefix}{k}"] = v del weights_copy[k] return weights_copy def load_weights( network: torch.nn.Module, weights_filepath: Optional[Path] = None, weights: Optional[OrderedDict] = None, prefix: str = "", ) -> torch.nn.Module: """Loads the network weights. Returns the network. """ if weights: prefixed_weights = prefix_keys_with(weights, prefix=prefix) network.load_state_dict(state_dict=prefixed_weights) return network elif weights_filepath: assert weights_filepath.exists(), f"Invalid model_filepath {weights_filepath}" weights = torch.load(weights_filepath) prefixed_weights = prefix_keys_with(weights, prefix=prefix) network.load_state_dict(state_dict=prefixed_weights) return network else: raise Exception(f"Should provide at least weights or weights_filepath") class MLP(nn.Module): # layer_sizes[0] is the dimension of the input # layer_sizes[-1] is the dimension of the output def __init__(self, layer_sizes, final_relu=False): super().__init__() layer_list = [] layer_sizes = [int(x) for x in layer_sizes] num_layers = len(layer_sizes) - 1 final_relu_layer = num_layers if final_relu else num_layers - 1 for i in range(len(layer_sizes) - 1): input_size = layer_sizes[i] curr_size = layer_sizes[i + 1] if i <= final_relu_layer: layer_list.append(nn.ReLU(inplace=False)) layer_list.append(nn.BatchNorm1d(input_size)) layer_list.append(nn.Linear(input_size, curr_size)) self.net = nn.Sequential(*layer_list) self.last_linear = self.net[-1] def forward(self, x): return self.net(x) def check_backbone(pretrained_backbone: str) -> None: allowed_backbones = { "resnet18", "resnet50", "convnext_tiny", "convnext_base", "convnext_large", "efficientnet_v2_s", # "squeezenet1_1", "vit_b_16", } assert ( pretrained_backbone in allowed_backbones ), f"pretrained_backbone {pretrained_backbone} is not implemented, only {allowed_backbones}" def make_trunk(pretrained_backbone: str = "resnet18") -> nn.Module: """Returns a nn.Module with pretrained weights using a given pretrained_backbone. Note: The currently available backbones are resnet18, resnet50, convnext_tiny, convnext_bas, efficientnet_v2_s, squeezenet1_1, vit_b_16 """ check_backbone(pretrained_backbone) if pretrained_backbone == "resnet18": return torchvision.models.resnet18( weights=models.ResNet18_Weights.IMAGENET1K_V1 ) elif pretrained_backbone == "resnet50": return torchvision.models.resnet50( weights=models.ResNet50_Weights.IMAGENET1K_V1 ) elif pretrained_backbone == "convnext_tiny": return torchvision.models.convnext_tiny( weights=models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1 ) elif pretrained_backbone == "convnext_base": return torchvision.models.convnext_base( weights=models.ConvNeXt_Base_Weights.IMAGENET1K_V1 ) elif pretrained_backbone == "convnext_large": return torchvision.models.convnext_large( weights=models.ConvNeXt_Large_Weights.IMAGENET1K_V1 ) elif pretrained_backbone == "efficientnet_v2_s": return torchvision.models.efficientnet_v2_s( weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1 ) elif pretrained_backbone == "squeezenet1_1": return torchvision.models.squeezenet1_1( weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1 ) elif pretrained_backbone == "vit_b_16": return torchvision.models.vit_b_16( weights=models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1 ) else: raise Exception(f"Cannot make trunk with backbone {pretrained_backbone}") def make_embedder( pretrained_backbone: str, trunk: nn.Module, embedding_size: int, hidden_layer_sizes: list[int], ) -> nn.Module: check_backbone(pretrained_backbone) if pretrained_backbone in ["resnet18", "resnet50"]: trunk_output_size = trunk.fc.in_features trunk.fc = nn.Identity() return MLP([trunk_output_size, *hidden_layer_sizes, embedding_size]) if pretrained_backbone in ["convnext_tiny", "convnext_base", "convnext_large"]: trunk_output_size = trunk.classifier[-1].in_features trunk.classifier[-1] = nn.Identity() return MLP([trunk_output_size, *hidden_layer_sizes, embedding_size]) elif pretrained_backbone == "efficientnet_v2_s": trunk_output_size = trunk.classifier[-1].in_features trunk.classifier[-1] = nn.Identity() return MLP([trunk_output_size, *hidden_layer_sizes, embedding_size]) elif pretrained_backbone == "vit_b_16": trunk_output_size = trunk.heads.head.in_features trunk.heads.head = nn.Identity() return MLP([trunk_output_size, *hidden_layer_sizes, embedding_size]) else: raise Exception(f"{pretrained_backbone} embedder not implemented yet") def make_model_dict( device: torch.device, pretrained_backbone: str = "resnet18", embedding_size: int = 128, hidden_layer_sizes: list[int] = [1024], ) -> dict[str, nn.Module]: """ Returns a dict with the following keys: - embedder: nn.Module - embedder model, usually an MLP. - trunk: nn.Module - the backbone model, usually a pretrained model (like a ResNet). """ trunk = make_trunk(pretrained_backbone=pretrained_backbone) embedder = make_embedder( pretrained_backbone=pretrained_backbone, embedding_size=embedding_size, hidden_layer_sizes=hidden_layer_sizes, trunk=trunk, ) trunk = torch.nn.DataParallel(trunk.to(device)) embedder = torch.nn.DataParallel(embedder.to(device)) return { "trunk": trunk, "embedder": embedder, } class BearDataset(Dataset): def __init__(self, dataframe, id_mapping, transform=None): self.dataframe = dataframe self.id_mapping = id_mapping self.transform = transform def __len__(self): return len(self.dataframe) def __getitem__(self, idx): sample = self.dataframe.iloc[idx] image_path = sample.path bear_id = sample.bear_id id_value = self.id_mapping.loc[self.id_mapping["label"] == bear_id, "id"].iloc[ 0 ] image = Image.open(image_path) if self.transform: image = self.transform(image) return image, id_value def make_dataloaders( batch_size: int, df_split: pd.DataFrame, transforms: dict, ) -> dict: """Returns a dict with top level keys in {dataset and loader}. Each returns a dict with the train, val and test objects associated. """ df_train = df_split[df_split["split"] == "train"] df_val = df_split[df_split["split"] == "val"] df_test = df_split[df_split["split"] == "test"] id_mapping = make_id_mapping(df=df_split) train_dataset = BearDataset( df_train, id_mapping, transform=transforms["train"], ) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, ) val_dataset = BearDataset( df_val, id_mapping, transform=transforms["val"], ) val_loader = DataLoader( val_dataset, batch_size=batch_size, ) test_dataset = BearDataset( df_test, id_mapping, transform=transforms["test"], ) test_loader = DataLoader( test_dataset, batch_size=batch_size, ) viz_dataset = BearDataset( df_train, id_mapping, transform=transforms["viz"], ) viz_loader = DataLoader( viz_dataset, batch_size=batch_size, shuffle=True, drop_last=True, ) full_dataset = BearDataset( df_split, id_mapping, transform=transforms["val"], ) return { "dataset": { "viz": viz_dataset, "train": train_dataset, "val": val_dataset, "test": test_dataset, "full": full_dataset, }, "loader": { "viz": viz_loader, "train": train_loader, "val": val_loader, "test": test_loader, }, } def make_id_mapping(df: pd.DataFrame) -> pd.DataFrame: """Returns a dataframe that maps a bear label (eg. bf_755) to a unique natural number (eg. 0). The dataFrame contains two columns, namely id and label. """ return pd.DataFrame( list(enumerate(df["bear_id"].unique())), columns=["id", "label"] ) def filter_none(xs: list) -> list: return [x for x in xs if x is not None] def get_dtype(dtype_str: str) -> torch.dtype: if dtype_str == "float32": return torch.float32 elif dtype_str == "int64": return torch.int64 else: logging.warning( f"dtype_str {dtype_str} not implemented, returning default value" ) return torch.float32 def get_transforms( data_augmentation: dict = {}, trunk_preprocessing: dict = {}, ) -> dict: """Returns a dict containing the transforms for the following splits: train, val, test and viz (the latter is used for batch visualization). """ logging.info(f"data_augmentation config: {data_augmentation}") logging.info(f"trunk preprocessing config: {trunk_preprocessing}") DEFAULT_CROP_SIZE = 224 crop_size = ( trunk_preprocessing.get("crop_size", DEFAULT_CROP_SIZE), trunk_preprocessing.get("crop_size", DEFAULT_CROP_SIZE), ) # transform to persist a batch of data as an artefact transform_viz = transforms.Compose( [ transforms.Resize(crop_size), transforms.ToTensor(), ] ) mdtype: Optional[torch.dtype] = ( get_dtype(trunk_preprocessing["values"].get("dtype", None)) if trunk_preprocessing.get("values", None) else None ) mscale: Optional[bool] = ( trunk_preprocessing["values"].get("scale", None) if trunk_preprocessing.get("values", None) else None ) mmean: Optional[list[float]] = ( trunk_preprocessing["normalization"].get("mean", None) if trunk_preprocessing.get("normalization", None) else None ) mstd: Optional[list[float]] = ( trunk_preprocessing["normalization"].get("std", None) if trunk_preprocessing.get("normalization", None) else None ) hue = ( data_augmentation["colorjitter"].get("hue", 0) if data_augmentation.get("colorjitter", 0) else 0 ) saturation = ( data_augmentation["colorjitter"].get("saturation", 0) if data_augmentation.get("colorjitter", 0) else 0 ) degrees = ( data_augmentation["rotation"].get("degrees", 0) if data_augmentation.get("rotation", 0) else 0 ) transformations_plain = [ transforms.Resize(crop_size), transforms.ToTensor(), v2.ToDtype(dtype=mdtype, scale=mscale) if mdtype and mscale else None, transforms.Normalize(mean=mmean, std=mstd) if mmean and mstd else None, ] transformations_train = [ transforms.Resize(crop_size), ( transforms.ColorJitter( hue=hue, saturation=saturation, ) if data_augmentation.get("colorjitter", None) else None ), # Taken from Dolphin ID ( v2.RandomRotation(degrees=degrees) if data_augmentation.get("rotation", None) else None ), # Taken from Dolphin ID transforms.ToTensor(), v2.ToDtype(dtype=mdtype, scale=mscale) if mdtype and mscale else None, transforms.Normalize(mean=mmean, std=mstd) if mmean and mstd else None, ] # Filtering out None transforms transform_plain = transforms.Compose(filter_none(transformations_plain)) transform_train = transforms.Compose(filter_none(transformations_train)) return { "viz": transform_viz, "train": transform_train, "val": transform_plain, "test": transform_plain, } def resize( mask: np.ndarray, dim: tuple[int, int], interpolation: int = cv2.INTER_LINEAR, ): """Resize the mask to the provided `dim` using the interpolation method. `dim`: (W, H) format """ return cv2.resize(mask, dsize=dim, interpolation=interpolation) def crop_from_yolov8(prediction_yolov8) -> np.ndarray: """Given a yolov8 prediction, returns an image containing the cropped bear head.""" H, W = prediction_yolov8.orig_shape predictions_masks = prediction_yolov8.masks.data.to("cpu").numpy() idx = np.argmax(prediction_yolov8.boxes.conf.to("cpu").numpy()) predictions_mask = predictions_masks[idx] prediction_resized = resize(predictions_mask, dim=(W, H)) masked_image = prediction_yolov8.orig_img.copy() black_pixel = [0, 0, 0] masked_image[~prediction_resized.astype(bool)] = black_pixel x0, y0, x1, y1 = prediction_yolov8.boxes[idx].xyxy[0].to("cpu").numpy() return masked_image[int(y0) : int(y1), int(x0) : int(x1)] def square_pad(img: np.ndarray): """Returns an image with dimension max(W, H) x max(W, H), padded with black pixels.""" H, W, _ = img.shape K = max(H, W) top = (K - H) // 2 bottom = (K - H) // 2 left = (K - W) // 2 right = (K - W) // 2 return cv2.copyMakeBorder( img.copy(), top, bottom, left, right, cv2.BORDER_CONSTANT, ) def get_best_device() -> torch.device: """Returns the best torch device depending on the hardware it is running on.""" return torch.device("cuda" if torch.cuda.is_available() else "cpu") def _setup_chips() -> None: """ Setup the Database of chips used for the face recognition. """ subprocess.run(["./scripts/chips/install.sh"]) def _setup_ml_pipeline(input_packaged_pipeline: Path, install_path: Path) -> None: """ Setup the ML pipeline, installing the model weights into their folders. """ logging.info(f"Installing the packaged pipeline in {install_path}") os.makedirs(install_path, exist_ok=True) packaged_pipeline_archive_filepath = input_packaged_pipeline shutil.unpack_archive( filename=packaged_pipeline_archive_filepath, extract_dir=install_path, ) metriclearning_model_filepath = install_path / "bearidentification" / "model.pt" device = get_best_device() bearidentification_model = torch.load( metriclearning_model_filepath, map_location=device, ) df_split = pd.DataFrame(bearidentification_model["data_split"]) chips_root_dir = Path("/".join(df_split.iloc[0]["path"].split("/")[:-4])) logging.info(f"Retrieved chips_root_dir: {chips_root_dir}") os.makedirs(chips_root_dir, exist_ok=True) shutil.copytree( src=install_path / "chips", dst=chips_root_dir, dirs_exist_ok=True, ) def setup(input_packaged_pipeline: Path, install_path: Path) -> None: """ Full setup of the project. """ _setup_chips() _setup_ml_pipeline( input_packaged_pipeline=input_packaged_pipeline, install_path=install_path ) def bgr_to_rgb(a: np.ndarray) -> np.ndarray: """ Turn a BGR numpy array into a RGB numpy array when the array `a` represents an image. """ return a[:, :, ::-1] def load_segmentation_model(filepath_weights: Path) -> YOLO: """ Load the YOLO model given the filepath_weights. """ assert filepath_weights.exists() return YOLO(filepath_weights) def load_metric_learning_model(device: torch.device, filepath_weights: Path) -> Any: assert filepath_weights.exists() return torch.load(filepath_weights, map_location=device) def load_models( filepath_segmentation_weights: Path, filepath_metric_learning_weights: Path, ) -> dict[str, Any]: assert filepath_segmentation_weights.exists() assert filepath_metric_learning_weights.exists() device = get_best_device() model_segmentation = load_segmentation_model(filepath_segmentation_weights) model_metric_learning = load_metric_learning_model( device=device, filepath_weights=filepath_metric_learning_weights, ) return { "segmentation": model_segmentation, "metric_learning": model_metric_learning, } def run_segmentation(model: YOLO, pil_image: Image.Image) -> dict[str, Any]: predictions = model(pil_image) if len(predictions) > 0: prediction = predictions[0] pil_image_with_prediction = Image.fromarray(bgr_to_rgb(prediction.plot())) return {"pil_image": pil_image_with_prediction, "prediction": prediction} else: return {} def run_crop(square_dim: int, yolo_prediction) -> dict[str, Any]: """ Run the crop stage on the yolo_prediction. It resizes a square bear face based on `square_dim`. """ cropped_bear_head = crop_from_yolov8(prediction_yolov8=yolo_prediction) padded_cropped_head = square_pad(cropped_bear_head) resized_padded_cropped_head = resize( padded_cropped_head, dim=(square_dim, square_dim) ) pil_image_cropped_bear_head = Image.fromarray(bgr_to_rgb(cropped_bear_head)) pil_image_padded_cropped_head = Image.fromarray( bgr_to_rgb(resized_padded_cropped_head) ) pil_image_resized_padded_cropped_head = Image.fromarray( bgr_to_rgb(resized_padded_cropped_head) ) return { "pil_images": { "cropped": pil_image_cropped_bear_head, "padded": pil_image_padded_cropped_head, "resized": pil_image_resized_padded_cropped_head, } } def make_id_to_label(id_mapping: pd.DataFrame) -> dict[int, str]: return id_mapping.set_index("id")["label"].to_dict() def run_identification( loaded_model, k: int, knn_index_filepath: Path, pil_image_chip: Image.Image, n_samples_per_individual: int = 5, ) -> dict[str, Any]: """ Run the identification stage. """ device = get_best_device() args = loaded_model["args"] config = args.copy() del config["run"] transforms = get_transforms( data_augmentation=config.get("data_augmentation", {}), trunk_preprocessing=config["model"]["trunk"].get("preprocessing", {}), ) logging.info("loading the df_split") df_split = pd.DataFrame(loaded_model["data_split"]) df_split.info() id_mapping = make_id_mapping(df=df_split) dataloaders = make_dataloaders( batch_size=config["batch_size"], df_split=df_split, transforms=transforms, ) model_dict = make_model_dict( device=device, pretrained_backbone=config["model"]["trunk"]["backbone"], embedding_size=config["model"]["embedder"]["embedding_size"], hidden_layer_sizes=config["model"]["embedder"]["hidden_layer_sizes"], ) trunk_weights = loaded_model["trunk"] trunk = model_dict["trunk"] trunk = load_weights( network=trunk, weights=trunk_weights, prefix="module.", ) embedder_weights = loaded_model["embedder"] embedder = model_dict["embedder"] embedder = load_weights( network=embedder, weights=embedder_weights, prefix="module.", ) model = InferenceModel( trunk=trunk, embedder=embedder, ) dataset_full = dataloaders["dataset"]["full"] assert ( knn_index_filepath.exists() ), f"knn_index_filepath invalid filepath: {knn_index_filepath}" model.load_knn_func(filename=str(knn_index_filepath)) image = pil_image_chip transform_test = transforms["test"] model_input = transform_test(image) query = model_input.unsqueeze(0) id_to_label = make_id_to_label(id_mapping=id_mapping) k_nearest_individuals = get_k_nearest_individuals( model=model, k=k, query=query, id_to_label=id_to_label, dataset=dataset_full, ) indexed_k_nearest_individuals = index_by_bearid( k_nearest_individuals=k_nearest_individuals ) bear_ids = list(indexed_k_nearest_individuals.keys()) indexed_samples = make_indexed_samples( bear_ids=bear_ids, df_split=df_split, n=n_samples_per_individual, ) return { "bear_ids": bear_ids, "k_nearest_individuals": k_nearest_individuals, "indexed_k_nearest_individuals": indexed_k_nearest_individuals, "indexed_samples": indexed_samples, } def run_pipeline( loaded_models: dict[str, Any], param_square_dim: int, param_k: int, param_n_samples_per_individual: int, knn_index_filepath: Path, pil_image: Image.Image, ) -> dict[str, Any]: """ Run the full pipeline on pil_image, using `pil_image` as an input. Args: loaded_models (dict[str, Any]): dict of all the loaded models needed to run the pipeline. Usually loaded via the `load_model` function. param_square_dim (int): size of the square chip. param_k (int): how many closest individuals to query to compare it to the chip param_n_samples_per_individual (int): How many chips from each individual do we want to compare it to? knn_index_filepath (Path): filepath to the KNN index of the embedded chips. pil_image (PIL): Main input image of the pipeline """ results_segmentation = run_segmentation( model=loaded_models["segmentation"], pil_image=pil_image ) results_crop = run_crop( square_dim=param_square_dim, yolo_prediction=results_segmentation["prediction"], ) pil_image_chip = results_crop["pil_images"]["resized"] results_identification = run_identification( loaded_model=loaded_models["metric_learning"], k=param_k, knn_index_filepath=knn_index_filepath, pil_image_chip=pil_image_chip, n_samples_per_individual=5, ) return { "order": ["segmentation", "crop", "identification"], "stages": { "segmentation": { "input": {"pil_image": pil_image}, "output": results_segmentation, }, "crop": { "input": { "square_dim": param_square_dim, "yolo_prediction": results_segmentation["prediction"], }, "output": results_crop, }, "identification": { "input": { "k": param_k, "n_samples_per_individual": param_n_samples_per_individual, "knn_index_filepath": knn_index_filepath, "pil_image_chip": pil_image_chip, }, "output": results_identification, }, }, }