|
import json |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import os |
|
import requests |
|
from config import Config |
|
from model import BirdAST |
|
import torch |
|
import librosa |
|
import noisereduce as nr |
|
import pandas as pd |
|
import torch.nn.functional as F |
|
import random |
|
from torchaudio.compliance import kaldi |
|
from torchaudio.functional import resample |
|
from transformers import ASTFeatureExtractor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FEATURE_EXTRACTOR = ASTFeatureExtractor() |
|
|
|
def plot_mel(sr, x): |
|
mel_spec = librosa.feature.melspectrogram(y=x, sr=sr, n_mels=224, fmax=10000) |
|
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max) |
|
mel_spec_db = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min()) |
|
mel_spec_db = np.stack([mel_spec_db, mel_spec_db, mel_spec_db], axis=-1) |
|
fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True) |
|
librosa.display.specshow(mel_spec_db[:, :, 0], sr=sr, x_axis='time', y_axis='mel', fmin = 0, fmax=10000, ax = ax) |
|
return fig |
|
|
|
def plot_wave(sr, x): |
|
ry = nr.reduce_noise(y=x, sr=sr) |
|
fig, ax = plt.subplots(2, 1, figsize=(12, 8)) |
|
|
|
|
|
librosa.display.waveshow(x, sr=sr, ax=ax[0]) |
|
ax[0].set(title='Original Waveform') |
|
ax[0].set_xlabel('Time (s)') |
|
ax[0].set_ylabel('Amplitude') |
|
|
|
|
|
librosa.display.waveshow(ry, sr=sr, ax=ax[1]) |
|
ax[1].set(title='Noise Reduced Waveform') |
|
ax[1].set_xlabel('Time (s)') |
|
ax[1].set_ylabel('Amplitude') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def predict(audio, start, end): |
|
sr, x = audio |
|
x = np.array(x, dtype=np.float64)/32768.0 |
|
res = preprocess_for_inference(x, sr) |
|
|
|
if start >= end: |
|
raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)") |
|
|
|
if x.shape[0] < start * sr: |
|
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)") |
|
|
|
if x.shape[0] > end * sr: |
|
end = x.shape[0]/(1.0*sr) |
|
|
|
fig1 = plot_mel(sr, x) |
|
fig2 = plot_wave(sr, x) |
|
|
|
|
|
return res, res, fig1, fig2 |
|
|
|
def download_model(url, model_path): |
|
if not os.path.exists(model_path): |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
with open(model_path, 'wb') as f: |
|
f.write(response.content) |
|
|
|
|
|
model_urls = [f'https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_5folds_fold_{i}.pth' for i in range(5)] |
|
model_paths = [f'BirdAST_Baseline_5folds_fold_{i}.pth' for i in range(5)] |
|
|
|
for (model_url, model_path) in zip(model_urls, model_paths): |
|
download_model(model_url, model_path) |
|
|
|
|
|
eval_models = [BirdAST(Config().backbone_name, Config().n_classes, n_mlp_layers=1, activation='silu') for i in range(5)] |
|
state_dicts = [torch.load(f'BirdAST_Baseline_5folds_fold_{i}.pth', map_location='cpu') for i in range(5)] |
|
for idx, sd in enumerate(state_dicts): |
|
eval_models[idx].load_state_dict(sd) |
|
|
|
|
|
for i in range(5): |
|
eval_models[i].eval() |
|
|
|
|
|
|
|
label_mapping = pd.read_csv('BirdAST_Baseline_5folds_label_map.csv') |
|
species_id_to_name = {row['species_id']: row['scientific_name'] for index, row in label_mapping.iterrows()} |
|
|
|
def preprocess_for_inference(audio_arr, sr): |
|
spec = FEATURE_EXTRACTOR(audio_arr, sampling_rate=sr, padding="max_length", return_tensors="pt") |
|
input_values = spec['input_values'] |
|
|
|
|
|
model_outputs = [] |
|
|
|
with torch.no_grad(): |
|
|
|
for model in eval_models: |
|
output = model(input_values) |
|
predict_score = F.softmax(output['logits'], dim=1) |
|
model_outputs.append(predict_score) |
|
|
|
|
|
avg_predictions = torch.mean(torch.stack(model_outputs), dim=0) |
|
|
|
|
|
topk_values, topk_indices = torch.topk(avg_predictions, 10, dim=1) |
|
|
|
|
|
results = [] |
|
for idx, scores in zip(topk_indices[0], topk_values[0]): |
|
species_name = species_id_to_name[idx.item()] |
|
probability = scores.item() |
|
results.append([species_name, probability]) |
|
|
|
return results |
|
|
|
DESCRIPTION = """ |
|
Bird audio classification using SOTA Voice of Jungle Technology. |
|
""" |
|
|
|
|
|
css = """ |
|
.number-input { |
|
height: 100%; |
|
padding-bottom: 60px; /* Adust the value as needed for more or less space */ |
|
} |
|
.full-height { |
|
height: 100%; |
|
} |
|
.column-container { |
|
height: 100%; |
|
} |
|
""" |
|
with gr.Blocks(css = css) as demo: |
|
gr.Markdown("# Bird Species Audio Classification") |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Row(): |
|
with gr.Column(elem_classes="column-container"): |
|
start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height") |
|
end_time_input = gr.Number(label="End Time", value=1, elem_classes="number-input full-height") |
|
with gr.Column(): |
|
audio_input = gr.Audio(label="Input Audio", elem_classes="full-height") |
|
|
|
|
|
with gr.Row(): |
|
raw_class_output = gr.Dataframe(headers=["class", "score"], row_count=10, label="Class Prediction") |
|
species_output = gr.Dataframe(headers=["class", "score"], row_count=10, label="Species Prediction") |
|
|
|
with gr.Row(): |
|
waveform_output = gr.Plot(label="Waveform") |
|
spectrogram_output = gr.Plot(label="Spectrogram") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["312_Cissopis_leverinia_1.wav", 0, 5], |
|
["1094_Pionus_fuscus_2.wav", 0, 10], |
|
], |
|
inputs=[audio_input, start_time_input, end_time_input] |
|
) |
|
|
|
gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input], [raw_class_output, species_output, waveform_output, spectrogram_output]) |
|
|
|
demo.launch(share = True) |