feat: use Gallery element to show predictions
Browse files- app.py +11 -24
- requirements.txt +1 -0
- utils.py +1 -1
app.py
CHANGED
@@ -34,7 +34,7 @@ def interface_fn(
|
|
34 |
model: YOLO,
|
35 |
audio_filepath: str,
|
36 |
config_model: dict[str, float | int],
|
37 |
-
) -> Tuple[Image.Image, pd.DataFrame, str]:
|
38 |
"""
|
39 |
Main interface function that runs the model on the provided audio_filepath and
|
40 |
returns the exepected tuple to populate the gradio interface.
|
@@ -60,6 +60,8 @@ def interface_fn(
|
|
60 |
overlap=overlap,
|
61 |
)
|
62 |
|
|
|
|
|
63 |
yolov8_predictions = inference(
|
64 |
model=model,
|
65 |
audio_filepath=Path(audio_filepath),
|
@@ -99,28 +101,12 @@ def interface_fn(
|
|
99 |
|
100 |
spectrograms_pil_images = [Image.fromarray(a) for a in spectrograms_array_images]
|
101 |
|
102 |
-
array_image = waveform_to_np_image(
|
103 |
-
waveform=waveforms[0],
|
104 |
-
sample_rate=sample_rate,
|
105 |
-
n_fft=config_model["n_fft"],
|
106 |
-
hop_length=config_model["hop_length"],
|
107 |
-
freq_max=config_model["freq_max"],
|
108 |
-
width=config_model["width"],
|
109 |
-
height=config_model["height"],
|
110 |
-
)
|
111 |
-
|
112 |
predictions = model.predict(spectrograms_pil_images)
|
113 |
-
|
114 |
-
bgr_to_rgb(
|
115 |
-
|
116 |
-
|
117 |
-
for i in range(1, len(predictions)):
|
118 |
-
pil_image_spectrogram_with_prediction = get_concat_v(
|
119 |
-
pil_image_spectrogram_with_prediction,
|
120 |
-
Image.fromarray(bgr_to_rgb(predictions[i].plot())),
|
121 |
-
)
|
122 |
|
123 |
-
return (
|
124 |
|
125 |
|
126 |
def examples(dir_examples: Path) -> list[Path]:
|
@@ -144,6 +130,7 @@ MODEL_FILEPATH_WEIGHTS = Path("data/model/weights/best.pt")
|
|
144 |
MODEL_FILEPTAH_CONFIG = Path("data/model/config.yaml")
|
145 |
DIR_EXAMPLES = Path("data/sounds/raw")
|
146 |
DEFAULT_VALUE_INDEX = 0
|
|
|
147 |
|
148 |
with gr.Blocks() as demo:
|
149 |
model = load_model(MODEL_FILEPATH_WEIGHTS)
|
@@ -157,10 +144,10 @@ with gr.Blocks() as demo:
|
|
157 |
type="filepath",
|
158 |
label="input audio",
|
159 |
)
|
160 |
-
|
161 |
output_raw = gr.Text(label="raw prediction")
|
162 |
output_dataframe = gr.DataFrame(
|
163 |
-
headers=
|
164 |
label="prediction as CSV",
|
165 |
)
|
166 |
|
@@ -173,7 +160,7 @@ with gr.Blocks() as demo:
|
|
173 |
title="ML model for forest elephant rumble detection π",
|
174 |
fn=fn,
|
175 |
inputs=input,
|
176 |
-
outputs=[
|
177 |
examples=sound_filepaths,
|
178 |
flagging_mode="never",
|
179 |
)
|
|
|
34 |
model: YOLO,
|
35 |
audio_filepath: str,
|
36 |
config_model: dict[str, float | int],
|
37 |
+
) -> Tuple[list[Image.Image], pd.DataFrame, str]:
|
38 |
"""
|
39 |
Main interface function that runs the model on the provided audio_filepath and
|
40 |
returns the exepected tuple to populate the gradio interface.
|
|
|
60 |
overlap=overlap,
|
61 |
)
|
62 |
|
63 |
+
print(f"waveforms: {waveforms}")
|
64 |
+
|
65 |
yolov8_predictions = inference(
|
66 |
model=model,
|
67 |
audio_filepath=Path(audio_filepath),
|
|
|
101 |
|
102 |
spectrograms_pil_images = [Image.fromarray(a) for a in spectrograms_array_images]
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
predictions = model.predict(spectrograms_pil_images)
|
105 |
+
pil_image_spectrogram_with_predictions = [
|
106 |
+
Image.fromarray(bgr_to_rgb(p.plot())) for p in predictions
|
107 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
+
return (pil_image_spectrogram_with_predictions, df[CSV_COLUMNS], prediction_to_str(df=df))
|
110 |
|
111 |
|
112 |
def examples(dir_examples: Path) -> list[Path]:
|
|
|
130 |
MODEL_FILEPTAH_CONFIG = Path("data/model/config.yaml")
|
131 |
DIR_EXAMPLES = Path("data/sounds/raw")
|
132 |
DEFAULT_VALUE_INDEX = 0
|
133 |
+
CSV_COLUMNS = ["t_start", "t_end", "freq_start", "freq_end", "probability"]
|
134 |
|
135 |
with gr.Blocks() as demo:
|
136 |
model = load_model(MODEL_FILEPATH_WEIGHTS)
|
|
|
144 |
type="filepath",
|
145 |
label="input audio",
|
146 |
)
|
147 |
+
output_gallery = gr.Gallery(label="model predictions")
|
148 |
output_raw = gr.Text(label="raw prediction")
|
149 |
output_dataframe = gr.DataFrame(
|
150 |
+
headers=CSV_COLUMNS,
|
151 |
label="prediction as CSV",
|
152 |
)
|
153 |
|
|
|
160 |
title="ML model for forest elephant rumble detection π",
|
161 |
fn=fn,
|
162 |
inputs=input,
|
163 |
+
outputs=[output_gallery, output_dataframe, output_raw],
|
164 |
examples=sound_filepaths,
|
165 |
flagging_mode="never",
|
166 |
)
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
gradio==5.4.*
|
|
|
2 |
torch==2.5.*
|
3 |
torchaudio==2.5.*
|
4 |
torchvision==0.20.*
|
|
|
1 |
gradio==5.4.*
|
2 |
+
pandas==2.2.*
|
3 |
torch==2.5.*
|
4 |
torchaudio==2.5.*
|
5 |
torchvision==0.20.*
|
utils.py
CHANGED
@@ -49,7 +49,7 @@ def chunk(
|
|
49 |
total_seconds = waveform.shape[1] / sample_rate
|
50 |
number_spectrograms = total_seconds / (duration - overlap)
|
51 |
offsets = [
|
52 |
-
idx * (duration - overlap) for idx in range(0, math.
|
53 |
]
|
54 |
return [
|
55 |
clip(
|
|
|
49 |
total_seconds = waveform.shape[1] / sample_rate
|
50 |
number_spectrograms = total_seconds / (duration - overlap)
|
51 |
offsets = [
|
52 |
+
idx * (duration - overlap) for idx in range(0, math.ceil(number_spectrograms))
|
53 |
]
|
54 |
return [
|
55 |
clip(
|