Spaces:
Running
Running
""" | |
Simple Gradio Interface to showcase the ML outputs in a UI and webapp. | |
""" | |
import math | |
from pathlib import Path | |
from typing import Any | |
import gradio as gr | |
from PIL import Image | |
from ultralytics import YOLO | |
import pipeline | |
from identification import IdentificationModel, generate_visualization | |
from utils import bgr_to_rgb, select_best_device | |
DEFAULT_IMAGE_INDEX = 0 | |
K = 5 | |
DIR_INSTALLED_PIPELINE = Path("./data/pipeline/") | |
DIR_EXAMPLES = Path("./data/images/") | |
FILEPATH_IDENTIFICATION_LIGHTGLUE_CONFIG = ( | |
DIR_INSTALLED_PIPELINE / "models/identification/config.yaml" | |
) | |
FILEPATH_IDENTIFICATION_DB = DIR_INSTALLED_PIPELINE / "db/db.csv" | |
FILEPATH_IDENTIFICATION_LIGHTGLUE_FEATURES = ( | |
DIR_INSTALLED_PIPELINE / "models/identification/features.pt" | |
) | |
FILEPATH_WEIGHTS_SEGMENTATION_MODEL = ( | |
DIR_INSTALLED_PIPELINE / "models/segmentation/weights.pt" | |
) | |
FILEPATH_WEIGHTS_POSE_MODEL = DIR_INSTALLED_PIPELINE / "models/pose/weights.pt" | |
CACHE_PIPELINE_RUN = {} | |
CACHE_VISUALIZATION_GENERATION = {} | |
def pipeline_run_fn( | |
loaded_models: dict[str, YOLO | IdentificationModel], | |
pil_image: Image.Image, | |
cache: dict, | |
k: int, | |
) -> dict[str, Any]: | |
""" | |
A simple cached version of the pipeline.run function. | |
__Note__: It keeps the cache in memory and can grow unbounded. | |
""" | |
bytes_image = pil_image.tobytes() | |
if bytes_image in cache: | |
return cache[bytes_image] | |
else: | |
results = pipeline.run( | |
loaded_models=loaded_models, pil_image=pil_image, param_k=k | |
) | |
cache[bytes_image] = results | |
return results | |
def generate_visualization_fn( | |
pil_image: Image.Image, | |
prediction: dict[str, Any], | |
cache: dict, | |
) -> dict: | |
""" | |
A simple cached version of the generate_visualization function. | |
__Note__: It keeps the cache in memory and can grow unbounded. | |
""" | |
bytes_image = pil_image.tobytes() | |
if bytes_image in cache: | |
return cache[bytes_image] | |
else: | |
results = generate_visualization( | |
pil_image=pil_image, | |
prediction=prediction, | |
) | |
cache[bytes_image] = results | |
return results | |
def examples(dir_examples: Path) -> list[Path]: | |
""" | |
Function to retrieve the default example images. | |
Returns: | |
examples (list[Path]): list of image filepaths. | |
""" | |
return list(dir_examples.glob("*.jpg")) | |
def make_ui(loaded_models: dict[str, YOLO | IdentificationModel], k: int): | |
""" | |
Main entrypoint to wire up the Gradio interface. | |
Args: | |
loaded_models (dict[str, YOLO | IdentificationModel]): loaded models ready to run inference with. | |
Returns: | |
gradio_ui | |
""" | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image( | |
type="pil", | |
value=default_value_input, | |
label="input image", | |
sources=["upload", "clipboard"], | |
) | |
gr.Examples( | |
examples=example_filepaths, | |
inputs=image_input, | |
) | |
submit_btn = gr.Button(value="Identify", variant="primary") | |
with gr.Column(): | |
with gr.Tab("Prediction"): | |
with gr.Row(): | |
pit_prediction = gr.Text(label="predicted individual") | |
name_prediction = gr.Text(label="fish name", visible=False) | |
image_feature_matching = gr.Image( | |
label="pattern matching", visible=False | |
) | |
image_extracted_keypoints = gr.Image( | |
label="extracted keypoints", visible=False | |
) | |
with gr.Tab(f"Top {K}", visible=False) as tab_top_k: | |
textbox_dataset_name_template = gr.Textbox(visible=False) | |
textbox_dataset_pit_template = gr.Textbox(visible=False) | |
textbox_dataset_wasserstein_template = gr.Textbox(visible=False) | |
textbox_dataset_is_match_template = gr.Textbox(visible=False) | |
image_dataset_orig_template = gr.Image(visible=False) | |
image_dataset_keypoints_template = gr.Image(visible=False) | |
image_dataset_matches_template = gr.Image(visible=False) | |
image_source_orig = gr.Image(visible=False) | |
gr.Markdown( | |
"Select a row from the table below to see the matches and extracted keypoints." | |
) | |
dataset = gr.Dataset( | |
type="values", | |
label=f"dataset top {k}", | |
headers=[ | |
"Match?", | |
"Image", | |
"Keypoints", | |
"Matches", | |
"Name", | |
"PIT", | |
"Wasserstein Distance", | |
], | |
components=[ | |
textbox_dataset_is_match_template, | |
image_dataset_orig_template, | |
image_dataset_keypoints_template, | |
image_dataset_matches_template, | |
textbox_dataset_name_template, | |
textbox_dataset_pit_template, | |
textbox_dataset_wasserstein_template, | |
], | |
samples=[], | |
) | |
with gr.Row(): | |
textbox_is_match = gr.Textbox( | |
label="Match?", | |
visible=False, | |
scale=1, | |
) | |
textbox_pit = gr.Textbox(label="PIT", visible=False, scale=5) | |
textbox_name = gr.Textbox(label="Name", visible=False, scale=5) | |
with gr.Row(): | |
image_orig_selected = gr.Image(visible=False) | |
with gr.Column(): | |
image_matches_selected = gr.Image(visible=False) | |
image_keypoints_selected = gr.Image(visible=False) | |
with gr.Tab("Details", visible=False) as tab_details: | |
with gr.Column(): | |
with gr.Row(): | |
text_rotation_angle = gr.Text( | |
label="correction angle (degrees)" | |
) | |
text_side = gr.Text(label="predicted side") | |
image_pose_keypoints = gr.Image( | |
type="pil", label="pose keypoints" | |
) | |
image_rotated_keypoints = gr.Image( | |
type="pil", label="rotated keypoints" | |
) | |
image_segmentation_mask = gr.Image(type="pil", label="mask") | |
image_masked = gr.Image(type="pil", label="masked") | |
def select_fn(inputs: list, evt: gr.EventData): | |
""" | |
Select a row in the dataset of the top5 individuals. | |
""" | |
return [ | |
gr.Textbox(inputs[0], label="Match", visible=True), | |
gr.Image(label="Image", value=inputs[1], visible=True), | |
gr.Image(label="keypoints", value=inputs[2], visible=True), | |
gr.Image(label="matches", value=inputs[3], visible=True), | |
gr.Textbox(inputs[4], label="Name", visible=True), | |
gr.Textbox(inputs[5], label="PIT", visible=True), | |
] | |
dataset.click( | |
fn=select_fn, | |
queue=False, | |
inputs=[dataset], | |
outputs=[ | |
textbox_is_match, | |
image_orig_selected, | |
image_keypoints_selected, | |
image_matches_selected, | |
textbox_name, | |
textbox_pit, | |
], | |
) | |
def submit_fn( | |
loaded_models: dict[str, YOLO | IdentificationModel], | |
orig_image: Image.Image, | |
): | |
""" | |
Main function used for the Gradio interface. | |
Args: | |
loaded_models (dict[str, YOLO]): loaded models. | |
orig_image (PIL): original image picked by the user | |
Returns: | |
fish side (str): predicted fish side | |
correction angle (str): rotation to do in degrees to re align the image. | |
keypoints image (PIL): image displaying the bbox and keypoints from the | |
pose estimation model. | |
rotated image (PIL): rotated image after applying the correction angle. | |
segmentation mask (PIL): segmentation mask predicted by the segmentation model. | |
segmented image (PIL): segmented orig_image using the segmentation mask | |
and the crop. | |
predicted_individual (str): The identified individual. | |
pil_image_extracted_keypoints (PIL): The extracted keypoints overlayed on the image. | |
feature_matching_image (PIL): The matching of the source with the identified individual. | |
""" | |
model_identification = loaded_models["identification"] | |
results = pipeline_run_fn( | |
loaded_models=loaded_models, | |
pil_image=orig_image, | |
cache=CACHE_PIPELINE_RUN, | |
k=K, | |
) | |
side = results["stages"]["pose"]["output"]["side"] | |
theta = results["stages"]["pose"]["output"]["theta"] | |
pil_image_keypoints = Image.fromarray( | |
bgr_to_rgb(results["stages"]["pose"]["output"]["prediction"].plot()) | |
) | |
pil_image_rotated = Image.fromarray( | |
results["stages"]["rotation"]["output"]["array_image"] | |
) | |
pil_image_mask = results["stages"]["segmentation"]["output"]["mask"] | |
pil_image_masked_cropped = results["stages"]["crop"]["output"]["pil_image"] | |
viz_dict = generate_visualization_fn( | |
pil_image=pil_image_masked_cropped, | |
prediction=results["stages"]["identification"]["output"], | |
cache=CACHE_VISUALIZATION_GENERATION, | |
) | |
is_new_individual = ( | |
results["stages"]["identification"]["output"]["type"] == "new" | |
) | |
def is_match_fn( | |
wasserstein_distance: float, | |
threshold_wasserstein: float, | |
epsilon: float = 0.1, | |
) -> str: | |
if wasserstein_distance < threshold_wasserstein: | |
return "❌" | |
elif wasserstein_distance > threshold_wasserstein + epsilon: | |
return "✅" | |
else: | |
return "❓" | |
samples_top_k_dataset = [ | |
[ | |
is_match_fn( | |
wasserstein_distance=wasserstein_distance[1], | |
threshold_wasserstein=model_identification.threshold_wasserstein, | |
), | |
Image.open(entry["filepath_crop"]), | |
viz["keypoints_target"], | |
viz["matches"], | |
entry["name"], | |
entry["pit"], | |
f"{wasserstein_distance[1]:.2f}", | |
] | |
for viz, entry, wasserstein_distance in zip( | |
viz_dict["top_k"], | |
results["stages"]["identification"]["output"]["top_k"]["sorted"], | |
results["stages"]["identification"]["output"]["sorted_wasserstein"], | |
) | |
] | |
return { | |
text_rotation_angle: f"{math.degrees(theta):0.1f}", | |
text_side: side.value, | |
image_pose_keypoints: pil_image_keypoints, | |
image_rotated_keypoints: pil_image_rotated, | |
image_segmentation_mask: pil_image_mask, | |
image_masked: pil_image_masked_cropped, | |
pit_prediction: ( | |
"New Fish!" | |
if is_new_individual | |
else gr.Text( | |
results["stages"]["identification"]["output"]["match"]["pit"], | |
visible=True, | |
) | |
), | |
name_prediction: ( | |
gr.Text(visible=False) | |
if is_new_individual | |
else gr.Text( | |
results["stages"]["identification"]["output"]["match"]["name"], | |
visible=True, | |
) | |
), | |
tab_details: gr.Column(visible=True), | |
tab_top_k: gr.Column(visible=True), | |
image_source_orig: gr.Image(pil_image_masked_cropped, visible=True), | |
image_extracted_keypoints: gr.Image( | |
viz_dict["source"]["keypoints"], visible=True | |
), | |
image_feature_matching: ( | |
gr.Image(visible=False) | |
if is_new_individual | |
else gr.Image(viz_dict["top_k"][0]["matches"], visible=True) | |
), | |
dataset: gr.Dataset(samples=samples_top_k_dataset, visible=True), | |
textbox_is_match: gr.Textbox(visible=False), | |
textbox_pit: gr.Textbox(visible=False), | |
textbox_name: gr.Textbox(visible=False), | |
image_orig_selected: gr.Image(visible=False), | |
image_keypoints_selected: gr.Image(visible=False), | |
image_matches_selected: gr.Image(visible=False), | |
} | |
submit_btn.click( | |
fn=lambda pil_image: submit_fn( | |
loaded_models=loaded_models, | |
orig_image=pil_image, | |
), | |
inputs=image_input, | |
outputs=[ | |
text_rotation_angle, | |
text_side, | |
image_pose_keypoints, | |
image_rotated_keypoints, | |
image_feature_matching, | |
image_segmentation_mask, | |
image_masked, | |
pit_prediction, | |
name_prediction, | |
tab_details, | |
tab_top_k, | |
image_source_orig, | |
image_feature_matching, | |
image_extracted_keypoints, | |
dataset, | |
textbox_is_match, | |
textbox_pit, | |
textbox_name, | |
image_orig_selected, | |
image_keypoints_selected, | |
image_matches_selected, | |
], | |
) | |
return demo | |
if __name__ == "__main__": | |
device = select_best_device() | |
# FIXME: get this from the config instead | |
extractor_type = "aliked" | |
n_keypoints = 1024 | |
threshold_wasserstein = 0.084 | |
loaded_models = pipeline.load_models( | |
device=device, | |
filepath_weights_segmentation_model=FILEPATH_WEIGHTS_SEGMENTATION_MODEL, | |
filepath_weights_pose_model=FILEPATH_WEIGHTS_POSE_MODEL, | |
filepath_identification_lightglue_features=FILEPATH_IDENTIFICATION_LIGHTGLUE_FEATURES, | |
filepath_identification_db=FILEPATH_IDENTIFICATION_DB, | |
extractor_type=extractor_type, | |
n_keypoints=n_keypoints, | |
threshold_wasserstein=threshold_wasserstein, | |
) | |
model_segmentation = loaded_models["segmentation"] | |
example_filepaths = examples(dir_examples=DIR_EXAMPLES) | |
default_value_input = Image.open(example_filepaths[DEFAULT_IMAGE_INDEX]) | |
demo = make_ui(loaded_models=loaded_models, k=K) | |
demo.launch() | |