Spaces:
Sleeping
Sleeping
arnavkumar24
commited on
Commit
•
89040ed
1
Parent(s):
ebbe80d
Addon
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- AudioSep_Colab.ipynb +128 -0
- CONTRIBUTING.md +92 -0
- Dockerfile +22 -0
- LICENSE +21 -0
- assets/results.png +0 -0
- benchmark.py +116 -0
- callbacks/base.py +35 -0
- checkpoint/audiosep_base_4M_steps.ckpt +3 -0
- checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt +3 -0
- cog.yaml +21 -0
- config/audiosep_base.yaml +41 -0
- data/audiotext_dataset.py +91 -0
- data/datamodules.py +122 -0
- data/waveform_mixers.py +127 -0
- datafiles/template.json +8 -0
- environment.yml +326 -0
- evaluation/evaluate_audiocaps.py +110 -0
- evaluation/evaluate_audioset.py +155 -0
- evaluation/evaluate_clotho.py +102 -0
- evaluation/evaluate_esc50.py +102 -0
- evaluation/evaluate_music.py +118 -0
- evaluation/evaluate_vggsound.py +114 -0
- evaluation/metadata/audiocaps_eval.csv +0 -0
- evaluation/metadata/audioset_eval.csv +0 -0
- evaluation/metadata/class_labels_indices.csv +528 -0
- evaluation/metadata/clotho_eval.csv +0 -0
- evaluation/metadata/esc50_eval.csv +0 -0
- evaluation/metadata/music_eval.csv +0 -0
- evaluation/metadata/vggsound_eval.csv +0 -0
- losses.py +17 -0
- models/CLAP/__init__.py +0 -0
- models/CLAP/__pycache__/__init__.cpython-310.pyc +0 -0
- models/CLAP/open_clip/__init__.py +25 -0
- models/CLAP/open_clip/__pycache__/__init__.cpython-310.pyc +0 -0
- models/CLAP/open_clip/__pycache__/factory.cpython-310.pyc +0 -0
- models/CLAP/open_clip/__pycache__/feature_fusion.cpython-310.pyc +0 -0
- models/CLAP/open_clip/__pycache__/htsat.cpython-310.pyc +0 -0
- models/CLAP/open_clip/__pycache__/loss.cpython-310.pyc +0 -0
- models/CLAP/open_clip/__pycache__/model.cpython-310.pyc +0 -0
- models/CLAP/open_clip/__pycache__/openai.cpython-310.pyc +0 -0
- models/CLAP/open_clip/__pycache__/pann_model.cpython-310.pyc +0 -0
- models/CLAP/open_clip/__pycache__/pretrained.cpython-310.pyc +0 -0
- models/CLAP/open_clip/__pycache__/timm_model.cpython-310.pyc +0 -0
- models/CLAP/open_clip/__pycache__/tokenizer.cpython-310.pyc +0 -0
- models/CLAP/open_clip/__pycache__/transform.cpython-310.pyc +0 -0
- models/CLAP/open_clip/__pycache__/utils.cpython-310.pyc +0 -0
- models/CLAP/open_clip/bert.py +40 -0
- models/CLAP/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- models/CLAP/open_clip/factory.py +277 -0
- models/CLAP/open_clip/feature_fusion.py +192 -0
AudioSep_Colab.ipynb
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from pathlib import Path\n",
|
10 |
+
"\n",
|
11 |
+
"repo_path = Path(\"/content/AudioSep\")\n",
|
12 |
+
"if not repo_path.exists():\n",
|
13 |
+
" !git clone https://github.com/Audio-AGI/AudioSep.git\n",
|
14 |
+
"\n",
|
15 |
+
"%cd /content/AudioSep"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"cell_type": "code",
|
20 |
+
"execution_count": null,
|
21 |
+
"metadata": {
|
22 |
+
"id": "pjIhw5ECS_3_"
|
23 |
+
},
|
24 |
+
"outputs": [],
|
25 |
+
"source": [
|
26 |
+
"!pip install torchlibrosa==0.1.0 gradio==3.47.1 gdown lightning transformers==4.28.1 ftfy braceexpand webdataset soundfile wget h5py"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": null,
|
32 |
+
"metadata": {
|
33 |
+
"id": "t6h9KB3CcjBd"
|
34 |
+
},
|
35 |
+
"outputs": [],
|
36 |
+
"source": [
|
37 |
+
"checkpoints_dir = Path(\"checkpoint\")\n",
|
38 |
+
"checkpoints_dir.mkdir(exist_ok=True)\n",
|
39 |
+
"\n",
|
40 |
+
"models = (\n",
|
41 |
+
" (\n",
|
42 |
+
" \"https://huggingface.co/spaces/badayvedat/AudioSep/resolve/main/checkpoint/audiosep_base_4M_steps.ckpt\",\n",
|
43 |
+
" checkpoints_dir / \"audiosep_base_4M_steps.ckpt\"\n",
|
44 |
+
" ),\n",
|
45 |
+
" (\n",
|
46 |
+
" \"https://huggingface.co/spaces/badayvedat/AudioSep/resolve/main/checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt\",\n",
|
47 |
+
" checkpoints_dir / \"music_speech_audioset_epoch_15_esc_89.98.pt\"\n",
|
48 |
+
" )\n",
|
49 |
+
")\n",
|
50 |
+
"\n",
|
51 |
+
"for model_url, model_path in models:\n",
|
52 |
+
" if not model_path.exists():\n",
|
53 |
+
" !wget {model_url} -O {model_path}"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": null,
|
59 |
+
"metadata": {
|
60 |
+
"id": "3uDrzCQyY58h"
|
61 |
+
},
|
62 |
+
"outputs": [],
|
63 |
+
"source": [
|
64 |
+
"!wget \"https://audio-agi.github.io/Separate-Anything-You-Describe/demos/exp31_water drops_mixture.wav\""
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": null,
|
70 |
+
"metadata": {
|
71 |
+
"id": "0nr77CGXTwO1"
|
72 |
+
},
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"import torch\n",
|
76 |
+
"from pipeline import build_audiosep, inference\n",
|
77 |
+
"\n",
|
78 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
79 |
+
"\n",
|
80 |
+
"model = build_audiosep(\n",
|
81 |
+
" config_yaml='config/audiosep_base.yaml',\n",
|
82 |
+
" checkpoint_path=str(models[0][1]),\n",
|
83 |
+
" device=device)\n",
|
84 |
+
"\n",
|
85 |
+
"audio_file = 'exp31_water drops_mixture.wav'\n",
|
86 |
+
"text = 'water drops'\n",
|
87 |
+
"output_file='separated_audio.wav'\n",
|
88 |
+
"\n",
|
89 |
+
"# AudioSep processes the audio at 32 kHz sampling rate\n",
|
90 |
+
"inference(model, audio_file, text, output_file, device)"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"execution_count": null,
|
96 |
+
"metadata": {
|
97 |
+
"id": "kssOe0pbPSWp"
|
98 |
+
},
|
99 |
+
"outputs": [],
|
100 |
+
"source": [
|
101 |
+
"print(f\"The separated audio is saved to: '{output_file}' file.\")"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": null,
|
107 |
+
"metadata": {
|
108 |
+
"id": "sl35U3dAR6KN"
|
109 |
+
},
|
110 |
+
"outputs": [],
|
111 |
+
"source": []
|
112 |
+
}
|
113 |
+
],
|
114 |
+
"metadata": {
|
115 |
+
"colab": {
|
116 |
+
"provenance": []
|
117 |
+
},
|
118 |
+
"kernelspec": {
|
119 |
+
"display_name": "Python 3",
|
120 |
+
"name": "python3"
|
121 |
+
},
|
122 |
+
"language_info": {
|
123 |
+
"name": "python"
|
124 |
+
}
|
125 |
+
},
|
126 |
+
"nbformat": 4,
|
127 |
+
"nbformat_minor": 0
|
128 |
+
}
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🎵 Contributing to AudioSep
|
2 |
+
|
3 |
+
Welcome to the AudioSep repository, where your contributions can harmonize the world of audio separation. To ensure a harmonious and organized collaboration, please follow the contribution guidelines outlined below.
|
4 |
+
|
5 |
+
## **Submitting Contributions**
|
6 |
+
|
7 |
+
To contribute to this project, please adhere to the following steps:
|
8 |
+
|
9 |
+
### **1. Choose or Create an Issue**
|
10 |
+
|
11 |
+
- Start by reviewing the existing issues to identify areas where your contributions can make a significant impact.
|
12 |
+
- If you have ideas for new features, enhancements, or bug fixes, feel free to create a new issue to propose your contributions. Provide comprehensive details for clarity.
|
13 |
+
|
14 |
+
### **2. Fork the Repository**
|
15 |
+
|
16 |
+
- To initiate your contribution, fork the primary repository by clicking the "Fork" button. This will create a copy of the repository in your personal GitHub account.
|
17 |
+
|
18 |
+
### **3. Clone Your Forked Repository**
|
19 |
+
|
20 |
+
- Clone your forked repository to your local development environment using the following command:
|
21 |
+
|
22 |
+
```bash
|
23 |
+
git clone https://github.com/your-username/AudioSep.git
|
24 |
+
```
|
25 |
+
|
26 |
+
### **4. Set Up the Upstream Remote**
|
27 |
+
|
28 |
+
- Maintain a reference to the primary project by adding it as the upstream remote:
|
29 |
+
|
30 |
+
```bash
|
31 |
+
cd AudioSep
|
32 |
+
git remote add upstream https://github.com/Audio-AGI/AudioSep
|
33 |
+
git remote -v
|
34 |
+
```
|
35 |
+
|
36 |
+
### **5. Create a New Branch**
|
37 |
+
|
38 |
+
- Before starting your contribution, establish a new branch dedicated to your specific task:
|
39 |
+
|
40 |
+
```bash
|
41 |
+
git checkout -b my-contribution
|
42 |
+
```
|
43 |
+
|
44 |
+
## **Working on Your Contribution**
|
45 |
+
|
46 |
+
Now that your development environment is ready and a new branch is established, you can start working on your contribution. Please ensure you adhere to the following guidelines:
|
47 |
+
|
48 |
+
### **6. Make Changes**
|
49 |
+
|
50 |
+
- Implement the necessary changes, including code additions, enhancements, or bug fixes. Ensure your contributions are well-structured, documented, and aligned with the project's objectives.
|
51 |
+
|
52 |
+
### **7. Commit Your Changes**
|
53 |
+
|
54 |
+
- Commit your changes using informative commit messages that clearly convey the purpose of your contributions:
|
55 |
+
|
56 |
+
```bash
|
57 |
+
git commit -m "Add a descriptive message here"
|
58 |
+
```
|
59 |
+
|
60 |
+
### **8. Push Your Changes**
|
61 |
+
|
62 |
+
- Push the committed changes to your remote repository on GitHub:
|
63 |
+
|
64 |
+
```bash
|
65 |
+
git push origin my-contribution
|
66 |
+
```
|
67 |
+
|
68 |
+
### **9. Create a Pull Request**
|
69 |
+
|
70 |
+
- Visit your repository on GitHub and click the "New Pull Request" button to initiate a pull request from your branch to the primary repository.
|
71 |
+
|
72 |
+
### **10. Await Review**
|
73 |
+
|
74 |
+
- Your pull request will undergo review, and feedback will be provided by the project maintainers or fellow contributors. Be prepared to address any suggested changes or refinements.
|
75 |
+
|
76 |
+
## **Community Engagement**
|
77 |
+
|
78 |
+
While contributing, please consider engaging with the community in the following ways:
|
79 |
+
|
80 |
+
### **11. Join Discussions**
|
81 |
+
|
82 |
+
- Participate in discussions related to audio separation techniques and their applications. Share your insights, experiences, and expertise in the audio field.
|
83 |
+
|
84 |
+
### **12. Share Ideas**
|
85 |
+
|
86 |
+
- If you have innovative ideas for advancing the project or optimizing audio separation, such as new algorithms or research findings, feel free to open issues to initiate productive discussions.
|
87 |
+
|
88 |
+
## **Acknowledgment**
|
89 |
+
|
90 |
+
We appreciate your dedication to the world of audio separation. Your contributions play a crucial role in harmonizing audio and improving the listening experience for all. If you have questions or require assistance, please don't hesitate to contact the project maintainers.
|
91 |
+
|
92 |
+
Thank you for your valuable contributions, and we eagerly anticipate collaborating with you on AudioSep! 🎶🙌
|
Dockerfile
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10.11
|
2 |
+
|
3 |
+
# Copy the current directory contents into the container at .
|
4 |
+
COPY . .
|
5 |
+
|
6 |
+
# Set the working directory to /
|
7 |
+
WORKDIR /
|
8 |
+
|
9 |
+
# Install requirements.txt
|
10 |
+
RUN pip install --no-cache-dir --upgrade -r /requirements.txt
|
11 |
+
|
12 |
+
RUN useradd -m -u 1000 user
|
13 |
+
USER user
|
14 |
+
ENV HOME=/home/user \
|
15 |
+
PATH=/home/user/.local/bin:$PATH
|
16 |
+
|
17 |
+
WORKDIR $HOME/app
|
18 |
+
|
19 |
+
COPY --chown=user . $HOME/app
|
20 |
+
|
21 |
+
# Start the FastAPI app on port 7860, the default port expected by Spaces
|
22 |
+
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "7860"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Xubo Liu
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE
|
assets/results.png
ADDED
benchmark.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
import numpy as np
|
4 |
+
from evaluation.evaluate_audioset import AudioSetEvaluator
|
5 |
+
from evaluation.evaluate_audiocaps import AudioCapsEvaluator
|
6 |
+
from evaluation.evaluate_vggsound import VGGSoundEvaluator
|
7 |
+
from evaluation.evaluate_music import MUSICEvaluator
|
8 |
+
from evaluation.evaluate_esc50 import ESC50Evaluator
|
9 |
+
from evaluation.evaluate_clotho import ClothoEvaluator
|
10 |
+
from models.clap_encoder import CLAP_Encoder
|
11 |
+
|
12 |
+
from utils import (
|
13 |
+
load_ss_model,
|
14 |
+
calculate_sdr,
|
15 |
+
calculate_sisdr,
|
16 |
+
parse_yaml,
|
17 |
+
get_mean_sdr_from_dict,
|
18 |
+
)
|
19 |
+
|
20 |
+
def eval(checkpoint_path, config_yaml='config/audiosep_base.yaml'):
|
21 |
+
|
22 |
+
log_dir = 'eval_logs'
|
23 |
+
os.makedirs(log_dir, exist_ok=True)
|
24 |
+
|
25 |
+
device = "cuda"
|
26 |
+
|
27 |
+
configs = parse_yaml(config_yaml)
|
28 |
+
|
29 |
+
# AudioSet Evaluators
|
30 |
+
audioset_evaluator = AudioSetEvaluator()
|
31 |
+
# AudioCaps Evaluator
|
32 |
+
audiocaps_evaluator = AudioCapsEvaluator()
|
33 |
+
# VGGSound+ Evaluator
|
34 |
+
vggsound_evaluator = VGGSoundEvaluator()
|
35 |
+
# Clotho Evaluator
|
36 |
+
clotho_evaluator = ClothoEvaluator()
|
37 |
+
# MUSIC Evaluator
|
38 |
+
music_evaluator = MUSICEvaluator()
|
39 |
+
# ESC-50 Evaluator
|
40 |
+
esc50_evaluator = ESC50Evaluator()
|
41 |
+
|
42 |
+
# Load model
|
43 |
+
query_encoder = CLAP_Encoder().eval()
|
44 |
+
|
45 |
+
pl_model = load_ss_model(
|
46 |
+
configs=configs,
|
47 |
+
checkpoint_path=checkpoint_path,
|
48 |
+
query_encoder=query_encoder
|
49 |
+
).to(device)
|
50 |
+
|
51 |
+
print(f'------- Start Evaluation -------')
|
52 |
+
|
53 |
+
# evaluation on Clotho
|
54 |
+
SISDR, SDRi = clotho_evaluator(pl_model)
|
55 |
+
msg_clotho = "Clotho Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
|
56 |
+
print(msg_clotho)
|
57 |
+
|
58 |
+
# evaluation on VGGSound+ (YAN)
|
59 |
+
SISDR, SDRi = vggsound_evaluator(pl_model)
|
60 |
+
msg_vgg = "VGGSound Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
|
61 |
+
print(msg_vgg)
|
62 |
+
|
63 |
+
# evaluation on MUSIC
|
64 |
+
SISDR, SDRi = music_evaluator(pl_model)
|
65 |
+
msg_music = "MUSIC Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
|
66 |
+
print(msg_music)
|
67 |
+
|
68 |
+
# evaluation on ESC-50
|
69 |
+
SISDR, SDRi = esc50_evaluator(pl_model)
|
70 |
+
msg_esc50 = "ESC-50 Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
|
71 |
+
print(msg_esc50)
|
72 |
+
|
73 |
+
# evaluation on AudioSet
|
74 |
+
stats_dict = audioset_evaluator(pl_model=pl_model)
|
75 |
+
median_sdris = {}
|
76 |
+
median_sisdrs = {}
|
77 |
+
|
78 |
+
for class_id in range(527):
|
79 |
+
median_sdris[class_id] = np.nanmedian(stats_dict["sdris_dict"][class_id])
|
80 |
+
median_sisdrs[class_id] = np.nanmedian(stats_dict["sisdrs_dict"][class_id])
|
81 |
+
|
82 |
+
SDRi = get_mean_sdr_from_dict(median_sdris)
|
83 |
+
SISDR = get_mean_sdr_from_dict(median_sisdrs)
|
84 |
+
msg_audioset = "AudioSet Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
|
85 |
+
print(msg_audioset)
|
86 |
+
|
87 |
+
# evaluation on AudioCaps
|
88 |
+
SISDR, SDRi = audiocaps_evaluator(pl_model)
|
89 |
+
msg_audiocaps = "AudioCaps Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
|
90 |
+
print(msg_audiocaps)
|
91 |
+
|
92 |
+
# evaluation on Clotho
|
93 |
+
SISDR, SDRi = clotho_evaluator(pl_model)
|
94 |
+
msg_clotho = "Clotho Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
|
95 |
+
print(msg_clotho)
|
96 |
+
|
97 |
+
msgs = [msg_audioset, msg_vgg, msg_audiocaps, msg_clotho, msg_music, msg_esc50]
|
98 |
+
|
99 |
+
# open file in write mode
|
100 |
+
log_path = os.path.join(log_dir, 'eval_results.txt')
|
101 |
+
with open(log_path, 'w') as fp:
|
102 |
+
for msg in msgs:
|
103 |
+
fp.write(msg + '\n')
|
104 |
+
print(f'Eval log is written to {log_path} ...')
|
105 |
+
print('------------------------- Done ---------------------------')
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == '__main__':
|
109 |
+
eval(checkpoint_path='checkpoint/audiosep_base.ckpt')
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
callbacks/base.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import lightning.pytorch as pl
|
3 |
+
from lightning.pytorch.utilities import rank_zero_only
|
4 |
+
|
5 |
+
|
6 |
+
class CheckpointEveryNSteps(pl.Callback):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
checkpoints_dir,
|
10 |
+
save_step_frequency,
|
11 |
+
) -> None:
|
12 |
+
r"""Save a checkpoint every N steps.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
checkpoints_dir (str): directory to save checkpoints
|
16 |
+
save_step_frequency (int): save checkpoint every N step
|
17 |
+
"""
|
18 |
+
|
19 |
+
self.checkpoints_dir = checkpoints_dir
|
20 |
+
self.save_step_frequency = save_step_frequency
|
21 |
+
|
22 |
+
@rank_zero_only
|
23 |
+
def on_train_batch_end(self, *args, **kwargs) -> None:
|
24 |
+
r"""Save a checkpoint every N steps."""
|
25 |
+
|
26 |
+
trainer = args[0]
|
27 |
+
global_step = trainer.global_step
|
28 |
+
|
29 |
+
if global_step == 1 or global_step % self.save_step_frequency == 0:
|
30 |
+
|
31 |
+
ckpt_path = os.path.join(
|
32 |
+
self.checkpoints_dir,
|
33 |
+
"step={}.ckpt".format(global_step))
|
34 |
+
trainer.save_checkpoint(ckpt_path)
|
35 |
+
print("Save checkpoint to {}".format(ckpt_path))
|
checkpoint/audiosep_base_4M_steps.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8cda01bfd0ebd141eef45d41db7a3ada23a56568465840d3cff04b8010ce82c
|
3 |
+
size 1264844076
|
checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51c68f12f9d7ea25fdaaccf741ec7f81e93ee594455410f3bca4f47f88d8e006
|
3 |
+
size 2352471003
|
cog.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Cog ⚙️
|
2 |
+
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
3 |
+
|
4 |
+
build:
|
5 |
+
gpu: true
|
6 |
+
python_version: "3.11"
|
7 |
+
python_packages:
|
8 |
+
- "torchlibrosa==0.1.0"
|
9 |
+
- "lightning==2.1.0"
|
10 |
+
- "torch==2.0.1"
|
11 |
+
- "transformers==4.28.1"
|
12 |
+
- "braceexpand==0.1.7"
|
13 |
+
- "webdataset==0.2.60"
|
14 |
+
- "soundfile==0.12.1"
|
15 |
+
- "torchaudio==2.0.2"
|
16 |
+
- "torchvision==0.15.2"
|
17 |
+
- "h5py==3.10.0"
|
18 |
+
- "ftfy==6.1.1"
|
19 |
+
- "pandas==2.1.1"
|
20 |
+
- "wget==3.2"
|
21 |
+
predict: "predict.py:Predictor"
|
config/audiosep_base.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
task_name: AudioSep
|
3 |
+
|
4 |
+
data:
|
5 |
+
datafiles:
|
6 |
+
- 'datafiles/template.json'
|
7 |
+
|
8 |
+
sampling_rate: 32000
|
9 |
+
segment_seconds: 5
|
10 |
+
loudness_norm:
|
11 |
+
lower_db: -10
|
12 |
+
higher_db: 10
|
13 |
+
max_mix_num: 2
|
14 |
+
|
15 |
+
model:
|
16 |
+
query_net: CLAP
|
17 |
+
condition_size: 512
|
18 |
+
model_type: ResUNet30
|
19 |
+
input_channels: 1
|
20 |
+
output_channels: 1
|
21 |
+
resume_checkpoint: ""
|
22 |
+
use_text_ratio: 1.0
|
23 |
+
|
24 |
+
train:
|
25 |
+
optimizer:
|
26 |
+
optimizer_type: AdamW
|
27 |
+
learning_rate: 1e-3
|
28 |
+
warm_up_steps: 10000
|
29 |
+
reduce_lr_steps: 1000000
|
30 |
+
lr_lambda_type: constant_warm_up
|
31 |
+
num_nodes: 1
|
32 |
+
num_workers: 6
|
33 |
+
loss_type: l1_wav
|
34 |
+
sync_batchnorm: True
|
35 |
+
batch_size_per_device: 12
|
36 |
+
steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`.
|
37 |
+
evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps.
|
38 |
+
save_step_frequency: 20000 # Save every #save_step_frequency steps.
|
39 |
+
early_stop_steps: 10000001
|
40 |
+
random_seed: 1234
|
41 |
+
|
data/audiotext_dataset.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
|
8 |
+
class AudioTextDataset(Dataset):
|
9 |
+
"""Can sample data from audio-text databases
|
10 |
+
Params:
|
11 |
+
sampling_rate: audio sampling rate
|
12 |
+
max_clip_len: max length (seconds) of audio clip to be sampled
|
13 |
+
"""
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
datafiles=[''],
|
17 |
+
sampling_rate=32000,
|
18 |
+
max_clip_len=5,
|
19 |
+
):
|
20 |
+
all_data_json = []
|
21 |
+
for datafile in datafiles:
|
22 |
+
with open(datafile, 'r') as fp:
|
23 |
+
data_json = json.load(fp)['data']
|
24 |
+
all_data_json.extend(data_json)
|
25 |
+
self.all_data_json = all_data_json
|
26 |
+
|
27 |
+
self.sampling_rate = sampling_rate
|
28 |
+
self.max_length = max_clip_len * sampling_rate
|
29 |
+
|
30 |
+
def __len__(self):
|
31 |
+
return len(self.all_data_json)
|
32 |
+
|
33 |
+
def _cut_or_randomcrop(self, waveform):
|
34 |
+
# waveform: [1, samples]
|
35 |
+
# random crop
|
36 |
+
if waveform.size(1) > self.max_length:
|
37 |
+
random_idx = random.randint(0, waveform.size(1)-self.max_length)
|
38 |
+
waveform = waveform[:, random_idx:random_idx+self.max_length]
|
39 |
+
else:
|
40 |
+
temp_wav = torch.zeros(1, self.max_length)
|
41 |
+
temp_wav[:, 0:waveform.size(1)] = waveform
|
42 |
+
waveform = temp_wav
|
43 |
+
|
44 |
+
assert waveform.size(1) == self.max_length, \
|
45 |
+
f"number of audio samples is {waveform.size(1)}"
|
46 |
+
|
47 |
+
return waveform
|
48 |
+
|
49 |
+
def _read_audio(self, index):
|
50 |
+
try:
|
51 |
+
audio_path = self.all_data_json[index]['wav']
|
52 |
+
audio_data, audio_rate = torchaudio.load(audio_path, channels_first=True)
|
53 |
+
text = self.all_data_json[index]['caption']
|
54 |
+
|
55 |
+
# drop short utterance
|
56 |
+
if audio_data.size(1) < self.sampling_rate * 1:
|
57 |
+
raise Exception(f'{audio_path} is too short, drop it ...')
|
58 |
+
|
59 |
+
return text, audio_data, audio_rate
|
60 |
+
|
61 |
+
except Exception as e:
|
62 |
+
print(f'error: {e} occurs, when loading {audio_path}')
|
63 |
+
random_index = random.randint(0, len(self.all_data_json)-1)
|
64 |
+
return self._read_audio(index=random_index)
|
65 |
+
|
66 |
+
def __getitem__(self, index):
|
67 |
+
# create a audio tensor
|
68 |
+
text, audio_data, audio_rate = self._read_audio(index)
|
69 |
+
audio_len = audio_data.shape[1] / audio_rate
|
70 |
+
# convert stero to single channel
|
71 |
+
if audio_data.shape[0] > 1:
|
72 |
+
# audio_data: [samples]
|
73 |
+
audio_data = (audio_data[0] + audio_data[1]) / 2
|
74 |
+
else:
|
75 |
+
audio_data = audio_data.squeeze(0)
|
76 |
+
|
77 |
+
# resample audio clip
|
78 |
+
if audio_rate != self.sampling_rate:
|
79 |
+
audio_data = torchaudio.functional.resample(audio_data, orig_freq=audio_rate, new_freq=self.sampling_rate)
|
80 |
+
|
81 |
+
audio_data = audio_data.unsqueeze(0)
|
82 |
+
|
83 |
+
audio_data = self._cut_or_randomcrop(audio_data)
|
84 |
+
|
85 |
+
data_dict = {
|
86 |
+
'text': text,
|
87 |
+
'waveform': audio_data,
|
88 |
+
'modality': 'audio_text'
|
89 |
+
}
|
90 |
+
|
91 |
+
return data_dict
|
data/datamodules.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, NoReturn
|
2 |
+
import torch
|
3 |
+
import lightning.pytorch as pl
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from data.audiotext_dataset import AudioTextDataset
|
6 |
+
|
7 |
+
|
8 |
+
class DataModule(pl.LightningDataModule):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
train_dataset: object,
|
12 |
+
batch_size: int,
|
13 |
+
num_workers: int
|
14 |
+
):
|
15 |
+
r"""Data module. To get one batch of data:
|
16 |
+
|
17 |
+
code-block:: python
|
18 |
+
|
19 |
+
data_module.setup()
|
20 |
+
|
21 |
+
for batch_data_dict in data_module.train_dataloader():
|
22 |
+
print(batch_data_dict.keys())
|
23 |
+
break
|
24 |
+
|
25 |
+
Args:
|
26 |
+
train_sampler: Sampler object
|
27 |
+
train_dataset: Dataset object
|
28 |
+
num_workers: int
|
29 |
+
distributed: bool
|
30 |
+
"""
|
31 |
+
super().__init__()
|
32 |
+
self._train_dataset = train_dataset
|
33 |
+
self.num_workers = num_workers
|
34 |
+
self.batch_size = batch_size
|
35 |
+
self.collate_fn = collate_fn
|
36 |
+
|
37 |
+
|
38 |
+
def prepare_data(self):
|
39 |
+
# download, split, etc...
|
40 |
+
# only called on 1 GPU/TPU in distributed
|
41 |
+
pass
|
42 |
+
|
43 |
+
def setup(self, stage: Optional[str] = None) -> NoReturn:
|
44 |
+
r"""called on every device."""
|
45 |
+
|
46 |
+
# make assignments here (val/train/test split)
|
47 |
+
# called on every process in DDP
|
48 |
+
|
49 |
+
# SegmentSampler is used for selecting segments for training.
|
50 |
+
# On multiple devices, each SegmentSampler samples a part of mini-batch
|
51 |
+
# data.
|
52 |
+
self.train_dataset = self._train_dataset
|
53 |
+
|
54 |
+
|
55 |
+
def train_dataloader(self) -> torch.utils.data.DataLoader:
|
56 |
+
r"""Get train loader."""
|
57 |
+
train_loader = DataLoader(
|
58 |
+
dataset=self.train_dataset,
|
59 |
+
batch_size=self.batch_size,
|
60 |
+
collate_fn=self.collate_fn,
|
61 |
+
num_workers=self.num_workers,
|
62 |
+
pin_memory=True,
|
63 |
+
persistent_workers=False,
|
64 |
+
shuffle=True
|
65 |
+
)
|
66 |
+
|
67 |
+
return train_loader
|
68 |
+
|
69 |
+
def val_dataloader(self):
|
70 |
+
# val_split = Dataset(...)
|
71 |
+
# return DataLoader(val_split)
|
72 |
+
pass
|
73 |
+
|
74 |
+
def test_dataloader(self):
|
75 |
+
# test_split = Dataset(...)
|
76 |
+
# return DataLoader(test_split)
|
77 |
+
pass
|
78 |
+
|
79 |
+
def teardown(self):
|
80 |
+
# clean up after fit or test
|
81 |
+
# called on every process in DDP
|
82 |
+
pass
|
83 |
+
|
84 |
+
|
85 |
+
def collate_fn(list_data_dict):
|
86 |
+
r"""Collate mini-batch data to inputs and targets for training.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
list_data_dict: e.g., [
|
90 |
+
{
|
91 |
+
'text': 'a sound of dog',
|
92 |
+
'waveform': (1, samples),
|
93 |
+
'modality': 'audio_text'
|
94 |
+
}
|
95 |
+
...
|
96 |
+
]
|
97 |
+
Returns:
|
98 |
+
data_dict: e.g.
|
99 |
+
'audio_text': {
|
100 |
+
'text': ['a sound of dog', ...]
|
101 |
+
'waveform': (batch_size, 1, samples)
|
102 |
+
}
|
103 |
+
"""
|
104 |
+
|
105 |
+
at_list_data_dict = [data_dict for data_dict in list_data_dict if data_dict['modality']=='audio_text']
|
106 |
+
|
107 |
+
at_data_dict = {}
|
108 |
+
|
109 |
+
if len(at_list_data_dict) > 0:
|
110 |
+
for key in at_list_data_dict[0].keys():
|
111 |
+
at_data_dict[key] = [at_data_dict[key] for at_data_dict in at_list_data_dict]
|
112 |
+
if key == 'waveform':
|
113 |
+
at_data_dict[key] = torch.stack(at_data_dict[key])
|
114 |
+
elif key == 'text':
|
115 |
+
at_data_dict[key] = [text for text in at_data_dict[key]]
|
116 |
+
|
117 |
+
|
118 |
+
data_dict = {
|
119 |
+
'audio_text': at_data_dict
|
120 |
+
}
|
121 |
+
|
122 |
+
return data_dict
|
data/waveform_mixers.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import sre_compile
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import pyloudnorm as pyln
|
7 |
+
|
8 |
+
|
9 |
+
class SegmentMixer(nn.Module):
|
10 |
+
def __init__(self, max_mix_num, lower_db, higher_db):
|
11 |
+
super(SegmentMixer, self).__init__()
|
12 |
+
|
13 |
+
self.max_mix_num = max_mix_num
|
14 |
+
self.loudness_param = {
|
15 |
+
'lower_db': lower_db,
|
16 |
+
'higher_db': higher_db,
|
17 |
+
}
|
18 |
+
|
19 |
+
def __call__(self, waveforms):
|
20 |
+
|
21 |
+
batch_size = waveforms.shape[0]
|
22 |
+
|
23 |
+
data_dict = {
|
24 |
+
'segment': [],
|
25 |
+
'mixture': [],
|
26 |
+
}
|
27 |
+
|
28 |
+
for n in range(0, batch_size):
|
29 |
+
|
30 |
+
segment = waveforms[n].clone()
|
31 |
+
|
32 |
+
# create zero tensors as the background template
|
33 |
+
noise = torch.zeros_like(segment)
|
34 |
+
|
35 |
+
mix_num = random.randint(2, self.max_mix_num)
|
36 |
+
assert mix_num >= 2
|
37 |
+
|
38 |
+
for i in range(1, mix_num):
|
39 |
+
next_segment = waveforms[(n + i) % batch_size]
|
40 |
+
rescaled_next_segment = dynamic_loudnorm(audio=next_segment, reference=segment, **self.loudness_param)
|
41 |
+
noise += rescaled_next_segment
|
42 |
+
|
43 |
+
# randomly normalize background noise
|
44 |
+
noise = dynamic_loudnorm(audio=noise, reference=segment, **self.loudness_param)
|
45 |
+
|
46 |
+
# create audio mixyure
|
47 |
+
mixture = segment + noise
|
48 |
+
|
49 |
+
# declipping if need be
|
50 |
+
max_value = torch.max(torch.abs(mixture))
|
51 |
+
if max_value > 1:
|
52 |
+
segment *= 0.9 / max_value
|
53 |
+
mixture *= 0.9 / max_value
|
54 |
+
|
55 |
+
data_dict['segment'].append(segment)
|
56 |
+
data_dict['mixture'].append(mixture)
|
57 |
+
|
58 |
+
for key in data_dict.keys():
|
59 |
+
data_dict[key] = torch.stack(data_dict[key], dim=0)
|
60 |
+
|
61 |
+
# return data_dict
|
62 |
+
return data_dict['mixture'], data_dict['segment']
|
63 |
+
|
64 |
+
|
65 |
+
def rescale_to_match_energy(segment1, segment2):
|
66 |
+
|
67 |
+
ratio = get_energy_ratio(segment1, segment2)
|
68 |
+
rescaled_segment1 = segment1 / ratio
|
69 |
+
return rescaled_segment1
|
70 |
+
|
71 |
+
|
72 |
+
def get_energy(x):
|
73 |
+
return torch.mean(x ** 2)
|
74 |
+
|
75 |
+
|
76 |
+
def get_energy_ratio(segment1, segment2):
|
77 |
+
|
78 |
+
energy1 = get_energy(segment1)
|
79 |
+
energy2 = max(get_energy(segment2), 1e-10)
|
80 |
+
ratio = (energy1 / energy2) ** 0.5
|
81 |
+
ratio = torch.clamp(ratio, 0.02, 50)
|
82 |
+
return ratio
|
83 |
+
|
84 |
+
|
85 |
+
def dynamic_loudnorm(audio, reference, lower_db=-10, higher_db=10):
|
86 |
+
rescaled_audio = rescale_to_match_energy(audio, reference)
|
87 |
+
|
88 |
+
delta_loudness = random.randint(lower_db, higher_db)
|
89 |
+
|
90 |
+
gain = np.power(10.0, delta_loudness / 20.0)
|
91 |
+
|
92 |
+
return gain * rescaled_audio
|
93 |
+
|
94 |
+
|
95 |
+
def torch_to_numpy(tensor):
|
96 |
+
"""Convert a PyTorch tensor to a NumPy array."""
|
97 |
+
if isinstance(tensor, torch.Tensor):
|
98 |
+
return tensor.detach().cpu().numpy()
|
99 |
+
else:
|
100 |
+
raise ValueError("Input must be a PyTorch tensor.")
|
101 |
+
|
102 |
+
|
103 |
+
def numpy_to_torch(array):
|
104 |
+
"""Convert a NumPy array to a PyTorch tensor."""
|
105 |
+
if isinstance(array, np.ndarray):
|
106 |
+
return torch.from_numpy(array)
|
107 |
+
else:
|
108 |
+
raise ValueError("Input must be a NumPy array.")
|
109 |
+
|
110 |
+
|
111 |
+
# decayed
|
112 |
+
def random_loudness_norm(audio, lower_db=-35, higher_db=-15, sr=32000):
|
113 |
+
device = audio.device
|
114 |
+
audio = torch_to_numpy(audio.squeeze(0))
|
115 |
+
# randomly select a norm volume
|
116 |
+
norm_vol = random.randint(lower_db, higher_db)
|
117 |
+
|
118 |
+
# measure the loudness first
|
119 |
+
meter = pyln.Meter(sr) # create BS.1770 meter
|
120 |
+
loudness = meter.integrated_loudness(audio)
|
121 |
+
# loudness normalize audio
|
122 |
+
normalized_audio = pyln.normalize.loudness(audio, loudness, norm_vol)
|
123 |
+
|
124 |
+
normalized_audio = numpy_to_torch(normalized_audio).unsqueeze(0)
|
125 |
+
|
126 |
+
return normalized_audio.to(device)
|
127 |
+
|
datafiles/template.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"data": [
|
3 |
+
{
|
4 |
+
"wav": "path_to_audio_file",
|
5 |
+
"caption": "textual_desciptions"
|
6 |
+
}
|
7 |
+
]
|
8 |
+
}
|
environment.yml
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: AudioSep
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- _libgcc_mutex=0.1=main
|
8 |
+
- _openmp_mutex=5.1=1_gnu
|
9 |
+
- backcall=0.2.0=pyhd3eb1b0_0
|
10 |
+
- blas=1.0=mkl
|
11 |
+
- boltons=23.0.0=py310h06a4308_0
|
12 |
+
- brotlipy=0.7.0=py310h7f8727e_1002
|
13 |
+
- bzip2=1.0.8=h7b6447c_0
|
14 |
+
- ca-certificates=2023.01.10=h06a4308_0
|
15 |
+
- certifi=2022.12.7=py310h06a4308_0
|
16 |
+
- cffi=1.15.1=py310h5eee18b_3
|
17 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
18 |
+
- comm=0.1.2=py310h06a4308_0
|
19 |
+
- conda=23.3.1=py310h06a4308_0
|
20 |
+
- conda-content-trust=0.1.3=py310h06a4308_0
|
21 |
+
- conda-package-handling=2.0.2=py310h06a4308_0
|
22 |
+
- conda-package-streaming=0.7.0=py310h06a4308_0
|
23 |
+
- cryptography=38.0.4=py310h9ce1e76_0
|
24 |
+
- cuda=11.6.1=0
|
25 |
+
- cuda-cccl=11.6.55=hf6102b2_0
|
26 |
+
- cuda-command-line-tools=11.6.2=0
|
27 |
+
- cuda-compiler=11.6.2=0
|
28 |
+
- cuda-cudart=11.6.55=he381448_0
|
29 |
+
- cuda-cudart-dev=11.6.55=h42ad0f4_0
|
30 |
+
- cuda-cuobjdump=11.6.124=h2eeebcb_0
|
31 |
+
- cuda-cupti=11.6.124=h86345e5_0
|
32 |
+
- cuda-cuxxfilt=11.6.124=hecbf4f6_0
|
33 |
+
- cuda-driver-dev=11.6.55=0
|
34 |
+
- cuda-gdb=12.1.55=0
|
35 |
+
- cuda-libraries=11.6.1=0
|
36 |
+
- cuda-libraries-dev=11.6.1=0
|
37 |
+
- cuda-memcheck=11.8.86=0
|
38 |
+
- cuda-nsight=12.1.55=0
|
39 |
+
- cuda-nsight-compute=12.1.0=0
|
40 |
+
- cuda-nvcc=11.6.124=hbba6d2d_0
|
41 |
+
- cuda-nvdisasm=12.1.55=0
|
42 |
+
- cuda-nvml-dev=11.6.55=haa9ef22_0
|
43 |
+
- cuda-nvprof=12.1.55=0
|
44 |
+
- cuda-nvprune=11.6.124=he22ec0a_0
|
45 |
+
- cuda-nvrtc=11.6.124=h020bade_0
|
46 |
+
- cuda-nvrtc-dev=11.6.124=h249d397_0
|
47 |
+
- cuda-nvtx=11.6.124=h0630a44_0
|
48 |
+
- cuda-nvvp=12.1.55=0
|
49 |
+
- cuda-runtime=11.6.1=0
|
50 |
+
- cuda-samples=11.6.101=h8efea70_0
|
51 |
+
- cuda-sanitizer-api=12.1.55=0
|
52 |
+
- cuda-toolkit=11.6.1=0
|
53 |
+
- cuda-tools=11.6.1=0
|
54 |
+
- cuda-visual-tools=11.6.1=0
|
55 |
+
- debugpy=1.5.1=py310h295c915_0
|
56 |
+
- decorator=5.1.1=pyhd3eb1b0_0
|
57 |
+
- flit-core=3.8.0=py310h06a4308_0
|
58 |
+
- freetype=2.12.1=h4a9f257_0
|
59 |
+
- gds-tools=1.6.0.25=0
|
60 |
+
- giflib=5.2.1=h5eee18b_3
|
61 |
+
- gmp=6.2.1=h295c915_3
|
62 |
+
- gnutls=3.6.15=he1e5248_0
|
63 |
+
- idna=3.4=py310h06a4308_0
|
64 |
+
- intel-openmp=2021.4.0=h06a4308_3561
|
65 |
+
- ipykernel=6.19.2=py310h2f386ee_0
|
66 |
+
- ipython=8.12.0=py310h06a4308_0
|
67 |
+
- jpeg=9e=h5eee18b_1
|
68 |
+
- jsonpatch=1.32=pyhd3eb1b0_0
|
69 |
+
- jsonpointer=2.1=pyhd3eb1b0_0
|
70 |
+
- jupyter_client=8.1.0=py310h06a4308_0
|
71 |
+
- jupyter_core=5.3.0=py310h06a4308_0
|
72 |
+
- lame=3.100=h7b6447c_0
|
73 |
+
- lcms2=2.12=h3be6417_0
|
74 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
75 |
+
- lerc=3.0=h295c915_0
|
76 |
+
- libcublas=11.9.2.110=h5e84587_0
|
77 |
+
- libcublas-dev=11.9.2.110=h5c901ab_0
|
78 |
+
- libcufft=10.7.1.112=hf425ae0_0
|
79 |
+
- libcufft-dev=10.7.1.112=ha5ce4c0_0
|
80 |
+
- libcufile=1.6.0.25=0
|
81 |
+
- libcufile-dev=1.6.0.25=0
|
82 |
+
- libcurand=10.3.2.56=0
|
83 |
+
- libcurand-dev=10.3.2.56=0
|
84 |
+
- libcusolver=11.3.4.124=h33c3c4e_0
|
85 |
+
- libcusparse=11.7.2.124=h7538f96_0
|
86 |
+
- libcusparse-dev=11.7.2.124=hbbe9722_0
|
87 |
+
- libdeflate=1.17=h5eee18b_0
|
88 |
+
- libffi=3.4.2=h6a678d5_6
|
89 |
+
- libgcc-ng=11.2.0=h1234567_1
|
90 |
+
- libgomp=11.2.0=h1234567_1
|
91 |
+
- libiconv=1.16=h7f8727e_2
|
92 |
+
- libidn2=2.3.2=h7f8727e_0
|
93 |
+
- libnpp=11.6.3.124=hd2722f0_0
|
94 |
+
- libnpp-dev=11.6.3.124=h3c42840_0
|
95 |
+
- libnvjpeg=11.6.2.124=hd473ad6_0
|
96 |
+
- libnvjpeg-dev=11.6.2.124=hb5906b9_0
|
97 |
+
- libpng=1.6.39=h5eee18b_0
|
98 |
+
- libsodium=1.0.18=h7b6447c_0
|
99 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
100 |
+
- libtasn1=4.19.0=h5eee18b_0
|
101 |
+
- libtiff=4.5.0=h6a678d5_2
|
102 |
+
- libunistring=0.9.10=h27cfd23_0
|
103 |
+
- libuuid=1.41.5=h5eee18b_0
|
104 |
+
- libwebp=1.2.4=h11a3e52_1
|
105 |
+
- libwebp-base=1.2.4=h5eee18b_1
|
106 |
+
- lz4-c=1.9.4=h6a678d5_0
|
107 |
+
- matplotlib-inline=0.1.6=py310h06a4308_0
|
108 |
+
- mkl=2021.4.0=h06a4308_640
|
109 |
+
- mkl-service=2.4.0=py310h7f8727e_0
|
110 |
+
- mkl_fft=1.3.1=py310hd6ae3a3_0
|
111 |
+
- mkl_random=1.2.2=py310h00e6091_0
|
112 |
+
- ncurses=6.4=h6a678d5_0
|
113 |
+
- nest-asyncio=1.5.6=py310h06a4308_0
|
114 |
+
- nettle=3.7.3=hbbd107a_1
|
115 |
+
- nsight-compute=2023.1.0.15=0
|
116 |
+
- numpy=1.23.5=py310hd5efca6_0
|
117 |
+
- numpy-base=1.23.5=py310h8e6c178_0
|
118 |
+
- openh264=2.1.1=h4ff587b_0
|
119 |
+
- openssl=1.1.1t=h7f8727e_0
|
120 |
+
- packaging=23.0=py310h06a4308_0
|
121 |
+
- parso=0.8.3=pyhd3eb1b0_0
|
122 |
+
- pexpect=4.8.0=pyhd3eb1b0_3
|
123 |
+
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
124 |
+
- pip=22.3.1=py310h06a4308_0
|
125 |
+
- platformdirs=2.5.2=py310h06a4308_0
|
126 |
+
- pluggy=1.0.0=py310h06a4308_1
|
127 |
+
- psutil=5.9.0=py310h5eee18b_0
|
128 |
+
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
129 |
+
- pure_eval=0.2.2=pyhd3eb1b0_0
|
130 |
+
- pycosat=0.6.4=py310h5eee18b_0
|
131 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
132 |
+
- pyopenssl=22.0.0=pyhd3eb1b0_0
|
133 |
+
- pysocks=1.7.1=py310h06a4308_0
|
134 |
+
- python=3.10.9=h7a1cb2a_0
|
135 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
136 |
+
- pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0
|
137 |
+
- pytorch-cuda=11.6=h867d48c_1
|
138 |
+
- pytorch-mutex=1.0=cuda
|
139 |
+
- pyzmq=23.2.0=py310h6a678d5_0
|
140 |
+
- readline=8.2=h5eee18b_0
|
141 |
+
- requests=2.28.1=py310h06a4308_0
|
142 |
+
- ruamel.yaml=0.17.21=py310h5eee18b_0
|
143 |
+
- ruamel.yaml.clib=0.2.6=py310h5eee18b_1
|
144 |
+
- setuptools=65.6.3=py310h06a4308_0
|
145 |
+
- six=1.16.0=pyhd3eb1b0_1
|
146 |
+
- sqlite=3.40.1=h5082296_0
|
147 |
+
- stack_data=0.2.0=pyhd3eb1b0_0
|
148 |
+
- tk=8.6.12=h1ccaba5_0
|
149 |
+
- toolz=0.12.0=py310h06a4308_0
|
150 |
+
- torchaudio=0.13.1=py310_cu116
|
151 |
+
- torchvision=0.14.1=py310_cu116
|
152 |
+
- tornado=6.2=py310h5eee18b_0
|
153 |
+
- tqdm=4.64.1=py310h06a4308_0
|
154 |
+
- typing_extensions=4.4.0=py310h06a4308_0
|
155 |
+
- tzdata=2022g=h04d1e81_0
|
156 |
+
- urllib3=1.26.14=py310h06a4308_0
|
157 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
158 |
+
- xz=5.2.10=h5eee18b_1
|
159 |
+
- zeromq=4.3.4=h2531618_0
|
160 |
+
- zlib=1.2.13=h5eee18b_0
|
161 |
+
- zstandard=0.18.0=py310h5eee18b_0
|
162 |
+
- zstd=1.5.4=hc292b87_0
|
163 |
+
- pip:
|
164 |
+
- absl-py==1.4.0
|
165 |
+
- aiohttp==3.8.4
|
166 |
+
- aiosignal==1.3.1
|
167 |
+
- anyio==3.6.2
|
168 |
+
- appdirs==1.4.4
|
169 |
+
- arrow==1.2.3
|
170 |
+
- asttokens==2.2.1
|
171 |
+
- async-generator==1.10
|
172 |
+
- async-timeout==4.0.2
|
173 |
+
- attrs==22.2.0
|
174 |
+
- audioread==3.0.0
|
175 |
+
- av==10.0.0
|
176 |
+
- beartype==0.12.0
|
177 |
+
- beautifulsoup4==4.12.2
|
178 |
+
- blessed==1.20.0
|
179 |
+
- braceexpand==0.1.7
|
180 |
+
- cachetools==5.3.0
|
181 |
+
- click==8.1.3
|
182 |
+
- contourpy==1.0.7
|
183 |
+
- croniter==1.3.10
|
184 |
+
- cycler==0.11.0
|
185 |
+
- dataclasses-json==0.5.8
|
186 |
+
- dateutils==0.6.12
|
187 |
+
- decord==0.6.0
|
188 |
+
- deepdiff==6.3.0
|
189 |
+
- dtk==0.2
|
190 |
+
- exceptiongroup==1.1.1
|
191 |
+
- executing==1.2.0
|
192 |
+
- fastapi==0.88.0
|
193 |
+
- ffmpeg==1.4
|
194 |
+
- ffmpeg-python==0.2.0
|
195 |
+
- filelock==3.12.0
|
196 |
+
- fonttools==4.39.3
|
197 |
+
- frozenlist==1.3.3
|
198 |
+
- fsspec==2023.4.0
|
199 |
+
- ftfy==6.1.1
|
200 |
+
- future==0.18.3
|
201 |
+
- gammatone==1.0
|
202 |
+
- google-auth==2.17.3
|
203 |
+
- google-auth-oauthlib==1.0.0
|
204 |
+
- greenlet==2.0.2
|
205 |
+
- grpcio==1.54.0
|
206 |
+
- h11==0.14.0
|
207 |
+
- h5py==3.8.0
|
208 |
+
- hickle==5.0.2
|
209 |
+
- huggingface-hub==0.14.1
|
210 |
+
- humanize==4.6.0
|
211 |
+
- imageio==2.27.0
|
212 |
+
- inquirer==3.1.3
|
213 |
+
- ipdb==0.13.13
|
214 |
+
- itsdangerous==2.1.2
|
215 |
+
- jedi==0.18.2
|
216 |
+
- jinja2==3.1.2
|
217 |
+
- joblib==1.2.0
|
218 |
+
- kiwisolver==1.4.4
|
219 |
+
- langchain==0.0.216
|
220 |
+
- langchainplus-sdk==0.0.17
|
221 |
+
- lazy-loader==0.2
|
222 |
+
- librosa==0.10.0.post2
|
223 |
+
- lightning==2.0.0
|
224 |
+
- lightning-cloud==0.5.33
|
225 |
+
- lightning-utilities==0.8.0
|
226 |
+
- llvmlite==0.39.1
|
227 |
+
- markdown==3.4.3
|
228 |
+
- markdown-it-py==2.2.0
|
229 |
+
- markupsafe==2.1.2
|
230 |
+
- marshmallow==3.19.0
|
231 |
+
- marshmallow-enum==1.5.1
|
232 |
+
- matplotlib==3.7.1
|
233 |
+
- mdurl==0.1.2
|
234 |
+
- mergedeep==1.3.4
|
235 |
+
- mock==5.0.2
|
236 |
+
- msgpack==1.0.5
|
237 |
+
- msgpack-numpy==0.4.8
|
238 |
+
- multidict==6.0.4
|
239 |
+
- musdb==0.4.0
|
240 |
+
- mypy-extensions==1.0.0
|
241 |
+
- networkx==3.1
|
242 |
+
- nose==1.3.7
|
243 |
+
- numba==0.56.4
|
244 |
+
- numexpr==2.8.4
|
245 |
+
- oauthlib==3.2.2
|
246 |
+
- openai==0.27.8
|
247 |
+
- openapi-schema-pydantic==1.2.4
|
248 |
+
- opencv-python==4.7.0.72
|
249 |
+
- ordered-set==4.1.0
|
250 |
+
- outcome==1.2.0
|
251 |
+
- pandas==1.5.3
|
252 |
+
- panns-inference==0.1.0
|
253 |
+
- pesq==0.0.4
|
254 |
+
- pillow==9.5.0
|
255 |
+
- pooch==1.6.0
|
256 |
+
- prompt-toolkit==3.0.38
|
257 |
+
- protobuf==4.22.3
|
258 |
+
- pyaml==23.5.9
|
259 |
+
- pyasn1==0.5.0
|
260 |
+
- pyasn1-modules==0.3.0
|
261 |
+
- pydantic==1.10.7
|
262 |
+
- pygments==2.14.0
|
263 |
+
- pyjwt==2.6.0
|
264 |
+
- pyloudnorm==0.1.1
|
265 |
+
- pyparsing==3.0.9
|
266 |
+
- pystoi==0.3.3
|
267 |
+
- python-editor==1.0.4
|
268 |
+
- python-multipart==0.0.6
|
269 |
+
- pytorch-ignite==0.3.0
|
270 |
+
- pytorch-lightning==2.0.1.post0
|
271 |
+
- pytz==2023.3
|
272 |
+
- pywavelets==1.4.1
|
273 |
+
- pyyaml==6.0
|
274 |
+
- readchar==4.0.5
|
275 |
+
- regex==2023.3.23
|
276 |
+
- requests-oauthlib==1.3.1
|
277 |
+
- resampy==0.4.2
|
278 |
+
- rich==13.3.3
|
279 |
+
- rsa==4.9
|
280 |
+
- scikit-image==0.20.0
|
281 |
+
- scikit-learn==1.2.2
|
282 |
+
- scipy==1.10.1
|
283 |
+
- selenium==4.8.3
|
284 |
+
- simplejpeg==1.6.6
|
285 |
+
- sniffio==1.3.0
|
286 |
+
- sortedcontainers==2.4.0
|
287 |
+
- soundfile==0.12.1
|
288 |
+
- soupsieve==2.4
|
289 |
+
- soxr==0.3.5
|
290 |
+
- sqlalchemy==2.0.17
|
291 |
+
- stack-data==0.6.2
|
292 |
+
- starlette==0.22.0
|
293 |
+
- starsessions==1.3.0
|
294 |
+
- stempeg==0.2.3
|
295 |
+
- tenacity==8.2.2
|
296 |
+
- tensorboard==2.12.2
|
297 |
+
- tensorboard-data-server==0.7.0
|
298 |
+
- tensorboard-plugin-wit==1.8.1
|
299 |
+
- termcolor==1.1.0
|
300 |
+
- threadpoolctl==3.1.0
|
301 |
+
- tifffile==2023.3.21
|
302 |
+
- timm==0.3.2
|
303 |
+
- tokenizers==0.13.3
|
304 |
+
- tomli==2.0.1
|
305 |
+
- torchfile==0.1.0
|
306 |
+
- torchlibrosa==0.1.0
|
307 |
+
- torchmetrics==0.11.4
|
308 |
+
- traitlets==5.9.0
|
309 |
+
- transformers==4.28.1
|
310 |
+
- trio==0.22.0
|
311 |
+
- trio-websocket==0.10.2
|
312 |
+
- typeguard==3.0.2
|
313 |
+
- typing-extensions==4.5.0
|
314 |
+
- typing-inspect==0.9.0
|
315 |
+
- uvicorn==0.21.1
|
316 |
+
- visdom==0.1.8.9
|
317 |
+
- wcwidth==0.2.6
|
318 |
+
- webdataset==0.2.48
|
319 |
+
- websocket-client==1.5.1
|
320 |
+
- websockets==11.0.1
|
321 |
+
- werkzeug==2.2.3
|
322 |
+
- wget==3.2
|
323 |
+
- wsproto==1.2.0
|
324 |
+
- yarl==1.8.2
|
325 |
+
- zenodo-get==1.3.4
|
326 |
+
- zsvision==0.7.8
|
evaluation/evaluate_audiocaps.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
from typing import Dict, List
|
5 |
+
|
6 |
+
import csv
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
import pathlib
|
12 |
+
import librosa
|
13 |
+
import lightning.pytorch as pl
|
14 |
+
from models.clap_encoder import CLAP_Encoder
|
15 |
+
|
16 |
+
sys.path.append('../AudioSep/')
|
17 |
+
from utils import (
|
18 |
+
load_ss_model,
|
19 |
+
calculate_sdr,
|
20 |
+
calculate_sisdr,
|
21 |
+
parse_yaml,
|
22 |
+
get_mean_sdr_from_dict,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class AudioCapsEvaluator:
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
query='caption',
|
30 |
+
sampling_rate=32000,
|
31 |
+
) -> None:
|
32 |
+
r"""AudioCaps evaluator.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
query (str): type of query, 'caption' or 'labels'
|
36 |
+
Returns:
|
37 |
+
None
|
38 |
+
"""
|
39 |
+
|
40 |
+
self.query = query
|
41 |
+
self.sampling_rate = sampling_rate
|
42 |
+
|
43 |
+
with open(f'evaluation/metadata/audiocaps_eval.csv') as csv_file:
|
44 |
+
csv_reader = csv.reader(csv_file, delimiter=',')
|
45 |
+
eval_list = [row for row in csv_reader][1:]
|
46 |
+
|
47 |
+
self.eval_list = eval_list
|
48 |
+
self.audio_dir = f'evaluation/data/audiocaps'
|
49 |
+
|
50 |
+
def __call__(
|
51 |
+
self,
|
52 |
+
pl_model: pl.LightningModule
|
53 |
+
) -> Dict:
|
54 |
+
r"""Evalute."""
|
55 |
+
|
56 |
+
print(f'Evaluation on AudioCaps with [{self.query}] queries.')
|
57 |
+
|
58 |
+
pl_model.eval()
|
59 |
+
device = pl_model.device
|
60 |
+
|
61 |
+
sisdrs_list = []
|
62 |
+
sdris_list = []
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
for eval_data in tqdm(self.eval_list):
|
66 |
+
|
67 |
+
idx, caption, labels, _, _ = eval_data
|
68 |
+
|
69 |
+
source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav')
|
70 |
+
mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav')
|
71 |
+
|
72 |
+
source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
|
73 |
+
mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
|
74 |
+
|
75 |
+
sdr_no_sep = calculate_sdr(ref=source, est=mixture)
|
76 |
+
|
77 |
+
if self.query == 'caption':
|
78 |
+
text = [caption]
|
79 |
+
elif self.query == 'labels':
|
80 |
+
text = [labels]
|
81 |
+
|
82 |
+
conditions = pl_model.query_encoder.get_query_embed(
|
83 |
+
modality='text',
|
84 |
+
text=text,
|
85 |
+
device=device
|
86 |
+
)
|
87 |
+
|
88 |
+
input_dict = {
|
89 |
+
"mixture": torch.Tensor(mixture)[None, None, :].to(device),
|
90 |
+
"condition": conditions,
|
91 |
+
}
|
92 |
+
|
93 |
+
|
94 |
+
sep_segment = pl_model.ss_model(input_dict)["waveform"]
|
95 |
+
# sep_segment: (batch_size=1, channels_num=1, segment_samples)
|
96 |
+
|
97 |
+
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
|
98 |
+
# sep_segment: (segment_samples,)
|
99 |
+
|
100 |
+
sdr = calculate_sdr(ref=source, est=sep_segment)
|
101 |
+
sdri = sdr - sdr_no_sep
|
102 |
+
sisdr = calculate_sisdr(ref=source, est=sep_segment)
|
103 |
+
|
104 |
+
sisdrs_list.append(sisdr)
|
105 |
+
sdris_list.append(sdri)
|
106 |
+
|
107 |
+
mean_sisdr = np.mean(sisdrs_list)
|
108 |
+
mean_sdri = np.mean(sdris_list)
|
109 |
+
|
110 |
+
return mean_sisdr, mean_sdri
|
evaluation/evaluate_audioset.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
from typing import Dict, List
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
import pathlib
|
11 |
+
import librosa
|
12 |
+
import lightning.pytorch as pl
|
13 |
+
from models.clap_encoder import CLAP_Encoder
|
14 |
+
|
15 |
+
sys.path.append('../AudioSep/')
|
16 |
+
from utils import (
|
17 |
+
load_ss_model,
|
18 |
+
calculate_sdr,
|
19 |
+
calculate_sisdr,
|
20 |
+
parse_yaml,
|
21 |
+
get_mean_sdr_from_dict,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
meta_csv_file = "evaluation/metadata/class_labels_indices.csv"
|
26 |
+
df = pd.read_csv(meta_csv_file, sep=',')
|
27 |
+
|
28 |
+
IDS = df['mid'].tolist()
|
29 |
+
LABELS = df['display_name'].tolist()
|
30 |
+
|
31 |
+
CLASSES_NUM = len(LABELS)
|
32 |
+
|
33 |
+
IX_TO_LB = {i : label for i, label in enumerate(LABELS)}
|
34 |
+
|
35 |
+
|
36 |
+
class AudioSetEvaluator:
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
audios_dir='evaluation/data/audioset',
|
40 |
+
classes_num=527,
|
41 |
+
sampling_rate=32000,
|
42 |
+
number_per_class=10,
|
43 |
+
) -> None:
|
44 |
+
r"""AudioSet evaluator.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
audios_dir (str): directory of evaluation segments
|
48 |
+
classes_num (int): the number of sound classes
|
49 |
+
number_per_class (int), the number of samples to evaluate for each sound class
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
None
|
53 |
+
"""
|
54 |
+
|
55 |
+
self.audios_dir = audios_dir
|
56 |
+
self.classes_num = classes_num
|
57 |
+
self.number_per_class = number_per_class
|
58 |
+
self.sampling_rate = sampling_rate
|
59 |
+
|
60 |
+
@torch.no_grad()
|
61 |
+
def __call__(
|
62 |
+
self,
|
63 |
+
pl_model: pl.LightningModule
|
64 |
+
) -> Dict:
|
65 |
+
r"""Evalute."""
|
66 |
+
|
67 |
+
pl_model.eval()
|
68 |
+
|
69 |
+
sisdrs_dict = {class_id: [] for class_id in range(self.classes_num)}
|
70 |
+
sdris_dict = {class_id: [] for class_id in range(self.classes_num)}
|
71 |
+
|
72 |
+
print('Evaluation on AudioSet with [text label] queries.')
|
73 |
+
|
74 |
+
for class_id in tqdm(range(self.classes_num)):
|
75 |
+
|
76 |
+
sub_dir = os.path.join(
|
77 |
+
self.audios_dir,
|
78 |
+
"class_id={}".format(class_id))
|
79 |
+
|
80 |
+
audio_names = self._get_audio_names(audios_dir=sub_dir)
|
81 |
+
|
82 |
+
for audio_index, audio_name in enumerate(audio_names):
|
83 |
+
|
84 |
+
if audio_index == self.number_per_class:
|
85 |
+
break
|
86 |
+
|
87 |
+
source_path = os.path.join(
|
88 |
+
sub_dir, "{},source.wav".format(audio_name))
|
89 |
+
mixture_path = os.path.join(
|
90 |
+
sub_dir, "{},mixture.wav".format(audio_name))
|
91 |
+
|
92 |
+
source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
|
93 |
+
mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
|
94 |
+
|
95 |
+
sdr_no_sep = calculate_sdr(ref=source, est=mixture)
|
96 |
+
|
97 |
+
device = pl_model.device
|
98 |
+
|
99 |
+
text = [IX_TO_LB[class_id]]
|
100 |
+
|
101 |
+
conditions = pl_model.query_encoder.get_query_embed(
|
102 |
+
modality='text',
|
103 |
+
text=text,
|
104 |
+
device=device
|
105 |
+
)
|
106 |
+
|
107 |
+
input_dict = {
|
108 |
+
"mixture": torch.Tensor(mixture)[None, None, :].to(device),
|
109 |
+
"condition": conditions,
|
110 |
+
}
|
111 |
+
|
112 |
+
sep_segment = pl_model.ss_model(input_dict)["waveform"]
|
113 |
+
# sep_segment: (batch_size=1, channels_num=1, segment_samples)
|
114 |
+
|
115 |
+
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
|
116 |
+
# sep_segment: (segment_samples,)
|
117 |
+
|
118 |
+
sdr = calculate_sdr(ref=source, est=sep_segment)
|
119 |
+
sdri = sdr - sdr_no_sep
|
120 |
+
sisdr = calculate_sisdr(ref=source, est=sep_segment)
|
121 |
+
|
122 |
+
|
123 |
+
sisdrs_dict[class_id].append(sisdr)
|
124 |
+
sdris_dict[class_id].append(sdri)
|
125 |
+
|
126 |
+
|
127 |
+
stats_dict = {
|
128 |
+
"sisdrs_dict": sisdrs_dict,
|
129 |
+
"sdris_dict": sdris_dict,
|
130 |
+
}
|
131 |
+
|
132 |
+
return stats_dict
|
133 |
+
|
134 |
+
def _get_audio_names(self, audios_dir: str) -> List[str]:
|
135 |
+
r"""Get evaluation audio names."""
|
136 |
+
audio_names = sorted(os.listdir(audios_dir))
|
137 |
+
|
138 |
+
audio_names = [audio_name for audio_name in audio_names if '.wav' in audio_name]
|
139 |
+
|
140 |
+
audio_names = [
|
141 |
+
re.search(
|
142 |
+
"(.*),(mixture|source).wav",
|
143 |
+
audio_name).group(1) for audio_name in audio_names]
|
144 |
+
|
145 |
+
audio_names = sorted(list(set(audio_names)))
|
146 |
+
|
147 |
+
return audio_names
|
148 |
+
|
149 |
+
@staticmethod
|
150 |
+
def get_median_metrics(stats_dict, metric_type):
|
151 |
+
class_ids = stats_dict[metric_type].keys()
|
152 |
+
median_stats_dict = {
|
153 |
+
class_id: np.nanmedian(
|
154 |
+
stats_dict[metric_type][class_id]) for class_id in class_ids}
|
155 |
+
return median_stats_dict
|
evaluation/evaluate_clotho.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
from typing import Dict, List
|
5 |
+
|
6 |
+
import csv
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
import pathlib
|
12 |
+
import librosa
|
13 |
+
import lightning.pytorch as pl
|
14 |
+
from models.clap_encoder import CLAP_Encoder
|
15 |
+
|
16 |
+
sys.path.append('../AudioSep/')
|
17 |
+
from utils import (
|
18 |
+
load_ss_model,
|
19 |
+
calculate_sdr,
|
20 |
+
calculate_sisdr,
|
21 |
+
parse_yaml,
|
22 |
+
get_mean_sdr_from_dict,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class ClothoEvaluator:
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
sampling_rate=32000,
|
30 |
+
) -> None:
|
31 |
+
r"""Clotho evaluator.
|
32 |
+
Returns:
|
33 |
+
None
|
34 |
+
"""
|
35 |
+
|
36 |
+
self.sampling_rate = sampling_rate
|
37 |
+
|
38 |
+
with open('evaluation/metadata/clotho_eval.csv') as csv_file:
|
39 |
+
csv_reader = csv.reader(csv_file, delimiter=',')
|
40 |
+
eval_list = [row for row in csv_reader][1:]
|
41 |
+
|
42 |
+
self.eval_list = eval_list
|
43 |
+
self.audio_dir = 'evaluation/data/clotho'
|
44 |
+
|
45 |
+
def __call__(
|
46 |
+
self,
|
47 |
+
pl_model: pl.LightningModule
|
48 |
+
) -> Dict:
|
49 |
+
r"""Evalute."""
|
50 |
+
|
51 |
+
print(f'Evaluation on Clotho Evaluation with [caption] queries.')
|
52 |
+
|
53 |
+
pl_model.eval()
|
54 |
+
device = pl_model.device
|
55 |
+
|
56 |
+
sisdrs_list = []
|
57 |
+
sdris_list = []
|
58 |
+
|
59 |
+
with torch.no_grad():
|
60 |
+
for eval_data in tqdm(self.eval_list):
|
61 |
+
|
62 |
+
idx, caption, _, _, _ = eval_data
|
63 |
+
|
64 |
+
source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav')
|
65 |
+
mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav')
|
66 |
+
|
67 |
+
source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
|
68 |
+
mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
|
69 |
+
|
70 |
+
sdr_no_sep = calculate_sdr(ref=source, est=mixture)
|
71 |
+
|
72 |
+
text = [caption]
|
73 |
+
|
74 |
+
conditions = pl_model.query_encoder.get_query_embed(
|
75 |
+
modality='text',
|
76 |
+
text=text,
|
77 |
+
device=device
|
78 |
+
)
|
79 |
+
|
80 |
+
input_dict = {
|
81 |
+
"mixture": torch.Tensor(mixture)[None, None, :].to(device),
|
82 |
+
"condition": conditions,
|
83 |
+
}
|
84 |
+
|
85 |
+
sep_segment = pl_model.ss_model(input_dict)["waveform"]
|
86 |
+
# sep_segment: (batch_size=1, channels_num=1, segment_samples)
|
87 |
+
|
88 |
+
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
|
89 |
+
# sep_segment: (segment_samples,)
|
90 |
+
|
91 |
+
sdr = calculate_sdr(ref=source, est=sep_segment)
|
92 |
+
sdri = sdr - sdr_no_sep
|
93 |
+
sisdr = calculate_sisdr(ref=source, est=sep_segment)
|
94 |
+
|
95 |
+
|
96 |
+
sisdrs_list.append(sisdr)
|
97 |
+
sdris_list.append(sdri)
|
98 |
+
|
99 |
+
mean_sisdr = np.mean(sisdrs_list)
|
100 |
+
mean_sdri = np.mean(sdris_list)
|
101 |
+
|
102 |
+
return mean_sisdr, mean_sdri
|
evaluation/evaluate_esc50.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
from typing import Dict, List
|
5 |
+
|
6 |
+
import csv
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
import pathlib
|
12 |
+
import librosa
|
13 |
+
import lightning.pytorch as pl
|
14 |
+
from models.clap_encoder import CLAP_Encoder
|
15 |
+
|
16 |
+
sys.path.append('../AudioSep/')
|
17 |
+
from utils import (
|
18 |
+
load_ss_model,
|
19 |
+
calculate_sdr,
|
20 |
+
calculate_sisdr,
|
21 |
+
parse_yaml,
|
22 |
+
get_mean_sdr_from_dict,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class ESC50Evaluator:
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
sampling_rate=32000
|
30 |
+
) -> None:
|
31 |
+
r"""ESC-50 evaluator.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
None
|
35 |
+
"""
|
36 |
+
|
37 |
+
self.sampling_rate = sampling_rate
|
38 |
+
|
39 |
+
with open('evaluation/metadata/esc50_eval.csv') as csv_file:
|
40 |
+
csv_reader = csv.reader(csv_file, delimiter=',')
|
41 |
+
eval_list = [row for row in csv_reader][1:]
|
42 |
+
|
43 |
+
self.eval_list = eval_list
|
44 |
+
self.audio_dir = 'evaluation/data/esc50'
|
45 |
+
|
46 |
+
def __call__(
|
47 |
+
self,
|
48 |
+
pl_model: pl.LightningModule
|
49 |
+
) -> Dict:
|
50 |
+
r"""Evalute."""
|
51 |
+
|
52 |
+
print(f'Evaluation on ESC-50 with [text label] queries.')
|
53 |
+
|
54 |
+
pl_model.eval()
|
55 |
+
device = pl_model.device
|
56 |
+
|
57 |
+
sisdrs_list = []
|
58 |
+
sdris_list = []
|
59 |
+
|
60 |
+
with torch.no_grad():
|
61 |
+
for eval_data in tqdm(self.eval_list):
|
62 |
+
|
63 |
+
idx, caption, _, _, = eval_data
|
64 |
+
|
65 |
+
source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav')
|
66 |
+
mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav')
|
67 |
+
|
68 |
+
source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
|
69 |
+
mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
|
70 |
+
|
71 |
+
sdr_no_sep = calculate_sdr(ref=source, est=mixture)
|
72 |
+
|
73 |
+
text = [caption]
|
74 |
+
|
75 |
+
conditions = pl_model.query_encoder.get_query_embed(
|
76 |
+
modality='text',
|
77 |
+
text=text,
|
78 |
+
device=device
|
79 |
+
)
|
80 |
+
|
81 |
+
input_dict = {
|
82 |
+
"mixture": torch.Tensor(mixture)[None, None, :].to(device),
|
83 |
+
"condition": conditions,
|
84 |
+
}
|
85 |
+
|
86 |
+
sep_segment = pl_model.ss_model(input_dict)["waveform"]
|
87 |
+
# sep_segment: (batch_size=1, channels_num=1, segment_samples)
|
88 |
+
|
89 |
+
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
|
90 |
+
# sep_segment: (segment_samples,)
|
91 |
+
|
92 |
+
sdr = calculate_sdr(ref=source, est=sep_segment)
|
93 |
+
sdri = sdr - sdr_no_sep
|
94 |
+
sisdr = calculate_sisdr(ref=source, est=sep_segment)
|
95 |
+
|
96 |
+
sisdrs_list.append(sisdr)
|
97 |
+
sdris_list.append(sdri)
|
98 |
+
|
99 |
+
mean_sdri = np.mean(sdris_list)
|
100 |
+
mean_sisdr = np.mean(sisdrs_list)
|
101 |
+
|
102 |
+
return mean_sisdr, mean_sdri
|
evaluation/evaluate_music.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
from typing import Dict, List
|
5 |
+
|
6 |
+
import csv
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
import pathlib
|
12 |
+
import librosa
|
13 |
+
import lightning.pytorch as pl
|
14 |
+
from models.clap_encoder import CLAP_Encoder
|
15 |
+
|
16 |
+
sys.path.append('../AudioSep/')
|
17 |
+
from utils import (
|
18 |
+
load_ss_model,
|
19 |
+
calculate_sdr,
|
20 |
+
calculate_sisdr,
|
21 |
+
parse_yaml,
|
22 |
+
get_mean_sdr_from_dict,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class MUSICEvaluator:
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
sampling_rate=32000
|
30 |
+
) -> None:
|
31 |
+
|
32 |
+
self.sampling_rate = sampling_rate
|
33 |
+
|
34 |
+
with open('evaluation/metadata/music_eval.csv') as csv_file:
|
35 |
+
csv_reader = csv.reader(csv_file, delimiter=',')
|
36 |
+
eval_list = [row for row in csv_reader][1:]
|
37 |
+
|
38 |
+
self.eval_list = eval_list
|
39 |
+
self.audio_dir = 'evaluation/data/music'
|
40 |
+
|
41 |
+
self.source_types = [
|
42 |
+
"acoustic guitar",
|
43 |
+
"violin",
|
44 |
+
"accordion",
|
45 |
+
"xylophone",
|
46 |
+
"erhu",
|
47 |
+
"trumpet",
|
48 |
+
"tuba",
|
49 |
+
"cello",
|
50 |
+
"flute",
|
51 |
+
"saxophone"]
|
52 |
+
|
53 |
+
def __call__(
|
54 |
+
self,
|
55 |
+
pl_model: pl.LightningModule
|
56 |
+
) -> Dict:
|
57 |
+
r"""Evalute."""
|
58 |
+
|
59 |
+
print(f'Evaluation on MUSIC Test with [text label] queries.')
|
60 |
+
|
61 |
+
pl_model.eval()
|
62 |
+
device = pl_model.device
|
63 |
+
|
64 |
+
sisdrs_list = {source_type: [] for source_type in self.source_types}
|
65 |
+
sdris_list = {source_type: [] for source_type in self.source_types}
|
66 |
+
|
67 |
+
with torch.no_grad():
|
68 |
+
for eval_data in tqdm(self.eval_list):
|
69 |
+
|
70 |
+
idx, caption, _, _, = eval_data
|
71 |
+
|
72 |
+
source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav')
|
73 |
+
mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav')
|
74 |
+
|
75 |
+
source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
|
76 |
+
mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
|
77 |
+
|
78 |
+
sdr_no_sep = calculate_sdr(ref=source, est=mixture)
|
79 |
+
|
80 |
+
text = [caption]
|
81 |
+
|
82 |
+
conditions = pl_model.query_encoder.get_query_embed(
|
83 |
+
modality='text',
|
84 |
+
text=text,
|
85 |
+
device=device
|
86 |
+
)
|
87 |
+
|
88 |
+
input_dict = {
|
89 |
+
"mixture": torch.Tensor(mixture)[None, None, :].to(device),
|
90 |
+
"condition": conditions,
|
91 |
+
}
|
92 |
+
|
93 |
+
sep_segment = pl_model.ss_model(input_dict)["waveform"]
|
94 |
+
# sep_segment: (batch_size=1, channels_num=1, segment_samples)
|
95 |
+
|
96 |
+
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
|
97 |
+
# sep_segment: (segment_samples,)
|
98 |
+
|
99 |
+
sdr = calculate_sdr(ref=source, est=sep_segment)
|
100 |
+
sdri = sdr - sdr_no_sep
|
101 |
+
sisdr = calculate_sisdr(ref=source, est=sep_segment)
|
102 |
+
|
103 |
+
sisdrs_list[caption].append(sisdr)
|
104 |
+
sdris_list[caption].append(sdri)
|
105 |
+
|
106 |
+
mean_sisdr_list = []
|
107 |
+
mean_sdri_list = []
|
108 |
+
|
109 |
+
for source_class in self.source_types:
|
110 |
+
sisdr = np.mean(sisdrs_list[source_class])
|
111 |
+
sdri = np.mean(sdris_list[source_class])
|
112 |
+
mean_sisdr_list.append(sisdr)
|
113 |
+
mean_sdri_list.append(sdri)
|
114 |
+
|
115 |
+
mean_sdri = np.mean(mean_sdri_list)
|
116 |
+
mean_sisdr = np.mean(mean_sisdr_list)
|
117 |
+
|
118 |
+
return mean_sisdr, mean_sdri
|
evaluation/evaluate_vggsound.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
from typing import Dict, List
|
5 |
+
|
6 |
+
import csv
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
import pathlib
|
12 |
+
import librosa
|
13 |
+
import lightning.pytorch as pl
|
14 |
+
from models.clap_encoder import CLAP_Encoder
|
15 |
+
|
16 |
+
sys.path.append('../AudioSep/')
|
17 |
+
from utils import (
|
18 |
+
load_ss_model,
|
19 |
+
calculate_sdr,
|
20 |
+
calculate_sisdr,
|
21 |
+
parse_yaml,
|
22 |
+
get_mean_sdr_from_dict,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class VGGSoundEvaluator:
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
sampling_rate=32000
|
30 |
+
) -> None:
|
31 |
+
r"""VGGSound evaluator.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
data_recipe (str): dataset split, 'yan'
|
35 |
+
Returns:
|
36 |
+
None
|
37 |
+
"""
|
38 |
+
|
39 |
+
self.sampling_rate = sampling_rate
|
40 |
+
|
41 |
+
with open('evaluation/metadata/vggsound_eval.csv') as csv_file:
|
42 |
+
csv_reader = csv.reader(csv_file, delimiter=',')
|
43 |
+
eval_list = [row for row in csv_reader][1:]
|
44 |
+
|
45 |
+
self.eval_list = eval_list
|
46 |
+
self.audio_dir = 'evaluation/data/vggsound'
|
47 |
+
|
48 |
+
def __call__(
|
49 |
+
self,
|
50 |
+
pl_model: pl.LightningModule
|
51 |
+
) -> Dict:
|
52 |
+
r"""Evalute."""
|
53 |
+
|
54 |
+
print(f'Evaluation on VGGSound+ with [text label] queries.')
|
55 |
+
|
56 |
+
pl_model.eval()
|
57 |
+
device = pl_model.device
|
58 |
+
|
59 |
+
sisdrs_list = []
|
60 |
+
sdris_list = []
|
61 |
+
sisdris_list = []
|
62 |
+
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
for eval_data in tqdm(self.eval_list):
|
66 |
+
|
67 |
+
# labels, source_path, mixture_path = eval_data
|
68 |
+
file_id, mix_wav, s0_wav, s0_text, s1_wav, s1_text = eval_data
|
69 |
+
|
70 |
+
labels = s0_text
|
71 |
+
|
72 |
+
mixture_path = os.path.join(self.audio_dir, mix_wav)
|
73 |
+
source_path = os.path.join(self.audio_dir, s0_wav)
|
74 |
+
|
75 |
+
|
76 |
+
source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
|
77 |
+
mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
|
78 |
+
|
79 |
+
sdr_no_sep = calculate_sdr(ref=source, est=mixture)
|
80 |
+
|
81 |
+
text = [labels]
|
82 |
+
conditions = pl_model.query_encoder.get_query_embed(
|
83 |
+
modality='text',
|
84 |
+
text=text,
|
85 |
+
device=device
|
86 |
+
)
|
87 |
+
|
88 |
+
input_dict = {
|
89 |
+
"mixture": torch.Tensor(mixture)[None, None, :].to(device),
|
90 |
+
"condition": conditions,
|
91 |
+
}
|
92 |
+
|
93 |
+
sep_segment = pl_model.ss_model(input_dict)["waveform"]
|
94 |
+
# sep_segment: (batch_size=1, channels_num=1, segment_samples)
|
95 |
+
|
96 |
+
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
|
97 |
+
# sep_segment: (segment_samples,)
|
98 |
+
|
99 |
+
sdr = calculate_sdr(ref=source, est=sep_segment)
|
100 |
+
sdri = sdr - sdr_no_sep
|
101 |
+
|
102 |
+
sisdr_no_sep = calculate_sisdr(ref=source, est=mixture)
|
103 |
+
sisdr = calculate_sisdr(ref=source, est=sep_segment)
|
104 |
+
sisdri = sisdr - sisdr_no_sep
|
105 |
+
|
106 |
+
sisdrs_list.append(sisdr)
|
107 |
+
sdris_list.append(sdri)
|
108 |
+
sisdris_list.append(sisdri)
|
109 |
+
|
110 |
+
|
111 |
+
mean_sisdr = np.mean(sisdrs_list)
|
112 |
+
mean_sdri = np.mean(sdris_list)
|
113 |
+
|
114 |
+
return mean_sisdr, mean_sdri
|
evaluation/metadata/audiocaps_eval.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluation/metadata/audioset_eval.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluation/metadata/class_labels_indices.csv
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
index,mid,display_name
|
2 |
+
0,/m/09x0r,"Speech"
|
3 |
+
1,/m/05zppz,"Male speech, man speaking"
|
4 |
+
2,/m/02zsn,"Female speech, woman speaking"
|
5 |
+
3,/m/0ytgt,"Child speech, kid speaking"
|
6 |
+
4,/m/01h8n0,"Conversation"
|
7 |
+
5,/m/02qldy,"Narration, monologue"
|
8 |
+
6,/m/0261r1,"Babbling"
|
9 |
+
7,/m/0brhx,"Speech synthesizer"
|
10 |
+
8,/m/07p6fty,"Shout"
|
11 |
+
9,/m/07q4ntr,"Bellow"
|
12 |
+
10,/m/07rwj3x,"Whoop"
|
13 |
+
11,/m/07sr1lc,"Yell"
|
14 |
+
12,/m/04gy_2,"Battle cry"
|
15 |
+
13,/t/dd00135,"Children shouting"
|
16 |
+
14,/m/03qc9zr,"Screaming"
|
17 |
+
15,/m/02rtxlg,"Whispering"
|
18 |
+
16,/m/01j3sz,"Laughter"
|
19 |
+
17,/t/dd00001,"Baby laughter"
|
20 |
+
18,/m/07r660_,"Giggle"
|
21 |
+
19,/m/07s04w4,"Snicker"
|
22 |
+
20,/m/07sq110,"Belly laugh"
|
23 |
+
21,/m/07rgt08,"Chuckle, chortle"
|
24 |
+
22,/m/0463cq4,"Crying, sobbing"
|
25 |
+
23,/t/dd00002,"Baby cry, infant cry"
|
26 |
+
24,/m/07qz6j3,"Whimper"
|
27 |
+
25,/m/07qw_06,"Wail, moan"
|
28 |
+
26,/m/07plz5l,"Sigh"
|
29 |
+
27,/m/015lz1,"Singing"
|
30 |
+
28,/m/0l14jd,"Choir"
|
31 |
+
29,/m/01swy6,"Yodeling"
|
32 |
+
30,/m/02bk07,"Chant"
|
33 |
+
31,/m/01c194,"Mantra"
|
34 |
+
32,/t/dd00003,"Male singing"
|
35 |
+
33,/t/dd00004,"Female singing"
|
36 |
+
34,/t/dd00005,"Child singing"
|
37 |
+
35,/t/dd00006,"Synthetic singing"
|
38 |
+
36,/m/06bxc,"Rapping"
|
39 |
+
37,/m/02fxyj,"Humming"
|
40 |
+
38,/m/07s2xch,"Groan"
|
41 |
+
39,/m/07r4k75,"Grunt"
|
42 |
+
40,/m/01w250,"Whistling"
|
43 |
+
41,/m/0lyf6,"Breathing"
|
44 |
+
42,/m/07mzm6,"Wheeze"
|
45 |
+
43,/m/01d3sd,"Snoring"
|
46 |
+
44,/m/07s0dtb,"Gasp"
|
47 |
+
45,/m/07pyy8b,"Pant"
|
48 |
+
46,/m/07q0yl5,"Snort"
|
49 |
+
47,/m/01b_21,"Cough"
|
50 |
+
48,/m/0dl9sf8,"Throat clearing"
|
51 |
+
49,/m/01hsr_,"Sneeze"
|
52 |
+
50,/m/07ppn3j,"Sniff"
|
53 |
+
51,/m/06h7j,"Run"
|
54 |
+
52,/m/07qv_x_,"Shuffle"
|
55 |
+
53,/m/07pbtc8,"Walk, footsteps"
|
56 |
+
54,/m/03cczk,"Chewing, mastication"
|
57 |
+
55,/m/07pdhp0,"Biting"
|
58 |
+
56,/m/0939n_,"Gargling"
|
59 |
+
57,/m/01g90h,"Stomach rumble"
|
60 |
+
58,/m/03q5_w,"Burping, eructation"
|
61 |
+
59,/m/02p3nc,"Hiccup"
|
62 |
+
60,/m/02_nn,"Fart"
|
63 |
+
61,/m/0k65p,"Hands"
|
64 |
+
62,/m/025_jnm,"Finger snapping"
|
65 |
+
63,/m/0l15bq,"Clapping"
|
66 |
+
64,/m/01jg02,"Heart sounds, heartbeat"
|
67 |
+
65,/m/01jg1z,"Heart murmur"
|
68 |
+
66,/m/053hz1,"Cheering"
|
69 |
+
67,/m/028ght,"Applause"
|
70 |
+
68,/m/07rkbfh,"Chatter"
|
71 |
+
69,/m/03qtwd,"Crowd"
|
72 |
+
70,/m/07qfr4h,"Hubbub, speech noise, speech babble"
|
73 |
+
71,/t/dd00013,"Children playing"
|
74 |
+
72,/m/0jbk,"Animal"
|
75 |
+
73,/m/068hy,"Domestic animals, pets"
|
76 |
+
74,/m/0bt9lr,"Dog"
|
77 |
+
75,/m/05tny_,"Bark"
|
78 |
+
76,/m/07r_k2n,"Yip"
|
79 |
+
77,/m/07qf0zm,"Howl"
|
80 |
+
78,/m/07rc7d9,"Bow-wow"
|
81 |
+
79,/m/0ghcn6,"Growling"
|
82 |
+
80,/t/dd00136,"Whimper (dog)"
|
83 |
+
81,/m/01yrx,"Cat"
|
84 |
+
82,/m/02yds9,"Purr"
|
85 |
+
83,/m/07qrkrw,"Meow"
|
86 |
+
84,/m/07rjwbb,"Hiss"
|
87 |
+
85,/m/07r81j2,"Caterwaul"
|
88 |
+
86,/m/0ch8v,"Livestock, farm animals, working animals"
|
89 |
+
87,/m/03k3r,"Horse"
|
90 |
+
88,/m/07rv9rh,"Clip-clop"
|
91 |
+
89,/m/07q5rw0,"Neigh, whinny"
|
92 |
+
90,/m/01xq0k1,"Cattle, bovinae"
|
93 |
+
91,/m/07rpkh9,"Moo"
|
94 |
+
92,/m/0239kh,"Cowbell"
|
95 |
+
93,/m/068zj,"Pig"
|
96 |
+
94,/t/dd00018,"Oink"
|
97 |
+
95,/m/03fwl,"Goat"
|
98 |
+
96,/m/07q0h5t,"Bleat"
|
99 |
+
97,/m/07bgp,"Sheep"
|
100 |
+
98,/m/025rv6n,"Fowl"
|
101 |
+
99,/m/09b5t,"Chicken, rooster"
|
102 |
+
100,/m/07st89h,"Cluck"
|
103 |
+
101,/m/07qn5dc,"Crowing, cock-a-doodle-doo"
|
104 |
+
102,/m/01rd7k,"Turkey"
|
105 |
+
103,/m/07svc2k,"Gobble"
|
106 |
+
104,/m/09ddx,"Duck"
|
107 |
+
105,/m/07qdb04,"Quack"
|
108 |
+
106,/m/0dbvp,"Goose"
|
109 |
+
107,/m/07qwf61,"Honk"
|
110 |
+
108,/m/01280g,"Wild animals"
|
111 |
+
109,/m/0cdnk,"Roaring cats (lions, tigers)"
|
112 |
+
110,/m/04cvmfc,"Roar"
|
113 |
+
111,/m/015p6,"Bird"
|
114 |
+
112,/m/020bb7,"Bird vocalization, bird call, bird song"
|
115 |
+
113,/m/07pggtn,"Chirp, tweet"
|
116 |
+
114,/m/07sx8x_,"Squawk"
|
117 |
+
115,/m/0h0rv,"Pigeon, dove"
|
118 |
+
116,/m/07r_25d,"Coo"
|
119 |
+
117,/m/04s8yn,"Crow"
|
120 |
+
118,/m/07r5c2p,"Caw"
|
121 |
+
119,/m/09d5_,"Owl"
|
122 |
+
120,/m/07r_80w,"Hoot"
|
123 |
+
121,/m/05_wcq,"Bird flight, flapping wings"
|
124 |
+
122,/m/01z5f,"Canidae, dogs, wolves"
|
125 |
+
123,/m/06hps,"Rodents, rats, mice"
|
126 |
+
124,/m/04rmv,"Mouse"
|
127 |
+
125,/m/07r4gkf,"Patter"
|
128 |
+
126,/m/03vt0,"Insect"
|
129 |
+
127,/m/09xqv,"Cricket"
|
130 |
+
128,/m/09f96,"Mosquito"
|
131 |
+
129,/m/0h2mp,"Fly, housefly"
|
132 |
+
130,/m/07pjwq1,"Buzz"
|
133 |
+
131,/m/01h3n,"Bee, wasp, etc."
|
134 |
+
132,/m/09ld4,"Frog"
|
135 |
+
133,/m/07st88b,"Croak"
|
136 |
+
134,/m/078jl,"Snake"
|
137 |
+
135,/m/07qn4z3,"Rattle"
|
138 |
+
136,/m/032n05,"Whale vocalization"
|
139 |
+
137,/m/04rlf,"Music"
|
140 |
+
138,/m/04szw,"Musical instrument"
|
141 |
+
139,/m/0fx80y,"Plucked string instrument"
|
142 |
+
140,/m/0342h,"Guitar"
|
143 |
+
141,/m/02sgy,"Electric guitar"
|
144 |
+
142,/m/018vs,"Bass guitar"
|
145 |
+
143,/m/042v_gx,"Acoustic guitar"
|
146 |
+
144,/m/06w87,"Steel guitar, slide guitar"
|
147 |
+
145,/m/01glhc,"Tapping (guitar technique)"
|
148 |
+
146,/m/07s0s5r,"Strum"
|
149 |
+
147,/m/018j2,"Banjo"
|
150 |
+
148,/m/0jtg0,"Sitar"
|
151 |
+
149,/m/04rzd,"Mandolin"
|
152 |
+
150,/m/01bns_,"Zither"
|
153 |
+
151,/m/07xzm,"Ukulele"
|
154 |
+
152,/m/05148p4,"Keyboard (musical)"
|
155 |
+
153,/m/05r5c,"Piano"
|
156 |
+
154,/m/01s0ps,"Electric piano"
|
157 |
+
155,/m/013y1f,"Organ"
|
158 |
+
156,/m/03xq_f,"Electronic organ"
|
159 |
+
157,/m/03gvt,"Hammond organ"
|
160 |
+
158,/m/0l14qv,"Synthesizer"
|
161 |
+
159,/m/01v1d8,"Sampler"
|
162 |
+
160,/m/03q5t,"Harpsichord"
|
163 |
+
161,/m/0l14md,"Percussion"
|
164 |
+
162,/m/02hnl,"Drum kit"
|
165 |
+
163,/m/0cfdd,"Drum machine"
|
166 |
+
164,/m/026t6,"Drum"
|
167 |
+
165,/m/06rvn,"Snare drum"
|
168 |
+
166,/m/03t3fj,"Rimshot"
|
169 |
+
167,/m/02k_mr,"Drum roll"
|
170 |
+
168,/m/0bm02,"Bass drum"
|
171 |
+
169,/m/011k_j,"Timpani"
|
172 |
+
170,/m/01p970,"Tabla"
|
173 |
+
171,/m/01qbl,"Cymbal"
|
174 |
+
172,/m/03qtq,"Hi-hat"
|
175 |
+
173,/m/01sm1g,"Wood block"
|
176 |
+
174,/m/07brj,"Tambourine"
|
177 |
+
175,/m/05r5wn,"Rattle (instrument)"
|
178 |
+
176,/m/0xzly,"Maraca"
|
179 |
+
177,/m/0mbct,"Gong"
|
180 |
+
178,/m/016622,"Tubular bells"
|
181 |
+
179,/m/0j45pbj,"Mallet percussion"
|
182 |
+
180,/m/0dwsp,"Marimba, xylophone"
|
183 |
+
181,/m/0dwtp,"Glockenspiel"
|
184 |
+
182,/m/0dwt5,"Vibraphone"
|
185 |
+
183,/m/0l156b,"Steelpan"
|
186 |
+
184,/m/05pd6,"Orchestra"
|
187 |
+
185,/m/01kcd,"Brass instrument"
|
188 |
+
186,/m/0319l,"French horn"
|
189 |
+
187,/m/07gql,"Trumpet"
|
190 |
+
188,/m/07c6l,"Trombone"
|
191 |
+
189,/m/0l14_3,"Bowed string instrument"
|
192 |
+
190,/m/02qmj0d,"String section"
|
193 |
+
191,/m/07y_7,"Violin, fiddle"
|
194 |
+
192,/m/0d8_n,"Pizzicato"
|
195 |
+
193,/m/01xqw,"Cello"
|
196 |
+
194,/m/02fsn,"Double bass"
|
197 |
+
195,/m/085jw,"Wind instrument, woodwind instrument"
|
198 |
+
196,/m/0l14j_,"Flute"
|
199 |
+
197,/m/06ncr,"Saxophone"
|
200 |
+
198,/m/01wy6,"Clarinet"
|
201 |
+
199,/m/03m5k,"Harp"
|
202 |
+
200,/m/0395lw,"Bell"
|
203 |
+
201,/m/03w41f,"Church bell"
|
204 |
+
202,/m/027m70_,"Jingle bell"
|
205 |
+
203,/m/0gy1t2s,"Bicycle bell"
|
206 |
+
204,/m/07n_g,"Tuning fork"
|
207 |
+
205,/m/0f8s22,"Chime"
|
208 |
+
206,/m/026fgl,"Wind chime"
|
209 |
+
207,/m/0150b9,"Change ringing (campanology)"
|
210 |
+
208,/m/03qjg,"Harmonica"
|
211 |
+
209,/m/0mkg,"Accordion"
|
212 |
+
210,/m/0192l,"Bagpipes"
|
213 |
+
211,/m/02bxd,"Didgeridoo"
|
214 |
+
212,/m/0l14l2,"Shofar"
|
215 |
+
213,/m/07kc_,"Theremin"
|
216 |
+
214,/m/0l14t7,"Singing bowl"
|
217 |
+
215,/m/01hgjl,"Scratching (performance technique)"
|
218 |
+
216,/m/064t9,"Pop music"
|
219 |
+
217,/m/0glt670,"Hip hop music"
|
220 |
+
218,/m/02cz_7,"Beatboxing"
|
221 |
+
219,/m/06by7,"Rock music"
|
222 |
+
220,/m/03lty,"Heavy metal"
|
223 |
+
221,/m/05r6t,"Punk rock"
|
224 |
+
222,/m/0dls3,"Grunge"
|
225 |
+
223,/m/0dl5d,"Progressive rock"
|
226 |
+
224,/m/07sbbz2,"Rock and roll"
|
227 |
+
225,/m/05w3f,"Psychedelic rock"
|
228 |
+
226,/m/06j6l,"Rhythm and blues"
|
229 |
+
227,/m/0gywn,"Soul music"
|
230 |
+
228,/m/06cqb,"Reggae"
|
231 |
+
229,/m/01lyv,"Country"
|
232 |
+
230,/m/015y_n,"Swing music"
|
233 |
+
231,/m/0gg8l,"Bluegrass"
|
234 |
+
232,/m/02x8m,"Funk"
|
235 |
+
233,/m/02w4v,"Folk music"
|
236 |
+
234,/m/06j64v,"Middle Eastern music"
|
237 |
+
235,/m/03_d0,"Jazz"
|
238 |
+
236,/m/026z9,"Disco"
|
239 |
+
237,/m/0ggq0m,"Classical music"
|
240 |
+
238,/m/05lls,"Opera"
|
241 |
+
239,/m/02lkt,"Electronic music"
|
242 |
+
240,/m/03mb9,"House music"
|
243 |
+
241,/m/07gxw,"Techno"
|
244 |
+
242,/m/07s72n,"Dubstep"
|
245 |
+
243,/m/0283d,"Drum and bass"
|
246 |
+
244,/m/0m0jc,"Electronica"
|
247 |
+
245,/m/08cyft,"Electronic dance music"
|
248 |
+
246,/m/0fd3y,"Ambient music"
|
249 |
+
247,/m/07lnk,"Trance music"
|
250 |
+
248,/m/0g293,"Music of Latin America"
|
251 |
+
249,/m/0ln16,"Salsa music"
|
252 |
+
250,/m/0326g,"Flamenco"
|
253 |
+
251,/m/0155w,"Blues"
|
254 |
+
252,/m/05fw6t,"Music for children"
|
255 |
+
253,/m/02v2lh,"New-age music"
|
256 |
+
254,/m/0y4f8,"Vocal music"
|
257 |
+
255,/m/0z9c,"A capella"
|
258 |
+
256,/m/0164x2,"Music of Africa"
|
259 |
+
257,/m/0145m,"Afrobeat"
|
260 |
+
258,/m/02mscn,"Christian music"
|
261 |
+
259,/m/016cjb,"Gospel music"
|
262 |
+
260,/m/028sqc,"Music of Asia"
|
263 |
+
261,/m/015vgc,"Carnatic music"
|
264 |
+
262,/m/0dq0md,"Music of Bollywood"
|
265 |
+
263,/m/06rqw,"Ska"
|
266 |
+
264,/m/02p0sh1,"Traditional music"
|
267 |
+
265,/m/05rwpb,"Independent music"
|
268 |
+
266,/m/074ft,"Song"
|
269 |
+
267,/m/025td0t,"Background music"
|
270 |
+
268,/m/02cjck,"Theme music"
|
271 |
+
269,/m/03r5q_,"Jingle (music)"
|
272 |
+
270,/m/0l14gg,"Soundtrack music"
|
273 |
+
271,/m/07pkxdp,"Lullaby"
|
274 |
+
272,/m/01z7dr,"Video game music"
|
275 |
+
273,/m/0140xf,"Christmas music"
|
276 |
+
274,/m/0ggx5q,"Dance music"
|
277 |
+
275,/m/04wptg,"Wedding music"
|
278 |
+
276,/t/dd00031,"Happy music"
|
279 |
+
277,/t/dd00032,"Funny music"
|
280 |
+
278,/t/dd00033,"Sad music"
|
281 |
+
279,/t/dd00034,"Tender music"
|
282 |
+
280,/t/dd00035,"Exciting music"
|
283 |
+
281,/t/dd00036,"Angry music"
|
284 |
+
282,/t/dd00037,"Scary music"
|
285 |
+
283,/m/03m9d0z,"Wind"
|
286 |
+
284,/m/09t49,"Rustling leaves"
|
287 |
+
285,/t/dd00092,"Wind noise (microphone)"
|
288 |
+
286,/m/0jb2l,"Thunderstorm"
|
289 |
+
287,/m/0ngt1,"Thunder"
|
290 |
+
288,/m/0838f,"Water"
|
291 |
+
289,/m/06mb1,"Rain"
|
292 |
+
290,/m/07r10fb,"Raindrop"
|
293 |
+
291,/t/dd00038,"Rain on surface"
|
294 |
+
292,/m/0j6m2,"Stream"
|
295 |
+
293,/m/0j2kx,"Waterfall"
|
296 |
+
294,/m/05kq4,"Ocean"
|
297 |
+
295,/m/034srq,"Waves, surf"
|
298 |
+
296,/m/06wzb,"Steam"
|
299 |
+
297,/m/07swgks,"Gurgling"
|
300 |
+
298,/m/02_41,"Fire"
|
301 |
+
299,/m/07pzfmf,"Crackle"
|
302 |
+
300,/m/07yv9,"Vehicle"
|
303 |
+
301,/m/019jd,"Boat, Water vehicle"
|
304 |
+
302,/m/0hsrw,"Sailboat, sailing ship"
|
305 |
+
303,/m/056ks2,"Rowboat, canoe, kayak"
|
306 |
+
304,/m/02rlv9,"Motorboat, speedboat"
|
307 |
+
305,/m/06q74,"Ship"
|
308 |
+
306,/m/012f08,"Motor vehicle (road)"
|
309 |
+
307,/m/0k4j,"Car"
|
310 |
+
308,/m/0912c9,"Vehicle horn, car horn, honking"
|
311 |
+
309,/m/07qv_d5,"Toot"
|
312 |
+
310,/m/02mfyn,"Car alarm"
|
313 |
+
311,/m/04gxbd,"Power windows, electric windows"
|
314 |
+
312,/m/07rknqz,"Skidding"
|
315 |
+
313,/m/0h9mv,"Tire squeal"
|
316 |
+
314,/t/dd00134,"Car passing by"
|
317 |
+
315,/m/0ltv,"Race car, auto racing"
|
318 |
+
316,/m/07r04,"Truck"
|
319 |
+
317,/m/0gvgw0,"Air brake"
|
320 |
+
318,/m/05x_td,"Air horn, truck horn"
|
321 |
+
319,/m/02rhddq,"Reversing beeps"
|
322 |
+
320,/m/03cl9h,"Ice cream truck, ice cream van"
|
323 |
+
321,/m/01bjv,"Bus"
|
324 |
+
322,/m/03j1ly,"Emergency vehicle"
|
325 |
+
323,/m/04qvtq,"Police car (siren)"
|
326 |
+
324,/m/012n7d,"Ambulance (siren)"
|
327 |
+
325,/m/012ndj,"Fire engine, fire truck (siren)"
|
328 |
+
326,/m/04_sv,"Motorcycle"
|
329 |
+
327,/m/0btp2,"Traffic noise, roadway noise"
|
330 |
+
328,/m/06d_3,"Rail transport"
|
331 |
+
329,/m/07jdr,"Train"
|
332 |
+
330,/m/04zmvq,"Train whistle"
|
333 |
+
331,/m/0284vy3,"Train horn"
|
334 |
+
332,/m/01g50p,"Railroad car, train wagon"
|
335 |
+
333,/t/dd00048,"Train wheels squealing"
|
336 |
+
334,/m/0195fx,"Subway, metro, underground"
|
337 |
+
335,/m/0k5j,"Aircraft"
|
338 |
+
336,/m/014yck,"Aircraft engine"
|
339 |
+
337,/m/04229,"Jet engine"
|
340 |
+
338,/m/02l6bg,"Propeller, airscrew"
|
341 |
+
339,/m/09ct_,"Helicopter"
|
342 |
+
340,/m/0cmf2,"Fixed-wing aircraft, airplane"
|
343 |
+
341,/m/0199g,"Bicycle"
|
344 |
+
342,/m/06_fw,"Skateboard"
|
345 |
+
343,/m/02mk9,"Engine"
|
346 |
+
344,/t/dd00065,"Light engine (high frequency)"
|
347 |
+
345,/m/08j51y,"Dental drill, dentist's drill"
|
348 |
+
346,/m/01yg9g,"Lawn mower"
|
349 |
+
347,/m/01j4z9,"Chainsaw"
|
350 |
+
348,/t/dd00066,"Medium engine (mid frequency)"
|
351 |
+
349,/t/dd00067,"Heavy engine (low frequency)"
|
352 |
+
350,/m/01h82_,"Engine knocking"
|
353 |
+
351,/t/dd00130,"Engine starting"
|
354 |
+
352,/m/07pb8fc,"Idling"
|
355 |
+
353,/m/07q2z82,"Accelerating, revving, vroom"
|
356 |
+
354,/m/02dgv,"Door"
|
357 |
+
355,/m/03wwcy,"Doorbell"
|
358 |
+
356,/m/07r67yg,"Ding-dong"
|
359 |
+
357,/m/02y_763,"Sliding door"
|
360 |
+
358,/m/07rjzl8,"Slam"
|
361 |
+
359,/m/07r4wb8,"Knock"
|
362 |
+
360,/m/07qcpgn,"Tap"
|
363 |
+
361,/m/07q6cd_,"Squeak"
|
364 |
+
362,/m/0642b4,"Cupboard open or close"
|
365 |
+
363,/m/0fqfqc,"Drawer open or close"
|
366 |
+
364,/m/04brg2,"Dishes, pots, and pans"
|
367 |
+
365,/m/023pjk,"Cutlery, silverware"
|
368 |
+
366,/m/07pn_8q,"Chopping (food)"
|
369 |
+
367,/m/0dxrf,"Frying (food)"
|
370 |
+
368,/m/0fx9l,"Microwave oven"
|
371 |
+
369,/m/02pjr4,"Blender"
|
372 |
+
370,/m/02jz0l,"Water tap, faucet"
|
373 |
+
371,/m/0130jx,"Sink (filling or washing)"
|
374 |
+
372,/m/03dnzn,"Bathtub (filling or washing)"
|
375 |
+
373,/m/03wvsk,"Hair dryer"
|
376 |
+
374,/m/01jt3m,"Toilet flush"
|
377 |
+
375,/m/012xff,"Toothbrush"
|
378 |
+
376,/m/04fgwm,"Electric toothbrush"
|
379 |
+
377,/m/0d31p,"Vacuum cleaner"
|
380 |
+
378,/m/01s0vc,"Zipper (clothing)"
|
381 |
+
379,/m/03v3yw,"Keys jangling"
|
382 |
+
380,/m/0242l,"Coin (dropping)"
|
383 |
+
381,/m/01lsmm,"Scissors"
|
384 |
+
382,/m/02g901,"Electric shaver, electric razor"
|
385 |
+
383,/m/05rj2,"Shuffling cards"
|
386 |
+
384,/m/0316dw,"Typing"
|
387 |
+
385,/m/0c2wf,"Typewriter"
|
388 |
+
386,/m/01m2v,"Computer keyboard"
|
389 |
+
387,/m/081rb,"Writing"
|
390 |
+
388,/m/07pp_mv,"Alarm"
|
391 |
+
389,/m/07cx4,"Telephone"
|
392 |
+
390,/m/07pp8cl,"Telephone bell ringing"
|
393 |
+
391,/m/01hnzm,"Ringtone"
|
394 |
+
392,/m/02c8p,"Telephone dialing, DTMF"
|
395 |
+
393,/m/015jpf,"Dial tone"
|
396 |
+
394,/m/01z47d,"Busy signal"
|
397 |
+
395,/m/046dlr,"Alarm clock"
|
398 |
+
396,/m/03kmc9,"Siren"
|
399 |
+
397,/m/0dgbq,"Civil defense siren"
|
400 |
+
398,/m/030rvx,"Buzzer"
|
401 |
+
399,/m/01y3hg,"Smoke detector, smoke alarm"
|
402 |
+
400,/m/0c3f7m,"Fire alarm"
|
403 |
+
401,/m/04fq5q,"Foghorn"
|
404 |
+
402,/m/0l156k,"Whistle"
|
405 |
+
403,/m/06hck5,"Steam whistle"
|
406 |
+
404,/t/dd00077,"Mechanisms"
|
407 |
+
405,/m/02bm9n,"Ratchet, pawl"
|
408 |
+
406,/m/01x3z,"Clock"
|
409 |
+
407,/m/07qjznt,"Tick"
|
410 |
+
408,/m/07qjznl,"Tick-tock"
|
411 |
+
409,/m/0l7xg,"Gears"
|
412 |
+
410,/m/05zc1,"Pulleys"
|
413 |
+
411,/m/0llzx,"Sewing machine"
|
414 |
+
412,/m/02x984l,"Mechanical fan"
|
415 |
+
413,/m/025wky1,"Air conditioning"
|
416 |
+
414,/m/024dl,"Cash register"
|
417 |
+
415,/m/01m4t,"Printer"
|
418 |
+
416,/m/0dv5r,"Camera"
|
419 |
+
417,/m/07bjf,"Single-lens reflex camera"
|
420 |
+
418,/m/07k1x,"Tools"
|
421 |
+
419,/m/03l9g,"Hammer"
|
422 |
+
420,/m/03p19w,"Jackhammer"
|
423 |
+
421,/m/01b82r,"Sawing"
|
424 |
+
422,/m/02p01q,"Filing (rasp)"
|
425 |
+
423,/m/023vsd,"Sanding"
|
426 |
+
424,/m/0_ksk,"Power tool"
|
427 |
+
425,/m/01d380,"Drill"
|
428 |
+
426,/m/014zdl,"Explosion"
|
429 |
+
427,/m/032s66,"Gunshot, gunfire"
|
430 |
+
428,/m/04zjc,"Machine gun"
|
431 |
+
429,/m/02z32qm,"Fusillade"
|
432 |
+
430,/m/0_1c,"Artillery fire"
|
433 |
+
431,/m/073cg4,"Cap gun"
|
434 |
+
432,/m/0g6b5,"Fireworks"
|
435 |
+
433,/g/122z_qxw,"Firecracker"
|
436 |
+
434,/m/07qsvvw,"Burst, pop"
|
437 |
+
435,/m/07pxg6y,"Eruption"
|
438 |
+
436,/m/07qqyl4,"Boom"
|
439 |
+
437,/m/083vt,"Wood"
|
440 |
+
438,/m/07pczhz,"Chop"
|
441 |
+
439,/m/07pl1bw,"Splinter"
|
442 |
+
440,/m/07qs1cx,"Crack"
|
443 |
+
441,/m/039jq,"Glass"
|
444 |
+
442,/m/07q7njn,"Chink, clink"
|
445 |
+
443,/m/07rn7sz,"Shatter"
|
446 |
+
444,/m/04k94,"Liquid"
|
447 |
+
445,/m/07rrlb6,"Splash, splatter"
|
448 |
+
446,/m/07p6mqd,"Slosh"
|
449 |
+
447,/m/07qlwh6,"Squish"
|
450 |
+
448,/m/07r5v4s,"Drip"
|
451 |
+
449,/m/07prgkl,"Pour"
|
452 |
+
450,/m/07pqc89,"Trickle, dribble"
|
453 |
+
451,/t/dd00088,"Gush"
|
454 |
+
452,/m/07p7b8y,"Fill (with liquid)"
|
455 |
+
453,/m/07qlf79,"Spray"
|
456 |
+
454,/m/07ptzwd,"Pump (liquid)"
|
457 |
+
455,/m/07ptfmf,"Stir"
|
458 |
+
456,/m/0dv3j,"Boiling"
|
459 |
+
457,/m/0790c,"Sonar"
|
460 |
+
458,/m/0dl83,"Arrow"
|
461 |
+
459,/m/07rqsjt,"Whoosh, swoosh, swish"
|
462 |
+
460,/m/07qnq_y,"Thump, thud"
|
463 |
+
461,/m/07rrh0c,"Thunk"
|
464 |
+
462,/m/0b_fwt,"Electronic tuner"
|
465 |
+
463,/m/02rr_,"Effects unit"
|
466 |
+
464,/m/07m2kt,"Chorus effect"
|
467 |
+
465,/m/018w8,"Basketball bounce"
|
468 |
+
466,/m/07pws3f,"Bang"
|
469 |
+
467,/m/07ryjzk,"Slap, smack"
|
470 |
+
468,/m/07rdhzs,"Whack, thwack"
|
471 |
+
469,/m/07pjjrj,"Smash, crash"
|
472 |
+
470,/m/07pc8lb,"Breaking"
|
473 |
+
471,/m/07pqn27,"Bouncing"
|
474 |
+
472,/m/07rbp7_,"Whip"
|
475 |
+
473,/m/07pyf11,"Flap"
|
476 |
+
474,/m/07qb_dv,"Scratch"
|
477 |
+
475,/m/07qv4k0,"Scrape"
|
478 |
+
476,/m/07pdjhy,"Rub"
|
479 |
+
477,/m/07s8j8t,"Roll"
|
480 |
+
478,/m/07plct2,"Crushing"
|
481 |
+
479,/t/dd00112,"Crumpling, crinkling"
|
482 |
+
480,/m/07qcx4z,"Tearing"
|
483 |
+
481,/m/02fs_r,"Beep, bleep"
|
484 |
+
482,/m/07qwdck,"Ping"
|
485 |
+
483,/m/07phxs1,"Ding"
|
486 |
+
484,/m/07rv4dm,"Clang"
|
487 |
+
485,/m/07s02z0,"Squeal"
|
488 |
+
486,/m/07qh7jl,"Creak"
|
489 |
+
487,/m/07qwyj0,"Rustle"
|
490 |
+
488,/m/07s34ls,"Whir"
|
491 |
+
489,/m/07qmpdm,"Clatter"
|
492 |
+
490,/m/07p9k1k,"Sizzle"
|
493 |
+
491,/m/07qc9xj,"Clicking"
|
494 |
+
492,/m/07rwm0c,"Clickety-clack"
|
495 |
+
493,/m/07phhsh,"Rumble"
|
496 |
+
494,/m/07qyrcz,"Plop"
|
497 |
+
495,/m/07qfgpx,"Jingle, tinkle"
|
498 |
+
496,/m/07rcgpl,"Hum"
|
499 |
+
497,/m/07p78v5,"Zing"
|
500 |
+
498,/t/dd00121,"Boing"
|
501 |
+
499,/m/07s12q4,"Crunch"
|
502 |
+
500,/m/028v0c,"Silence"
|
503 |
+
501,/m/01v_m0,"Sine wave"
|
504 |
+
502,/m/0b9m1,"Harmonic"
|
505 |
+
503,/m/0hdsk,"Chirp tone"
|
506 |
+
504,/m/0c1dj,"Sound effect"
|
507 |
+
505,/m/07pt_g0,"Pulse"
|
508 |
+
506,/t/dd00125,"Inside, small room"
|
509 |
+
507,/t/dd00126,"Inside, large room or hall"
|
510 |
+
508,/t/dd00127,"Inside, public space"
|
511 |
+
509,/t/dd00128,"Outside, urban or manmade"
|
512 |
+
510,/t/dd00129,"Outside, rural or natural"
|
513 |
+
511,/m/01b9nn,"Reverberation"
|
514 |
+
512,/m/01jnbd,"Echo"
|
515 |
+
513,/m/096m7z,"Noise"
|
516 |
+
514,/m/06_y0by,"Environmental noise"
|
517 |
+
515,/m/07rgkc5,"Static"
|
518 |
+
516,/m/06xkwv,"Mains hum"
|
519 |
+
517,/m/0g12c5,"Distortion"
|
520 |
+
518,/m/08p9q4,"Sidetone"
|
521 |
+
519,/m/07szfh9,"Cacophony"
|
522 |
+
520,/m/0chx_,"White noise"
|
523 |
+
521,/m/0cj0r,"Pink noise"
|
524 |
+
522,/m/07p_0gm,"Throbbing"
|
525 |
+
523,/m/01jwx6,"Vibration"
|
526 |
+
524,/m/07c52,"Television"
|
527 |
+
525,/m/06bz3,"Radio"
|
528 |
+
526,/m/07hvw1,"Field recording"
|
evaluation/metadata/clotho_eval.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluation/metadata/esc50_eval.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluation/metadata/music_eval.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluation/metadata/vggsound_eval.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
losses.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def l1(output, target):
|
5 |
+
return torch.mean(torch.abs(output - target))
|
6 |
+
|
7 |
+
|
8 |
+
def l1_wav(output_dict, target_dict):
|
9 |
+
return l1(output_dict['segment'], target_dict['segment'])
|
10 |
+
|
11 |
+
|
12 |
+
def get_loss_function(loss_type):
|
13 |
+
if loss_type == "l1_wav":
|
14 |
+
return l1_wav
|
15 |
+
|
16 |
+
else:
|
17 |
+
raise NotImplementedError("Error!")
|
models/CLAP/__init__.py
ADDED
File without changes
|
models/CLAP/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (158 Bytes). View file
|
|
models/CLAP/open_clip/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .factory import (
|
2 |
+
list_models,
|
3 |
+
create_model,
|
4 |
+
create_model_and_transforms,
|
5 |
+
add_model_config,
|
6 |
+
)
|
7 |
+
from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
|
8 |
+
from .model import (
|
9 |
+
CLAP,
|
10 |
+
CLAPTextCfg,
|
11 |
+
CLAPVisionCfg,
|
12 |
+
CLAPAudioCfp,
|
13 |
+
convert_weights_to_fp16,
|
14 |
+
trace_model,
|
15 |
+
)
|
16 |
+
from .openai import load_openai_model, list_openai_models
|
17 |
+
from .pretrained import (
|
18 |
+
list_pretrained,
|
19 |
+
list_pretrained_tag_models,
|
20 |
+
list_pretrained_model_tags,
|
21 |
+
get_pretrained_url,
|
22 |
+
download_pretrained,
|
23 |
+
)
|
24 |
+
from .tokenizer import SimpleTokenizer, tokenize
|
25 |
+
from .transform import image_transform
|
models/CLAP/open_clip/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (964 Bytes). View file
|
|
models/CLAP/open_clip/__pycache__/factory.cpython-310.pyc
ADDED
Binary file (6.65 kB). View file
|
|
models/CLAP/open_clip/__pycache__/feature_fusion.cpython-310.pyc
ADDED
Binary file (4.12 kB). View file
|
|
models/CLAP/open_clip/__pycache__/htsat.cpython-310.pyc
ADDED
Binary file (30.8 kB). View file
|
|
models/CLAP/open_clip/__pycache__/loss.cpython-310.pyc
ADDED
Binary file (7.97 kB). View file
|
|
models/CLAP/open_clip/__pycache__/model.cpython-310.pyc
ADDED
Binary file (24.1 kB). View file
|
|
models/CLAP/open_clip/__pycache__/openai.cpython-310.pyc
ADDED
Binary file (4.52 kB). View file
|
|
models/CLAP/open_clip/__pycache__/pann_model.cpython-310.pyc
ADDED
Binary file (13.1 kB). View file
|
|
models/CLAP/open_clip/__pycache__/pretrained.cpython-310.pyc
ADDED
Binary file (5.04 kB). View file
|
|
models/CLAP/open_clip/__pycache__/timm_model.cpython-310.pyc
ADDED
Binary file (3.44 kB). View file
|
|
models/CLAP/open_clip/__pycache__/tokenizer.cpython-310.pyc
ADDED
Binary file (7.36 kB). View file
|
|
models/CLAP/open_clip/__pycache__/transform.cpython-310.pyc
ADDED
Binary file (982 Bytes). View file
|
|
models/CLAP/open_clip/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (10.5 kB). View file
|
|
models/CLAP/open_clip/bert.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertTokenizer, BertModel
|
2 |
+
|
3 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
4 |
+
model = BertModel.from_pretrained("bert-base-uncased")
|
5 |
+
text = "Replace me by any text you'd like."
|
6 |
+
|
7 |
+
|
8 |
+
def bert_embeddings(text):
|
9 |
+
# text = "Replace me by any text you'd like."
|
10 |
+
encoded_input = tokenizer(text, return_tensors="pt")
|
11 |
+
output = model(**encoded_input)
|
12 |
+
return output
|
13 |
+
|
14 |
+
|
15 |
+
from transformers import RobertaTokenizer, RobertaModel
|
16 |
+
|
17 |
+
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
18 |
+
model = RobertaModel.from_pretrained("roberta-base")
|
19 |
+
text = "Replace me by any text you'd like."
|
20 |
+
|
21 |
+
|
22 |
+
def Roberta_embeddings(text):
|
23 |
+
# text = "Replace me by any text you'd like."
|
24 |
+
encoded_input = tokenizer(text, return_tensors="pt")
|
25 |
+
output = model(**encoded_input)
|
26 |
+
return output
|
27 |
+
|
28 |
+
|
29 |
+
from transformers import BartTokenizer, BartModel
|
30 |
+
|
31 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
32 |
+
model = BartModel.from_pretrained("facebook/bart-base")
|
33 |
+
text = "Replace me by any text you'd like."
|
34 |
+
|
35 |
+
|
36 |
+
def bart_embeddings(text):
|
37 |
+
# text = "Replace me by any text you'd like."
|
38 |
+
encoded_input = tokenizer(text, return_tensors="pt")
|
39 |
+
output = model(**encoded_input)
|
40 |
+
return output
|
models/CLAP/open_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
models/CLAP/open_clip/factory.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
import re
|
6 |
+
from copy import deepcopy
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .model import CLAP, convert_weights_to_fp16
|
12 |
+
from .openai import load_openai_model
|
13 |
+
from .pretrained import get_pretrained_url, download_pretrained
|
14 |
+
from .transform import image_transform
|
15 |
+
|
16 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
17 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
18 |
+
|
19 |
+
|
20 |
+
def _natural_key(string_):
|
21 |
+
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
|
22 |
+
|
23 |
+
|
24 |
+
def _rescan_model_configs():
|
25 |
+
global _MODEL_CONFIGS
|
26 |
+
|
27 |
+
config_ext = (".json",)
|
28 |
+
config_files = []
|
29 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
30 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
31 |
+
config_files.append(config_path)
|
32 |
+
elif config_path.is_dir():
|
33 |
+
for ext in config_ext:
|
34 |
+
config_files.extend(config_path.glob(f"*{ext}"))
|
35 |
+
|
36 |
+
for cf in config_files:
|
37 |
+
if os.path.basename(cf)[0] == ".":
|
38 |
+
continue # Ignore hidden files
|
39 |
+
|
40 |
+
with open(cf, "r") as f:
|
41 |
+
model_cfg = json.load(f)
|
42 |
+
if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
|
43 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
44 |
+
|
45 |
+
_MODEL_CONFIGS = {
|
46 |
+
k: v
|
47 |
+
for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
_rescan_model_configs() # initial populate of model config registry
|
52 |
+
|
53 |
+
|
54 |
+
def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
|
55 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
56 |
+
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
57 |
+
state_dict = checkpoint["state_dict"]
|
58 |
+
else:
|
59 |
+
state_dict = checkpoint
|
60 |
+
if skip_params:
|
61 |
+
if next(iter(state_dict.items()))[0].startswith("module"):
|
62 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
63 |
+
# for k in state_dict:
|
64 |
+
# if k.startswith('transformer'):
|
65 |
+
# v = state_dict.pop(k)
|
66 |
+
# state_dict['text_branch.' + k[12:]] = v
|
67 |
+
return state_dict
|
68 |
+
|
69 |
+
|
70 |
+
def create_model(
|
71 |
+
amodel_name: str,
|
72 |
+
tmodel_name: str,
|
73 |
+
pretrained: str = "",
|
74 |
+
precision: str = "fp32",
|
75 |
+
device: torch.device = torch.device("cpu"),
|
76 |
+
jit: bool = False,
|
77 |
+
force_quick_gelu: bool = False,
|
78 |
+
openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
|
79 |
+
skip_params=True,
|
80 |
+
pretrained_audio: str = "",
|
81 |
+
pretrained_text: str = "",
|
82 |
+
enable_fusion: bool = False,
|
83 |
+
fusion_type: str = "None"
|
84 |
+
# pretrained_image: bool = False,
|
85 |
+
):
|
86 |
+
amodel_name = amodel_name.replace(
|
87 |
+
"/", "-"
|
88 |
+
) # for callers using old naming with / in ViT names
|
89 |
+
pretrained_orig = pretrained
|
90 |
+
pretrained = pretrained.lower()
|
91 |
+
if pretrained == "openai":
|
92 |
+
if amodel_name in _MODEL_CONFIGS:
|
93 |
+
logging.info(f"Loading {amodel_name} model config.")
|
94 |
+
model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
|
95 |
+
else:
|
96 |
+
logging.error(
|
97 |
+
f"Model config for {amodel_name} not found; available models {list_models()}."
|
98 |
+
)
|
99 |
+
raise RuntimeError(f"Model config for {amodel_name} not found.")
|
100 |
+
|
101 |
+
logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
|
102 |
+
# Hard Code in model name
|
103 |
+
model_cfg["text_cfg"]["model_type"] = tmodel_name
|
104 |
+
model = load_openai_model(
|
105 |
+
"ViT-B-16",
|
106 |
+
model_cfg,
|
107 |
+
device=device,
|
108 |
+
jit=jit,
|
109 |
+
cache_dir=openai_model_cache_dir,
|
110 |
+
enable_fusion=enable_fusion,
|
111 |
+
fusion_type=fusion_type,
|
112 |
+
)
|
113 |
+
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
|
114 |
+
if precision == "amp" or precision == "fp32":
|
115 |
+
model = model.float()
|
116 |
+
else:
|
117 |
+
if amodel_name in _MODEL_CONFIGS:
|
118 |
+
logging.info(f"Loading {amodel_name} model config.")
|
119 |
+
model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
|
120 |
+
else:
|
121 |
+
logging.error(
|
122 |
+
f"Model config for {amodel_name} not found; available models {list_models()}."
|
123 |
+
)
|
124 |
+
raise RuntimeError(f"Model config for {amodel_name} not found.")
|
125 |
+
|
126 |
+
if force_quick_gelu:
|
127 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
128 |
+
model_cfg["quick_gelu"] = True
|
129 |
+
|
130 |
+
# if pretrained_image:
|
131 |
+
# if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
|
132 |
+
# # pretrained weight loading for timm models set via vision_cfg
|
133 |
+
# model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
134 |
+
# else:
|
135 |
+
# assert False, 'pretrained image towers currently only supported for timm models'
|
136 |
+
model_cfg["text_cfg"]["model_type"] = tmodel_name
|
137 |
+
model_cfg["enable_fusion"] = enable_fusion
|
138 |
+
model_cfg["fusion_type"] = fusion_type
|
139 |
+
model = CLAP(**model_cfg)
|
140 |
+
|
141 |
+
if pretrained:
|
142 |
+
checkpoint_path = ""
|
143 |
+
url = get_pretrained_url(amodel_name, pretrained)
|
144 |
+
if url:
|
145 |
+
checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
|
146 |
+
elif os.path.exists(pretrained_orig):
|
147 |
+
checkpoint_path = pretrained_orig
|
148 |
+
if checkpoint_path:
|
149 |
+
logging.info(
|
150 |
+
f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
|
151 |
+
)
|
152 |
+
ckpt = load_state_dict(checkpoint_path, skip_params=True)
|
153 |
+
model.load_state_dict(ckpt)
|
154 |
+
param_names = [n for n, p in model.named_parameters()]
|
155 |
+
# for n in param_names:
|
156 |
+
# print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
|
157 |
+
else:
|
158 |
+
logging.warning(
|
159 |
+
f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
|
160 |
+
)
|
161 |
+
raise RuntimeError(
|
162 |
+
f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
|
163 |
+
)
|
164 |
+
|
165 |
+
if pretrained_audio:
|
166 |
+
if amodel_name.startswith("PANN"):
|
167 |
+
if "Cnn14_mAP" in pretrained_audio: # official checkpoint
|
168 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
169 |
+
audio_ckpt = audio_ckpt["model"]
|
170 |
+
keys = list(audio_ckpt.keys())
|
171 |
+
for key in keys:
|
172 |
+
if (
|
173 |
+
"spectrogram_extractor" not in key
|
174 |
+
and "logmel_extractor" not in key
|
175 |
+
):
|
176 |
+
v = audio_ckpt.pop(key)
|
177 |
+
audio_ckpt["audio_branch." + key] = v
|
178 |
+
elif os.path.basename(pretrained_audio).startswith(
|
179 |
+
"PANN"
|
180 |
+
): # checkpoint trained via HTSAT codebase
|
181 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
182 |
+
audio_ckpt = audio_ckpt["state_dict"]
|
183 |
+
keys = list(audio_ckpt.keys())
|
184 |
+
for key in keys:
|
185 |
+
if key.startswith("sed_model"):
|
186 |
+
v = audio_ckpt.pop(key)
|
187 |
+
audio_ckpt["audio_branch." + key[10:]] = v
|
188 |
+
elif os.path.basename(pretrained_audio).startswith(
|
189 |
+
"finetuned"
|
190 |
+
): # checkpoint trained via linear probe codebase
|
191 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
192 |
+
else:
|
193 |
+
raise ValueError("Unknown audio checkpoint")
|
194 |
+
elif amodel_name.startswith("HTSAT"):
|
195 |
+
if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
|
196 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
197 |
+
audio_ckpt = audio_ckpt["state_dict"]
|
198 |
+
keys = list(audio_ckpt.keys())
|
199 |
+
for key in keys:
|
200 |
+
if key.startswith("sed_model") and (
|
201 |
+
"spectrogram_extractor" not in key
|
202 |
+
and "logmel_extractor" not in key
|
203 |
+
):
|
204 |
+
v = audio_ckpt.pop(key)
|
205 |
+
audio_ckpt["audio_branch." + key[10:]] = v
|
206 |
+
elif os.path.basename(pretrained_audio).startswith(
|
207 |
+
"HTSAT"
|
208 |
+
): # checkpoint trained via HTSAT codebase
|
209 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
210 |
+
audio_ckpt = audio_ckpt["state_dict"]
|
211 |
+
keys = list(audio_ckpt.keys())
|
212 |
+
for key in keys:
|
213 |
+
if key.startswith("sed_model"):
|
214 |
+
v = audio_ckpt.pop(key)
|
215 |
+
audio_ckpt["audio_branch." + key[10:]] = v
|
216 |
+
elif os.path.basename(pretrained_audio).startswith(
|
217 |
+
"finetuned"
|
218 |
+
): # checkpoint trained via linear probe codebase
|
219 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
220 |
+
else:
|
221 |
+
raise ValueError("Unknown audio checkpoint")
|
222 |
+
else:
|
223 |
+
raise f"this audio encoder pretrained checkpoint is not support"
|
224 |
+
|
225 |
+
model.load_state_dict(audio_ckpt, strict=False)
|
226 |
+
logging.info(
|
227 |
+
f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
|
228 |
+
)
|
229 |
+
param_names = [n for n, p in model.named_parameters()]
|
230 |
+
for n in param_names:
|
231 |
+
print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
|
232 |
+
|
233 |
+
model.to(device=device)
|
234 |
+
if precision == "fp16":
|
235 |
+
assert device.type != "cpu"
|
236 |
+
convert_weights_to_fp16(model)
|
237 |
+
|
238 |
+
if jit:
|
239 |
+
model = torch.jit.script(model)
|
240 |
+
|
241 |
+
return model, model_cfg
|
242 |
+
|
243 |
+
|
244 |
+
def create_model_and_transforms(
|
245 |
+
model_name: str,
|
246 |
+
pretrained: str = "",
|
247 |
+
precision: str = "fp32",
|
248 |
+
device: torch.device = torch.device("cpu"),
|
249 |
+
jit: bool = False,
|
250 |
+
force_quick_gelu: bool = False,
|
251 |
+
# pretrained_image: bool = False,
|
252 |
+
):
|
253 |
+
model = create_model(
|
254 |
+
model_name,
|
255 |
+
pretrained,
|
256 |
+
precision,
|
257 |
+
device,
|
258 |
+
jit,
|
259 |
+
force_quick_gelu=force_quick_gelu,
|
260 |
+
# pretrained_image=pretrained_image
|
261 |
+
)
|
262 |
+
preprocess_train = image_transform(model.visual.image_size, is_train=True)
|
263 |
+
preprocess_val = image_transform(model.visual.image_size, is_train=False)
|
264 |
+
return model, preprocess_train, preprocess_val
|
265 |
+
|
266 |
+
|
267 |
+
def list_models():
|
268 |
+
"""enumerate available model architectures based on config files"""
|
269 |
+
return list(_MODEL_CONFIGS.keys())
|
270 |
+
|
271 |
+
|
272 |
+
def add_model_config(path):
|
273 |
+
"""add model config path or file and update registry"""
|
274 |
+
if not isinstance(path, Path):
|
275 |
+
path = Path(path)
|
276 |
+
_MODEL_CONFIG_PATHS.append(path)
|
277 |
+
_rescan_model_configs()
|
models/CLAP/open_clip/feature_fusion.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Feature Fusion for Variable-Length Data Processing
|
3 |
+
AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
|
4 |
+
According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
class DAF(nn.Module):
|
12 |
+
"""
|
13 |
+
直接相加 DirectAddFuse
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
super(DAF, self).__init__()
|
18 |
+
|
19 |
+
def forward(self, x, residual):
|
20 |
+
return x + residual
|
21 |
+
|
22 |
+
|
23 |
+
class iAFF(nn.Module):
|
24 |
+
"""
|
25 |
+
多特征融合 iAFF
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, channels=64, r=4, type="2D"):
|
29 |
+
super(iAFF, self).__init__()
|
30 |
+
inter_channels = int(channels // r)
|
31 |
+
|
32 |
+
if type == "1D":
|
33 |
+
# 本地注意力
|
34 |
+
self.local_att = nn.Sequential(
|
35 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
36 |
+
nn.BatchNorm1d(inter_channels),
|
37 |
+
nn.ReLU(inplace=True),
|
38 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
39 |
+
nn.BatchNorm1d(channels),
|
40 |
+
)
|
41 |
+
|
42 |
+
# 全局注意力
|
43 |
+
self.global_att = nn.Sequential(
|
44 |
+
nn.AdaptiveAvgPool1d(1),
|
45 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
46 |
+
nn.BatchNorm1d(inter_channels),
|
47 |
+
nn.ReLU(inplace=True),
|
48 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
49 |
+
nn.BatchNorm1d(channels),
|
50 |
+
)
|
51 |
+
|
52 |
+
# 第二次本地注意力
|
53 |
+
self.local_att2 = nn.Sequential(
|
54 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
55 |
+
nn.BatchNorm1d(inter_channels),
|
56 |
+
nn.ReLU(inplace=True),
|
57 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
58 |
+
nn.BatchNorm1d(channels),
|
59 |
+
)
|
60 |
+
# 第二次全局注意力
|
61 |
+
self.global_att2 = nn.Sequential(
|
62 |
+
nn.AdaptiveAvgPool1d(1),
|
63 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
64 |
+
nn.BatchNorm1d(inter_channels),
|
65 |
+
nn.ReLU(inplace=True),
|
66 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
67 |
+
nn.BatchNorm1d(channels),
|
68 |
+
)
|
69 |
+
elif type == "2D":
|
70 |
+
# 本地注意力
|
71 |
+
self.local_att = nn.Sequential(
|
72 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
73 |
+
nn.BatchNorm2d(inter_channels),
|
74 |
+
nn.ReLU(inplace=True),
|
75 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
76 |
+
nn.BatchNorm2d(channels),
|
77 |
+
)
|
78 |
+
|
79 |
+
# 全局注意力
|
80 |
+
self.global_att = nn.Sequential(
|
81 |
+
nn.AdaptiveAvgPool2d(1),
|
82 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
83 |
+
nn.BatchNorm2d(inter_channels),
|
84 |
+
nn.ReLU(inplace=True),
|
85 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
86 |
+
nn.BatchNorm2d(channels),
|
87 |
+
)
|
88 |
+
|
89 |
+
# 第二次本地注意力
|
90 |
+
self.local_att2 = nn.Sequential(
|
91 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
92 |
+
nn.BatchNorm2d(inter_channels),
|
93 |
+
nn.ReLU(inplace=True),
|
94 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
95 |
+
nn.BatchNorm2d(channels),
|
96 |
+
)
|
97 |
+
# 第二次全局注意力
|
98 |
+
self.global_att2 = nn.Sequential(
|
99 |
+
nn.AdaptiveAvgPool2d(1),
|
100 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
101 |
+
nn.BatchNorm2d(inter_channels),
|
102 |
+
nn.ReLU(inplace=True),
|
103 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
104 |
+
nn.BatchNorm2d(channels),
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
raise f"the type is not supported"
|
108 |
+
|
109 |
+
self.sigmoid = nn.Sigmoid()
|
110 |
+
|
111 |
+
def forward(self, x, residual):
|
112 |
+
flag = False
|
113 |
+
xa = x + residual
|
114 |
+
if xa.size(0) == 1:
|
115 |
+
xa = torch.cat([xa, xa], dim=0)
|
116 |
+
flag = True
|
117 |
+
xl = self.local_att(xa)
|
118 |
+
xg = self.global_att(xa)
|
119 |
+
xlg = xl + xg
|
120 |
+
wei = self.sigmoid(xlg)
|
121 |
+
xi = x * wei + residual * (1 - wei)
|
122 |
+
|
123 |
+
xl2 = self.local_att2(xi)
|
124 |
+
xg2 = self.global_att(xi)
|
125 |
+
xlg2 = xl2 + xg2
|
126 |
+
wei2 = self.sigmoid(xlg2)
|
127 |
+
xo = x * wei2 + residual * (1 - wei2)
|
128 |
+
if flag:
|
129 |
+
xo = xo[0].unsqueeze(0)
|
130 |
+
return xo
|
131 |
+
|
132 |
+
|
133 |
+
class AFF(nn.Module):
|
134 |
+
"""
|
135 |
+
多特征融合 AFF
|
136 |
+
"""
|
137 |
+
|
138 |
+
def __init__(self, channels=64, r=4, type="2D"):
|
139 |
+
super(AFF, self).__init__()
|
140 |
+
inter_channels = int(channels // r)
|
141 |
+
|
142 |
+
if type == "1D":
|
143 |
+
self.local_att = nn.Sequential(
|
144 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
145 |
+
nn.BatchNorm1d(inter_channels),
|
146 |
+
nn.ReLU(inplace=True),
|
147 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
148 |
+
nn.BatchNorm1d(channels),
|
149 |
+
)
|
150 |
+
self.global_att = nn.Sequential(
|
151 |
+
nn.AdaptiveAvgPool1d(1),
|
152 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
153 |
+
nn.BatchNorm1d(inter_channels),
|
154 |
+
nn.ReLU(inplace=True),
|
155 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
156 |
+
nn.BatchNorm1d(channels),
|
157 |
+
)
|
158 |
+
elif type == "2D":
|
159 |
+
self.local_att = nn.Sequential(
|
160 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
161 |
+
nn.BatchNorm2d(inter_channels),
|
162 |
+
nn.ReLU(inplace=True),
|
163 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
164 |
+
nn.BatchNorm2d(channels),
|
165 |
+
)
|
166 |
+
self.global_att = nn.Sequential(
|
167 |
+
nn.AdaptiveAvgPool2d(1),
|
168 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
169 |
+
nn.BatchNorm2d(inter_channels),
|
170 |
+
nn.ReLU(inplace=True),
|
171 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
172 |
+
nn.BatchNorm2d(channels),
|
173 |
+
)
|
174 |
+
else:
|
175 |
+
raise f"the type is not supported."
|
176 |
+
|
177 |
+
self.sigmoid = nn.Sigmoid()
|
178 |
+
|
179 |
+
def forward(self, x, residual):
|
180 |
+
flag = False
|
181 |
+
xa = x + residual
|
182 |
+
if xa.size(0) == 1:
|
183 |
+
xa = torch.cat([xa, xa], dim=0)
|
184 |
+
flag = True
|
185 |
+
xl = self.local_att(xa)
|
186 |
+
xg = self.global_att(xa)
|
187 |
+
xlg = xl + xg
|
188 |
+
wei = self.sigmoid(xlg)
|
189 |
+
xo = 2 * x * wei + 2 * residual * (1 - wei)
|
190 |
+
if flag:
|
191 |
+
xo = xo[0].unsqueeze(0)
|
192 |
+
return xo
|