Feature Extraction
PyTorch
Bioacoustics
ProtoCLR / README.md
ilyassmoummad's picture
Update README.md
242cbb0 verified
|
raw
history blame
8.05 kB
---
license: cc-by-4.0
datasets:
- ilyassmoummad/Xeno-Canto-6s-16khz
pipeline_tag: feature-extraction
tags:
- Bioacoustics
- pytorch
---
# ProtoCLR
This repository contains a CvT-13 [Convolutional Vision Transformer](https://arxiv.org/abs/2103.15808) model trained from scratch on the [Xeno-Canto dataset](https://huggingface.co/datasets/ilyassmoummad/Xeno-Canto-6s-16khz), specifically on 6-second audio segments sampled at 16 kHz. The model is trained on Mel spectrograms of bird sounds using ProtoCLR [(Prototypical Contrastive Loss)](https://arxiv.org/abs/2409.08589) for 300 epochs and can be used as a feature extractor for bird audio classification and related tasks.
## Files
- `cvt.py`: Defines the CvT-13 model architecture.
- `protoclr.pth`: Pre-trained model weights for ProtoCLR.
- `config/`: Configuration files for CvT-13 setup.
- `mel_spectrogram.py`: Contains the `MelSpectrogramProcessor` class, which converts audio waveforms into Mel spectrograms, a format suitable for model input.
## Setup
1. **Clone this repository**:
Clone the repository and navigate into the project directory:
```git clone https://huggingface.co/ilyassmoummad/ProtoCLR```
```cd ProtoCLR/```
2. **Install dependencies**:
Ensure you have the required Python packages, including `torch` and any other dependencies listed in `requirements.txt`.
```bash
pip install -r requirements.txt
```
## Usage
1. **Prepare the Audio**:
To ensure compatibility with the model, follow these preprocessing steps for your audio files:
- **Mono Channel (Mandatory)**:
If the audio has multiple channels, convert it to a single mono channel by averaging the channels.
- **Sample Rate (Mandatory)**:
Resample the audio to a consistent sample rate of 16 kHz.
- **Padding (Recommended)**:
For audio files shorter than 6 seconds, pad with zeros or repeat the audio until it reaches a length of 6 seconds.
- **Chunking (Recommended)**:
For audio files longer than 6 seconds, split them into chunks of 6 seconds each for better processing.
2. **Process the Audio**:
Use the `MelSpectrogramProcessor` (from `melspectrogram.py`) to transform the prepared audio into a Mel spectrogram, a format suitable for model input, as demonstrated in the following example.
## Example Code
The following example demonstrates loading, processing, and running inference on an audio file:
```python
import torch
from cvt import cvt13 # Import model architecture
from melspectrogram import MelSpectrogramProcessor # Import Mel spectrogram processor
# Initialize the preprocessor and model
preprocessor = MelSpectrogramProcessor()
model = cvt13()
# Load weights trained using Cross-Entropy
model.load_state_dict(torch.load("ce.pth", map_location="cpu")['encoder'])
# Load weights trained using SimCLR (self-supervised contrastive learning)
model.load_state_dict(torch.load("simclr.pth", map_location="cpu"))
# Load weights trained using SupCon (supervised contrastive learning)
model.load_state_dict(torch.load("supcon.pth", map_location="cpu"))
# Load weights trained using ProtoCLR (supervised contrastive learning using prototypes)
model.load_state_dict(torch.load("protoclr.pth", map_location="cpu"))
# Optional: Move the model to GPU for faster processing if available using : model = model.to('cuda') , for instance.
model.eval()
# Load and preprocess a sample audio waveform
def load_waveform(file_path):
# Replace this with your specific audio loading function
# For example, using torchaudio to load and resample
pass
waveform = load_waveform("path/to/audio.wav") # Load your audio file here
# Ensure waveform is sampled at 16 kHz, then pad/chunk as needed for 6s length
input_tensor = preprocessor.process(waveform).unsqueeze(0) # Add batch dimension
# Run the model on the preprocessed audio
with torch.no_grad():
output = model(input_tensor)
print("Model output shape:", output.shape)
```
## Model Performance Comparison
The following table presents the classification accuracy of various models on one-shot and five-shot bird sound classification tasks, evaluated across different [soundscape datasets](https://zenodo.org/records/13994373).
| Model | Model Size | PER | NES | UHH | HSN | SSW | SNE | Mean |
|---------------------------|------------|-------------|-------------|-------------|-------------|-------------|-------------|-------|
| Random Guessing | - | 0.75 | 1.12 | 3.70 | 5.26 | 1.04 | 1.78 | 2.22 |
| | | | | | | | | |
| **1-Shot Classification** | | | | | | | | |
| BirdAVES-biox-base | 90M | 7.41±1.0 | 26.4±2.3 | 13.2±3.1 | 9.84±3.5 | 8.74±0.6 | 14.1±3.1 | 13.2 |
| BirdAVES-bioxn-large | 300M | 7.59±0.8 | 27.2±3.6 | 13.7±2.9 | 12.5±3.6 | 10.0±1.4 | 14.5±3.2 | 14.2 |
| BioLingual | 28M | 6.21±1.1 | 37.5±2.9 | 17.8±3.5 | 17.6±5.1 | 22.5±4.0 | 26.4±3.4 | 21.3 |
| Perch | 80M | 9.10±5.3 | 42.4±4.9 | 19.8±5.0 | 26.7±9.8 | 22.3±3.3 | 29.1±5.9 | 24.9 |
| CE (Ours) | 19M | 9.55±1.5 | 41.3±3.6 | 19.7±4.7 | 25.2±5.7 | 17.8±1.4 | 31.5±5.4 | 24.2 |
| SimCLR (Ours) | 19M | 7.85±1.1 | 31.2±2.4 | 14.9±2.9 | 19.0±3.8 | 10.6±1.1 | 24.0±4.1 | 17.9 |
| SupCon (Ours) | 19M | 8.53±1.1 | 39.8±6.0 | 18.8±3.0 | 20.4±6.9 | 12.6±1.6 | 23.2±3.1 | 20.5 |
| ProtoCLR (Ours) | 19M | 9.23±1.6 | 38.6±5.1 | 18.4±2.3 | 21.2±7.3 | 15.5±2.3 | 25.8±5.2 | 21.4 |
| | | | | | | | | |
| **5-Shot Classification** | | | | | | | | |
| BirdAVES-biox-base | 90M | 11.6±0.8 | 39.7±1.8 | 22.5±2.4 | 22.1±3.3 | 16.1±1.7 | 28.3±2.3 | 23.3 |
| BirdAVES-bioxn-large | 300M | 15.0±0.9 | 42.6±2.7 | 23.7±3.8 | 28.4±2.4 | 18.3±1.8 | 27.3±2.3 | 25.8 |
| BioLingual | 28M | 13.6±1.3 | 65.2±1.4 | 31.0±2.9 | 34.3±3.5 | 43.9±0.9 | 49.9±2.3 | 39.6 |
| Perch | 80M | 21.2±1.2 | 71.7±1.5 | 39.5±3.0 | 52.5±5.9 | 48.0±1.9 | 59.7±1.8 | 48.7 |
| CE (Ours) | 19M | 21.4±1.3 | 69.2±1.8 | 35.6±3.4 | 48.2±5.5 | 39.9±1.1 | 57.5±2.3 | 45.3 |
| SimCLR (Ours) | 19M | 15.4±1.0 | 54.0±1.8 | 23.0±2.3 | 32.8±4.0 | 22.0±1.2 | 40.7±2.4 | 31.3 |
| SupCon (Ours) | 19M | 17.2±1.3 | 64.6±2.4 | 34.1±2.9 | 42.5±2.9 | 30.8±0.8 | 48.1±2.4 | 39.5 |
| ProtoCLR (Ours) | 19M | 19.2±1.1 | 67.9±2.8 | 36.1±4.3 | 48.0±4.3 | 34.6±2.3 | 48.6±2.8 | 42.4 |
For additional details, please see the [pre-print on arXiv](https://arxiv.org/abs/2409.08589) and the [official GitHub repository](https://github.com/ilyassmoummad/ProtoCLR).
## Citation
If you use our model in your research, please cite the following paper:
```bibtex
@misc{moummad2024dirlbs,
title={Domain-Invariant Representation Learning of Bird Sounds},
author={Ilyass Moummad and Romain Serizel and Emmanouil Benetos and Nicolas Farrugia},
year={2024},
eprint={2409.08589},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2409.08589},
}
```