""" Gradio app to showcase the elephant rumbles detector. """ from pathlib import Path from typing import Tuple import gradio as gr import pandas as pd from PIL import Image from ultralytics import YOLO from utils import ( bgr_to_rgb, chunk, get_concat_v, inference, load_audio, to_dataframe, waveform_to_np_image, yaml_read, ) def prediction_to_str(df: pd.DataFrame) -> str: """ Turn the yolo_prediction into a human friendly string. """ n = len(df) return f"""{n} elephant rumbles detected in the audio sequence.""" def interface_fn( model: YOLO, audio_filepath: str, config_model: dict[str, float | int], ) -> Tuple[list[Image.Image], pd.DataFrame, str]: """ Main interface function that runs the model on the provided audio_filepath and returns the exepected tuple to populate the gradio interface. Args: model (YOLO): Loaded ultralytics YOLO model. audio_filepath (str): audio to run inference on. config_model (dict[str, float | int]): config of the model. Returns: pil_image_spectrogram_with_prediction (PIL): spectrogram with overlaid predictions df (pd.DataFrame): results postprocessed as a pd.DataFrame predition_str (str): some raw prediction for the string. """ overlap = 10.0 waveform, sample_rate = load_audio(Path(audio_filepath)) waveforms = chunk( waveform=waveform, sample_rate=sample_rate, duration=config_model["duration"], overlap=overlap, ) print(f"waveforms: {waveforms}") yolov8_predictions = inference( model=model, audio_filepath=Path(audio_filepath), duration=config_model["duration"], overlap=overlap, width=config_model["width"], height=config_model["height"], freq_max=config_model["freq_max"], n_fft=config_model["n_fft"], hop_length=config_model["hop_length"], batch_size=16, output_dir=Path("."), save_spectrograms=False, save_predictions=False, verbose=True, ) df = to_dataframe( yolov8_predictions=yolov8_predictions, duration=config_model["duration"], overlap=overlap, freq_min=config_model["freq_min"], freq_max=config_model["freq_max"], ) spectrograms_array_images = [ waveform_to_np_image( waveform=waveform, sample_rate=sample_rate, n_fft=config_model["n_fft"], hop_length=config_model["hop_length"], freq_max=config_model["freq_max"], width=config_model["width"], height=config_model["height"], ) for waveform in waveforms ] spectrograms_pil_images = [Image.fromarray(a) for a in spectrograms_array_images] predictions = model.predict(spectrograms_pil_images) pil_image_spectrogram_with_predictions = [ Image.fromarray(bgr_to_rgb(p.plot())) for p in predictions ] return (pil_image_spectrogram_with_predictions, df[CSV_COLUMNS], prediction_to_str(df=df)) def examples(dir_examples: Path) -> list[Path]: """ List the sound filepaths from the dir_examples directory. Returns: filepaths (list[Path]): list of image filepaths. """ return list(dir_examples.glob("*.wav")) def load_model(filepath_weights: Path) -> YOLO: """ Load the YOLO model given the filepath_weights. """ return YOLO(filepath_weights) MODEL_FILEPATH_WEIGHTS = Path("data/model/weights/best.pt") MODEL_FILEPTAH_CONFIG = Path("data/model/config.yaml") DIR_EXAMPLES = Path("data/sounds/raw") DEFAULT_VALUE_INDEX = 0 CSV_COLUMNS = ["t_start", "t_end", "freq_start", "freq_end", "probability"] with gr.Blocks() as demo: model = load_model(MODEL_FILEPATH_WEIGHTS) sound_filepaths = examples(dir_examples=DIR_EXAMPLES) config_model = yaml_read(MODEL_FILEPTAH_CONFIG) print(config_model) default_value_input = sound_filepaths[DEFAULT_VALUE_INDEX] input = gr.Audio( value=default_value_input, sources=["upload"], type="filepath", label="input audio", ) output_gallery = gr.Gallery(label="model predictions", preview=True) output_raw = gr.Text(label="raw prediction") output_dataframe = gr.DataFrame( headers=CSV_COLUMNS, label="prediction as CSV", ) fn = lambda audio_filepath: interface_fn( model=model, audio_filepath=audio_filepath, config_model=config_model, ) gr.Interface( title="ML model for forest elephant rumble detection 🐘", fn=fn, inputs=input, outputs=[output_gallery, output_dataframe, output_raw], examples=sound_filepaths, flagging_mode="never", ) demo.launch()