ProtoCLR
This repository contains a CvT-13 Convolutional Vision Transformer model trained from scratch on the Xeno-Canto dataset, specifically on 6-second audio segments sampled at 16 kHz. The model is trained on Mel spectrograms of bird sounds using ProtoCLR (Prototypical Contrastive Loss) for 300 epochs and can be used as a feature extractor for bird audio classification and related tasks.
Files
cvt.py
: Defines the CvT-13 model architecture.protoclr.pth
: Pre-trained model weights for ProtoCLR.config/
: Configuration files for CvT-13 setup.mel_spectrogram.py
: Contains theMelSpectrogramProcessor
class, which converts audio waveforms into Mel spectrograms, a format suitable for model input.
Setup
Clone this repository: Clone the repository and navigate into the project directory:
git clone https://huggingface.co/ilyassmoummad/ProtoCLR
cd ProtoCLR/
Install dependencies: Ensure you have the required Python packages, including
torch
and any other dependencies listed inrequirements.txt
.pip install -r requirements.txt
Usage
Prepare the Audio:
To ensure compatibility with the model, follow these preprocessing steps for your audio files:- Mono Channel (Mandatory):
If the audio has multiple channels, convert it to a single mono channel by averaging the channels. - Sample Rate (Mandatory):
Resample the audio to a consistent sample rate of 16 kHz. - Padding (Recommended):
For audio files shorter than 6 seconds, pad with zeros or repeat the audio until it reaches a length of 6 seconds. - Chunking (Recommended):
For audio files longer than 6 seconds, split them into chunks of 6 seconds each for better processing.
- Mono Channel (Mandatory):
Process the Audio:
Use theMelSpectrogramProcessor
(frommelspectrogram.py
) to transform the prepared audio into a Mel spectrogram, a format suitable for model input, as demonstrated in the following example.
Example Code
The following example demonstrates loading, processing, and running inference on an audio file:
import torch
from cvt import cvt13 # Import model architecture
from melspectrogram import MelSpectrogramProcessor # Import Mel spectrogram processor
# Initialize the preprocessor and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
preprocessor = MelSpectrogramProcessor(device=device)
model = cvt13().to(device)
# Load weights trained using Cross-Entropy
model.load_state_dict(torch.load("ce.pth", map_location="cpu")['encoder'])
# Load weights trained using SimCLR (self-supervised contrastive learning)
model.load_state_dict(torch.load("simclr.pth", map_location="cpu"))
# Load weights trained using SupCon (supervised contrastive learning)
model.load_state_dict(torch.load("supcon.pth", map_location="cpu"))
# Load weights trained using ProtoCLR (supervised contrastive learning using prototypes)
model.load_state_dict(torch.load("protoclr.pth", map_location="cpu"))
# Optional: Move the model to GPU for faster processing if available using : model = model.to('cuda') , for instance.
model.eval()
# Load and preprocess a sample audio waveform
def load_waveform(file_path):
# Replace this with your specific audio loading function
# For example, using torchaudio to load and resample
pass
waveform = load_waveform("path/to/audio.wav").to(device) # Load your audio file here and convert it to a PyTorch tensor.
# Ensure waveform is sampled at 16 kHz, then pad/chunk as needed for 6s length
input_tensor = preprocessor.process(waveform).unsqueeze(0) # Add batch dimension
# Run the model on the preprocessed audio
with torch.no_grad():
output = model(input_tensor)
print("Model output shape:", output.shape)
Model Performance Comparison
The following table presents the classification accuracy of various models on one-shot and five-shot bird sound classification tasks, evaluated across different soundscape datasets.
Model | Model Size | PER | NES | UHH | HSN | SSW | SNE | Mean |
---|---|---|---|---|---|---|---|---|
Random Guessing | - | 0.75 | 1.12 | 3.70 | 5.26 | 1.04 | 1.78 | 2.22 |
1-Shot Classification | ||||||||
BirdAVES-biox-base | 90M | 7.41±1.0 | 26.4±2.3 | 13.2±3.1 | 9.84±3.5 | 8.74±0.6 | 14.1±3.1 | 13.2 |
BirdAVES-bioxn-large | 300M | 7.59±0.8 | 27.2±3.6 | 13.7±2.9 | 12.5±3.6 | 10.0±1.4 | 14.5±3.2 | 14.2 |
BioLingual | 28M | 6.21±1.1 | 37.5±2.9 | 17.8±3.5 | 17.6±5.1 | 22.5±4.0 | 26.4±3.4 | 21.3 |
Perch | 80M | 9.10±5.3 | 42.4±4.9 | 19.8±5.0 | 26.7±9.8 | 22.3±3.3 | 29.1±5.9 | 24.9 |
CE (Ours) | 23M | 9.55±1.5 | 41.3±3.6 | 19.7±4.7 | 25.2±5.7 | 17.8±1.4 | 31.5±5.4 | 24.2 |
SimCLR (Ours) | 19M | 7.85±1.1 | 31.2±2.4 | 14.9±2.9 | 19.0±3.8 | 10.6±1.1 | 24.0±4.1 | 17.9 |
SupCon (Ours) | 19M | 8.53±1.1 | 39.8±6.0 | 18.8±3.0 | 20.4±6.9 | 12.6±1.6 | 23.2±3.1 | 20.5 |
ProtoCLR (Ours) | 19M | 9.23±1.6 | 38.6±5.1 | 18.4±2.3 | 21.2±7.3 | 15.5±2.3 | 25.8±5.2 | 21.4 |
5-Shot Classification | ||||||||
BirdAVES-biox-base | 90M | 11.6±0.8 | 39.7±1.8 | 22.5±2.4 | 22.1±3.3 | 16.1±1.7 | 28.3±2.3 | 23.3 |
BirdAVES-bioxn-large | 300M | 15.0±0.9 | 42.6±2.7 | 23.7±3.8 | 28.4±2.4 | 18.3±1.8 | 27.3±2.3 | 25.8 |
BioLingual | 28M | 13.6±1.3 | 65.2±1.4 | 31.0±2.9 | 34.3±3.5 | 43.9±0.9 | 49.9±2.3 | 39.6 |
Perch | 80M | 21.2±1.2 | 71.7±1.5 | 39.5±3.0 | 52.5±5.9 | 48.0±1.9 | 59.7±1.8 | 48.7 |
CE (Ours) | 23M | 21.4±1.3 | 69.2±1.8 | 35.6±3.4 | 48.2±5.5 | 39.9±1.1 | 57.5±2.3 | 45.3 |
SimCLR (Ours) | 19M | 15.4±1.0 | 54.0±1.8 | 23.0±2.3 | 32.8±4.0 | 22.0±1.2 | 40.7±2.4 | 31.3 |
SupCon (Ours) | 19M | 17.2±1.3 | 64.6±2.4 | 34.1±2.9 | 42.5±2.9 | 30.8±0.8 | 48.1±2.4 | 39.5 |
ProtoCLR (Ours) | 19M | 19.2±1.1 | 67.9±2.8 | 36.1±4.3 | 48.0±4.3 | 34.6±2.3 | 48.6±2.8 | 42.4 |
For additional details, please see the pre-print on arXiv and the official GitHub repository.
Citation
If you use our model in your research, please cite the following paper:
@misc{moummad2024dirlbs,
title={Domain-Invariant Representation Learning of Bird Sounds},
author={Ilyass Moummad and Romain Serizel and Emmanouil Benetos and Nicolas Farrugia},
year={2024},
eprint={2409.08589},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2409.08589},
}
- Downloads last month
- 6