|
""" |
|
Gradio app to showcase the elephant rumbles detector. |
|
""" |
|
|
|
from pathlib import Path |
|
from typing import Tuple |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
from PIL import Image |
|
from ultralytics import YOLO |
|
|
|
from utils import ( |
|
bgr_to_rgb, |
|
chunk, |
|
get_concat_v, |
|
inference, |
|
load_audio, |
|
to_dataframe, |
|
waveform_to_np_image, |
|
yaml_read, |
|
) |
|
|
|
|
|
def prediction_to_str(df: pd.DataFrame) -> str: |
|
""" |
|
Turn the yolo_prediction into a human friendly string. |
|
""" |
|
n = len(df) |
|
return f"""{n} elephant rumbles detected in the audio sequence.""" |
|
|
|
|
|
def interface_fn( |
|
model: YOLO, |
|
audio_filepath: str, |
|
config_model: dict[str, float | int], |
|
) -> Tuple[list[Image.Image], pd.DataFrame, str]: |
|
""" |
|
Main interface function that runs the model on the provided audio_filepath and |
|
returns the exepected tuple to populate the gradio interface. |
|
|
|
Args: |
|
model (YOLO): Loaded ultralytics YOLO model. |
|
audio_filepath (str): audio to run inference on. |
|
config_model (dict[str, float | int]): config of the model. |
|
|
|
Returns: |
|
pil_image_spectrogram_with_prediction (PIL): spectrogram with overlaid |
|
predictions |
|
df (pd.DataFrame): results postprocessed as a pd.DataFrame |
|
predition_str (str): some raw prediction for the string. |
|
""" |
|
overlap = 10.0 |
|
|
|
waveform, sample_rate = load_audio(Path(audio_filepath)) |
|
waveforms = chunk( |
|
waveform=waveform, |
|
sample_rate=sample_rate, |
|
duration=config_model["duration"], |
|
overlap=overlap, |
|
) |
|
|
|
print(f"waveforms: {waveforms}") |
|
|
|
yolov8_predictions = inference( |
|
model=model, |
|
audio_filepath=Path(audio_filepath), |
|
duration=config_model["duration"], |
|
overlap=overlap, |
|
width=config_model["width"], |
|
height=config_model["height"], |
|
freq_max=config_model["freq_max"], |
|
n_fft=config_model["n_fft"], |
|
hop_length=config_model["hop_length"], |
|
batch_size=16, |
|
output_dir=Path("."), |
|
save_spectrograms=False, |
|
save_predictions=False, |
|
verbose=True, |
|
) |
|
df = to_dataframe( |
|
yolov8_predictions=yolov8_predictions, |
|
duration=config_model["duration"], |
|
overlap=overlap, |
|
freq_min=config_model["freq_min"], |
|
freq_max=config_model["freq_max"], |
|
) |
|
|
|
spectrograms_array_images = [ |
|
waveform_to_np_image( |
|
waveform=waveform, |
|
sample_rate=sample_rate, |
|
n_fft=config_model["n_fft"], |
|
hop_length=config_model["hop_length"], |
|
freq_max=config_model["freq_max"], |
|
width=config_model["width"], |
|
height=config_model["height"], |
|
) |
|
for waveform in waveforms |
|
] |
|
|
|
spectrograms_pil_images = [Image.fromarray(a) for a in spectrograms_array_images] |
|
|
|
predictions = model.predict(spectrograms_pil_images) |
|
pil_image_spectrogram_with_predictions = [ |
|
Image.fromarray(bgr_to_rgb(p.plot())) for p in predictions |
|
] |
|
|
|
return (pil_image_spectrogram_with_predictions, df[CSV_COLUMNS], prediction_to_str(df=df)) |
|
|
|
|
|
def examples(dir_examples: Path) -> list[Path]: |
|
""" |
|
List the sound filepaths from the dir_examples directory. |
|
|
|
Returns: |
|
filepaths (list[Path]): list of image filepaths. |
|
""" |
|
return list(dir_examples.glob("*.wav")) |
|
|
|
|
|
def load_model(filepath_weights: Path) -> YOLO: |
|
""" |
|
Load the YOLO model given the filepath_weights. |
|
""" |
|
return YOLO(filepath_weights) |
|
|
|
|
|
MODEL_FILEPATH_WEIGHTS = Path("data/model/weights/best.pt") |
|
MODEL_FILEPTAH_CONFIG = Path("data/model/config.yaml") |
|
DIR_EXAMPLES = Path("data/sounds/raw") |
|
DEFAULT_VALUE_INDEX = 0 |
|
CSV_COLUMNS = ["t_start", "t_end", "freq_start", "freq_end", "probability"] |
|
|
|
with gr.Blocks() as demo: |
|
model = load_model(MODEL_FILEPATH_WEIGHTS) |
|
sound_filepaths = examples(dir_examples=DIR_EXAMPLES) |
|
config_model = yaml_read(MODEL_FILEPTAH_CONFIG) |
|
print(config_model) |
|
default_value_input = sound_filepaths[DEFAULT_VALUE_INDEX] |
|
input = gr.Audio( |
|
value=default_value_input, |
|
sources=["upload"], |
|
type="filepath", |
|
label="input audio", |
|
) |
|
output_gallery = gr.Gallery(label="model predictions") |
|
output_raw = gr.Text(label="raw prediction") |
|
output_dataframe = gr.DataFrame( |
|
headers=CSV_COLUMNS, |
|
label="prediction as CSV", |
|
) |
|
|
|
fn = lambda audio_filepath: interface_fn( |
|
model=model, |
|
audio_filepath=audio_filepath, |
|
config_model=config_model, |
|
) |
|
gr.Interface( |
|
title="ML model for forest elephant rumble detection π", |
|
fn=fn, |
|
inputs=input, |
|
outputs=[output_gallery, output_dataframe, output_raw], |
|
examples=sound_filepaths, |
|
flagging_mode="never", |
|
) |
|
|
|
demo.launch() |
|
|