|
--- |
|
license: cc-by-4.0 |
|
datasets: |
|
- ilyassmoummad/Xeno-Canto-6s-16khz |
|
pipeline_tag: feature-extraction |
|
tags: |
|
- Bioacoustics |
|
- pytorch |
|
--- |
|
# ProtoCLR |
|
|
|
This repository contains a CvT-13 [Convolutional Vision Transformer](https://arxiv.org/abs/2103.15808) model trained from scratch on the [Xeno-Canto dataset](https://huggingface.co/datasets/ilyassmoummad/Xeno-Canto-6s-16khz), specifically on 6-second audio segments sampled at 16 kHz. The model is trained on Mel spectrograms of bird sounds using ProtoCLR [(Prototypical Contrastive Loss)](https://arxiv.org/abs/2409.08589) for 300 epochs and can be used as a feature extractor for bird audio classification and related tasks. |
|
|
|
## Files |
|
|
|
- `cvt.py`: Defines the CvT-13 model architecture. |
|
- `protoclr.pth`: Pre-trained model weights for ProtoCLR. |
|
- `config/`: Configuration files for CvT-13 setup. |
|
- `mel_spectrogram.py`: Contains the `MelSpectrogramProcessor` class, which converts audio waveforms into Mel spectrograms, a format suitable for model input. |
|
|
|
## Setup |
|
|
|
1. **Clone this repository**: |
|
Clone the repository and navigate into the project directory: |
|
```git clone https://huggingface.co/ilyassmoummad/ProtoCLR``` |
|
```cd ProtoCLR/``` |
|
|
|
2. **Install dependencies**: |
|
Ensure you have the required Python packages, including `torch` and any other dependencies listed in `requirements.txt`. |
|
```bash |
|
pip install -r requirements.txt |
|
``` |
|
|
|
## Usage |
|
|
|
1. **Prepare the Audio**: |
|
To ensure compatibility with the model, follow these preprocessing steps for your audio files: |
|
- **Mono Channel (Mandatory)**: |
|
If the audio has multiple channels, convert it to a single mono channel by averaging the channels. |
|
- **Sample Rate (Mandatory)**: |
|
Resample the audio to a consistent sample rate of 16 kHz. |
|
- **Padding (Recommended)**: |
|
For audio files shorter than 6 seconds, pad with zeros or repeat the audio until it reaches a length of 6 seconds. |
|
- **Chunking (Recommended)**: |
|
For audio files longer than 6 seconds, split them into chunks of 6 seconds each for better processing. |
|
|
|
2. **Process the Audio**: |
|
Use the `MelSpectrogramProcessor` (from `melspectrogram.py`) to transform the prepared audio into a Mel spectrogram, a format suitable for model input, as demonstrated in the following example. |
|
|
|
## Example Code |
|
|
|
The following example demonstrates loading, processing, and running inference on an audio file: |
|
|
|
```python |
|
import torch |
|
from cvt import cvt13 # Import model architecture |
|
from melspectrogram import MelSpectrogramProcessor # Import Mel spectrogram processor |
|
|
|
# Initialize the preprocessor and model |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
preprocessor = MelSpectrogramProcessor(device=device) |
|
model = cvt13().to(device) |
|
|
|
# Load weights trained using Cross-Entropy |
|
model.load_state_dict(torch.load("ce.pth", map_location="cpu")['encoder']) |
|
|
|
# Load weights trained using SimCLR (self-supervised contrastive learning) |
|
model.load_state_dict(torch.load("simclr.pth", map_location="cpu")) |
|
|
|
# Load weights trained using SupCon (supervised contrastive learning) |
|
model.load_state_dict(torch.load("supcon.pth", map_location="cpu")) |
|
|
|
# Load weights trained using ProtoCLR (supervised contrastive learning using prototypes) |
|
model.load_state_dict(torch.load("protoclr.pth", map_location="cpu")) |
|
|
|
# Optional: Move the model to GPU for faster processing if available using : model = model.to('cuda') , for instance. |
|
model.eval() |
|
|
|
# Load and preprocess a sample audio waveform |
|
def load_waveform(file_path): |
|
# Replace this with your specific audio loading function |
|
# For example, using torchaudio to load and resample |
|
pass |
|
|
|
waveform = load_waveform("path/to/audio.wav").to(device) # Load your audio file here and convert it to a PyTorch tensor. |
|
|
|
# Ensure waveform is sampled at 16 kHz, then pad/chunk as needed for 6s length |
|
input_tensor = preprocessor.process(waveform).unsqueeze(0) # Add batch dimension |
|
|
|
# Run the model on the preprocessed audio |
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
print("Model output shape:", output.shape) |
|
``` |
|
|
|
## Model Performance Comparison |
|
The following table presents the classification accuracy of various models on one-shot and five-shot bird sound classification tasks, evaluated across different [soundscape datasets](https://zenodo.org/records/13994373). |
|
|
|
| Model | Model Size | PER | NES | UHH | HSN | SSW | SNE | Mean | |
|
|---------------------------|------------|-------------|-------------|-------------|-------------|-------------|-------------|-------| |
|
| Random Guessing | - | 0.75 | 1.12 | 3.70 | 5.26 | 1.04 | 1.78 | 2.22 | |
|
| | | | | | | | | | |
|
| **1-Shot Classification** | | | | | | | | | |
|
| BirdAVES-biox-base | 90M | 7.41±1.0 | 26.4±2.3 | 13.2±3.1 | 9.84±3.5 | 8.74±0.6 | 14.1±3.1 | 13.2 | |
|
| BirdAVES-bioxn-large | 300M | 7.59±0.8 | 27.2±3.6 | 13.7±2.9 | 12.5±3.6 | 10.0±1.4 | 14.5±3.2 | 14.2 | |
|
| BioLingual | 28M | 6.21±1.1 | 37.5±2.9 | 17.8±3.5 | 17.6±5.1 | 22.5±4.0 | 26.4±3.4 | 21.3 | |
|
| Perch | 80M | 9.10±5.3 | 42.4±4.9 | 19.8±5.0 | 26.7±9.8 | 22.3±3.3 | 29.1±5.9 | 24.9 | |
|
| CE (Ours) | 23M | 9.55±1.5 | 41.3±3.6 | 19.7±4.7 | 25.2±5.7 | 17.8±1.4 | 31.5±5.4 | 24.2 | |
|
| SimCLR (Ours) | 19M | 7.85±1.1 | 31.2±2.4 | 14.9±2.9 | 19.0±3.8 | 10.6±1.1 | 24.0±4.1 | 17.9 | |
|
| SupCon (Ours) | 19M | 8.53±1.1 | 39.8±6.0 | 18.8±3.0 | 20.4±6.9 | 12.6±1.6 | 23.2±3.1 | 20.5 | |
|
| ProtoCLR (Ours) | 19M | 9.23±1.6 | 38.6±5.1 | 18.4±2.3 | 21.2±7.3 | 15.5±2.3 | 25.8±5.2 | 21.4 | |
|
| | | | | | | | | | |
|
| **5-Shot Classification** | | | | | | | | | |
|
| BirdAVES-biox-base | 90M | 11.6±0.8 | 39.7±1.8 | 22.5±2.4 | 22.1±3.3 | 16.1±1.7 | 28.3±2.3 | 23.3 | |
|
| BirdAVES-bioxn-large | 300M | 15.0±0.9 | 42.6±2.7 | 23.7±3.8 | 28.4±2.4 | 18.3±1.8 | 27.3±2.3 | 25.8 | |
|
| BioLingual | 28M | 13.6±1.3 | 65.2±1.4 | 31.0±2.9 | 34.3±3.5 | 43.9±0.9 | 49.9±2.3 | 39.6 | |
|
| Perch | 80M | 21.2±1.2 | 71.7±1.5 | 39.5±3.0 | 52.5±5.9 | 48.0±1.9 | 59.7±1.8 | 48.7 | |
|
| CE (Ours) | 23M | 21.4±1.3 | 69.2±1.8 | 35.6±3.4 | 48.2±5.5 | 39.9±1.1 | 57.5±2.3 | 45.3 | |
|
| SimCLR (Ours) | 19M | 15.4±1.0 | 54.0±1.8 | 23.0±2.3 | 32.8±4.0 | 22.0±1.2 | 40.7±2.4 | 31.3 | |
|
| SupCon (Ours) | 19M | 17.2±1.3 | 64.6±2.4 | 34.1±2.9 | 42.5±2.9 | 30.8±0.8 | 48.1±2.4 | 39.5 | |
|
| ProtoCLR (Ours) | 19M | 19.2±1.1 | 67.9±2.8 | 36.1±4.3 | 48.0±4.3 | 34.6±2.3 | 48.6±2.8 | 42.4 | |
|
|
|
For additional details, please see the [pre-print on arXiv](https://arxiv.org/abs/2409.08589) and the [official GitHub repository](https://github.com/ilyassmoummad/ProtoCLR). |
|
|
|
## Citation |
|
|
|
If you use our model in your research, please cite the following paper: |
|
|
|
```bibtex |
|
@misc{moummad2024dirlbs, |
|
title={Domain-Invariant Representation Learning of Bird Sounds}, |
|
author={Ilyass Moummad and Romain Serizel and Emmanouil Benetos and Nicolas Farrugia}, |
|
year={2024}, |
|
eprint={2409.08589}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.SD}, |
|
url={https://arxiv.org/abs/2409.08589}, |
|
} |
|
``` |