achouffe commited on
Commit
4823bb1
1 Parent(s): f89b98f

feat: initial app

Browse files
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10.12
README.md CHANGED
@@ -1,7 +1,8 @@
1
  ---
2
  title: Coral Segmentation Reef Support
3
- emoji: 👁
4
- colorFrom: red
 
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.4.0
 
1
  ---
2
  title: Coral Segmentation Reef Support
3
+ emoji: 🪸
4
+ python_version: 3.10.12
5
+ colorFrom: blue
6
  colorTo: pink
7
  sdk: gradio
8
  sdk_version: 5.4.0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app to showcase the pyronear model for early forest fire detection.
3
+ """
4
+
5
+ from pathlib import Path
6
+ from typing import Tuple
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ from PIL import Image
11
+ from ultralytics import YOLO
12
+
13
+
14
+ def bgr_to_rgb(a: np.ndarray) -> np.ndarray:
15
+ """
16
+ Turn a BGR numpy array into a RGB numpy array when the array `a` represents
17
+ an image.
18
+ """
19
+ return a[:, :, ::-1]
20
+
21
+
22
+ def prediction_to_str(yolo_prediction) -> str:
23
+ """
24
+ Turn the yolo_prediction into a human friendly string.
25
+ """
26
+ boxes = yolo_prediction.boxes
27
+ classes = boxes.cls.cpu().numpy().astype(np.int8)
28
+ n_hard_coral = len([c for c in classes if c == 0])
29
+ n_soft_coral = len([c for c in classes if c == 1])
30
+
31
+ return f"""{len(boxes.conf)} corals detected:\n- {n_hard_coral} hard corals\n- {n_soft_coral} soft corals"""
32
+
33
+
34
+ def predict(model: YOLO, pil_image: Image.Image) -> Tuple[Image.Image, str]:
35
+ """
36
+ Main interface function that runs the model on the provided pil_image and
37
+ returns the exepected tuple to populate the gradio interface.
38
+
39
+ Args:
40
+ model (YOLO): Loaded ultralytics YOLO model.
41
+ pil_image (PIL): image to run inference on.
42
+
43
+ Returns:
44
+ pil_image_with_prediction (PIL): image with prediction from the model.
45
+ raw_prediction_str (str): string representing the raw prediction from the
46
+ model.
47
+ """
48
+ predictions = model(pil_image)
49
+ prediction = predictions[0]
50
+ pil_image_with_prediction = Image.fromarray(bgr_to_rgb(prediction.plot()))
51
+ raw_prediction_str = prediction_to_str(prediction)
52
+
53
+ return (pil_image_with_prediction, raw_prediction_str)
54
+
55
+
56
+ def examples(dir_examples: Path) -> list[Path]:
57
+ """
58
+ List the images from the dir_examples directory.
59
+
60
+ Returns:
61
+ filepaths (list[Path]): list of image filepaths.
62
+ """
63
+ return list(dir_examples.glob("*.jpg"))
64
+
65
+
66
+ def load_model(filepath_weights: Path) -> YOLO:
67
+ """
68
+ Load the YOLO model given the filepath_weights.
69
+ """
70
+ return YOLO(filepath_weights)
71
+
72
+
73
+ # Main Gradio interface
74
+
75
+ MODEL_FILEPATH_WEIGHTS = Path("data/model/best.pt")
76
+ DIR_EXAMPLES = Path("data/images/")
77
+ DEFAULT_IMAGE_INDEX = 3
78
+
79
+ with gr.Blocks() as demo:
80
+ model = load_model(MODEL_FILEPATH_WEIGHTS)
81
+ image_filepaths = examples(dir_examples=DIR_EXAMPLES)
82
+ default_value_input = Image.open(image_filepaths[DEFAULT_IMAGE_INDEX])
83
+ input = gr.Image(
84
+ value=default_value_input,
85
+ type="pil",
86
+ label="input image",
87
+ sources=["upload", "clipboard"],
88
+ )
89
+ output_image = gr.Image(type="pil", label="model prediction")
90
+ output_raw = gr.Text(label="raw prediction")
91
+
92
+ fn = lambda pil_image: predict(model=model, pil_image=pil_image)
93
+ gr.Interface(
94
+ title="ML model for benthic imagery segmentation 🪸",
95
+ fn=fn,
96
+ inputs=input,
97
+ outputs=[output_image, output_raw],
98
+ examples=image_filepaths,
99
+ allow_flagging="never",
100
+ )
101
+
102
+ demo.launch()
data/images/1.jpg ADDED
data/images/2.jpg ADDED
data/images/3.jpg ADDED
data/images/4.jpg ADDED
data/images/5.jpg ADDED
data/images/6.jpg ADDED
data/images/7.jpg ADDED
data/images/8.jpg ADDED
data/images/9.jpg ADDED
data/model/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e85ace94202499740f291388e6e58c4f8923094d710d34513ca670f85037caa3
3
+ size 23948067
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ultralytics==8.3.*
2
+ gradio==5.4.*