trout-reID / app.py
achouffe's picture
feat: add topk tab
6dfc5ab verified
"""
Simple Gradio Interface to showcase the ML outputs in a UI and webapp.
"""
import math
from pathlib import Path
from typing import Any
import gradio as gr
from PIL import Image
from ultralytics import YOLO
import pipeline
from identification import IdentificationModel, generate_visualization
from utils import bgr_to_rgb, select_best_device
DEFAULT_IMAGE_INDEX = 0
K = 5
DIR_INSTALLED_PIPELINE = Path("./data/pipeline/")
DIR_EXAMPLES = Path("./data/images/")
FILEPATH_IDENTIFICATION_LIGHTGLUE_CONFIG = (
DIR_INSTALLED_PIPELINE / "models/identification/config.yaml"
)
FILEPATH_IDENTIFICATION_DB = DIR_INSTALLED_PIPELINE / "db/db.csv"
FILEPATH_IDENTIFICATION_LIGHTGLUE_FEATURES = (
DIR_INSTALLED_PIPELINE / "models/identification/features.pt"
)
FILEPATH_WEIGHTS_SEGMENTATION_MODEL = (
DIR_INSTALLED_PIPELINE / "models/segmentation/weights.pt"
)
FILEPATH_WEIGHTS_POSE_MODEL = DIR_INSTALLED_PIPELINE / "models/pose/weights.pt"
CACHE_PIPELINE_RUN = {}
CACHE_VISUALIZATION_GENERATION = {}
def pipeline_run_fn(
loaded_models: dict[str, YOLO | IdentificationModel],
pil_image: Image.Image,
cache: dict,
k: int,
) -> dict[str, Any]:
"""
A simple cached version of the pipeline.run function.
__Note__: It keeps the cache in memory and can grow unbounded.
"""
bytes_image = pil_image.tobytes()
if bytes_image in cache:
return cache[bytes_image]
else:
results = pipeline.run(
loaded_models=loaded_models, pil_image=pil_image, param_k=k
)
cache[bytes_image] = results
return results
def generate_visualization_fn(
pil_image: Image.Image,
prediction: dict[str, Any],
cache: dict,
) -> dict:
"""
A simple cached version of the generate_visualization function.
__Note__: It keeps the cache in memory and can grow unbounded.
"""
bytes_image = pil_image.tobytes()
if bytes_image in cache:
return cache[bytes_image]
else:
results = generate_visualization(
pil_image=pil_image,
prediction=prediction,
)
cache[bytes_image] = results
return results
def examples(dir_examples: Path) -> list[Path]:
"""
Function to retrieve the default example images.
Returns:
examples (list[Path]): list of image filepaths.
"""
return list(dir_examples.glob("*.jpg"))
def make_ui(loaded_models: dict[str, YOLO | IdentificationModel], k: int):
"""
Main entrypoint to wire up the Gradio interface.
Args:
loaded_models (dict[str, YOLO | IdentificationModel]): loaded models ready to run inference with.
Returns:
gradio_ui
"""
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
value=default_value_input,
label="input image",
sources=["upload", "clipboard"],
)
gr.Examples(
examples=example_filepaths,
inputs=image_input,
)
submit_btn = gr.Button(value="Identify", variant="primary")
with gr.Column():
with gr.Tab("Prediction"):
with gr.Row():
pit_prediction = gr.Text(label="predicted individual")
name_prediction = gr.Text(label="fish name", visible=False)
image_feature_matching = gr.Image(
label="pattern matching", visible=False
)
image_extracted_keypoints = gr.Image(
label="extracted keypoints", visible=False
)
with gr.Tab(f"Top {K}", visible=False) as tab_top_k:
textbox_dataset_name_template = gr.Textbox(visible=False)
textbox_dataset_pit_template = gr.Textbox(visible=False)
textbox_dataset_wasserstein_template = gr.Textbox(visible=False)
textbox_dataset_is_match_template = gr.Textbox(visible=False)
image_dataset_orig_template = gr.Image(visible=False)
image_dataset_keypoints_template = gr.Image(visible=False)
image_dataset_matches_template = gr.Image(visible=False)
image_source_orig = gr.Image(visible=False)
gr.Markdown(
"Select a row from the table below to see the matches and extracted keypoints."
)
dataset = gr.Dataset(
type="values",
label=f"dataset top {k}",
headers=[
"Match?",
"Image",
"Keypoints",
"Matches",
"Name",
"PIT",
"Wasserstein Distance",
],
components=[
textbox_dataset_is_match_template,
image_dataset_orig_template,
image_dataset_keypoints_template,
image_dataset_matches_template,
textbox_dataset_name_template,
textbox_dataset_pit_template,
textbox_dataset_wasserstein_template,
],
samples=[],
)
with gr.Row():
textbox_is_match = gr.Textbox(
label="Match?",
visible=False,
scale=1,
)
textbox_pit = gr.Textbox(label="PIT", visible=False, scale=5)
textbox_name = gr.Textbox(label="Name", visible=False, scale=5)
with gr.Row():
image_orig_selected = gr.Image(visible=False)
with gr.Column():
image_matches_selected = gr.Image(visible=False)
image_keypoints_selected = gr.Image(visible=False)
with gr.Tab("Details", visible=False) as tab_details:
with gr.Column():
with gr.Row():
text_rotation_angle = gr.Text(
label="correction angle (degrees)"
)
text_side = gr.Text(label="predicted side")
image_pose_keypoints = gr.Image(
type="pil", label="pose keypoints"
)
image_rotated_keypoints = gr.Image(
type="pil", label="rotated keypoints"
)
image_segmentation_mask = gr.Image(type="pil", label="mask")
image_masked = gr.Image(type="pil", label="masked")
def select_fn(inputs: list, evt: gr.EventData):
"""
Select a row in the dataset of the top5 individuals.
"""
return [
gr.Textbox(inputs[0], label="Match", visible=True),
gr.Image(label="Image", value=inputs[1], visible=True),
gr.Image(label="keypoints", value=inputs[2], visible=True),
gr.Image(label="matches", value=inputs[3], visible=True),
gr.Textbox(inputs[4], label="Name", visible=True),
gr.Textbox(inputs[5], label="PIT", visible=True),
]
dataset.click(
fn=select_fn,
queue=False,
inputs=[dataset],
outputs=[
textbox_is_match,
image_orig_selected,
image_keypoints_selected,
image_matches_selected,
textbox_name,
textbox_pit,
],
)
def submit_fn(
loaded_models: dict[str, YOLO | IdentificationModel],
orig_image: Image.Image,
):
"""
Main function used for the Gradio interface.
Args:
loaded_models (dict[str, YOLO]): loaded models.
orig_image (PIL): original image picked by the user
Returns:
fish side (str): predicted fish side
correction angle (str): rotation to do in degrees to re align the image.
keypoints image (PIL): image displaying the bbox and keypoints from the
pose estimation model.
rotated image (PIL): rotated image after applying the correction angle.
segmentation mask (PIL): segmentation mask predicted by the segmentation model.
segmented image (PIL): segmented orig_image using the segmentation mask
and the crop.
predicted_individual (str): The identified individual.
pil_image_extracted_keypoints (PIL): The extracted keypoints overlayed on the image.
feature_matching_image (PIL): The matching of the source with the identified individual.
"""
model_identification = loaded_models["identification"]
results = pipeline_run_fn(
loaded_models=loaded_models,
pil_image=orig_image,
cache=CACHE_PIPELINE_RUN,
k=K,
)
side = results["stages"]["pose"]["output"]["side"]
theta = results["stages"]["pose"]["output"]["theta"]
pil_image_keypoints = Image.fromarray(
bgr_to_rgb(results["stages"]["pose"]["output"]["prediction"].plot())
)
pil_image_rotated = Image.fromarray(
results["stages"]["rotation"]["output"]["array_image"]
)
pil_image_mask = results["stages"]["segmentation"]["output"]["mask"]
pil_image_masked_cropped = results["stages"]["crop"]["output"]["pil_image"]
viz_dict = generate_visualization_fn(
pil_image=pil_image_masked_cropped,
prediction=results["stages"]["identification"]["output"],
cache=CACHE_VISUALIZATION_GENERATION,
)
is_new_individual = (
results["stages"]["identification"]["output"]["type"] == "new"
)
def is_match_fn(
wasserstein_distance: float,
threshold_wasserstein: float,
epsilon: float = 0.1,
) -> str:
if wasserstein_distance < threshold_wasserstein:
return "❌"
elif wasserstein_distance > threshold_wasserstein + epsilon:
return "✅"
else:
return "❓"
samples_top_k_dataset = [
[
is_match_fn(
wasserstein_distance=wasserstein_distance[1],
threshold_wasserstein=model_identification.threshold_wasserstein,
),
Image.open(entry["filepath_crop"]),
viz["keypoints_target"],
viz["matches"],
entry["name"],
entry["pit"],
f"{wasserstein_distance[1]:.2f}",
]
for viz, entry, wasserstein_distance in zip(
viz_dict["top_k"],
results["stages"]["identification"]["output"]["top_k"]["sorted"],
results["stages"]["identification"]["output"]["sorted_wasserstein"],
)
]
return {
text_rotation_angle: f"{math.degrees(theta):0.1f}",
text_side: side.value,
image_pose_keypoints: pil_image_keypoints,
image_rotated_keypoints: pil_image_rotated,
image_segmentation_mask: pil_image_mask,
image_masked: pil_image_masked_cropped,
pit_prediction: (
"New Fish!"
if is_new_individual
else gr.Text(
results["stages"]["identification"]["output"]["match"]["pit"],
visible=True,
)
),
name_prediction: (
gr.Text(visible=False)
if is_new_individual
else gr.Text(
results["stages"]["identification"]["output"]["match"]["name"],
visible=True,
)
),
tab_details: gr.Column(visible=True),
tab_top_k: gr.Column(visible=True),
image_source_orig: gr.Image(pil_image_masked_cropped, visible=True),
image_extracted_keypoints: gr.Image(
viz_dict["source"]["keypoints"], visible=True
),
image_feature_matching: (
gr.Image(visible=False)
if is_new_individual
else gr.Image(viz_dict["top_k"][0]["matches"], visible=True)
),
dataset: gr.Dataset(samples=samples_top_k_dataset, visible=True),
textbox_is_match: gr.Textbox(visible=False),
textbox_pit: gr.Textbox(visible=False),
textbox_name: gr.Textbox(visible=False),
image_orig_selected: gr.Image(visible=False),
image_keypoints_selected: gr.Image(visible=False),
image_matches_selected: gr.Image(visible=False),
}
submit_btn.click(
fn=lambda pil_image: submit_fn(
loaded_models=loaded_models,
orig_image=pil_image,
),
inputs=image_input,
outputs=[
text_rotation_angle,
text_side,
image_pose_keypoints,
image_rotated_keypoints,
image_feature_matching,
image_segmentation_mask,
image_masked,
pit_prediction,
name_prediction,
tab_details,
tab_top_k,
image_source_orig,
image_feature_matching,
image_extracted_keypoints,
dataset,
textbox_is_match,
textbox_pit,
textbox_name,
image_orig_selected,
image_keypoints_selected,
image_matches_selected,
],
)
return demo
if __name__ == "__main__":
device = select_best_device()
# FIXME: get this from the config instead
extractor_type = "aliked"
n_keypoints = 1024
threshold_wasserstein = 0.084
loaded_models = pipeline.load_models(
device=device,
filepath_weights_segmentation_model=FILEPATH_WEIGHTS_SEGMENTATION_MODEL,
filepath_weights_pose_model=FILEPATH_WEIGHTS_POSE_MODEL,
filepath_identification_lightglue_features=FILEPATH_IDENTIFICATION_LIGHTGLUE_FEATURES,
filepath_identification_db=FILEPATH_IDENTIFICATION_DB,
extractor_type=extractor_type,
n_keypoints=n_keypoints,
threshold_wasserstein=threshold_wasserstein,
)
model_segmentation = loaded_models["segmentation"]
example_filepaths = examples(dir_examples=DIR_EXAMPLES)
default_value_input = Image.open(example_filepaths[DEFAULT_IMAGE_INDEX])
demo = make_ui(loaded_models=loaded_models, k=K)
demo.launch()