import json import gradio as gr import matplotlib.pyplot as plt import numpy as np import os import requests from config import Config from model import BirdAST import torch import librosa import noisereduce as nr import pandas as pd import torch.nn.functional as F import random from torchaudio.compliance import kaldi from torchaudio.functional import resample from transformers import ASTFeatureExtractor #TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k" #MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval() #LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json" #AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values()) FEATURE_EXTRACTOR = ASTFeatureExtractor() def plot_mel(sr, x): mel_spec = librosa.feature.melspectrogram(y=x, sr=sr, n_mels=224, fmax=10000) mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max) mel_spec_db = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min()) # normalize spectrogram to [0,1] mel_spec_db = np.stack([mel_spec_db, mel_spec_db, mel_spec_db], axis=-1) # Convert to 3-channel fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True) librosa.display.specshow(mel_spec_db[:, :, 0], sr=sr, x_axis='time', y_axis='mel', fmin = 0, fmax=10000, ax = ax) return fig def plot_wave(sr, x): ry = nr.reduce_noise(y=x, sr=sr) fig, ax = plt.subplots(2, 1, figsize=(12, 8)) # Plot the original waveform librosa.display.waveshow(x, sr=sr, ax=ax[0]) ax[0].set(title='Original Waveform') ax[0].set_xlabel('Time (s)') ax[0].set_ylabel('Amplitude') # Plot the noise-reduced waveform librosa.display.waveshow(ry, sr=sr, ax=ax[1]) ax[1].set(title='Noise Reduced Waveform') ax[1].set_xlabel('Time (s)') ax[1].set_ylabel('Amplitude') plt.tight_layout() return fig def predict(audio, start, end): sr, x = audio x = np.array(x, dtype=np.float64)/32768.0 res = preprocess_for_inference(x, sr) if start >= end: raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)") if x.shape[0] < start * sr: raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)") if x.shape[0] > end * sr: end = x.shape[0]/(1.0*sr) fig1 = plot_mel(sr, x) fig2 = plot_wave(sr, x) return res, res, fig1, fig2 def download_model(url, model_path): if not os.path.exists(model_path): response = requests.get(url) response.raise_for_status() # Ensure the request was successful with open(model_path, 'wb') as f: f.write(response.content) # Model URL and path model_urls = [f'https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_5folds_fold_{i}.pth' for i in range(5)] model_paths = [f'BirdAST_Baseline_5folds_fold_{i}.pth' for i in range(5)] for (model_url, model_path) in zip(model_urls, model_paths): download_model(model_url, model_path) # Load the model (assumes you have the model architecture defined) eval_models = [BirdAST(Config().backbone_name, Config().n_classes, n_mlp_layers=1, activation='silu') for i in range(5)] state_dicts = [torch.load(f'BirdAST_Baseline_5folds_fold_{i}.pth', map_location='cpu') for i in range(5)] for idx, sd in enumerate(state_dicts): eval_models[idx].load_state_dict(sd) # Set to evaluation mode for i in range(5): eval_models[i].eval() # Load the species mapping # label_mapping = pd.read_csv('label_mapping.csv') label_mapping = pd.read_csv('BirdAST_Baseline_5folds_label_map.csv') species_id_to_name = {row['species_id']: row['scientific_name'] for index, row in label_mapping.iterrows()} def preprocess_for_inference(audio_arr, sr): spec = FEATURE_EXTRACTOR(audio_arr, sampling_rate=sr, padding="max_length", return_tensors="pt") input_values = spec['input_values'] # Get the input values prepared for model input # Initialize a list to store predictions from all models model_outputs = [] with torch.no_grad(): # Accumulate predictions from each model for model in eval_models: output = model(input_values) predict_score = F.softmax(output['logits'], dim=1) model_outputs.append(predict_score) # Average the predictions across all models avg_predictions = torch.mean(torch.stack(model_outputs), dim=0) # Get the top 10 predictions based on the average prediction scores topk_values, topk_indices = torch.topk(avg_predictions, 10, dim=1) # Initialize results list to store the species names and their associated probabilities results = [] for idx, scores in zip(topk_indices[0], topk_values[0]): species_name = species_id_to_name[idx.item()] probability = scores.item() results.append([species_name, probability]) return results DESCRIPTION = """ Bird audio classification using SOTA Voice of Jungle Technology. """ css = """ .number-input { height: 100%; padding-bottom: 60px; /* Adust the value as needed for more or less space */ } .full-height { height: 100%; } .column-container { height: 100%; } """ with gr.Blocks(css = css) as demo: gr.Markdown("# Bird Species Audio Classification") gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(elem_classes="column-container"): start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height") end_time_input = gr.Number(label="End Time", value=1, elem_classes="number-input full-height") with gr.Column(): audio_input = gr.Audio(label="Input Audio", elem_classes="full-height") with gr.Row(): raw_class_output = gr.Dataframe(headers=["class", "score"], row_count=10, label="Class Prediction") species_output = gr.Dataframe(headers=["class", "score"], row_count=10, label="Species Prediction") with gr.Row(): waveform_output = gr.Plot(label="Waveform") spectrogram_output = gr.Plot(label="Spectrogram") gr.Examples( examples=[ ["312_Cissopis_leverinia_1.wav", 0, 5], ["1094_Pionus_fuscus_2.wav", 0, 10], ], inputs=[audio_input, start_time_input, end_time_input] ) gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input], [raw_class_output, species_output, waveform_output, spectrogram_output]) demo.launch(share = True)