File size: 4,838 Bytes
4aa9c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481da55
4aa9c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481da55
 
4aa9c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481da55
 
 
4aa9c9f
481da55
4aa9c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481da55
4aa9c9f
 
 
 
 
 
 
 
 
 
 
 
 
6d518d4
4aa9c9f
 
481da55
4aa9c9f
 
 
 
 
 
 
 
 
 
 
 
481da55
4aa9c9f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""
Gradio app to showcase the elephant rumbles detector.
"""

from pathlib import Path
from typing import Tuple

import gradio as gr
import pandas as pd
from PIL import Image
from ultralytics import YOLO

from utils import (
    bgr_to_rgb,
    chunk,
    get_concat_v,
    inference,
    load_audio,
    to_dataframe,
    waveform_to_np_image,
    yaml_read,
)


def prediction_to_str(df: pd.DataFrame) -> str:
    """
    Turn the yolo_prediction into a human friendly string.
    """
    n = len(df)
    return f"""{n} elephant rumbles detected in the audio sequence."""


def interface_fn(
    model: YOLO,
    audio_filepath: str,
    config_model: dict[str, float | int],
) -> Tuple[list[Image.Image], pd.DataFrame, str]:
    """
    Main interface function that runs the model on the provided audio_filepath and
    returns the exepected tuple to populate the gradio interface.

    Args:
        model (YOLO): Loaded ultralytics YOLO model.
        audio_filepath (str): audio to run inference on.
        config_model (dict[str, float | int]): config of the model.

    Returns:
        pil_image_spectrogram_with_prediction (PIL): spectrogram with overlaid
        predictions
        df (pd.DataFrame): results postprocessed as a pd.DataFrame
        predition_str (str): some raw prediction for the string.
    """
    overlap = 10.0

    waveform, sample_rate = load_audio(Path(audio_filepath))
    waveforms = chunk(
        waveform=waveform,
        sample_rate=sample_rate,
        duration=config_model["duration"],
        overlap=overlap,
    )

    print(f"waveforms: {waveforms}")

    yolov8_predictions = inference(
        model=model,
        audio_filepath=Path(audio_filepath),
        duration=config_model["duration"],
        overlap=overlap,
        width=config_model["width"],
        height=config_model["height"],
        freq_max=config_model["freq_max"],
        n_fft=config_model["n_fft"],
        hop_length=config_model["hop_length"],
        batch_size=16,
        output_dir=Path("."),
        save_spectrograms=False,
        save_predictions=False,
        verbose=True,
    )
    df = to_dataframe(
        yolov8_predictions=yolov8_predictions,
        duration=config_model["duration"],
        overlap=overlap,
        freq_min=config_model["freq_min"],
        freq_max=config_model["freq_max"],
    )

    spectrograms_array_images = [
        waveform_to_np_image(
            waveform=waveform,
            sample_rate=sample_rate,
            n_fft=config_model["n_fft"],
            hop_length=config_model["hop_length"],
            freq_max=config_model["freq_max"],
            width=config_model["width"],
            height=config_model["height"],
        )
        for waveform in waveforms
    ]

    spectrograms_pil_images = [Image.fromarray(a) for a in spectrograms_array_images]

    predictions = model.predict(spectrograms_pil_images)
    pil_image_spectrogram_with_predictions = [
        Image.fromarray(bgr_to_rgb(p.plot())) for p in predictions
    ]

    return (pil_image_spectrogram_with_predictions, df[CSV_COLUMNS], prediction_to_str(df=df))


def examples(dir_examples: Path) -> list[Path]:
    """
    List the sound filepaths from the dir_examples directory.

    Returns:
        filepaths (list[Path]): list of image filepaths.
    """
    return list(dir_examples.glob("*.wav"))


def load_model(filepath_weights: Path) -> YOLO:
    """
    Load the YOLO model given the filepath_weights.
    """
    return YOLO(filepath_weights)


MODEL_FILEPATH_WEIGHTS = Path("data/model/weights/best.pt")
MODEL_FILEPTAH_CONFIG = Path("data/model/config.yaml")
DIR_EXAMPLES = Path("data/sounds/raw")
DEFAULT_VALUE_INDEX = 0
CSV_COLUMNS = ["t_start", "t_end", "freq_start", "freq_end", "probability"]

with gr.Blocks() as demo:
    model = load_model(MODEL_FILEPATH_WEIGHTS)
    sound_filepaths = examples(dir_examples=DIR_EXAMPLES)
    config_model = yaml_read(MODEL_FILEPTAH_CONFIG)
    print(config_model)
    default_value_input = sound_filepaths[DEFAULT_VALUE_INDEX]
    input = gr.Audio(
        value=default_value_input,
        sources=["upload"],
        type="filepath",
        label="input audio",
    )
    output_gallery = gr.Gallery(label="model predictions", preview=True)
    output_raw = gr.Text(label="raw prediction")
    output_dataframe = gr.DataFrame(
        headers=CSV_COLUMNS,
        label="prediction as CSV",
    )

    fn = lambda audio_filepath: interface_fn(
        model=model,
        audio_filepath=audio_filepath,
        config_model=config_model,
    )
    gr.Interface(
        title="ML model for forest elephant rumble detection 🐘",
        fn=fn,
        inputs=input,
        outputs=[output_gallery, output_dataframe, output_raw],
        examples=sound_filepaths,
        flagging_mode="never",
    )

demo.launch()