import json
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import os
import requests
from config import Config
from model import BirdAST
import torch
import librosa
import noisereduce as nr
import timm
from typing import Iterable
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
import time
import pandas as pd
from classpred import predict_class
import torch.nn.functional as F
import random
from torchaudio.compliance import kaldi
from torchaudio.functional import resample
from transformers import ASTFeatureExtractor
def plot_mel(sr, x):
mel_spec = librosa.feature.melspectrogram(y=x, sr=sr, n_mels=128, fmax=10000)
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
mel_spec_db = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min()) # normalize spectrogram to [0,1]
mel_spec_db = np.stack([mel_spec_db, mel_spec_db, mel_spec_db], axis=-1) # Convert to 3-channel
fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True)
librosa.display.specshow(mel_spec_db[:, :, 0], sr=sr, x_axis='time', y_axis='mel', fmin = 0, fmax=10000, ax = ax)
return fig
def plot_wave(sr, x):
ry = nr.reduce_noise(y=x, sr=sr)
fig, ax = plt.subplots(2, 1, figsize=(12, 8))
# Plot the original waveform
librosa.display.waveshow(x, sr=sr, ax=ax[0])
ax[0].set(title='Original Waveform')
ax[0].set_xlabel('Time (s)')
# Plot the noise-reduced waveform
librosa.display.waveshow(ry, sr=sr, ax=ax[1])
ax[1].set(title='Noise Reduced Waveform')
ax[1].set_xlabel('Time (s)')
return fig
def predict(audio, start, end):
sr, x = audio
x = np.array(x, dtype=np.float32)/32768.0
x = x[int(start*sr) : int(end*sr)]
res = preprocess_for_inference(x, sr)
if start >= end:
raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)")
if x.shape[0] < start * sr:
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)")
if x.shape[0] > end * sr:
end = x.shape[0]/(1.0*sr)
fig1 = plot_mel(sr, x)
fig2 = plot_wave(sr, x)
return predict_class(x, sr, start, end), res, fig1, fig2
def download_model(url, model_path):
if not os.path.exists(model_path):
response = requests.get(url)
response.raise_for_status() # Ensure the request was successful
with open(model_path, 'wb') as f:
model_urls = [f'{i}.pth' for i in range(5)]
model_paths = [f'BirdAST_Baseline_GroupKFold_fold_{i}.pth' for i in range(5)]
for (model_url, model_path) in zip(model_urls, model_paths):
download_model(model_url, model_path)
eval_models = [BirdAST(Config().backbone_name, Config().n_classes, n_mlp_layers=1, activation='silu') for i in range(5)]
state_dicts = [torch.load(f'BirdAST_Baseline_GroupKFold_fold_{i}.pth', map_location='cpu') for i in range(5)]
for idx, sd in enumerate(state_dicts):
for i in range(5):
label_mapping = pd.read_csv('BirdAST_Baseline_GroupKFold_label_map.csv')
species_id_to_name = {row['species_id']: row['scientific_name'] for index, row in label_mapping.iterrows()}
def preprocess_for_inference(audio_arr, sr):
spec = FEATURE_EXTRACTOR(audio_arr, sampling_rate=sr, padding="max_length", return_tensors="pt")
input_values = spec['input_values'] # Get the input values prepared for model input
model_outputs = []
with torch.no_grad():
for model in eval_models:
output = model(input_values)
predict_score = F.softmax(output['logits'], dim=1)
print(predict_score[0, 434])
avg_predictions = torch.mean(, dim=0) #.values
topk_values, topk_indices = torch.topk(avg_predictions, 10)
print(topk_values.shape, topk_indices.shape)
results = []
for idx, scores in zip(topk_indices, topk_values):
species_name = species_id_to_name[idx.item()]
probability = scores.item()*100
results.append([species_name, probability])
return results
# Bird audio classification using SOTA Voice of Jungle Technology. \n
# Introduction
It is esimated that 50% of the global economy is threatened by biodiversity loss. As such, efforts have been concerted into estimating bird biodiversity, as birds are a top indicator of biodiversity in the region. One of these efforts is
finding the bird species in a region using bird species audio classification.
Prediction on left table shows prediction on the type of noise (class), while the right predictions are the species of bird. If class prediction does not output bird, then consequently the species prediction is not confident.
css = """
.number-input {
height: 100%;
padding-bottom: 60px; /* Adust the value as needed for more or less space */
.full-height {
height: 100%;
.column-container {
height: 100%;
class Seafoam(Base):
def __init__(
primary_hue: colors.Color | str = colors.emerald,
secondary_hue: colors.Color | str =,
neutral_hue: colors.Color | str = colors.gray,
spacing_size: sizes.Size | str = sizes.spacing_md,
radius_size: sizes.Size | str = sizes.radius_md,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font
| str
| Iterable[fonts.Font | str] = (
font_mono: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
seafoam = Seafoam()
with gr.Blocks(theme=seafoam, css = css) as demo:
with gr.Row():
with gr.Column(elem_classes="column-container"):
start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height")
end_time_input = gr.Number(label="End Time", value=10, elem_classes="number-input full-height")
with gr.Column():
audio_input = gr.Audio(label="Input Audio", elem_classes="full-height")
with gr.Row():
raw_class_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Class Prediction")
species_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Species Prediction")
with gr.Row():
waveform_output = gr.Plot(label="Waveform")
spectrogram_output = gr.Plot(label="Spectrogram")
["312_Cissopis_leverinia_1.wav", 0, 5],
["1094_Pionus_fuscus_2.wav", 0, 10],
inputs=[audio_input, start_time_input, end_time_input]
gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input], [raw_class_output, species_output, waveform_output, spectrogram_output])
demo.launch(share = True) |