Feature Extraction
PyTorch
Bioacoustics

ProtoCLR

This repository contains a CvT-13 Convolutional Vision Transformer model trained from scratch on the Xeno-Canto dataset, specifically on 6-second audio segments sampled at 16 kHz. The model is trained on Mel spectrograms of bird sounds using ProtoCLR (Prototypical Contrastive Loss) for 300 epochs and can be used as a feature extractor for bird audio classification and related tasks.

Files

  • cvt.py: Defines the CvT-13 model architecture.
  • protoclr.pth: Pre-trained model weights for ProtoCLR.
  • config/: Configuration files for CvT-13 setup.
  • mel_spectrogram.py: Contains the MelSpectrogramProcessor class, which converts audio waveforms into Mel spectrograms, a format suitable for model input.

Setup

  1. Clone this repository: Clone the repository and navigate into the project directory: git clone https://huggingface.co/ilyassmoummad/ProtoCLR cd ProtoCLR/

  2. Install dependencies: Ensure you have the required Python packages, including torch and any other dependencies listed in requirements.txt.

    pip install -r requirements.txt
    

Usage

  1. Prepare the Audio:
    To ensure compatibility with the model, follow these preprocessing steps for your audio files:

    • Mono Channel (Mandatory):
      If the audio has multiple channels, convert it to a single mono channel by averaging the channels.
    • Sample Rate (Mandatory):
      Resample the audio to a consistent sample rate of 16 kHz.
    • Padding (Recommended):
      For audio files shorter than 6 seconds, pad with zeros or repeat the audio until it reaches a length of 6 seconds.
    • Chunking (Recommended):
      For audio files longer than 6 seconds, split them into chunks of 6 seconds each for better processing.
  2. Process the Audio:
    Use the MelSpectrogramProcessor (from melspectrogram.py) to transform the prepared audio into a Mel spectrogram, a format suitable for model input, as demonstrated in the following example.

Example Code

The following example demonstrates loading, processing, and running inference on an audio file:

import torch
from cvt import cvt13  # Import model architecture
from melspectrogram import MelSpectrogramProcessor  # Import Mel spectrogram processor

# Initialize the preprocessor and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
preprocessor = MelSpectrogramProcessor(device=device)
model = cvt13().to(device)

# Load weights trained using Cross-Entropy
model.load_state_dict(torch.load("ce.pth", map_location="cpu")['encoder'])

# Load weights trained using SimCLR (self-supervised contrastive learning)
model.load_state_dict(torch.load("simclr.pth", map_location="cpu"))

# Load weights trained using SupCon (supervised contrastive learning)
model.load_state_dict(torch.load("supcon.pth", map_location="cpu"))

# Load weights trained using ProtoCLR (supervised contrastive learning using prototypes)
model.load_state_dict(torch.load("protoclr.pth", map_location="cpu"))

# Optional: Move the model to GPU for faster processing if available using : model = model.to('cuda') , for instance.
model.eval()

# Load and preprocess a sample audio waveform
def load_waveform(file_path):
    # Replace this with your specific audio loading function
    # For example, using torchaudio to load and resample
    pass

waveform = load_waveform("path/to/audio.wav").to(device)  # Load your audio file here and convert it to a PyTorch tensor.

# Ensure waveform is sampled at 16 kHz, then pad/chunk as needed for 6s length
input_tensor = preprocessor.process(waveform).unsqueeze(0)  # Add batch dimension

# Run the model on the preprocessed audio
with torch.no_grad():
    output = model(input_tensor)
    print("Model output shape:", output.shape)

Model Performance Comparison

The following table presents the classification accuracy of various models on one-shot and five-shot bird sound classification tasks, evaluated across different soundscape datasets.

Model Model Size PER NES UHH HSN SSW SNE Mean
Random Guessing - 0.75 1.12 3.70 5.26 1.04 1.78 2.22
1-Shot Classification
BirdAVES-biox-base 90M 7.41±1.0 26.4±2.3 13.2±3.1 9.84±3.5 8.74±0.6 14.1±3.1 13.2
BirdAVES-bioxn-large 300M 7.59±0.8 27.2±3.6 13.7±2.9 12.5±3.6 10.0±1.4 14.5±3.2 14.2
BioLingual 28M 6.21±1.1 37.5±2.9 17.8±3.5 17.6±5.1 22.5±4.0 26.4±3.4 21.3
Perch 80M 9.10±5.3 42.4±4.9 19.8±5.0 26.7±9.8 22.3±3.3 29.1±5.9 24.9
CE (Ours) 23M 9.55±1.5 41.3±3.6 19.7±4.7 25.2±5.7 17.8±1.4 31.5±5.4 24.2
SimCLR (Ours) 19M 7.85±1.1 31.2±2.4 14.9±2.9 19.0±3.8 10.6±1.1 24.0±4.1 17.9
SupCon (Ours) 19M 8.53±1.1 39.8±6.0 18.8±3.0 20.4±6.9 12.6±1.6 23.2±3.1 20.5
ProtoCLR (Ours) 19M 9.23±1.6 38.6±5.1 18.4±2.3 21.2±7.3 15.5±2.3 25.8±5.2 21.4
5-Shot Classification
BirdAVES-biox-base 90M 11.6±0.8 39.7±1.8 22.5±2.4 22.1±3.3 16.1±1.7 28.3±2.3 23.3
BirdAVES-bioxn-large 300M 15.0±0.9 42.6±2.7 23.7±3.8 28.4±2.4 18.3±1.8 27.3±2.3 25.8
BioLingual 28M 13.6±1.3 65.2±1.4 31.0±2.9 34.3±3.5 43.9±0.9 49.9±2.3 39.6
Perch 80M 21.2±1.2 71.7±1.5 39.5±3.0 52.5±5.9 48.0±1.9 59.7±1.8 48.7
CE (Ours) 23M 21.4±1.3 69.2±1.8 35.6±3.4 48.2±5.5 39.9±1.1 57.5±2.3 45.3
SimCLR (Ours) 19M 15.4±1.0 54.0±1.8 23.0±2.3 32.8±4.0 22.0±1.2 40.7±2.4 31.3
SupCon (Ours) 19M 17.2±1.3 64.6±2.4 34.1±2.9 42.5±2.9 30.8±0.8 48.1±2.4 39.5
ProtoCLR (Ours) 19M 19.2±1.1 67.9±2.8 36.1±4.3 48.0±4.3 34.6±2.3 48.6±2.8 42.4

For additional details, please see the pre-print on arXiv and the official GitHub repository.

Citation

If you use our model in your research, please cite the following paper:

@misc{moummad2024dirlbs,
      title={Domain-Invariant Representation Learning of Bird Sounds}, 
      author={Ilyass Moummad and Romain Serizel and Emmanouil Benetos and Nicolas Farrugia},
      year={2024},
      eprint={2409.08589},
      archivePrefix={arXiv},
      primaryClass={cs.SD},
      url={https://arxiv.org/abs/2409.08589}, 
}
Downloads last month
6
Inference Examples
Unable to determine this model's library. Check the docs .

Dataset used to train ilyassmoummad/ProtoCLR