achouffe commited on
Commit
b5f64dc
Β·
verified Β·
1 Parent(s): 1a67461

feat: mvp of the space

Browse files
.gitattributes CHANGED
@@ -32,4 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
36
  *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ runs/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10.12
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
  title: Salmon Vision
3
- emoji: πŸš€
4
  colorFrom: red
5
- colorTo: red
 
6
  sdk: gradio
7
  sdk_version: 5.5.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: wild salmon migration monitoring
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Salmon Vision
3
+ emoji: 🐟
4
  colorFrom: red
5
+ colorTo: blue
6
+ python_version: 3.10.12
7
  sdk: gradio
8
  sdk_version: 5.5.0
9
  app_file: app.py
10
  pinned: false
11
+ short_description: Wild salmon migration monitoring
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app to showcase the pyronear model for early forest fire detection.
3
+ """
4
+
5
+ from collections import Counter
6
+ from pathlib import Path
7
+ from typing import Any, Tuple
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ from ultralytics import YOLO
12
+
13
+
14
+ def bgr_to_rgb(a: np.ndarray) -> np.ndarray:
15
+ """
16
+ Turn a BGR numpy array into a RGB numpy array when the array `a` represents
17
+ an image.
18
+ """
19
+ return a[:, :, ::-1]
20
+
21
+
22
+ def analyze_predictions(yolo_predictions) -> dict[str, Any]:
23
+ """
24
+ Analyze the raw `yolo_predictions` and outputs a dict containg information.
25
+
26
+ Args:
27
+ yolo_predictions: result of calling model.track() on a video
28
+
29
+ Returns:
30
+ counts (int): number of distinct identifiers.
31
+ ids (set[int]): all the assigned identifiers.
32
+ detected_species (dict[int, int]): mapping from identifier to instance class
33
+ names (list[str]): the class names used by the model
34
+ """
35
+ if len(yolo_predictions) == 0:
36
+ return {
37
+ "counts": 0,
38
+ "ids": set(),
39
+ "detected_species": {},
40
+ "names": None,
41
+ }
42
+ else:
43
+ names = yolo_predictions[0].names
44
+ ids = set()
45
+ for prediction in yolo_predictions:
46
+ if prediction.boxes.id:
47
+ for id in prediction.boxes.id.numpy().astype("int"):
48
+ ids.add(id.item())
49
+ detected_species = {}
50
+ for id in ids:
51
+ counter = Counter()
52
+ for prediction in yolo_predictions:
53
+ if prediction.boxes.id:
54
+ for idd, klass in zip(
55
+ prediction.boxes.id.numpy().astype("int"),
56
+ prediction.boxes.cls.numpy().astype("int"),
57
+ ):
58
+ if idd.item() == id:
59
+ counter[klass.item()] += 1
60
+ selected_class = counter.most_common(1)[0][0]
61
+ detected_species[id] = selected_class
62
+ return {
63
+ "counts": len(ids),
64
+ "ids": ids,
65
+ "detected_species": detected_species,
66
+ "names": names,
67
+ }
68
+
69
+
70
+ def prediction_to_str(yolo_predictions) -> str:
71
+ """
72
+ Turn the yolo_predictions into a human friendly string.
73
+ """
74
+ if len(yolo_predictions) == 0:
75
+ return "No prediction"
76
+ else:
77
+ result = analyze_predictions(yolo_predictions=yolo_predictions)
78
+ names = result["names"]
79
+ detected_species = result["detected_species"]
80
+ ids = result["ids"]
81
+ summary_str = "\n".join(
82
+ [
83
+ f"- The fish with id {id} is a {names.get(klass, 'Unknown')}"
84
+ for id, klass in detected_species.items()
85
+ ]
86
+ )
87
+ print(summary_str)
88
+ return f"Detected {len(ids)} salmons in the video clip with ids {ids}:\n{summary_str}"
89
+
90
+
91
+ def predict(model: YOLO, video_filepath: Path) -> Tuple[Path, str]:
92
+ """
93
+ Main interface function that runs the model on the provided pil_image and
94
+ returns the exepected tuple to populate the gradio interface.
95
+
96
+ Args:
97
+ model (YOLO): Loaded ultralytics YOLO model.
98
+ pil_image (PIL): image to run inference on.
99
+
100
+ Returns:
101
+ pil_image_with_prediction (PIL): image with prediction from the model.
102
+ raw_prediction_str (str): string representing the raw prediction from the
103
+ model.
104
+ """
105
+ project = "runs/track/"
106
+ name = video_filepath.stem
107
+ predictions = model.track(
108
+ source=video_filepath,
109
+ save=True,
110
+ tracker="bytetrack.yaml",
111
+ exist_ok=True,
112
+ project=project,
113
+ name=name,
114
+ )
115
+ filepath_video_prediction = Path(f"{project}/{name}/{name}.avi")
116
+ raw_prediction_str = prediction_to_str(yolo_predictions=predictions)
117
+ return (filepath_video_prediction, raw_prediction_str)
118
+
119
+
120
+ def examples(dir_examples: Path) -> list[Path]:
121
+ """
122
+ List the images from the dir_examples directory.
123
+
124
+ Returns:
125
+ filepaths (list[Path]): list of image filepaths.
126
+ """
127
+ return list(dir_examples.glob("*.mp4"))
128
+
129
+
130
+ def load_model(filepath_weights: Path) -> YOLO:
131
+ """
132
+ Load the YOLO model given the filepath_weights.
133
+ """
134
+ return YOLO(filepath_weights)
135
+
136
+
137
+ # Main Gradio interface
138
+
139
+ MODEL_FILEPATH_WEIGHTS = Path("data/model/weights.pt")
140
+ DIR_EXAMPLES = Path("data/videos/")
141
+ DEFAULT_IMAGE_INDEX = 0
142
+
143
+ with gr.Blocks() as demo:
144
+ model = load_model(MODEL_FILEPATH_WEIGHTS)
145
+ videos_filepaths = examples(dir_examples=DIR_EXAMPLES)
146
+ print(f"videos_filepaths: {videos_filepaths}")
147
+ default_value_input = videos_filepaths[DEFAULT_IMAGE_INDEX]
148
+ input = gr.Video(
149
+ value=default_value_input,
150
+ format="mp4",
151
+ label="input video",
152
+ sources=["upload"],
153
+ )
154
+ output_video = gr.Video(format="mp4", label="model prediction")
155
+ output_raw = gr.Text(label="raw prediction")
156
+
157
+ fn = lambda video_filepath: predict(
158
+ model=model, video_filepath=Path(video_filepath)
159
+ )
160
+ gr.Interface(
161
+ title="ML model for wild salmon migration monitoring 🐟",
162
+ fn=fn,
163
+ inputs=input,
164
+ outputs=[output_video, output_raw],
165
+ examples=videos_filepaths,
166
+ flagging_mode="never",
167
+ )
168
+
169
+ demo.launch()
data/model/weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33d5e39e94c3f94badb476743ec9773f5df09b3f2755379f43b3a594fa755bd2
3
+ size 6239129
data/videos/video1-clip.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fb22eeeb3be60b9cb2a65d54a2d6c5379e9e65c47d7f2ce88eab0986f069167
3
+ size 2966504
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ultralytics==8.3.*
2
+ gradio==5.4.*