--- license: cc-by-4.0 datasets: - ilyassmoummad/Xeno-Canto-6s-16khz pipeline_tag: feature-extraction tags: - Bioacoustics - pytorch --- # ProtoCLR This repository contains a CvT-13 [Convolutional Vision Transformer](https://arxiv.org/abs/2103.15808) model trained from scratch on the [Xeno-Canto dataset](https://huggingface.co/datasets/ilyassmoummad/Xeno-Canto-6s-16khz), specifically on 6-second audio segments sampled at 16 kHz. The model is trained on Mel spectrograms of bird sounds using ProtoCLR [(Prototypical Contrastive Loss)](https://arxiv.org/abs/2409.08589) for 300 epochs and can be used as a feature extractor for bird audio classification and related tasks. ## Files - `cvt.py`: Defines the CvT-13 model architecture. - `protoclr.pth`: Pre-trained model weights for ProtoCLR. - `config/`: Configuration files for CvT-13 setup. - `mel_spectrogram.py`: Contains the `MelSpectrogramProcessor` class, which converts audio waveforms into Mel spectrograms, a format suitable for model input. ## Setup 1. **Clone this repository**: Clone the repository and navigate into the project directory: ```git clone https://huggingface.co/ilyassmoummad/ProtoCLR``` ```cd ProtoCLR/``` 2. **Install dependencies**: Ensure you have the required Python packages, including `torch` and any other dependencies listed in `requirements.txt`. ```bash pip install -r requirements.txt ``` ## Usage 1. **Prepare the Audio**: To ensure compatibility with the model, follow these preprocessing steps for your audio files: - **Mono Channel (Mandatory)**: If the audio has multiple channels, convert it to a single mono channel by averaging the channels. - **Sample Rate (Mandatory)**: Resample the audio to a consistent sample rate of 16 kHz. - **Padding (Recommended)**: For audio files shorter than 6 seconds, pad with zeros or repeat the audio until it reaches a length of 6 seconds. - **Chunking (Recommended)**: For audio files longer than 6 seconds, split them into chunks of 6 seconds each for better processing. 2. **Process the Audio**: Use the `MelSpectrogramProcessor` (from `melspectrogram.py`) to transform the prepared audio into a Mel spectrogram, a format suitable for model input, as demonstrated in the following example. ## Example Code The following example demonstrates loading, processing, and running inference on an audio file: ```python import torch from cvt import cvt13 # Import model architecture from melspectrogram import MelSpectrogramProcessor # Import Mel spectrogram processor # Initialize the preprocessor and model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") preprocessor = MelSpectrogramProcessor(device=device) model = cvt13().to(device) # Load weights trained using Cross-Entropy model.load_state_dict(torch.load("ce.pth", map_location="cpu")['encoder']) # Load weights trained using SimCLR (self-supervised contrastive learning) model.load_state_dict(torch.load("simclr.pth", map_location="cpu")) # Load weights trained using SupCon (supervised contrastive learning) model.load_state_dict(torch.load("supcon.pth", map_location="cpu")) # Load weights trained using ProtoCLR (supervised contrastive learning using prototypes) model.load_state_dict(torch.load("protoclr.pth", map_location="cpu")) # Optional: Move the model to GPU for faster processing if available using : model = model.to('cuda') , for instance. model.eval() # Load and preprocess a sample audio waveform def load_waveform(file_path): # Replace this with your specific audio loading function # For example, using torchaudio to load and resample pass waveform = load_waveform("path/to/audio.wav").to(device) # Load your audio file here and convert it to a PyTorch tensor. # Ensure waveform is sampled at 16 kHz, then pad/chunk as needed for 6s length input_tensor = preprocessor.process(waveform).unsqueeze(0) # Add batch dimension # Run the model on the preprocessed audio with torch.no_grad(): output = model(input_tensor) print("Model output shape:", output.shape) ``` ## Model Performance Comparison The following table presents the classification accuracy of various models on one-shot and five-shot bird sound classification tasks, evaluated across different [soundscape datasets](https://zenodo.org/records/13994373). | Model | Model Size | PER | NES | UHH | HSN | SSW | SNE | Mean | |---------------------------|------------|-------------|-------------|-------------|-------------|-------------|-------------|-------| | Random Guessing | - | 0.75 | 1.12 | 3.70 | 5.26 | 1.04 | 1.78 | 2.22 | | | | | | | | | | | | **1-Shot Classification** | | | | | | | | | | BirdAVES-biox-base | 90M | 7.41±1.0 | 26.4±2.3 | 13.2±3.1 | 9.84±3.5 | 8.74±0.6 | 14.1±3.1 | 13.2 | | BirdAVES-bioxn-large | 300M | 7.59±0.8 | 27.2±3.6 | 13.7±2.9 | 12.5±3.6 | 10.0±1.4 | 14.5±3.2 | 14.2 | | BioLingual | 28M | 6.21±1.1 | 37.5±2.9 | 17.8±3.5 | 17.6±5.1 | 22.5±4.0 | 26.4±3.4 | 21.3 | | Perch | 80M | 9.10±5.3 | 42.4±4.9 | 19.8±5.0 | 26.7±9.8 | 22.3±3.3 | 29.1±5.9 | 24.9 | | CE (Ours) | 23M | 9.55±1.5 | 41.3±3.6 | 19.7±4.7 | 25.2±5.7 | 17.8±1.4 | 31.5±5.4 | 24.2 | | SimCLR (Ours) | 19M | 7.85±1.1 | 31.2±2.4 | 14.9±2.9 | 19.0±3.8 | 10.6±1.1 | 24.0±4.1 | 17.9 | | SupCon (Ours) | 19M | 8.53±1.1 | 39.8±6.0 | 18.8±3.0 | 20.4±6.9 | 12.6±1.6 | 23.2±3.1 | 20.5 | | ProtoCLR (Ours) | 19M | 9.23±1.6 | 38.6±5.1 | 18.4±2.3 | 21.2±7.3 | 15.5±2.3 | 25.8±5.2 | 21.4 | | | | | | | | | | | | **5-Shot Classification** | | | | | | | | | | BirdAVES-biox-base | 90M | 11.6±0.8 | 39.7±1.8 | 22.5±2.4 | 22.1±3.3 | 16.1±1.7 | 28.3±2.3 | 23.3 | | BirdAVES-bioxn-large | 300M | 15.0±0.9 | 42.6±2.7 | 23.7±3.8 | 28.4±2.4 | 18.3±1.8 | 27.3±2.3 | 25.8 | | BioLingual | 28M | 13.6±1.3 | 65.2±1.4 | 31.0±2.9 | 34.3±3.5 | 43.9±0.9 | 49.9±2.3 | 39.6 | | Perch | 80M | 21.2±1.2 | 71.7±1.5 | 39.5±3.0 | 52.5±5.9 | 48.0±1.9 | 59.7±1.8 | 48.7 | | CE (Ours) | 23M | 21.4±1.3 | 69.2±1.8 | 35.6±3.4 | 48.2±5.5 | 39.9±1.1 | 57.5±2.3 | 45.3 | | SimCLR (Ours) | 19M | 15.4±1.0 | 54.0±1.8 | 23.0±2.3 | 32.8±4.0 | 22.0±1.2 | 40.7±2.4 | 31.3 | | SupCon (Ours) | 19M | 17.2±1.3 | 64.6±2.4 | 34.1±2.9 | 42.5±2.9 | 30.8±0.8 | 48.1±2.4 | 39.5 | | ProtoCLR (Ours) | 19M | 19.2±1.1 | 67.9±2.8 | 36.1±4.3 | 48.0±4.3 | 34.6±2.3 | 48.6±2.8 | 42.4 | For additional details, please see the [pre-print on arXiv](https://arxiv.org/abs/2409.08589) and the [official GitHub repository](https://github.com/ilyassmoummad/ProtoCLR). ## Citation If you use our model in your research, please cite the following paper: ```bibtex @misc{moummad2024dirlbs, title={Domain-Invariant Representation Learning of Bird Sounds}, author={Ilyass Moummad and Romain Serizel and Emmanouil Benetos and Nicolas Farrugia}, year={2024}, eprint={2409.08589}, archivePrefix={arXiv}, primaryClass={cs.SD}, url={https://arxiv.org/abs/2409.08589}, } ```