""" Gradio app to showcase the pyronear model for salmon vision. """ from collections import Counter from pathlib import Path from typing import Any, Tuple import torch import gradio as gr import numpy as np from ultralytics import YOLO def bgr_to_rgb(a: np.ndarray) -> np.ndarray: """ Turn a BGR numpy array into a RGB numpy array when the array `a` represents an image. """ return a[:, :, ::-1] def has_values(maybe_tensor: torch.Tensor | None) -> bool: """ Check whether the `maybe_tensor` contains items. """ if maybe_tensor is None: return False elif isinstance(maybe_tensor, torch.Tensor): if len(maybe_tensor) == 0: return False else: return True def analyze_predictions(yolo_predictions) -> dict[str, Any]: """ Analyze the raw `yolo_predictions` and outputs a dict containg information. Args: yolo_predictions: result of calling model.track() on a video Returns: counts (int): number of distinct identifiers. ids (set[int]): all the assigned identifiers. detected_species (dict[int, int]): mapping from identifier to instance class names (list[str]): the class names used by the model """ if len(yolo_predictions) == 0: return { "counts": 0, "ids": set(), "detected_species": {}, "names": None, } else: names = yolo_predictions[0].names ids = set() for prediction in yolo_predictions: if has_values(prediction.boxes.id): for id in prediction.boxes.id.numpy().astype("int"): ids.add(id.item()) detected_species = {} for id in ids: counter = Counter() for prediction in yolo_predictions: if has_values(prediction.boxes.id): for idd, klass in zip( prediction.boxes.id.numpy().astype("int"), prediction.boxes.cls.numpy().astype("int"), ): if idd.item() == id: counter[klass.item()] += 1 selected_class = counter.most_common(1)[0][0] detected_species[id] = selected_class return { "counts": len(ids), "ids": ids, "detected_species": detected_species, "names": names, } def prediction_to_str(yolo_predictions) -> str: """ Turn the yolo_predictions into a human friendly string. """ if len(yolo_predictions) == 0: return "No prediction" else: result = analyze_predictions(yolo_predictions=yolo_predictions) names = result["names"] detected_species = result["detected_species"] ids = result["ids"] summary_str = "\n".join( [ f"- The fish with id {id} is a {names.get(klass, 'Unknown')}" for id, klass in detected_species.items() ] ) print(summary_str) return f"Detected {len(ids)} salmons in the video clip with ids {ids}:\n{summary_str}" def interface_fn(model: YOLO, video_filepath: Path) -> Tuple[Path, str]: """ Main interface function that runs the model on the provided pil_image and returns the exepected tuple to populate the gradio interface. Args: model (YOLO): Loaded ultralytics YOLO model. pil_image (PIL): image to run inference on. Returns: pil_image_with_prediction (PIL): image with prediction from the model. raw_prediction_str (str): string representing the raw prediction from the model. """ project = "runs/track/" name = video_filepath.stem predictions = model.track( source=video_filepath, save=True, tracker="bytetrack.yaml", exist_ok=True, project=project, name=name, ) filepath_video_prediction = Path(f"{project}/{name}/{name}.avi") raw_prediction_str = prediction_to_str(yolo_predictions=predictions) return (filepath_video_prediction, raw_prediction_str) def examples(dir_examples: Path) -> list[Path]: """ List the images from the dir_examples directory. Returns: filepaths (list[Path]): list of image filepaths. """ return list(dir_examples.glob("*.mp4")) def load_model(filepath_weights: Path) -> YOLO: """ Load the YOLO model given the filepath_weights. """ return YOLO(filepath_weights) # Main Gradio interface MODEL_FILEPATH_WEIGHTS = Path("data/model/weights.pt") DIR_EXAMPLES = Path("data/videos/") DEFAULT_IMAGE_INDEX = 0 with gr.Blocks() as demo: model = load_model(MODEL_FILEPATH_WEIGHTS) videos_filepaths = examples(dir_examples=DIR_EXAMPLES) print(f"videos_filepaths: {videos_filepaths}") default_value_input = videos_filepaths[DEFAULT_IMAGE_INDEX] input = gr.Video( value=default_value_input, format="mp4", autoplay=True, loop=True, label="input video", sources=["upload"], ) output_video = gr.Video( format="mp4", label="model prediction", autoplay=True, loop=True, ) output_raw = gr.Text(label="raw prediction") fn = lambda video_filepath: interface_fn( model=model, video_filepath=Path(video_filepath) ) gr.Interface( title="ML model for wild salmon migration monitoring 🐟", fn=fn, inputs=input, outputs=[output_video, output_raw], examples=videos_filepaths, flagging_mode="never", ) demo.launch()