achouffe's picture
feat: add preview to Gallery Component for spectrograms
6d518d4 verified
raw
history blame
4.84 kB
"""
Gradio app to showcase the elephant rumbles detector.
"""
from pathlib import Path
from typing import Tuple
import gradio as gr
import pandas as pd
from PIL import Image
from ultralytics import YOLO
from utils import (
bgr_to_rgb,
chunk,
get_concat_v,
inference,
load_audio,
to_dataframe,
waveform_to_np_image,
yaml_read,
)
def prediction_to_str(df: pd.DataFrame) -> str:
"""
Turn the yolo_prediction into a human friendly string.
"""
n = len(df)
return f"""{n} elephant rumbles detected in the audio sequence."""
def interface_fn(
model: YOLO,
audio_filepath: str,
config_model: dict[str, float | int],
) -> Tuple[list[Image.Image], pd.DataFrame, str]:
"""
Main interface function that runs the model on the provided audio_filepath and
returns the exepected tuple to populate the gradio interface.
Args:
model (YOLO): Loaded ultralytics YOLO model.
audio_filepath (str): audio to run inference on.
config_model (dict[str, float | int]): config of the model.
Returns:
pil_image_spectrogram_with_prediction (PIL): spectrogram with overlaid
predictions
df (pd.DataFrame): results postprocessed as a pd.DataFrame
predition_str (str): some raw prediction for the string.
"""
overlap = 10.0
waveform, sample_rate = load_audio(Path(audio_filepath))
waveforms = chunk(
waveform=waveform,
sample_rate=sample_rate,
duration=config_model["duration"],
overlap=overlap,
)
print(f"waveforms: {waveforms}")
yolov8_predictions = inference(
model=model,
audio_filepath=Path(audio_filepath),
duration=config_model["duration"],
overlap=overlap,
width=config_model["width"],
height=config_model["height"],
freq_max=config_model["freq_max"],
n_fft=config_model["n_fft"],
hop_length=config_model["hop_length"],
batch_size=16,
output_dir=Path("."),
save_spectrograms=False,
save_predictions=False,
verbose=True,
)
df = to_dataframe(
yolov8_predictions=yolov8_predictions,
duration=config_model["duration"],
overlap=overlap,
freq_min=config_model["freq_min"],
freq_max=config_model["freq_max"],
)
spectrograms_array_images = [
waveform_to_np_image(
waveform=waveform,
sample_rate=sample_rate,
n_fft=config_model["n_fft"],
hop_length=config_model["hop_length"],
freq_max=config_model["freq_max"],
width=config_model["width"],
height=config_model["height"],
)
for waveform in waveforms
]
spectrograms_pil_images = [Image.fromarray(a) for a in spectrograms_array_images]
predictions = model.predict(spectrograms_pil_images)
pil_image_spectrogram_with_predictions = [
Image.fromarray(bgr_to_rgb(p.plot())) for p in predictions
]
return (pil_image_spectrogram_with_predictions, df[CSV_COLUMNS], prediction_to_str(df=df))
def examples(dir_examples: Path) -> list[Path]:
"""
List the sound filepaths from the dir_examples directory.
Returns:
filepaths (list[Path]): list of image filepaths.
"""
return list(dir_examples.glob("*.wav"))
def load_model(filepath_weights: Path) -> YOLO:
"""
Load the YOLO model given the filepath_weights.
"""
return YOLO(filepath_weights)
MODEL_FILEPATH_WEIGHTS = Path("data/model/weights/best.pt")
MODEL_FILEPTAH_CONFIG = Path("data/model/config.yaml")
DIR_EXAMPLES = Path("data/sounds/raw")
DEFAULT_VALUE_INDEX = 0
CSV_COLUMNS = ["t_start", "t_end", "freq_start", "freq_end", "probability"]
with gr.Blocks() as demo:
model = load_model(MODEL_FILEPATH_WEIGHTS)
sound_filepaths = examples(dir_examples=DIR_EXAMPLES)
config_model = yaml_read(MODEL_FILEPTAH_CONFIG)
print(config_model)
default_value_input = sound_filepaths[DEFAULT_VALUE_INDEX]
input = gr.Audio(
value=default_value_input,
sources=["upload"],
type="filepath",
label="input audio",
)
output_gallery = gr.Gallery(label="model predictions", preview=True)
output_raw = gr.Text(label="raw prediction")
output_dataframe = gr.DataFrame(
headers=CSV_COLUMNS,
label="prediction as CSV",
)
fn = lambda audio_filepath: interface_fn(
model=model,
audio_filepath=audio_filepath,
config_model=config_model,
)
gr.Interface(
title="ML model for forest elephant rumble detection 🐘",
fn=fn,
inputs=input,
outputs=[output_gallery, output_dataframe, output_raw],
examples=sound_filepaths,
flagging_mode="never",
)
demo.launch()