achouffe commited on
Commit
481da55
β€’
1 Parent(s): ded1e7b

feat: use Gallery element to show predictions

Browse files
Files changed (3) hide show
  1. app.py +11 -24
  2. requirements.txt +1 -0
  3. utils.py +1 -1
app.py CHANGED
@@ -34,7 +34,7 @@ def interface_fn(
34
  model: YOLO,
35
  audio_filepath: str,
36
  config_model: dict[str, float | int],
37
- ) -> Tuple[Image.Image, pd.DataFrame, str]:
38
  """
39
  Main interface function that runs the model on the provided audio_filepath and
40
  returns the exepected tuple to populate the gradio interface.
@@ -60,6 +60,8 @@ def interface_fn(
60
  overlap=overlap,
61
  )
62
 
 
 
63
  yolov8_predictions = inference(
64
  model=model,
65
  audio_filepath=Path(audio_filepath),
@@ -99,28 +101,12 @@ def interface_fn(
99
 
100
  spectrograms_pil_images = [Image.fromarray(a) for a in spectrograms_array_images]
101
 
102
- array_image = waveform_to_np_image(
103
- waveform=waveforms[0],
104
- sample_rate=sample_rate,
105
- n_fft=config_model["n_fft"],
106
- hop_length=config_model["hop_length"],
107
- freq_max=config_model["freq_max"],
108
- width=config_model["width"],
109
- height=config_model["height"],
110
- )
111
-
112
  predictions = model.predict(spectrograms_pil_images)
113
- pil_image_spectrogram_with_prediction = Image.fromarray(
114
- bgr_to_rgb(predictions[0].plot())
115
- )
116
-
117
- for i in range(1, len(predictions)):
118
- pil_image_spectrogram_with_prediction = get_concat_v(
119
- pil_image_spectrogram_with_prediction,
120
- Image.fromarray(bgr_to_rgb(predictions[i].plot())),
121
- )
122
 
123
- return (pil_image_spectrogram_with_prediction, df, prediction_to_str(df=df))
124
 
125
 
126
  def examples(dir_examples: Path) -> list[Path]:
@@ -144,6 +130,7 @@ MODEL_FILEPATH_WEIGHTS = Path("data/model/weights/best.pt")
144
  MODEL_FILEPTAH_CONFIG = Path("data/model/config.yaml")
145
  DIR_EXAMPLES = Path("data/sounds/raw")
146
  DEFAULT_VALUE_INDEX = 0
 
147
 
148
  with gr.Blocks() as demo:
149
  model = load_model(MODEL_FILEPATH_WEIGHTS)
@@ -157,10 +144,10 @@ with gr.Blocks() as demo:
157
  type="filepath",
158
  label="input audio",
159
  )
160
- output_image = gr.Image(type="pil", label="model prediction")
161
  output_raw = gr.Text(label="raw prediction")
162
  output_dataframe = gr.DataFrame(
163
- headers=["t_start", "t_end", "freq_start", "freq_end", "probability"],
164
  label="prediction as CSV",
165
  )
166
 
@@ -173,7 +160,7 @@ with gr.Blocks() as demo:
173
  title="ML model for forest elephant rumble detection 🐘",
174
  fn=fn,
175
  inputs=input,
176
- outputs=[output_image, output_dataframe, output_raw],
177
  examples=sound_filepaths,
178
  flagging_mode="never",
179
  )
 
34
  model: YOLO,
35
  audio_filepath: str,
36
  config_model: dict[str, float | int],
37
+ ) -> Tuple[list[Image.Image], pd.DataFrame, str]:
38
  """
39
  Main interface function that runs the model on the provided audio_filepath and
40
  returns the exepected tuple to populate the gradio interface.
 
60
  overlap=overlap,
61
  )
62
 
63
+ print(f"waveforms: {waveforms}")
64
+
65
  yolov8_predictions = inference(
66
  model=model,
67
  audio_filepath=Path(audio_filepath),
 
101
 
102
  spectrograms_pil_images = [Image.fromarray(a) for a in spectrograms_array_images]
103
 
 
 
 
 
 
 
 
 
 
 
104
  predictions = model.predict(spectrograms_pil_images)
105
+ pil_image_spectrogram_with_predictions = [
106
+ Image.fromarray(bgr_to_rgb(p.plot())) for p in predictions
107
+ ]
 
 
 
 
 
 
108
 
109
+ return (pil_image_spectrogram_with_predictions, df[CSV_COLUMNS], prediction_to_str(df=df))
110
 
111
 
112
  def examples(dir_examples: Path) -> list[Path]:
 
130
  MODEL_FILEPTAH_CONFIG = Path("data/model/config.yaml")
131
  DIR_EXAMPLES = Path("data/sounds/raw")
132
  DEFAULT_VALUE_INDEX = 0
133
+ CSV_COLUMNS = ["t_start", "t_end", "freq_start", "freq_end", "probability"]
134
 
135
  with gr.Blocks() as demo:
136
  model = load_model(MODEL_FILEPATH_WEIGHTS)
 
144
  type="filepath",
145
  label="input audio",
146
  )
147
+ output_gallery = gr.Gallery(label="model predictions")
148
  output_raw = gr.Text(label="raw prediction")
149
  output_dataframe = gr.DataFrame(
150
+ headers=CSV_COLUMNS,
151
  label="prediction as CSV",
152
  )
153
 
 
160
  title="ML model for forest elephant rumble detection 🐘",
161
  fn=fn,
162
  inputs=input,
163
+ outputs=[output_gallery, output_dataframe, output_raw],
164
  examples=sound_filepaths,
165
  flagging_mode="never",
166
  )
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  gradio==5.4.*
 
2
  torch==2.5.*
3
  torchaudio==2.5.*
4
  torchvision==0.20.*
 
1
  gradio==5.4.*
2
+ pandas==2.2.*
3
  torch==2.5.*
4
  torchaudio==2.5.*
5
  torchvision==0.20.*
utils.py CHANGED
@@ -49,7 +49,7 @@ def chunk(
49
  total_seconds = waveform.shape[1] / sample_rate
50
  number_spectrograms = total_seconds / (duration - overlap)
51
  offsets = [
52
- idx * (duration - overlap) for idx in range(0, math.floor(number_spectrograms))
53
  ]
54
  return [
55
  clip(
 
49
  total_seconds = waveform.shape[1] / sample_rate
50
  number_spectrograms = total_seconds / (duration - overlap)
51
  offsets = [
52
+ idx * (duration - overlap) for idx in range(0, math.ceil(number_spectrograms))
53
  ]
54
  return [
55
  clip(