File size: 2,969 Bytes
4823bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Gradio app to showcase the pyronear model for early forest fire detection.
"""

from pathlib import Path
from typing import Tuple

import gradio as gr
import numpy as np
from PIL import Image
from ultralytics import YOLO


def bgr_to_rgb(a: np.ndarray) -> np.ndarray:
    """
    Turn a BGR numpy array into a RGB numpy array when the array `a` represents
    an image.
    """
    return a[:, :, ::-1]


def prediction_to_str(yolo_prediction) -> str:
    """
    Turn the yolo_prediction into a human friendly string.
    """
    boxes = yolo_prediction.boxes
    classes = boxes.cls.cpu().numpy().astype(np.int8)
    n_hard_coral = len([c for c in classes if c == 0])
    n_soft_coral = len([c for c in classes if c == 1])

    return f"""{len(boxes.conf)} corals detected:\n- {n_hard_coral} hard corals\n- {n_soft_coral} soft corals"""


def predict(model: YOLO, pil_image: Image.Image) -> Tuple[Image.Image, str]:
    """
    Main interface function that runs the model on the provided pil_image and
    returns the exepected tuple to populate the gradio interface.

    Args:
        model (YOLO): Loaded ultralytics YOLO model.
        pil_image (PIL): image to run inference on.

    Returns:
        pil_image_with_prediction (PIL): image with prediction from the model.
        raw_prediction_str (str): string representing the raw prediction from the
        model.
    """
    predictions = model(pil_image)
    prediction = predictions[0]
    pil_image_with_prediction = Image.fromarray(bgr_to_rgb(prediction.plot()))
    raw_prediction_str = prediction_to_str(prediction)

    return (pil_image_with_prediction, raw_prediction_str)


def examples(dir_examples: Path) -> list[Path]:
    """
    List the images from the dir_examples directory.

    Returns:
        filepaths (list[Path]): list of image filepaths.
    """
    return list(dir_examples.glob("*.jpg"))


def load_model(filepath_weights: Path) -> YOLO:
    """
    Load the YOLO model given the filepath_weights.
    """
    return YOLO(filepath_weights)


# Main Gradio interface

MODEL_FILEPATH_WEIGHTS = Path("data/model/best.pt")
DIR_EXAMPLES = Path("data/images/")
DEFAULT_IMAGE_INDEX = 3

with gr.Blocks() as demo:
    model = load_model(MODEL_FILEPATH_WEIGHTS)
    image_filepaths = examples(dir_examples=DIR_EXAMPLES)
    default_value_input = Image.open(image_filepaths[DEFAULT_IMAGE_INDEX])
    input = gr.Image(
        value=default_value_input,
        type="pil",
        label="input image",
        sources=["upload", "clipboard"],
    )
    output_image = gr.Image(type="pil", label="model prediction")
    output_raw = gr.Text(label="raw prediction")

    fn = lambda pil_image: predict(model=model, pil_image=pil_image)
    gr.Interface(
        title="ML model for benthic imagery segmentation 🪸",
        fn=fn,
        inputs=input,
        outputs=[output_image, output_raw],
        examples=image_filepaths,
        allow_flagging="never",
    )

demo.launch()