arnavkumar24 commited on
Commit
89040ed
1 Parent(s): ebbe80d
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. AudioSep_Colab.ipynb +128 -0
  2. CONTRIBUTING.md +92 -0
  3. Dockerfile +22 -0
  4. LICENSE +21 -0
  5. assets/results.png +0 -0
  6. benchmark.py +116 -0
  7. callbacks/base.py +35 -0
  8. checkpoint/audiosep_base_4M_steps.ckpt +3 -0
  9. checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt +3 -0
  10. cog.yaml +21 -0
  11. config/audiosep_base.yaml +41 -0
  12. data/audiotext_dataset.py +91 -0
  13. data/datamodules.py +122 -0
  14. data/waveform_mixers.py +127 -0
  15. datafiles/template.json +8 -0
  16. environment.yml +326 -0
  17. evaluation/evaluate_audiocaps.py +110 -0
  18. evaluation/evaluate_audioset.py +155 -0
  19. evaluation/evaluate_clotho.py +102 -0
  20. evaluation/evaluate_esc50.py +102 -0
  21. evaluation/evaluate_music.py +118 -0
  22. evaluation/evaluate_vggsound.py +114 -0
  23. evaluation/metadata/audiocaps_eval.csv +0 -0
  24. evaluation/metadata/audioset_eval.csv +0 -0
  25. evaluation/metadata/class_labels_indices.csv +528 -0
  26. evaluation/metadata/clotho_eval.csv +0 -0
  27. evaluation/metadata/esc50_eval.csv +0 -0
  28. evaluation/metadata/music_eval.csv +0 -0
  29. evaluation/metadata/vggsound_eval.csv +0 -0
  30. losses.py +17 -0
  31. models/CLAP/__init__.py +0 -0
  32. models/CLAP/__pycache__/__init__.cpython-310.pyc +0 -0
  33. models/CLAP/open_clip/__init__.py +25 -0
  34. models/CLAP/open_clip/__pycache__/__init__.cpython-310.pyc +0 -0
  35. models/CLAP/open_clip/__pycache__/factory.cpython-310.pyc +0 -0
  36. models/CLAP/open_clip/__pycache__/feature_fusion.cpython-310.pyc +0 -0
  37. models/CLAP/open_clip/__pycache__/htsat.cpython-310.pyc +0 -0
  38. models/CLAP/open_clip/__pycache__/loss.cpython-310.pyc +0 -0
  39. models/CLAP/open_clip/__pycache__/model.cpython-310.pyc +0 -0
  40. models/CLAP/open_clip/__pycache__/openai.cpython-310.pyc +0 -0
  41. models/CLAP/open_clip/__pycache__/pann_model.cpython-310.pyc +0 -0
  42. models/CLAP/open_clip/__pycache__/pretrained.cpython-310.pyc +0 -0
  43. models/CLAP/open_clip/__pycache__/timm_model.cpython-310.pyc +0 -0
  44. models/CLAP/open_clip/__pycache__/tokenizer.cpython-310.pyc +0 -0
  45. models/CLAP/open_clip/__pycache__/transform.cpython-310.pyc +0 -0
  46. models/CLAP/open_clip/__pycache__/utils.cpython-310.pyc +0 -0
  47. models/CLAP/open_clip/bert.py +40 -0
  48. models/CLAP/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  49. models/CLAP/open_clip/factory.py +277 -0
  50. models/CLAP/open_clip/feature_fusion.py +192 -0
AudioSep_Colab.ipynb ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from pathlib import Path\n",
10
+ "\n",
11
+ "repo_path = Path(\"/content/AudioSep\")\n",
12
+ "if not repo_path.exists():\n",
13
+ " !git clone https://github.com/Audio-AGI/AudioSep.git\n",
14
+ "\n",
15
+ "%cd /content/AudioSep"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": null,
21
+ "metadata": {
22
+ "id": "pjIhw5ECS_3_"
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "!pip install torchlibrosa==0.1.0 gradio==3.47.1 gdown lightning transformers==4.28.1 ftfy braceexpand webdataset soundfile wget h5py"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {
33
+ "id": "t6h9KB3CcjBd"
34
+ },
35
+ "outputs": [],
36
+ "source": [
37
+ "checkpoints_dir = Path(\"checkpoint\")\n",
38
+ "checkpoints_dir.mkdir(exist_ok=True)\n",
39
+ "\n",
40
+ "models = (\n",
41
+ " (\n",
42
+ " \"https://huggingface.co/spaces/badayvedat/AudioSep/resolve/main/checkpoint/audiosep_base_4M_steps.ckpt\",\n",
43
+ " checkpoints_dir / \"audiosep_base_4M_steps.ckpt\"\n",
44
+ " ),\n",
45
+ " (\n",
46
+ " \"https://huggingface.co/spaces/badayvedat/AudioSep/resolve/main/checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt\",\n",
47
+ " checkpoints_dir / \"music_speech_audioset_epoch_15_esc_89.98.pt\"\n",
48
+ " )\n",
49
+ ")\n",
50
+ "\n",
51
+ "for model_url, model_path in models:\n",
52
+ " if not model_path.exists():\n",
53
+ " !wget {model_url} -O {model_path}"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {
60
+ "id": "3uDrzCQyY58h"
61
+ },
62
+ "outputs": [],
63
+ "source": [
64
+ "!wget \"https://audio-agi.github.io/Separate-Anything-You-Describe/demos/exp31_water drops_mixture.wav\""
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {
71
+ "id": "0nr77CGXTwO1"
72
+ },
73
+ "outputs": [],
74
+ "source": [
75
+ "import torch\n",
76
+ "from pipeline import build_audiosep, inference\n",
77
+ "\n",
78
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
79
+ "\n",
80
+ "model = build_audiosep(\n",
81
+ " config_yaml='config/audiosep_base.yaml',\n",
82
+ " checkpoint_path=str(models[0][1]),\n",
83
+ " device=device)\n",
84
+ "\n",
85
+ "audio_file = 'exp31_water drops_mixture.wav'\n",
86
+ "text = 'water drops'\n",
87
+ "output_file='separated_audio.wav'\n",
88
+ "\n",
89
+ "# AudioSep processes the audio at 32 kHz sampling rate\n",
90
+ "inference(model, audio_file, text, output_file, device)"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {
97
+ "id": "kssOe0pbPSWp"
98
+ },
99
+ "outputs": [],
100
+ "source": [
101
+ "print(f\"The separated audio is saved to: '{output_file}' file.\")"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "metadata": {
108
+ "id": "sl35U3dAR6KN"
109
+ },
110
+ "outputs": [],
111
+ "source": []
112
+ }
113
+ ],
114
+ "metadata": {
115
+ "colab": {
116
+ "provenance": []
117
+ },
118
+ "kernelspec": {
119
+ "display_name": "Python 3",
120
+ "name": "python3"
121
+ },
122
+ "language_info": {
123
+ "name": "python"
124
+ }
125
+ },
126
+ "nbformat": 4,
127
+ "nbformat_minor": 0
128
+ }
CONTRIBUTING.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎵 Contributing to AudioSep
2
+
3
+ Welcome to the AudioSep repository, where your contributions can harmonize the world of audio separation. To ensure a harmonious and organized collaboration, please follow the contribution guidelines outlined below.
4
+
5
+ ## **Submitting Contributions**
6
+
7
+ To contribute to this project, please adhere to the following steps:
8
+
9
+ ### **1. Choose or Create an Issue**
10
+
11
+ - Start by reviewing the existing issues to identify areas where your contributions can make a significant impact.
12
+ - If you have ideas for new features, enhancements, or bug fixes, feel free to create a new issue to propose your contributions. Provide comprehensive details for clarity.
13
+
14
+ ### **2. Fork the Repository**
15
+
16
+ - To initiate your contribution, fork the primary repository by clicking the "Fork" button. This will create a copy of the repository in your personal GitHub account.
17
+
18
+ ### **3. Clone Your Forked Repository**
19
+
20
+ - Clone your forked repository to your local development environment using the following command:
21
+
22
+ ```bash
23
+ git clone https://github.com/your-username/AudioSep.git
24
+ ```
25
+
26
+ ### **4. Set Up the Upstream Remote**
27
+
28
+ - Maintain a reference to the primary project by adding it as the upstream remote:
29
+
30
+ ```bash
31
+ cd AudioSep
32
+ git remote add upstream https://github.com/Audio-AGI/AudioSep
33
+ git remote -v
34
+ ```
35
+
36
+ ### **5. Create a New Branch**
37
+
38
+ - Before starting your contribution, establish a new branch dedicated to your specific task:
39
+
40
+ ```bash
41
+ git checkout -b my-contribution
42
+ ```
43
+
44
+ ## **Working on Your Contribution**
45
+
46
+ Now that your development environment is ready and a new branch is established, you can start working on your contribution. Please ensure you adhere to the following guidelines:
47
+
48
+ ### **6. Make Changes**
49
+
50
+ - Implement the necessary changes, including code additions, enhancements, or bug fixes. Ensure your contributions are well-structured, documented, and aligned with the project's objectives.
51
+
52
+ ### **7. Commit Your Changes**
53
+
54
+ - Commit your changes using informative commit messages that clearly convey the purpose of your contributions:
55
+
56
+ ```bash
57
+ git commit -m "Add a descriptive message here"
58
+ ```
59
+
60
+ ### **8. Push Your Changes**
61
+
62
+ - Push the committed changes to your remote repository on GitHub:
63
+
64
+ ```bash
65
+ git push origin my-contribution
66
+ ```
67
+
68
+ ### **9. Create a Pull Request**
69
+
70
+ - Visit your repository on GitHub and click the "New Pull Request" button to initiate a pull request from your branch to the primary repository.
71
+
72
+ ### **10. Await Review**
73
+
74
+ - Your pull request will undergo review, and feedback will be provided by the project maintainers or fellow contributors. Be prepared to address any suggested changes or refinements.
75
+
76
+ ## **Community Engagement**
77
+
78
+ While contributing, please consider engaging with the community in the following ways:
79
+
80
+ ### **11. Join Discussions**
81
+
82
+ - Participate in discussions related to audio separation techniques and their applications. Share your insights, experiences, and expertise in the audio field.
83
+
84
+ ### **12. Share Ideas**
85
+
86
+ - If you have innovative ideas for advancing the project or optimizing audio separation, such as new algorithms or research findings, feel free to open issues to initiate productive discussions.
87
+
88
+ ## **Acknowledgment**
89
+
90
+ We appreciate your dedication to the world of audio separation. Your contributions play a crucial role in harmonizing audio and improving the listening experience for all. If you have questions or require assistance, please don't hesitate to contact the project maintainers.
91
+
92
+ Thank you for your valuable contributions, and we eagerly anticipate collaborating with you on AudioSep! 🎶🙌
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10.11
2
+
3
+ # Copy the current directory contents into the container at .
4
+ COPY . .
5
+
6
+ # Set the working directory to /
7
+ WORKDIR /
8
+
9
+ # Install requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r /requirements.txt
11
+
12
+ RUN useradd -m -u 1000 user
13
+ USER user
14
+ ENV HOME=/home/user \
15
+ PATH=/home/user/.local/bin:$PATH
16
+
17
+ WORKDIR $HOME/app
18
+
19
+ COPY --chown=user . $HOME/app
20
+
21
+ # Start the FastAPI app on port 7860, the default port expected by Spaces
22
+ CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Xubo Liu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
assets/results.png ADDED
benchmark.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import numpy as np
4
+ from evaluation.evaluate_audioset import AudioSetEvaluator
5
+ from evaluation.evaluate_audiocaps import AudioCapsEvaluator
6
+ from evaluation.evaluate_vggsound import VGGSoundEvaluator
7
+ from evaluation.evaluate_music import MUSICEvaluator
8
+ from evaluation.evaluate_esc50 import ESC50Evaluator
9
+ from evaluation.evaluate_clotho import ClothoEvaluator
10
+ from models.clap_encoder import CLAP_Encoder
11
+
12
+ from utils import (
13
+ load_ss_model,
14
+ calculate_sdr,
15
+ calculate_sisdr,
16
+ parse_yaml,
17
+ get_mean_sdr_from_dict,
18
+ )
19
+
20
+ def eval(checkpoint_path, config_yaml='config/audiosep_base.yaml'):
21
+
22
+ log_dir = 'eval_logs'
23
+ os.makedirs(log_dir, exist_ok=True)
24
+
25
+ device = "cuda"
26
+
27
+ configs = parse_yaml(config_yaml)
28
+
29
+ # AudioSet Evaluators
30
+ audioset_evaluator = AudioSetEvaluator()
31
+ # AudioCaps Evaluator
32
+ audiocaps_evaluator = AudioCapsEvaluator()
33
+ # VGGSound+ Evaluator
34
+ vggsound_evaluator = VGGSoundEvaluator()
35
+ # Clotho Evaluator
36
+ clotho_evaluator = ClothoEvaluator()
37
+ # MUSIC Evaluator
38
+ music_evaluator = MUSICEvaluator()
39
+ # ESC-50 Evaluator
40
+ esc50_evaluator = ESC50Evaluator()
41
+
42
+ # Load model
43
+ query_encoder = CLAP_Encoder().eval()
44
+
45
+ pl_model = load_ss_model(
46
+ configs=configs,
47
+ checkpoint_path=checkpoint_path,
48
+ query_encoder=query_encoder
49
+ ).to(device)
50
+
51
+ print(f'------- Start Evaluation -------')
52
+
53
+ # evaluation on Clotho
54
+ SISDR, SDRi = clotho_evaluator(pl_model)
55
+ msg_clotho = "Clotho Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
56
+ print(msg_clotho)
57
+
58
+ # evaluation on VGGSound+ (YAN)
59
+ SISDR, SDRi = vggsound_evaluator(pl_model)
60
+ msg_vgg = "VGGSound Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
61
+ print(msg_vgg)
62
+
63
+ # evaluation on MUSIC
64
+ SISDR, SDRi = music_evaluator(pl_model)
65
+ msg_music = "MUSIC Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
66
+ print(msg_music)
67
+
68
+ # evaluation on ESC-50
69
+ SISDR, SDRi = esc50_evaluator(pl_model)
70
+ msg_esc50 = "ESC-50 Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
71
+ print(msg_esc50)
72
+
73
+ # evaluation on AudioSet
74
+ stats_dict = audioset_evaluator(pl_model=pl_model)
75
+ median_sdris = {}
76
+ median_sisdrs = {}
77
+
78
+ for class_id in range(527):
79
+ median_sdris[class_id] = np.nanmedian(stats_dict["sdris_dict"][class_id])
80
+ median_sisdrs[class_id] = np.nanmedian(stats_dict["sisdrs_dict"][class_id])
81
+
82
+ SDRi = get_mean_sdr_from_dict(median_sdris)
83
+ SISDR = get_mean_sdr_from_dict(median_sisdrs)
84
+ msg_audioset = "AudioSet Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
85
+ print(msg_audioset)
86
+
87
+ # evaluation on AudioCaps
88
+ SISDR, SDRi = audiocaps_evaluator(pl_model)
89
+ msg_audiocaps = "AudioCaps Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
90
+ print(msg_audiocaps)
91
+
92
+ # evaluation on Clotho
93
+ SISDR, SDRi = clotho_evaluator(pl_model)
94
+ msg_clotho = "Clotho Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR)
95
+ print(msg_clotho)
96
+
97
+ msgs = [msg_audioset, msg_vgg, msg_audiocaps, msg_clotho, msg_music, msg_esc50]
98
+
99
+ # open file in write mode
100
+ log_path = os.path.join(log_dir, 'eval_results.txt')
101
+ with open(log_path, 'w') as fp:
102
+ for msg in msgs:
103
+ fp.write(msg + '\n')
104
+ print(f'Eval log is written to {log_path} ...')
105
+ print('------------------------- Done ---------------------------')
106
+
107
+
108
+ if __name__ == '__main__':
109
+ eval(checkpoint_path='checkpoint/audiosep_base.ckpt')
110
+
111
+
112
+
113
+
114
+
115
+
116
+
callbacks/base.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import lightning.pytorch as pl
3
+ from lightning.pytorch.utilities import rank_zero_only
4
+
5
+
6
+ class CheckpointEveryNSteps(pl.Callback):
7
+ def __init__(
8
+ self,
9
+ checkpoints_dir,
10
+ save_step_frequency,
11
+ ) -> None:
12
+ r"""Save a checkpoint every N steps.
13
+
14
+ Args:
15
+ checkpoints_dir (str): directory to save checkpoints
16
+ save_step_frequency (int): save checkpoint every N step
17
+ """
18
+
19
+ self.checkpoints_dir = checkpoints_dir
20
+ self.save_step_frequency = save_step_frequency
21
+
22
+ @rank_zero_only
23
+ def on_train_batch_end(self, *args, **kwargs) -> None:
24
+ r"""Save a checkpoint every N steps."""
25
+
26
+ trainer = args[0]
27
+ global_step = trainer.global_step
28
+
29
+ if global_step == 1 or global_step % self.save_step_frequency == 0:
30
+
31
+ ckpt_path = os.path.join(
32
+ self.checkpoints_dir,
33
+ "step={}.ckpt".format(global_step))
34
+ trainer.save_checkpoint(ckpt_path)
35
+ print("Save checkpoint to {}".format(ckpt_path))
checkpoint/audiosep_base_4M_steps.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8cda01bfd0ebd141eef45d41db7a3ada23a56568465840d3cff04b8010ce82c
3
+ size 1264844076
checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51c68f12f9d7ea25fdaaccf741ec7f81e93ee594455410f3bca4f47f88d8e006
3
+ size 2352471003
cog.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3
+
4
+ build:
5
+ gpu: true
6
+ python_version: "3.11"
7
+ python_packages:
8
+ - "torchlibrosa==0.1.0"
9
+ - "lightning==2.1.0"
10
+ - "torch==2.0.1"
11
+ - "transformers==4.28.1"
12
+ - "braceexpand==0.1.7"
13
+ - "webdataset==0.2.60"
14
+ - "soundfile==0.12.1"
15
+ - "torchaudio==2.0.2"
16
+ - "torchvision==0.15.2"
17
+ - "h5py==3.10.0"
18
+ - "ftfy==6.1.1"
19
+ - "pandas==2.1.1"
20
+ - "wget==3.2"
21
+ predict: "predict.py:Predictor"
config/audiosep_base.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ task_name: AudioSep
3
+
4
+ data:
5
+ datafiles:
6
+ - 'datafiles/template.json'
7
+
8
+ sampling_rate: 32000
9
+ segment_seconds: 5
10
+ loudness_norm:
11
+ lower_db: -10
12
+ higher_db: 10
13
+ max_mix_num: 2
14
+
15
+ model:
16
+ query_net: CLAP
17
+ condition_size: 512
18
+ model_type: ResUNet30
19
+ input_channels: 1
20
+ output_channels: 1
21
+ resume_checkpoint: ""
22
+ use_text_ratio: 1.0
23
+
24
+ train:
25
+ optimizer:
26
+ optimizer_type: AdamW
27
+ learning_rate: 1e-3
28
+ warm_up_steps: 10000
29
+ reduce_lr_steps: 1000000
30
+ lr_lambda_type: constant_warm_up
31
+ num_nodes: 1
32
+ num_workers: 6
33
+ loss_type: l1_wav
34
+ sync_batchnorm: True
35
+ batch_size_per_device: 12
36
+ steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`.
37
+ evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps.
38
+ save_step_frequency: 20000 # Save every #save_step_frequency steps.
39
+ early_stop_steps: 10000001
40
+ random_seed: 1234
41
+
data/audiotext_dataset.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import torch
4
+ import torchaudio
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class AudioTextDataset(Dataset):
9
+ """Can sample data from audio-text databases
10
+ Params:
11
+ sampling_rate: audio sampling rate
12
+ max_clip_len: max length (seconds) of audio clip to be sampled
13
+ """
14
+ def __init__(
15
+ self,
16
+ datafiles=[''],
17
+ sampling_rate=32000,
18
+ max_clip_len=5,
19
+ ):
20
+ all_data_json = []
21
+ for datafile in datafiles:
22
+ with open(datafile, 'r') as fp:
23
+ data_json = json.load(fp)['data']
24
+ all_data_json.extend(data_json)
25
+ self.all_data_json = all_data_json
26
+
27
+ self.sampling_rate = sampling_rate
28
+ self.max_length = max_clip_len * sampling_rate
29
+
30
+ def __len__(self):
31
+ return len(self.all_data_json)
32
+
33
+ def _cut_or_randomcrop(self, waveform):
34
+ # waveform: [1, samples]
35
+ # random crop
36
+ if waveform.size(1) > self.max_length:
37
+ random_idx = random.randint(0, waveform.size(1)-self.max_length)
38
+ waveform = waveform[:, random_idx:random_idx+self.max_length]
39
+ else:
40
+ temp_wav = torch.zeros(1, self.max_length)
41
+ temp_wav[:, 0:waveform.size(1)] = waveform
42
+ waveform = temp_wav
43
+
44
+ assert waveform.size(1) == self.max_length, \
45
+ f"number of audio samples is {waveform.size(1)}"
46
+
47
+ return waveform
48
+
49
+ def _read_audio(self, index):
50
+ try:
51
+ audio_path = self.all_data_json[index]['wav']
52
+ audio_data, audio_rate = torchaudio.load(audio_path, channels_first=True)
53
+ text = self.all_data_json[index]['caption']
54
+
55
+ # drop short utterance
56
+ if audio_data.size(1) < self.sampling_rate * 1:
57
+ raise Exception(f'{audio_path} is too short, drop it ...')
58
+
59
+ return text, audio_data, audio_rate
60
+
61
+ except Exception as e:
62
+ print(f'error: {e} occurs, when loading {audio_path}')
63
+ random_index = random.randint(0, len(self.all_data_json)-1)
64
+ return self._read_audio(index=random_index)
65
+
66
+ def __getitem__(self, index):
67
+ # create a audio tensor
68
+ text, audio_data, audio_rate = self._read_audio(index)
69
+ audio_len = audio_data.shape[1] / audio_rate
70
+ # convert stero to single channel
71
+ if audio_data.shape[0] > 1:
72
+ # audio_data: [samples]
73
+ audio_data = (audio_data[0] + audio_data[1]) / 2
74
+ else:
75
+ audio_data = audio_data.squeeze(0)
76
+
77
+ # resample audio clip
78
+ if audio_rate != self.sampling_rate:
79
+ audio_data = torchaudio.functional.resample(audio_data, orig_freq=audio_rate, new_freq=self.sampling_rate)
80
+
81
+ audio_data = audio_data.unsqueeze(0)
82
+
83
+ audio_data = self._cut_or_randomcrop(audio_data)
84
+
85
+ data_dict = {
86
+ 'text': text,
87
+ 'waveform': audio_data,
88
+ 'modality': 'audio_text'
89
+ }
90
+
91
+ return data_dict
data/datamodules.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, NoReturn
2
+ import torch
3
+ import lightning.pytorch as pl
4
+ from torch.utils.data import DataLoader
5
+ from data.audiotext_dataset import AudioTextDataset
6
+
7
+
8
+ class DataModule(pl.LightningDataModule):
9
+ def __init__(
10
+ self,
11
+ train_dataset: object,
12
+ batch_size: int,
13
+ num_workers: int
14
+ ):
15
+ r"""Data module. To get one batch of data:
16
+
17
+ code-block:: python
18
+
19
+ data_module.setup()
20
+
21
+ for batch_data_dict in data_module.train_dataloader():
22
+ print(batch_data_dict.keys())
23
+ break
24
+
25
+ Args:
26
+ train_sampler: Sampler object
27
+ train_dataset: Dataset object
28
+ num_workers: int
29
+ distributed: bool
30
+ """
31
+ super().__init__()
32
+ self._train_dataset = train_dataset
33
+ self.num_workers = num_workers
34
+ self.batch_size = batch_size
35
+ self.collate_fn = collate_fn
36
+
37
+
38
+ def prepare_data(self):
39
+ # download, split, etc...
40
+ # only called on 1 GPU/TPU in distributed
41
+ pass
42
+
43
+ def setup(self, stage: Optional[str] = None) -> NoReturn:
44
+ r"""called on every device."""
45
+
46
+ # make assignments here (val/train/test split)
47
+ # called on every process in DDP
48
+
49
+ # SegmentSampler is used for selecting segments for training.
50
+ # On multiple devices, each SegmentSampler samples a part of mini-batch
51
+ # data.
52
+ self.train_dataset = self._train_dataset
53
+
54
+
55
+ def train_dataloader(self) -> torch.utils.data.DataLoader:
56
+ r"""Get train loader."""
57
+ train_loader = DataLoader(
58
+ dataset=self.train_dataset,
59
+ batch_size=self.batch_size,
60
+ collate_fn=self.collate_fn,
61
+ num_workers=self.num_workers,
62
+ pin_memory=True,
63
+ persistent_workers=False,
64
+ shuffle=True
65
+ )
66
+
67
+ return train_loader
68
+
69
+ def val_dataloader(self):
70
+ # val_split = Dataset(...)
71
+ # return DataLoader(val_split)
72
+ pass
73
+
74
+ def test_dataloader(self):
75
+ # test_split = Dataset(...)
76
+ # return DataLoader(test_split)
77
+ pass
78
+
79
+ def teardown(self):
80
+ # clean up after fit or test
81
+ # called on every process in DDP
82
+ pass
83
+
84
+
85
+ def collate_fn(list_data_dict):
86
+ r"""Collate mini-batch data to inputs and targets for training.
87
+
88
+ Args:
89
+ list_data_dict: e.g., [
90
+ {
91
+ 'text': 'a sound of dog',
92
+ 'waveform': (1, samples),
93
+ 'modality': 'audio_text'
94
+ }
95
+ ...
96
+ ]
97
+ Returns:
98
+ data_dict: e.g.
99
+ 'audio_text': {
100
+ 'text': ['a sound of dog', ...]
101
+ 'waveform': (batch_size, 1, samples)
102
+ }
103
+ """
104
+
105
+ at_list_data_dict = [data_dict for data_dict in list_data_dict if data_dict['modality']=='audio_text']
106
+
107
+ at_data_dict = {}
108
+
109
+ if len(at_list_data_dict) > 0:
110
+ for key in at_list_data_dict[0].keys():
111
+ at_data_dict[key] = [at_data_dict[key] for at_data_dict in at_list_data_dict]
112
+ if key == 'waveform':
113
+ at_data_dict[key] = torch.stack(at_data_dict[key])
114
+ elif key == 'text':
115
+ at_data_dict[key] = [text for text in at_data_dict[key]]
116
+
117
+
118
+ data_dict = {
119
+ 'audio_text': at_data_dict
120
+ }
121
+
122
+ return data_dict
data/waveform_mixers.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import sre_compile
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import pyloudnorm as pyln
7
+
8
+
9
+ class SegmentMixer(nn.Module):
10
+ def __init__(self, max_mix_num, lower_db, higher_db):
11
+ super(SegmentMixer, self).__init__()
12
+
13
+ self.max_mix_num = max_mix_num
14
+ self.loudness_param = {
15
+ 'lower_db': lower_db,
16
+ 'higher_db': higher_db,
17
+ }
18
+
19
+ def __call__(self, waveforms):
20
+
21
+ batch_size = waveforms.shape[0]
22
+
23
+ data_dict = {
24
+ 'segment': [],
25
+ 'mixture': [],
26
+ }
27
+
28
+ for n in range(0, batch_size):
29
+
30
+ segment = waveforms[n].clone()
31
+
32
+ # create zero tensors as the background template
33
+ noise = torch.zeros_like(segment)
34
+
35
+ mix_num = random.randint(2, self.max_mix_num)
36
+ assert mix_num >= 2
37
+
38
+ for i in range(1, mix_num):
39
+ next_segment = waveforms[(n + i) % batch_size]
40
+ rescaled_next_segment = dynamic_loudnorm(audio=next_segment, reference=segment, **self.loudness_param)
41
+ noise += rescaled_next_segment
42
+
43
+ # randomly normalize background noise
44
+ noise = dynamic_loudnorm(audio=noise, reference=segment, **self.loudness_param)
45
+
46
+ # create audio mixyure
47
+ mixture = segment + noise
48
+
49
+ # declipping if need be
50
+ max_value = torch.max(torch.abs(mixture))
51
+ if max_value > 1:
52
+ segment *= 0.9 / max_value
53
+ mixture *= 0.9 / max_value
54
+
55
+ data_dict['segment'].append(segment)
56
+ data_dict['mixture'].append(mixture)
57
+
58
+ for key in data_dict.keys():
59
+ data_dict[key] = torch.stack(data_dict[key], dim=0)
60
+
61
+ # return data_dict
62
+ return data_dict['mixture'], data_dict['segment']
63
+
64
+
65
+ def rescale_to_match_energy(segment1, segment2):
66
+
67
+ ratio = get_energy_ratio(segment1, segment2)
68
+ rescaled_segment1 = segment1 / ratio
69
+ return rescaled_segment1
70
+
71
+
72
+ def get_energy(x):
73
+ return torch.mean(x ** 2)
74
+
75
+
76
+ def get_energy_ratio(segment1, segment2):
77
+
78
+ energy1 = get_energy(segment1)
79
+ energy2 = max(get_energy(segment2), 1e-10)
80
+ ratio = (energy1 / energy2) ** 0.5
81
+ ratio = torch.clamp(ratio, 0.02, 50)
82
+ return ratio
83
+
84
+
85
+ def dynamic_loudnorm(audio, reference, lower_db=-10, higher_db=10):
86
+ rescaled_audio = rescale_to_match_energy(audio, reference)
87
+
88
+ delta_loudness = random.randint(lower_db, higher_db)
89
+
90
+ gain = np.power(10.0, delta_loudness / 20.0)
91
+
92
+ return gain * rescaled_audio
93
+
94
+
95
+ def torch_to_numpy(tensor):
96
+ """Convert a PyTorch tensor to a NumPy array."""
97
+ if isinstance(tensor, torch.Tensor):
98
+ return tensor.detach().cpu().numpy()
99
+ else:
100
+ raise ValueError("Input must be a PyTorch tensor.")
101
+
102
+
103
+ def numpy_to_torch(array):
104
+ """Convert a NumPy array to a PyTorch tensor."""
105
+ if isinstance(array, np.ndarray):
106
+ return torch.from_numpy(array)
107
+ else:
108
+ raise ValueError("Input must be a NumPy array.")
109
+
110
+
111
+ # decayed
112
+ def random_loudness_norm(audio, lower_db=-35, higher_db=-15, sr=32000):
113
+ device = audio.device
114
+ audio = torch_to_numpy(audio.squeeze(0))
115
+ # randomly select a norm volume
116
+ norm_vol = random.randint(lower_db, higher_db)
117
+
118
+ # measure the loudness first
119
+ meter = pyln.Meter(sr) # create BS.1770 meter
120
+ loudness = meter.integrated_loudness(audio)
121
+ # loudness normalize audio
122
+ normalized_audio = pyln.normalize.loudness(audio, loudness, norm_vol)
123
+
124
+ normalized_audio = numpy_to_torch(normalized_audio).unsqueeze(0)
125
+
126
+ return normalized_audio.to(device)
127
+
datafiles/template.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data": [
3
+ {
4
+ "wav": "path_to_audio_file",
5
+ "caption": "textual_desciptions"
6
+ }
7
+ ]
8
+ }
environment.yml ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: AudioSep
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - _openmp_mutex=5.1=1_gnu
9
+ - backcall=0.2.0=pyhd3eb1b0_0
10
+ - blas=1.0=mkl
11
+ - boltons=23.0.0=py310h06a4308_0
12
+ - brotlipy=0.7.0=py310h7f8727e_1002
13
+ - bzip2=1.0.8=h7b6447c_0
14
+ - ca-certificates=2023.01.10=h06a4308_0
15
+ - certifi=2022.12.7=py310h06a4308_0
16
+ - cffi=1.15.1=py310h5eee18b_3
17
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
18
+ - comm=0.1.2=py310h06a4308_0
19
+ - conda=23.3.1=py310h06a4308_0
20
+ - conda-content-trust=0.1.3=py310h06a4308_0
21
+ - conda-package-handling=2.0.2=py310h06a4308_0
22
+ - conda-package-streaming=0.7.0=py310h06a4308_0
23
+ - cryptography=38.0.4=py310h9ce1e76_0
24
+ - cuda=11.6.1=0
25
+ - cuda-cccl=11.6.55=hf6102b2_0
26
+ - cuda-command-line-tools=11.6.2=0
27
+ - cuda-compiler=11.6.2=0
28
+ - cuda-cudart=11.6.55=he381448_0
29
+ - cuda-cudart-dev=11.6.55=h42ad0f4_0
30
+ - cuda-cuobjdump=11.6.124=h2eeebcb_0
31
+ - cuda-cupti=11.6.124=h86345e5_0
32
+ - cuda-cuxxfilt=11.6.124=hecbf4f6_0
33
+ - cuda-driver-dev=11.6.55=0
34
+ - cuda-gdb=12.1.55=0
35
+ - cuda-libraries=11.6.1=0
36
+ - cuda-libraries-dev=11.6.1=0
37
+ - cuda-memcheck=11.8.86=0
38
+ - cuda-nsight=12.1.55=0
39
+ - cuda-nsight-compute=12.1.0=0
40
+ - cuda-nvcc=11.6.124=hbba6d2d_0
41
+ - cuda-nvdisasm=12.1.55=0
42
+ - cuda-nvml-dev=11.6.55=haa9ef22_0
43
+ - cuda-nvprof=12.1.55=0
44
+ - cuda-nvprune=11.6.124=he22ec0a_0
45
+ - cuda-nvrtc=11.6.124=h020bade_0
46
+ - cuda-nvrtc-dev=11.6.124=h249d397_0
47
+ - cuda-nvtx=11.6.124=h0630a44_0
48
+ - cuda-nvvp=12.1.55=0
49
+ - cuda-runtime=11.6.1=0
50
+ - cuda-samples=11.6.101=h8efea70_0
51
+ - cuda-sanitizer-api=12.1.55=0
52
+ - cuda-toolkit=11.6.1=0
53
+ - cuda-tools=11.6.1=0
54
+ - cuda-visual-tools=11.6.1=0
55
+ - debugpy=1.5.1=py310h295c915_0
56
+ - decorator=5.1.1=pyhd3eb1b0_0
57
+ - flit-core=3.8.0=py310h06a4308_0
58
+ - freetype=2.12.1=h4a9f257_0
59
+ - gds-tools=1.6.0.25=0
60
+ - giflib=5.2.1=h5eee18b_3
61
+ - gmp=6.2.1=h295c915_3
62
+ - gnutls=3.6.15=he1e5248_0
63
+ - idna=3.4=py310h06a4308_0
64
+ - intel-openmp=2021.4.0=h06a4308_3561
65
+ - ipykernel=6.19.2=py310h2f386ee_0
66
+ - ipython=8.12.0=py310h06a4308_0
67
+ - jpeg=9e=h5eee18b_1
68
+ - jsonpatch=1.32=pyhd3eb1b0_0
69
+ - jsonpointer=2.1=pyhd3eb1b0_0
70
+ - jupyter_client=8.1.0=py310h06a4308_0
71
+ - jupyter_core=5.3.0=py310h06a4308_0
72
+ - lame=3.100=h7b6447c_0
73
+ - lcms2=2.12=h3be6417_0
74
+ - ld_impl_linux-64=2.38=h1181459_1
75
+ - lerc=3.0=h295c915_0
76
+ - libcublas=11.9.2.110=h5e84587_0
77
+ - libcublas-dev=11.9.2.110=h5c901ab_0
78
+ - libcufft=10.7.1.112=hf425ae0_0
79
+ - libcufft-dev=10.7.1.112=ha5ce4c0_0
80
+ - libcufile=1.6.0.25=0
81
+ - libcufile-dev=1.6.0.25=0
82
+ - libcurand=10.3.2.56=0
83
+ - libcurand-dev=10.3.2.56=0
84
+ - libcusolver=11.3.4.124=h33c3c4e_0
85
+ - libcusparse=11.7.2.124=h7538f96_0
86
+ - libcusparse-dev=11.7.2.124=hbbe9722_0
87
+ - libdeflate=1.17=h5eee18b_0
88
+ - libffi=3.4.2=h6a678d5_6
89
+ - libgcc-ng=11.2.0=h1234567_1
90
+ - libgomp=11.2.0=h1234567_1
91
+ - libiconv=1.16=h7f8727e_2
92
+ - libidn2=2.3.2=h7f8727e_0
93
+ - libnpp=11.6.3.124=hd2722f0_0
94
+ - libnpp-dev=11.6.3.124=h3c42840_0
95
+ - libnvjpeg=11.6.2.124=hd473ad6_0
96
+ - libnvjpeg-dev=11.6.2.124=hb5906b9_0
97
+ - libpng=1.6.39=h5eee18b_0
98
+ - libsodium=1.0.18=h7b6447c_0
99
+ - libstdcxx-ng=11.2.0=h1234567_1
100
+ - libtasn1=4.19.0=h5eee18b_0
101
+ - libtiff=4.5.0=h6a678d5_2
102
+ - libunistring=0.9.10=h27cfd23_0
103
+ - libuuid=1.41.5=h5eee18b_0
104
+ - libwebp=1.2.4=h11a3e52_1
105
+ - libwebp-base=1.2.4=h5eee18b_1
106
+ - lz4-c=1.9.4=h6a678d5_0
107
+ - matplotlib-inline=0.1.6=py310h06a4308_0
108
+ - mkl=2021.4.0=h06a4308_640
109
+ - mkl-service=2.4.0=py310h7f8727e_0
110
+ - mkl_fft=1.3.1=py310hd6ae3a3_0
111
+ - mkl_random=1.2.2=py310h00e6091_0
112
+ - ncurses=6.4=h6a678d5_0
113
+ - nest-asyncio=1.5.6=py310h06a4308_0
114
+ - nettle=3.7.3=hbbd107a_1
115
+ - nsight-compute=2023.1.0.15=0
116
+ - numpy=1.23.5=py310hd5efca6_0
117
+ - numpy-base=1.23.5=py310h8e6c178_0
118
+ - openh264=2.1.1=h4ff587b_0
119
+ - openssl=1.1.1t=h7f8727e_0
120
+ - packaging=23.0=py310h06a4308_0
121
+ - parso=0.8.3=pyhd3eb1b0_0
122
+ - pexpect=4.8.0=pyhd3eb1b0_3
123
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
124
+ - pip=22.3.1=py310h06a4308_0
125
+ - platformdirs=2.5.2=py310h06a4308_0
126
+ - pluggy=1.0.0=py310h06a4308_1
127
+ - psutil=5.9.0=py310h5eee18b_0
128
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
129
+ - pure_eval=0.2.2=pyhd3eb1b0_0
130
+ - pycosat=0.6.4=py310h5eee18b_0
131
+ - pycparser=2.21=pyhd3eb1b0_0
132
+ - pyopenssl=22.0.0=pyhd3eb1b0_0
133
+ - pysocks=1.7.1=py310h06a4308_0
134
+ - python=3.10.9=h7a1cb2a_0
135
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
136
+ - pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0
137
+ - pytorch-cuda=11.6=h867d48c_1
138
+ - pytorch-mutex=1.0=cuda
139
+ - pyzmq=23.2.0=py310h6a678d5_0
140
+ - readline=8.2=h5eee18b_0
141
+ - requests=2.28.1=py310h06a4308_0
142
+ - ruamel.yaml=0.17.21=py310h5eee18b_0
143
+ - ruamel.yaml.clib=0.2.6=py310h5eee18b_1
144
+ - setuptools=65.6.3=py310h06a4308_0
145
+ - six=1.16.0=pyhd3eb1b0_1
146
+ - sqlite=3.40.1=h5082296_0
147
+ - stack_data=0.2.0=pyhd3eb1b0_0
148
+ - tk=8.6.12=h1ccaba5_0
149
+ - toolz=0.12.0=py310h06a4308_0
150
+ - torchaudio=0.13.1=py310_cu116
151
+ - torchvision=0.14.1=py310_cu116
152
+ - tornado=6.2=py310h5eee18b_0
153
+ - tqdm=4.64.1=py310h06a4308_0
154
+ - typing_extensions=4.4.0=py310h06a4308_0
155
+ - tzdata=2022g=h04d1e81_0
156
+ - urllib3=1.26.14=py310h06a4308_0
157
+ - wheel=0.37.1=pyhd3eb1b0_0
158
+ - xz=5.2.10=h5eee18b_1
159
+ - zeromq=4.3.4=h2531618_0
160
+ - zlib=1.2.13=h5eee18b_0
161
+ - zstandard=0.18.0=py310h5eee18b_0
162
+ - zstd=1.5.4=hc292b87_0
163
+ - pip:
164
+ - absl-py==1.4.0
165
+ - aiohttp==3.8.4
166
+ - aiosignal==1.3.1
167
+ - anyio==3.6.2
168
+ - appdirs==1.4.4
169
+ - arrow==1.2.3
170
+ - asttokens==2.2.1
171
+ - async-generator==1.10
172
+ - async-timeout==4.0.2
173
+ - attrs==22.2.0
174
+ - audioread==3.0.0
175
+ - av==10.0.0
176
+ - beartype==0.12.0
177
+ - beautifulsoup4==4.12.2
178
+ - blessed==1.20.0
179
+ - braceexpand==0.1.7
180
+ - cachetools==5.3.0
181
+ - click==8.1.3
182
+ - contourpy==1.0.7
183
+ - croniter==1.3.10
184
+ - cycler==0.11.0
185
+ - dataclasses-json==0.5.8
186
+ - dateutils==0.6.12
187
+ - decord==0.6.0
188
+ - deepdiff==6.3.0
189
+ - dtk==0.2
190
+ - exceptiongroup==1.1.1
191
+ - executing==1.2.0
192
+ - fastapi==0.88.0
193
+ - ffmpeg==1.4
194
+ - ffmpeg-python==0.2.0
195
+ - filelock==3.12.0
196
+ - fonttools==4.39.3
197
+ - frozenlist==1.3.3
198
+ - fsspec==2023.4.0
199
+ - ftfy==6.1.1
200
+ - future==0.18.3
201
+ - gammatone==1.0
202
+ - google-auth==2.17.3
203
+ - google-auth-oauthlib==1.0.0
204
+ - greenlet==2.0.2
205
+ - grpcio==1.54.0
206
+ - h11==0.14.0
207
+ - h5py==3.8.0
208
+ - hickle==5.0.2
209
+ - huggingface-hub==0.14.1
210
+ - humanize==4.6.0
211
+ - imageio==2.27.0
212
+ - inquirer==3.1.3
213
+ - ipdb==0.13.13
214
+ - itsdangerous==2.1.2
215
+ - jedi==0.18.2
216
+ - jinja2==3.1.2
217
+ - joblib==1.2.0
218
+ - kiwisolver==1.4.4
219
+ - langchain==0.0.216
220
+ - langchainplus-sdk==0.0.17
221
+ - lazy-loader==0.2
222
+ - librosa==0.10.0.post2
223
+ - lightning==2.0.0
224
+ - lightning-cloud==0.5.33
225
+ - lightning-utilities==0.8.0
226
+ - llvmlite==0.39.1
227
+ - markdown==3.4.3
228
+ - markdown-it-py==2.2.0
229
+ - markupsafe==2.1.2
230
+ - marshmallow==3.19.0
231
+ - marshmallow-enum==1.5.1
232
+ - matplotlib==3.7.1
233
+ - mdurl==0.1.2
234
+ - mergedeep==1.3.4
235
+ - mock==5.0.2
236
+ - msgpack==1.0.5
237
+ - msgpack-numpy==0.4.8
238
+ - multidict==6.0.4
239
+ - musdb==0.4.0
240
+ - mypy-extensions==1.0.0
241
+ - networkx==3.1
242
+ - nose==1.3.7
243
+ - numba==0.56.4
244
+ - numexpr==2.8.4
245
+ - oauthlib==3.2.2
246
+ - openai==0.27.8
247
+ - openapi-schema-pydantic==1.2.4
248
+ - opencv-python==4.7.0.72
249
+ - ordered-set==4.1.0
250
+ - outcome==1.2.0
251
+ - pandas==1.5.3
252
+ - panns-inference==0.1.0
253
+ - pesq==0.0.4
254
+ - pillow==9.5.0
255
+ - pooch==1.6.0
256
+ - prompt-toolkit==3.0.38
257
+ - protobuf==4.22.3
258
+ - pyaml==23.5.9
259
+ - pyasn1==0.5.0
260
+ - pyasn1-modules==0.3.0
261
+ - pydantic==1.10.7
262
+ - pygments==2.14.0
263
+ - pyjwt==2.6.0
264
+ - pyloudnorm==0.1.1
265
+ - pyparsing==3.0.9
266
+ - pystoi==0.3.3
267
+ - python-editor==1.0.4
268
+ - python-multipart==0.0.6
269
+ - pytorch-ignite==0.3.0
270
+ - pytorch-lightning==2.0.1.post0
271
+ - pytz==2023.3
272
+ - pywavelets==1.4.1
273
+ - pyyaml==6.0
274
+ - readchar==4.0.5
275
+ - regex==2023.3.23
276
+ - requests-oauthlib==1.3.1
277
+ - resampy==0.4.2
278
+ - rich==13.3.3
279
+ - rsa==4.9
280
+ - scikit-image==0.20.0
281
+ - scikit-learn==1.2.2
282
+ - scipy==1.10.1
283
+ - selenium==4.8.3
284
+ - simplejpeg==1.6.6
285
+ - sniffio==1.3.0
286
+ - sortedcontainers==2.4.0
287
+ - soundfile==0.12.1
288
+ - soupsieve==2.4
289
+ - soxr==0.3.5
290
+ - sqlalchemy==2.0.17
291
+ - stack-data==0.6.2
292
+ - starlette==0.22.0
293
+ - starsessions==1.3.0
294
+ - stempeg==0.2.3
295
+ - tenacity==8.2.2
296
+ - tensorboard==2.12.2
297
+ - tensorboard-data-server==0.7.0
298
+ - tensorboard-plugin-wit==1.8.1
299
+ - termcolor==1.1.0
300
+ - threadpoolctl==3.1.0
301
+ - tifffile==2023.3.21
302
+ - timm==0.3.2
303
+ - tokenizers==0.13.3
304
+ - tomli==2.0.1
305
+ - torchfile==0.1.0
306
+ - torchlibrosa==0.1.0
307
+ - torchmetrics==0.11.4
308
+ - traitlets==5.9.0
309
+ - transformers==4.28.1
310
+ - trio==0.22.0
311
+ - trio-websocket==0.10.2
312
+ - typeguard==3.0.2
313
+ - typing-extensions==4.5.0
314
+ - typing-inspect==0.9.0
315
+ - uvicorn==0.21.1
316
+ - visdom==0.1.8.9
317
+ - wcwidth==0.2.6
318
+ - webdataset==0.2.48
319
+ - websocket-client==1.5.1
320
+ - websockets==11.0.1
321
+ - werkzeug==2.2.3
322
+ - wget==3.2
323
+ - wsproto==1.2.0
324
+ - yarl==1.8.2
325
+ - zenodo-get==1.3.4
326
+ - zsvision==0.7.8
evaluation/evaluate_audiocaps.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ from typing import Dict, List
5
+
6
+ import csv
7
+ import pandas as pd
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+ import pathlib
12
+ import librosa
13
+ import lightning.pytorch as pl
14
+ from models.clap_encoder import CLAP_Encoder
15
+
16
+ sys.path.append('../AudioSep/')
17
+ from utils import (
18
+ load_ss_model,
19
+ calculate_sdr,
20
+ calculate_sisdr,
21
+ parse_yaml,
22
+ get_mean_sdr_from_dict,
23
+ )
24
+
25
+
26
+ class AudioCapsEvaluator:
27
+ def __init__(
28
+ self,
29
+ query='caption',
30
+ sampling_rate=32000,
31
+ ) -> None:
32
+ r"""AudioCaps evaluator.
33
+
34
+ Args:
35
+ query (str): type of query, 'caption' or 'labels'
36
+ Returns:
37
+ None
38
+ """
39
+
40
+ self.query = query
41
+ self.sampling_rate = sampling_rate
42
+
43
+ with open(f'evaluation/metadata/audiocaps_eval.csv') as csv_file:
44
+ csv_reader = csv.reader(csv_file, delimiter=',')
45
+ eval_list = [row for row in csv_reader][1:]
46
+
47
+ self.eval_list = eval_list
48
+ self.audio_dir = f'evaluation/data/audiocaps'
49
+
50
+ def __call__(
51
+ self,
52
+ pl_model: pl.LightningModule
53
+ ) -> Dict:
54
+ r"""Evalute."""
55
+
56
+ print(f'Evaluation on AudioCaps with [{self.query}] queries.')
57
+
58
+ pl_model.eval()
59
+ device = pl_model.device
60
+
61
+ sisdrs_list = []
62
+ sdris_list = []
63
+
64
+ with torch.no_grad():
65
+ for eval_data in tqdm(self.eval_list):
66
+
67
+ idx, caption, labels, _, _ = eval_data
68
+
69
+ source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav')
70
+ mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav')
71
+
72
+ source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
73
+ mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
74
+
75
+ sdr_no_sep = calculate_sdr(ref=source, est=mixture)
76
+
77
+ if self.query == 'caption':
78
+ text = [caption]
79
+ elif self.query == 'labels':
80
+ text = [labels]
81
+
82
+ conditions = pl_model.query_encoder.get_query_embed(
83
+ modality='text',
84
+ text=text,
85
+ device=device
86
+ )
87
+
88
+ input_dict = {
89
+ "mixture": torch.Tensor(mixture)[None, None, :].to(device),
90
+ "condition": conditions,
91
+ }
92
+
93
+
94
+ sep_segment = pl_model.ss_model(input_dict)["waveform"]
95
+ # sep_segment: (batch_size=1, channels_num=1, segment_samples)
96
+
97
+ sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
98
+ # sep_segment: (segment_samples,)
99
+
100
+ sdr = calculate_sdr(ref=source, est=sep_segment)
101
+ sdri = sdr - sdr_no_sep
102
+ sisdr = calculate_sisdr(ref=source, est=sep_segment)
103
+
104
+ sisdrs_list.append(sisdr)
105
+ sdris_list.append(sdri)
106
+
107
+ mean_sisdr = np.mean(sisdrs_list)
108
+ mean_sdri = np.mean(sdris_list)
109
+
110
+ return mean_sisdr, mean_sdri
evaluation/evaluate_audioset.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ from typing import Dict, List
5
+
6
+ import pandas as pd
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+ import pathlib
11
+ import librosa
12
+ import lightning.pytorch as pl
13
+ from models.clap_encoder import CLAP_Encoder
14
+
15
+ sys.path.append('../AudioSep/')
16
+ from utils import (
17
+ load_ss_model,
18
+ calculate_sdr,
19
+ calculate_sisdr,
20
+ parse_yaml,
21
+ get_mean_sdr_from_dict,
22
+ )
23
+
24
+
25
+ meta_csv_file = "evaluation/metadata/class_labels_indices.csv"
26
+ df = pd.read_csv(meta_csv_file, sep=',')
27
+
28
+ IDS = df['mid'].tolist()
29
+ LABELS = df['display_name'].tolist()
30
+
31
+ CLASSES_NUM = len(LABELS)
32
+
33
+ IX_TO_LB = {i : label for i, label in enumerate(LABELS)}
34
+
35
+
36
+ class AudioSetEvaluator:
37
+ def __init__(
38
+ self,
39
+ audios_dir='evaluation/data/audioset',
40
+ classes_num=527,
41
+ sampling_rate=32000,
42
+ number_per_class=10,
43
+ ) -> None:
44
+ r"""AudioSet evaluator.
45
+
46
+ Args:
47
+ audios_dir (str): directory of evaluation segments
48
+ classes_num (int): the number of sound classes
49
+ number_per_class (int), the number of samples to evaluate for each sound class
50
+
51
+ Returns:
52
+ None
53
+ """
54
+
55
+ self.audios_dir = audios_dir
56
+ self.classes_num = classes_num
57
+ self.number_per_class = number_per_class
58
+ self.sampling_rate = sampling_rate
59
+
60
+ @torch.no_grad()
61
+ def __call__(
62
+ self,
63
+ pl_model: pl.LightningModule
64
+ ) -> Dict:
65
+ r"""Evalute."""
66
+
67
+ pl_model.eval()
68
+
69
+ sisdrs_dict = {class_id: [] for class_id in range(self.classes_num)}
70
+ sdris_dict = {class_id: [] for class_id in range(self.classes_num)}
71
+
72
+ print('Evaluation on AudioSet with [text label] queries.')
73
+
74
+ for class_id in tqdm(range(self.classes_num)):
75
+
76
+ sub_dir = os.path.join(
77
+ self.audios_dir,
78
+ "class_id={}".format(class_id))
79
+
80
+ audio_names = self._get_audio_names(audios_dir=sub_dir)
81
+
82
+ for audio_index, audio_name in enumerate(audio_names):
83
+
84
+ if audio_index == self.number_per_class:
85
+ break
86
+
87
+ source_path = os.path.join(
88
+ sub_dir, "{},source.wav".format(audio_name))
89
+ mixture_path = os.path.join(
90
+ sub_dir, "{},mixture.wav".format(audio_name))
91
+
92
+ source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
93
+ mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
94
+
95
+ sdr_no_sep = calculate_sdr(ref=source, est=mixture)
96
+
97
+ device = pl_model.device
98
+
99
+ text = [IX_TO_LB[class_id]]
100
+
101
+ conditions = pl_model.query_encoder.get_query_embed(
102
+ modality='text',
103
+ text=text,
104
+ device=device
105
+ )
106
+
107
+ input_dict = {
108
+ "mixture": torch.Tensor(mixture)[None, None, :].to(device),
109
+ "condition": conditions,
110
+ }
111
+
112
+ sep_segment = pl_model.ss_model(input_dict)["waveform"]
113
+ # sep_segment: (batch_size=1, channels_num=1, segment_samples)
114
+
115
+ sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
116
+ # sep_segment: (segment_samples,)
117
+
118
+ sdr = calculate_sdr(ref=source, est=sep_segment)
119
+ sdri = sdr - sdr_no_sep
120
+ sisdr = calculate_sisdr(ref=source, est=sep_segment)
121
+
122
+
123
+ sisdrs_dict[class_id].append(sisdr)
124
+ sdris_dict[class_id].append(sdri)
125
+
126
+
127
+ stats_dict = {
128
+ "sisdrs_dict": sisdrs_dict,
129
+ "sdris_dict": sdris_dict,
130
+ }
131
+
132
+ return stats_dict
133
+
134
+ def _get_audio_names(self, audios_dir: str) -> List[str]:
135
+ r"""Get evaluation audio names."""
136
+ audio_names = sorted(os.listdir(audios_dir))
137
+
138
+ audio_names = [audio_name for audio_name in audio_names if '.wav' in audio_name]
139
+
140
+ audio_names = [
141
+ re.search(
142
+ "(.*),(mixture|source).wav",
143
+ audio_name).group(1) for audio_name in audio_names]
144
+
145
+ audio_names = sorted(list(set(audio_names)))
146
+
147
+ return audio_names
148
+
149
+ @staticmethod
150
+ def get_median_metrics(stats_dict, metric_type):
151
+ class_ids = stats_dict[metric_type].keys()
152
+ median_stats_dict = {
153
+ class_id: np.nanmedian(
154
+ stats_dict[metric_type][class_id]) for class_id in class_ids}
155
+ return median_stats_dict
evaluation/evaluate_clotho.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ from typing import Dict, List
5
+
6
+ import csv
7
+ import pandas as pd
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+ import pathlib
12
+ import librosa
13
+ import lightning.pytorch as pl
14
+ from models.clap_encoder import CLAP_Encoder
15
+
16
+ sys.path.append('../AudioSep/')
17
+ from utils import (
18
+ load_ss_model,
19
+ calculate_sdr,
20
+ calculate_sisdr,
21
+ parse_yaml,
22
+ get_mean_sdr_from_dict,
23
+ )
24
+
25
+
26
+ class ClothoEvaluator:
27
+ def __init__(
28
+ self,
29
+ sampling_rate=32000,
30
+ ) -> None:
31
+ r"""Clotho evaluator.
32
+ Returns:
33
+ None
34
+ """
35
+
36
+ self.sampling_rate = sampling_rate
37
+
38
+ with open('evaluation/metadata/clotho_eval.csv') as csv_file:
39
+ csv_reader = csv.reader(csv_file, delimiter=',')
40
+ eval_list = [row for row in csv_reader][1:]
41
+
42
+ self.eval_list = eval_list
43
+ self.audio_dir = 'evaluation/data/clotho'
44
+
45
+ def __call__(
46
+ self,
47
+ pl_model: pl.LightningModule
48
+ ) -> Dict:
49
+ r"""Evalute."""
50
+
51
+ print(f'Evaluation on Clotho Evaluation with [caption] queries.')
52
+
53
+ pl_model.eval()
54
+ device = pl_model.device
55
+
56
+ sisdrs_list = []
57
+ sdris_list = []
58
+
59
+ with torch.no_grad():
60
+ for eval_data in tqdm(self.eval_list):
61
+
62
+ idx, caption, _, _, _ = eval_data
63
+
64
+ source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav')
65
+ mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav')
66
+
67
+ source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
68
+ mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
69
+
70
+ sdr_no_sep = calculate_sdr(ref=source, est=mixture)
71
+
72
+ text = [caption]
73
+
74
+ conditions = pl_model.query_encoder.get_query_embed(
75
+ modality='text',
76
+ text=text,
77
+ device=device
78
+ )
79
+
80
+ input_dict = {
81
+ "mixture": torch.Tensor(mixture)[None, None, :].to(device),
82
+ "condition": conditions,
83
+ }
84
+
85
+ sep_segment = pl_model.ss_model(input_dict)["waveform"]
86
+ # sep_segment: (batch_size=1, channels_num=1, segment_samples)
87
+
88
+ sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
89
+ # sep_segment: (segment_samples,)
90
+
91
+ sdr = calculate_sdr(ref=source, est=sep_segment)
92
+ sdri = sdr - sdr_no_sep
93
+ sisdr = calculate_sisdr(ref=source, est=sep_segment)
94
+
95
+
96
+ sisdrs_list.append(sisdr)
97
+ sdris_list.append(sdri)
98
+
99
+ mean_sisdr = np.mean(sisdrs_list)
100
+ mean_sdri = np.mean(sdris_list)
101
+
102
+ return mean_sisdr, mean_sdri
evaluation/evaluate_esc50.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ from typing import Dict, List
5
+
6
+ import csv
7
+ import pandas as pd
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+ import pathlib
12
+ import librosa
13
+ import lightning.pytorch as pl
14
+ from models.clap_encoder import CLAP_Encoder
15
+
16
+ sys.path.append('../AudioSep/')
17
+ from utils import (
18
+ load_ss_model,
19
+ calculate_sdr,
20
+ calculate_sisdr,
21
+ parse_yaml,
22
+ get_mean_sdr_from_dict,
23
+ )
24
+
25
+
26
+ class ESC50Evaluator:
27
+ def __init__(
28
+ self,
29
+ sampling_rate=32000
30
+ ) -> None:
31
+ r"""ESC-50 evaluator.
32
+
33
+ Returns:
34
+ None
35
+ """
36
+
37
+ self.sampling_rate = sampling_rate
38
+
39
+ with open('evaluation/metadata/esc50_eval.csv') as csv_file:
40
+ csv_reader = csv.reader(csv_file, delimiter=',')
41
+ eval_list = [row for row in csv_reader][1:]
42
+
43
+ self.eval_list = eval_list
44
+ self.audio_dir = 'evaluation/data/esc50'
45
+
46
+ def __call__(
47
+ self,
48
+ pl_model: pl.LightningModule
49
+ ) -> Dict:
50
+ r"""Evalute."""
51
+
52
+ print(f'Evaluation on ESC-50 with [text label] queries.')
53
+
54
+ pl_model.eval()
55
+ device = pl_model.device
56
+
57
+ sisdrs_list = []
58
+ sdris_list = []
59
+
60
+ with torch.no_grad():
61
+ for eval_data in tqdm(self.eval_list):
62
+
63
+ idx, caption, _, _, = eval_data
64
+
65
+ source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav')
66
+ mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav')
67
+
68
+ source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
69
+ mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
70
+
71
+ sdr_no_sep = calculate_sdr(ref=source, est=mixture)
72
+
73
+ text = [caption]
74
+
75
+ conditions = pl_model.query_encoder.get_query_embed(
76
+ modality='text',
77
+ text=text,
78
+ device=device
79
+ )
80
+
81
+ input_dict = {
82
+ "mixture": torch.Tensor(mixture)[None, None, :].to(device),
83
+ "condition": conditions,
84
+ }
85
+
86
+ sep_segment = pl_model.ss_model(input_dict)["waveform"]
87
+ # sep_segment: (batch_size=1, channels_num=1, segment_samples)
88
+
89
+ sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
90
+ # sep_segment: (segment_samples,)
91
+
92
+ sdr = calculate_sdr(ref=source, est=sep_segment)
93
+ sdri = sdr - sdr_no_sep
94
+ sisdr = calculate_sisdr(ref=source, est=sep_segment)
95
+
96
+ sisdrs_list.append(sisdr)
97
+ sdris_list.append(sdri)
98
+
99
+ mean_sdri = np.mean(sdris_list)
100
+ mean_sisdr = np.mean(sisdrs_list)
101
+
102
+ return mean_sisdr, mean_sdri
evaluation/evaluate_music.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ from typing import Dict, List
5
+
6
+ import csv
7
+ import pandas as pd
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+ import pathlib
12
+ import librosa
13
+ import lightning.pytorch as pl
14
+ from models.clap_encoder import CLAP_Encoder
15
+
16
+ sys.path.append('../AudioSep/')
17
+ from utils import (
18
+ load_ss_model,
19
+ calculate_sdr,
20
+ calculate_sisdr,
21
+ parse_yaml,
22
+ get_mean_sdr_from_dict,
23
+ )
24
+
25
+
26
+ class MUSICEvaluator:
27
+ def __init__(
28
+ self,
29
+ sampling_rate=32000
30
+ ) -> None:
31
+
32
+ self.sampling_rate = sampling_rate
33
+
34
+ with open('evaluation/metadata/music_eval.csv') as csv_file:
35
+ csv_reader = csv.reader(csv_file, delimiter=',')
36
+ eval_list = [row for row in csv_reader][1:]
37
+
38
+ self.eval_list = eval_list
39
+ self.audio_dir = 'evaluation/data/music'
40
+
41
+ self.source_types = [
42
+ "acoustic guitar",
43
+ "violin",
44
+ "accordion",
45
+ "xylophone",
46
+ "erhu",
47
+ "trumpet",
48
+ "tuba",
49
+ "cello",
50
+ "flute",
51
+ "saxophone"]
52
+
53
+ def __call__(
54
+ self,
55
+ pl_model: pl.LightningModule
56
+ ) -> Dict:
57
+ r"""Evalute."""
58
+
59
+ print(f'Evaluation on MUSIC Test with [text label] queries.')
60
+
61
+ pl_model.eval()
62
+ device = pl_model.device
63
+
64
+ sisdrs_list = {source_type: [] for source_type in self.source_types}
65
+ sdris_list = {source_type: [] for source_type in self.source_types}
66
+
67
+ with torch.no_grad():
68
+ for eval_data in tqdm(self.eval_list):
69
+
70
+ idx, caption, _, _, = eval_data
71
+
72
+ source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav')
73
+ mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav')
74
+
75
+ source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
76
+ mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
77
+
78
+ sdr_no_sep = calculate_sdr(ref=source, est=mixture)
79
+
80
+ text = [caption]
81
+
82
+ conditions = pl_model.query_encoder.get_query_embed(
83
+ modality='text',
84
+ text=text,
85
+ device=device
86
+ )
87
+
88
+ input_dict = {
89
+ "mixture": torch.Tensor(mixture)[None, None, :].to(device),
90
+ "condition": conditions,
91
+ }
92
+
93
+ sep_segment = pl_model.ss_model(input_dict)["waveform"]
94
+ # sep_segment: (batch_size=1, channels_num=1, segment_samples)
95
+
96
+ sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
97
+ # sep_segment: (segment_samples,)
98
+
99
+ sdr = calculate_sdr(ref=source, est=sep_segment)
100
+ sdri = sdr - sdr_no_sep
101
+ sisdr = calculate_sisdr(ref=source, est=sep_segment)
102
+
103
+ sisdrs_list[caption].append(sisdr)
104
+ sdris_list[caption].append(sdri)
105
+
106
+ mean_sisdr_list = []
107
+ mean_sdri_list = []
108
+
109
+ for source_class in self.source_types:
110
+ sisdr = np.mean(sisdrs_list[source_class])
111
+ sdri = np.mean(sdris_list[source_class])
112
+ mean_sisdr_list.append(sisdr)
113
+ mean_sdri_list.append(sdri)
114
+
115
+ mean_sdri = np.mean(mean_sdri_list)
116
+ mean_sisdr = np.mean(mean_sisdr_list)
117
+
118
+ return mean_sisdr, mean_sdri
evaluation/evaluate_vggsound.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ from typing import Dict, List
5
+
6
+ import csv
7
+ import pandas as pd
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+ import pathlib
12
+ import librosa
13
+ import lightning.pytorch as pl
14
+ from models.clap_encoder import CLAP_Encoder
15
+
16
+ sys.path.append('../AudioSep/')
17
+ from utils import (
18
+ load_ss_model,
19
+ calculate_sdr,
20
+ calculate_sisdr,
21
+ parse_yaml,
22
+ get_mean_sdr_from_dict,
23
+ )
24
+
25
+
26
+ class VGGSoundEvaluator:
27
+ def __init__(
28
+ self,
29
+ sampling_rate=32000
30
+ ) -> None:
31
+ r"""VGGSound evaluator.
32
+
33
+ Args:
34
+ data_recipe (str): dataset split, 'yan'
35
+ Returns:
36
+ None
37
+ """
38
+
39
+ self.sampling_rate = sampling_rate
40
+
41
+ with open('evaluation/metadata/vggsound_eval.csv') as csv_file:
42
+ csv_reader = csv.reader(csv_file, delimiter=',')
43
+ eval_list = [row for row in csv_reader][1:]
44
+
45
+ self.eval_list = eval_list
46
+ self.audio_dir = 'evaluation/data/vggsound'
47
+
48
+ def __call__(
49
+ self,
50
+ pl_model: pl.LightningModule
51
+ ) -> Dict:
52
+ r"""Evalute."""
53
+
54
+ print(f'Evaluation on VGGSound+ with [text label] queries.')
55
+
56
+ pl_model.eval()
57
+ device = pl_model.device
58
+
59
+ sisdrs_list = []
60
+ sdris_list = []
61
+ sisdris_list = []
62
+
63
+
64
+ with torch.no_grad():
65
+ for eval_data in tqdm(self.eval_list):
66
+
67
+ # labels, source_path, mixture_path = eval_data
68
+ file_id, mix_wav, s0_wav, s0_text, s1_wav, s1_text = eval_data
69
+
70
+ labels = s0_text
71
+
72
+ mixture_path = os.path.join(self.audio_dir, mix_wav)
73
+ source_path = os.path.join(self.audio_dir, s0_wav)
74
+
75
+
76
+ source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
77
+ mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
78
+
79
+ sdr_no_sep = calculate_sdr(ref=source, est=mixture)
80
+
81
+ text = [labels]
82
+ conditions = pl_model.query_encoder.get_query_embed(
83
+ modality='text',
84
+ text=text,
85
+ device=device
86
+ )
87
+
88
+ input_dict = {
89
+ "mixture": torch.Tensor(mixture)[None, None, :].to(device),
90
+ "condition": conditions,
91
+ }
92
+
93
+ sep_segment = pl_model.ss_model(input_dict)["waveform"]
94
+ # sep_segment: (batch_size=1, channels_num=1, segment_samples)
95
+
96
+ sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
97
+ # sep_segment: (segment_samples,)
98
+
99
+ sdr = calculate_sdr(ref=source, est=sep_segment)
100
+ sdri = sdr - sdr_no_sep
101
+
102
+ sisdr_no_sep = calculate_sisdr(ref=source, est=mixture)
103
+ sisdr = calculate_sisdr(ref=source, est=sep_segment)
104
+ sisdri = sisdr - sisdr_no_sep
105
+
106
+ sisdrs_list.append(sisdr)
107
+ sdris_list.append(sdri)
108
+ sisdris_list.append(sisdri)
109
+
110
+
111
+ mean_sisdr = np.mean(sisdrs_list)
112
+ mean_sdri = np.mean(sdris_list)
113
+
114
+ return mean_sisdr, mean_sdri
evaluation/metadata/audiocaps_eval.csv ADDED
The diff for this file is too large to render. See raw diff
 
evaluation/metadata/audioset_eval.csv ADDED
The diff for this file is too large to render. See raw diff
 
evaluation/metadata/class_labels_indices.csv ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ index,mid,display_name
2
+ 0,/m/09x0r,"Speech"
3
+ 1,/m/05zppz,"Male speech, man speaking"
4
+ 2,/m/02zsn,"Female speech, woman speaking"
5
+ 3,/m/0ytgt,"Child speech, kid speaking"
6
+ 4,/m/01h8n0,"Conversation"
7
+ 5,/m/02qldy,"Narration, monologue"
8
+ 6,/m/0261r1,"Babbling"
9
+ 7,/m/0brhx,"Speech synthesizer"
10
+ 8,/m/07p6fty,"Shout"
11
+ 9,/m/07q4ntr,"Bellow"
12
+ 10,/m/07rwj3x,"Whoop"
13
+ 11,/m/07sr1lc,"Yell"
14
+ 12,/m/04gy_2,"Battle cry"
15
+ 13,/t/dd00135,"Children shouting"
16
+ 14,/m/03qc9zr,"Screaming"
17
+ 15,/m/02rtxlg,"Whispering"
18
+ 16,/m/01j3sz,"Laughter"
19
+ 17,/t/dd00001,"Baby laughter"
20
+ 18,/m/07r660_,"Giggle"
21
+ 19,/m/07s04w4,"Snicker"
22
+ 20,/m/07sq110,"Belly laugh"
23
+ 21,/m/07rgt08,"Chuckle, chortle"
24
+ 22,/m/0463cq4,"Crying, sobbing"
25
+ 23,/t/dd00002,"Baby cry, infant cry"
26
+ 24,/m/07qz6j3,"Whimper"
27
+ 25,/m/07qw_06,"Wail, moan"
28
+ 26,/m/07plz5l,"Sigh"
29
+ 27,/m/015lz1,"Singing"
30
+ 28,/m/0l14jd,"Choir"
31
+ 29,/m/01swy6,"Yodeling"
32
+ 30,/m/02bk07,"Chant"
33
+ 31,/m/01c194,"Mantra"
34
+ 32,/t/dd00003,"Male singing"
35
+ 33,/t/dd00004,"Female singing"
36
+ 34,/t/dd00005,"Child singing"
37
+ 35,/t/dd00006,"Synthetic singing"
38
+ 36,/m/06bxc,"Rapping"
39
+ 37,/m/02fxyj,"Humming"
40
+ 38,/m/07s2xch,"Groan"
41
+ 39,/m/07r4k75,"Grunt"
42
+ 40,/m/01w250,"Whistling"
43
+ 41,/m/0lyf6,"Breathing"
44
+ 42,/m/07mzm6,"Wheeze"
45
+ 43,/m/01d3sd,"Snoring"
46
+ 44,/m/07s0dtb,"Gasp"
47
+ 45,/m/07pyy8b,"Pant"
48
+ 46,/m/07q0yl5,"Snort"
49
+ 47,/m/01b_21,"Cough"
50
+ 48,/m/0dl9sf8,"Throat clearing"
51
+ 49,/m/01hsr_,"Sneeze"
52
+ 50,/m/07ppn3j,"Sniff"
53
+ 51,/m/06h7j,"Run"
54
+ 52,/m/07qv_x_,"Shuffle"
55
+ 53,/m/07pbtc8,"Walk, footsteps"
56
+ 54,/m/03cczk,"Chewing, mastication"
57
+ 55,/m/07pdhp0,"Biting"
58
+ 56,/m/0939n_,"Gargling"
59
+ 57,/m/01g90h,"Stomach rumble"
60
+ 58,/m/03q5_w,"Burping, eructation"
61
+ 59,/m/02p3nc,"Hiccup"
62
+ 60,/m/02_nn,"Fart"
63
+ 61,/m/0k65p,"Hands"
64
+ 62,/m/025_jnm,"Finger snapping"
65
+ 63,/m/0l15bq,"Clapping"
66
+ 64,/m/01jg02,"Heart sounds, heartbeat"
67
+ 65,/m/01jg1z,"Heart murmur"
68
+ 66,/m/053hz1,"Cheering"
69
+ 67,/m/028ght,"Applause"
70
+ 68,/m/07rkbfh,"Chatter"
71
+ 69,/m/03qtwd,"Crowd"
72
+ 70,/m/07qfr4h,"Hubbub, speech noise, speech babble"
73
+ 71,/t/dd00013,"Children playing"
74
+ 72,/m/0jbk,"Animal"
75
+ 73,/m/068hy,"Domestic animals, pets"
76
+ 74,/m/0bt9lr,"Dog"
77
+ 75,/m/05tny_,"Bark"
78
+ 76,/m/07r_k2n,"Yip"
79
+ 77,/m/07qf0zm,"Howl"
80
+ 78,/m/07rc7d9,"Bow-wow"
81
+ 79,/m/0ghcn6,"Growling"
82
+ 80,/t/dd00136,"Whimper (dog)"
83
+ 81,/m/01yrx,"Cat"
84
+ 82,/m/02yds9,"Purr"
85
+ 83,/m/07qrkrw,"Meow"
86
+ 84,/m/07rjwbb,"Hiss"
87
+ 85,/m/07r81j2,"Caterwaul"
88
+ 86,/m/0ch8v,"Livestock, farm animals, working animals"
89
+ 87,/m/03k3r,"Horse"
90
+ 88,/m/07rv9rh,"Clip-clop"
91
+ 89,/m/07q5rw0,"Neigh, whinny"
92
+ 90,/m/01xq0k1,"Cattle, bovinae"
93
+ 91,/m/07rpkh9,"Moo"
94
+ 92,/m/0239kh,"Cowbell"
95
+ 93,/m/068zj,"Pig"
96
+ 94,/t/dd00018,"Oink"
97
+ 95,/m/03fwl,"Goat"
98
+ 96,/m/07q0h5t,"Bleat"
99
+ 97,/m/07bgp,"Sheep"
100
+ 98,/m/025rv6n,"Fowl"
101
+ 99,/m/09b5t,"Chicken, rooster"
102
+ 100,/m/07st89h,"Cluck"
103
+ 101,/m/07qn5dc,"Crowing, cock-a-doodle-doo"
104
+ 102,/m/01rd7k,"Turkey"
105
+ 103,/m/07svc2k,"Gobble"
106
+ 104,/m/09ddx,"Duck"
107
+ 105,/m/07qdb04,"Quack"
108
+ 106,/m/0dbvp,"Goose"
109
+ 107,/m/07qwf61,"Honk"
110
+ 108,/m/01280g,"Wild animals"
111
+ 109,/m/0cdnk,"Roaring cats (lions, tigers)"
112
+ 110,/m/04cvmfc,"Roar"
113
+ 111,/m/015p6,"Bird"
114
+ 112,/m/020bb7,"Bird vocalization, bird call, bird song"
115
+ 113,/m/07pggtn,"Chirp, tweet"
116
+ 114,/m/07sx8x_,"Squawk"
117
+ 115,/m/0h0rv,"Pigeon, dove"
118
+ 116,/m/07r_25d,"Coo"
119
+ 117,/m/04s8yn,"Crow"
120
+ 118,/m/07r5c2p,"Caw"
121
+ 119,/m/09d5_,"Owl"
122
+ 120,/m/07r_80w,"Hoot"
123
+ 121,/m/05_wcq,"Bird flight, flapping wings"
124
+ 122,/m/01z5f,"Canidae, dogs, wolves"
125
+ 123,/m/06hps,"Rodents, rats, mice"
126
+ 124,/m/04rmv,"Mouse"
127
+ 125,/m/07r4gkf,"Patter"
128
+ 126,/m/03vt0,"Insect"
129
+ 127,/m/09xqv,"Cricket"
130
+ 128,/m/09f96,"Mosquito"
131
+ 129,/m/0h2mp,"Fly, housefly"
132
+ 130,/m/07pjwq1,"Buzz"
133
+ 131,/m/01h3n,"Bee, wasp, etc."
134
+ 132,/m/09ld4,"Frog"
135
+ 133,/m/07st88b,"Croak"
136
+ 134,/m/078jl,"Snake"
137
+ 135,/m/07qn4z3,"Rattle"
138
+ 136,/m/032n05,"Whale vocalization"
139
+ 137,/m/04rlf,"Music"
140
+ 138,/m/04szw,"Musical instrument"
141
+ 139,/m/0fx80y,"Plucked string instrument"
142
+ 140,/m/0342h,"Guitar"
143
+ 141,/m/02sgy,"Electric guitar"
144
+ 142,/m/018vs,"Bass guitar"
145
+ 143,/m/042v_gx,"Acoustic guitar"
146
+ 144,/m/06w87,"Steel guitar, slide guitar"
147
+ 145,/m/01glhc,"Tapping (guitar technique)"
148
+ 146,/m/07s0s5r,"Strum"
149
+ 147,/m/018j2,"Banjo"
150
+ 148,/m/0jtg0,"Sitar"
151
+ 149,/m/04rzd,"Mandolin"
152
+ 150,/m/01bns_,"Zither"
153
+ 151,/m/07xzm,"Ukulele"
154
+ 152,/m/05148p4,"Keyboard (musical)"
155
+ 153,/m/05r5c,"Piano"
156
+ 154,/m/01s0ps,"Electric piano"
157
+ 155,/m/013y1f,"Organ"
158
+ 156,/m/03xq_f,"Electronic organ"
159
+ 157,/m/03gvt,"Hammond organ"
160
+ 158,/m/0l14qv,"Synthesizer"
161
+ 159,/m/01v1d8,"Sampler"
162
+ 160,/m/03q5t,"Harpsichord"
163
+ 161,/m/0l14md,"Percussion"
164
+ 162,/m/02hnl,"Drum kit"
165
+ 163,/m/0cfdd,"Drum machine"
166
+ 164,/m/026t6,"Drum"
167
+ 165,/m/06rvn,"Snare drum"
168
+ 166,/m/03t3fj,"Rimshot"
169
+ 167,/m/02k_mr,"Drum roll"
170
+ 168,/m/0bm02,"Bass drum"
171
+ 169,/m/011k_j,"Timpani"
172
+ 170,/m/01p970,"Tabla"
173
+ 171,/m/01qbl,"Cymbal"
174
+ 172,/m/03qtq,"Hi-hat"
175
+ 173,/m/01sm1g,"Wood block"
176
+ 174,/m/07brj,"Tambourine"
177
+ 175,/m/05r5wn,"Rattle (instrument)"
178
+ 176,/m/0xzly,"Maraca"
179
+ 177,/m/0mbct,"Gong"
180
+ 178,/m/016622,"Tubular bells"
181
+ 179,/m/0j45pbj,"Mallet percussion"
182
+ 180,/m/0dwsp,"Marimba, xylophone"
183
+ 181,/m/0dwtp,"Glockenspiel"
184
+ 182,/m/0dwt5,"Vibraphone"
185
+ 183,/m/0l156b,"Steelpan"
186
+ 184,/m/05pd6,"Orchestra"
187
+ 185,/m/01kcd,"Brass instrument"
188
+ 186,/m/0319l,"French horn"
189
+ 187,/m/07gql,"Trumpet"
190
+ 188,/m/07c6l,"Trombone"
191
+ 189,/m/0l14_3,"Bowed string instrument"
192
+ 190,/m/02qmj0d,"String section"
193
+ 191,/m/07y_7,"Violin, fiddle"
194
+ 192,/m/0d8_n,"Pizzicato"
195
+ 193,/m/01xqw,"Cello"
196
+ 194,/m/02fsn,"Double bass"
197
+ 195,/m/085jw,"Wind instrument, woodwind instrument"
198
+ 196,/m/0l14j_,"Flute"
199
+ 197,/m/06ncr,"Saxophone"
200
+ 198,/m/01wy6,"Clarinet"
201
+ 199,/m/03m5k,"Harp"
202
+ 200,/m/0395lw,"Bell"
203
+ 201,/m/03w41f,"Church bell"
204
+ 202,/m/027m70_,"Jingle bell"
205
+ 203,/m/0gy1t2s,"Bicycle bell"
206
+ 204,/m/07n_g,"Tuning fork"
207
+ 205,/m/0f8s22,"Chime"
208
+ 206,/m/026fgl,"Wind chime"
209
+ 207,/m/0150b9,"Change ringing (campanology)"
210
+ 208,/m/03qjg,"Harmonica"
211
+ 209,/m/0mkg,"Accordion"
212
+ 210,/m/0192l,"Bagpipes"
213
+ 211,/m/02bxd,"Didgeridoo"
214
+ 212,/m/0l14l2,"Shofar"
215
+ 213,/m/07kc_,"Theremin"
216
+ 214,/m/0l14t7,"Singing bowl"
217
+ 215,/m/01hgjl,"Scratching (performance technique)"
218
+ 216,/m/064t9,"Pop music"
219
+ 217,/m/0glt670,"Hip hop music"
220
+ 218,/m/02cz_7,"Beatboxing"
221
+ 219,/m/06by7,"Rock music"
222
+ 220,/m/03lty,"Heavy metal"
223
+ 221,/m/05r6t,"Punk rock"
224
+ 222,/m/0dls3,"Grunge"
225
+ 223,/m/0dl5d,"Progressive rock"
226
+ 224,/m/07sbbz2,"Rock and roll"
227
+ 225,/m/05w3f,"Psychedelic rock"
228
+ 226,/m/06j6l,"Rhythm and blues"
229
+ 227,/m/0gywn,"Soul music"
230
+ 228,/m/06cqb,"Reggae"
231
+ 229,/m/01lyv,"Country"
232
+ 230,/m/015y_n,"Swing music"
233
+ 231,/m/0gg8l,"Bluegrass"
234
+ 232,/m/02x8m,"Funk"
235
+ 233,/m/02w4v,"Folk music"
236
+ 234,/m/06j64v,"Middle Eastern music"
237
+ 235,/m/03_d0,"Jazz"
238
+ 236,/m/026z9,"Disco"
239
+ 237,/m/0ggq0m,"Classical music"
240
+ 238,/m/05lls,"Opera"
241
+ 239,/m/02lkt,"Electronic music"
242
+ 240,/m/03mb9,"House music"
243
+ 241,/m/07gxw,"Techno"
244
+ 242,/m/07s72n,"Dubstep"
245
+ 243,/m/0283d,"Drum and bass"
246
+ 244,/m/0m0jc,"Electronica"
247
+ 245,/m/08cyft,"Electronic dance music"
248
+ 246,/m/0fd3y,"Ambient music"
249
+ 247,/m/07lnk,"Trance music"
250
+ 248,/m/0g293,"Music of Latin America"
251
+ 249,/m/0ln16,"Salsa music"
252
+ 250,/m/0326g,"Flamenco"
253
+ 251,/m/0155w,"Blues"
254
+ 252,/m/05fw6t,"Music for children"
255
+ 253,/m/02v2lh,"New-age music"
256
+ 254,/m/0y4f8,"Vocal music"
257
+ 255,/m/0z9c,"A capella"
258
+ 256,/m/0164x2,"Music of Africa"
259
+ 257,/m/0145m,"Afrobeat"
260
+ 258,/m/02mscn,"Christian music"
261
+ 259,/m/016cjb,"Gospel music"
262
+ 260,/m/028sqc,"Music of Asia"
263
+ 261,/m/015vgc,"Carnatic music"
264
+ 262,/m/0dq0md,"Music of Bollywood"
265
+ 263,/m/06rqw,"Ska"
266
+ 264,/m/02p0sh1,"Traditional music"
267
+ 265,/m/05rwpb,"Independent music"
268
+ 266,/m/074ft,"Song"
269
+ 267,/m/025td0t,"Background music"
270
+ 268,/m/02cjck,"Theme music"
271
+ 269,/m/03r5q_,"Jingle (music)"
272
+ 270,/m/0l14gg,"Soundtrack music"
273
+ 271,/m/07pkxdp,"Lullaby"
274
+ 272,/m/01z7dr,"Video game music"
275
+ 273,/m/0140xf,"Christmas music"
276
+ 274,/m/0ggx5q,"Dance music"
277
+ 275,/m/04wptg,"Wedding music"
278
+ 276,/t/dd00031,"Happy music"
279
+ 277,/t/dd00032,"Funny music"
280
+ 278,/t/dd00033,"Sad music"
281
+ 279,/t/dd00034,"Tender music"
282
+ 280,/t/dd00035,"Exciting music"
283
+ 281,/t/dd00036,"Angry music"
284
+ 282,/t/dd00037,"Scary music"
285
+ 283,/m/03m9d0z,"Wind"
286
+ 284,/m/09t49,"Rustling leaves"
287
+ 285,/t/dd00092,"Wind noise (microphone)"
288
+ 286,/m/0jb2l,"Thunderstorm"
289
+ 287,/m/0ngt1,"Thunder"
290
+ 288,/m/0838f,"Water"
291
+ 289,/m/06mb1,"Rain"
292
+ 290,/m/07r10fb,"Raindrop"
293
+ 291,/t/dd00038,"Rain on surface"
294
+ 292,/m/0j6m2,"Stream"
295
+ 293,/m/0j2kx,"Waterfall"
296
+ 294,/m/05kq4,"Ocean"
297
+ 295,/m/034srq,"Waves, surf"
298
+ 296,/m/06wzb,"Steam"
299
+ 297,/m/07swgks,"Gurgling"
300
+ 298,/m/02_41,"Fire"
301
+ 299,/m/07pzfmf,"Crackle"
302
+ 300,/m/07yv9,"Vehicle"
303
+ 301,/m/019jd,"Boat, Water vehicle"
304
+ 302,/m/0hsrw,"Sailboat, sailing ship"
305
+ 303,/m/056ks2,"Rowboat, canoe, kayak"
306
+ 304,/m/02rlv9,"Motorboat, speedboat"
307
+ 305,/m/06q74,"Ship"
308
+ 306,/m/012f08,"Motor vehicle (road)"
309
+ 307,/m/0k4j,"Car"
310
+ 308,/m/0912c9,"Vehicle horn, car horn, honking"
311
+ 309,/m/07qv_d5,"Toot"
312
+ 310,/m/02mfyn,"Car alarm"
313
+ 311,/m/04gxbd,"Power windows, electric windows"
314
+ 312,/m/07rknqz,"Skidding"
315
+ 313,/m/0h9mv,"Tire squeal"
316
+ 314,/t/dd00134,"Car passing by"
317
+ 315,/m/0ltv,"Race car, auto racing"
318
+ 316,/m/07r04,"Truck"
319
+ 317,/m/0gvgw0,"Air brake"
320
+ 318,/m/05x_td,"Air horn, truck horn"
321
+ 319,/m/02rhddq,"Reversing beeps"
322
+ 320,/m/03cl9h,"Ice cream truck, ice cream van"
323
+ 321,/m/01bjv,"Bus"
324
+ 322,/m/03j1ly,"Emergency vehicle"
325
+ 323,/m/04qvtq,"Police car (siren)"
326
+ 324,/m/012n7d,"Ambulance (siren)"
327
+ 325,/m/012ndj,"Fire engine, fire truck (siren)"
328
+ 326,/m/04_sv,"Motorcycle"
329
+ 327,/m/0btp2,"Traffic noise, roadway noise"
330
+ 328,/m/06d_3,"Rail transport"
331
+ 329,/m/07jdr,"Train"
332
+ 330,/m/04zmvq,"Train whistle"
333
+ 331,/m/0284vy3,"Train horn"
334
+ 332,/m/01g50p,"Railroad car, train wagon"
335
+ 333,/t/dd00048,"Train wheels squealing"
336
+ 334,/m/0195fx,"Subway, metro, underground"
337
+ 335,/m/0k5j,"Aircraft"
338
+ 336,/m/014yck,"Aircraft engine"
339
+ 337,/m/04229,"Jet engine"
340
+ 338,/m/02l6bg,"Propeller, airscrew"
341
+ 339,/m/09ct_,"Helicopter"
342
+ 340,/m/0cmf2,"Fixed-wing aircraft, airplane"
343
+ 341,/m/0199g,"Bicycle"
344
+ 342,/m/06_fw,"Skateboard"
345
+ 343,/m/02mk9,"Engine"
346
+ 344,/t/dd00065,"Light engine (high frequency)"
347
+ 345,/m/08j51y,"Dental drill, dentist's drill"
348
+ 346,/m/01yg9g,"Lawn mower"
349
+ 347,/m/01j4z9,"Chainsaw"
350
+ 348,/t/dd00066,"Medium engine (mid frequency)"
351
+ 349,/t/dd00067,"Heavy engine (low frequency)"
352
+ 350,/m/01h82_,"Engine knocking"
353
+ 351,/t/dd00130,"Engine starting"
354
+ 352,/m/07pb8fc,"Idling"
355
+ 353,/m/07q2z82,"Accelerating, revving, vroom"
356
+ 354,/m/02dgv,"Door"
357
+ 355,/m/03wwcy,"Doorbell"
358
+ 356,/m/07r67yg,"Ding-dong"
359
+ 357,/m/02y_763,"Sliding door"
360
+ 358,/m/07rjzl8,"Slam"
361
+ 359,/m/07r4wb8,"Knock"
362
+ 360,/m/07qcpgn,"Tap"
363
+ 361,/m/07q6cd_,"Squeak"
364
+ 362,/m/0642b4,"Cupboard open or close"
365
+ 363,/m/0fqfqc,"Drawer open or close"
366
+ 364,/m/04brg2,"Dishes, pots, and pans"
367
+ 365,/m/023pjk,"Cutlery, silverware"
368
+ 366,/m/07pn_8q,"Chopping (food)"
369
+ 367,/m/0dxrf,"Frying (food)"
370
+ 368,/m/0fx9l,"Microwave oven"
371
+ 369,/m/02pjr4,"Blender"
372
+ 370,/m/02jz0l,"Water tap, faucet"
373
+ 371,/m/0130jx,"Sink (filling or washing)"
374
+ 372,/m/03dnzn,"Bathtub (filling or washing)"
375
+ 373,/m/03wvsk,"Hair dryer"
376
+ 374,/m/01jt3m,"Toilet flush"
377
+ 375,/m/012xff,"Toothbrush"
378
+ 376,/m/04fgwm,"Electric toothbrush"
379
+ 377,/m/0d31p,"Vacuum cleaner"
380
+ 378,/m/01s0vc,"Zipper (clothing)"
381
+ 379,/m/03v3yw,"Keys jangling"
382
+ 380,/m/0242l,"Coin (dropping)"
383
+ 381,/m/01lsmm,"Scissors"
384
+ 382,/m/02g901,"Electric shaver, electric razor"
385
+ 383,/m/05rj2,"Shuffling cards"
386
+ 384,/m/0316dw,"Typing"
387
+ 385,/m/0c2wf,"Typewriter"
388
+ 386,/m/01m2v,"Computer keyboard"
389
+ 387,/m/081rb,"Writing"
390
+ 388,/m/07pp_mv,"Alarm"
391
+ 389,/m/07cx4,"Telephone"
392
+ 390,/m/07pp8cl,"Telephone bell ringing"
393
+ 391,/m/01hnzm,"Ringtone"
394
+ 392,/m/02c8p,"Telephone dialing, DTMF"
395
+ 393,/m/015jpf,"Dial tone"
396
+ 394,/m/01z47d,"Busy signal"
397
+ 395,/m/046dlr,"Alarm clock"
398
+ 396,/m/03kmc9,"Siren"
399
+ 397,/m/0dgbq,"Civil defense siren"
400
+ 398,/m/030rvx,"Buzzer"
401
+ 399,/m/01y3hg,"Smoke detector, smoke alarm"
402
+ 400,/m/0c3f7m,"Fire alarm"
403
+ 401,/m/04fq5q,"Foghorn"
404
+ 402,/m/0l156k,"Whistle"
405
+ 403,/m/06hck5,"Steam whistle"
406
+ 404,/t/dd00077,"Mechanisms"
407
+ 405,/m/02bm9n,"Ratchet, pawl"
408
+ 406,/m/01x3z,"Clock"
409
+ 407,/m/07qjznt,"Tick"
410
+ 408,/m/07qjznl,"Tick-tock"
411
+ 409,/m/0l7xg,"Gears"
412
+ 410,/m/05zc1,"Pulleys"
413
+ 411,/m/0llzx,"Sewing machine"
414
+ 412,/m/02x984l,"Mechanical fan"
415
+ 413,/m/025wky1,"Air conditioning"
416
+ 414,/m/024dl,"Cash register"
417
+ 415,/m/01m4t,"Printer"
418
+ 416,/m/0dv5r,"Camera"
419
+ 417,/m/07bjf,"Single-lens reflex camera"
420
+ 418,/m/07k1x,"Tools"
421
+ 419,/m/03l9g,"Hammer"
422
+ 420,/m/03p19w,"Jackhammer"
423
+ 421,/m/01b82r,"Sawing"
424
+ 422,/m/02p01q,"Filing (rasp)"
425
+ 423,/m/023vsd,"Sanding"
426
+ 424,/m/0_ksk,"Power tool"
427
+ 425,/m/01d380,"Drill"
428
+ 426,/m/014zdl,"Explosion"
429
+ 427,/m/032s66,"Gunshot, gunfire"
430
+ 428,/m/04zjc,"Machine gun"
431
+ 429,/m/02z32qm,"Fusillade"
432
+ 430,/m/0_1c,"Artillery fire"
433
+ 431,/m/073cg4,"Cap gun"
434
+ 432,/m/0g6b5,"Fireworks"
435
+ 433,/g/122z_qxw,"Firecracker"
436
+ 434,/m/07qsvvw,"Burst, pop"
437
+ 435,/m/07pxg6y,"Eruption"
438
+ 436,/m/07qqyl4,"Boom"
439
+ 437,/m/083vt,"Wood"
440
+ 438,/m/07pczhz,"Chop"
441
+ 439,/m/07pl1bw,"Splinter"
442
+ 440,/m/07qs1cx,"Crack"
443
+ 441,/m/039jq,"Glass"
444
+ 442,/m/07q7njn,"Chink, clink"
445
+ 443,/m/07rn7sz,"Shatter"
446
+ 444,/m/04k94,"Liquid"
447
+ 445,/m/07rrlb6,"Splash, splatter"
448
+ 446,/m/07p6mqd,"Slosh"
449
+ 447,/m/07qlwh6,"Squish"
450
+ 448,/m/07r5v4s,"Drip"
451
+ 449,/m/07prgkl,"Pour"
452
+ 450,/m/07pqc89,"Trickle, dribble"
453
+ 451,/t/dd00088,"Gush"
454
+ 452,/m/07p7b8y,"Fill (with liquid)"
455
+ 453,/m/07qlf79,"Spray"
456
+ 454,/m/07ptzwd,"Pump (liquid)"
457
+ 455,/m/07ptfmf,"Stir"
458
+ 456,/m/0dv3j,"Boiling"
459
+ 457,/m/0790c,"Sonar"
460
+ 458,/m/0dl83,"Arrow"
461
+ 459,/m/07rqsjt,"Whoosh, swoosh, swish"
462
+ 460,/m/07qnq_y,"Thump, thud"
463
+ 461,/m/07rrh0c,"Thunk"
464
+ 462,/m/0b_fwt,"Electronic tuner"
465
+ 463,/m/02rr_,"Effects unit"
466
+ 464,/m/07m2kt,"Chorus effect"
467
+ 465,/m/018w8,"Basketball bounce"
468
+ 466,/m/07pws3f,"Bang"
469
+ 467,/m/07ryjzk,"Slap, smack"
470
+ 468,/m/07rdhzs,"Whack, thwack"
471
+ 469,/m/07pjjrj,"Smash, crash"
472
+ 470,/m/07pc8lb,"Breaking"
473
+ 471,/m/07pqn27,"Bouncing"
474
+ 472,/m/07rbp7_,"Whip"
475
+ 473,/m/07pyf11,"Flap"
476
+ 474,/m/07qb_dv,"Scratch"
477
+ 475,/m/07qv4k0,"Scrape"
478
+ 476,/m/07pdjhy,"Rub"
479
+ 477,/m/07s8j8t,"Roll"
480
+ 478,/m/07plct2,"Crushing"
481
+ 479,/t/dd00112,"Crumpling, crinkling"
482
+ 480,/m/07qcx4z,"Tearing"
483
+ 481,/m/02fs_r,"Beep, bleep"
484
+ 482,/m/07qwdck,"Ping"
485
+ 483,/m/07phxs1,"Ding"
486
+ 484,/m/07rv4dm,"Clang"
487
+ 485,/m/07s02z0,"Squeal"
488
+ 486,/m/07qh7jl,"Creak"
489
+ 487,/m/07qwyj0,"Rustle"
490
+ 488,/m/07s34ls,"Whir"
491
+ 489,/m/07qmpdm,"Clatter"
492
+ 490,/m/07p9k1k,"Sizzle"
493
+ 491,/m/07qc9xj,"Clicking"
494
+ 492,/m/07rwm0c,"Clickety-clack"
495
+ 493,/m/07phhsh,"Rumble"
496
+ 494,/m/07qyrcz,"Plop"
497
+ 495,/m/07qfgpx,"Jingle, tinkle"
498
+ 496,/m/07rcgpl,"Hum"
499
+ 497,/m/07p78v5,"Zing"
500
+ 498,/t/dd00121,"Boing"
501
+ 499,/m/07s12q4,"Crunch"
502
+ 500,/m/028v0c,"Silence"
503
+ 501,/m/01v_m0,"Sine wave"
504
+ 502,/m/0b9m1,"Harmonic"
505
+ 503,/m/0hdsk,"Chirp tone"
506
+ 504,/m/0c1dj,"Sound effect"
507
+ 505,/m/07pt_g0,"Pulse"
508
+ 506,/t/dd00125,"Inside, small room"
509
+ 507,/t/dd00126,"Inside, large room or hall"
510
+ 508,/t/dd00127,"Inside, public space"
511
+ 509,/t/dd00128,"Outside, urban or manmade"
512
+ 510,/t/dd00129,"Outside, rural or natural"
513
+ 511,/m/01b9nn,"Reverberation"
514
+ 512,/m/01jnbd,"Echo"
515
+ 513,/m/096m7z,"Noise"
516
+ 514,/m/06_y0by,"Environmental noise"
517
+ 515,/m/07rgkc5,"Static"
518
+ 516,/m/06xkwv,"Mains hum"
519
+ 517,/m/0g12c5,"Distortion"
520
+ 518,/m/08p9q4,"Sidetone"
521
+ 519,/m/07szfh9,"Cacophony"
522
+ 520,/m/0chx_,"White noise"
523
+ 521,/m/0cj0r,"Pink noise"
524
+ 522,/m/07p_0gm,"Throbbing"
525
+ 523,/m/01jwx6,"Vibration"
526
+ 524,/m/07c52,"Television"
527
+ 525,/m/06bz3,"Radio"
528
+ 526,/m/07hvw1,"Field recording"
evaluation/metadata/clotho_eval.csv ADDED
The diff for this file is too large to render. See raw diff
 
evaluation/metadata/esc50_eval.csv ADDED
The diff for this file is too large to render. See raw diff
 
evaluation/metadata/music_eval.csv ADDED
The diff for this file is too large to render. See raw diff
 
evaluation/metadata/vggsound_eval.csv ADDED
The diff for this file is too large to render. See raw diff
 
losses.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def l1(output, target):
5
+ return torch.mean(torch.abs(output - target))
6
+
7
+
8
+ def l1_wav(output_dict, target_dict):
9
+ return l1(output_dict['segment'], target_dict['segment'])
10
+
11
+
12
+ def get_loss_function(loss_type):
13
+ if loss_type == "l1_wav":
14
+ return l1_wav
15
+
16
+ else:
17
+ raise NotImplementedError("Error!")
models/CLAP/__init__.py ADDED
File without changes
models/CLAP/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (158 Bytes). View file
 
models/CLAP/open_clip/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .factory import (
2
+ list_models,
3
+ create_model,
4
+ create_model_and_transforms,
5
+ add_model_config,
6
+ )
7
+ from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8
+ from .model import (
9
+ CLAP,
10
+ CLAPTextCfg,
11
+ CLAPVisionCfg,
12
+ CLAPAudioCfp,
13
+ convert_weights_to_fp16,
14
+ trace_model,
15
+ )
16
+ from .openai import load_openai_model, list_openai_models
17
+ from .pretrained import (
18
+ list_pretrained,
19
+ list_pretrained_tag_models,
20
+ list_pretrained_model_tags,
21
+ get_pretrained_url,
22
+ download_pretrained,
23
+ )
24
+ from .tokenizer import SimpleTokenizer, tokenize
25
+ from .transform import image_transform
models/CLAP/open_clip/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (964 Bytes). View file
 
models/CLAP/open_clip/__pycache__/factory.cpython-310.pyc ADDED
Binary file (6.65 kB). View file
 
models/CLAP/open_clip/__pycache__/feature_fusion.cpython-310.pyc ADDED
Binary file (4.12 kB). View file
 
models/CLAP/open_clip/__pycache__/htsat.cpython-310.pyc ADDED
Binary file (30.8 kB). View file
 
models/CLAP/open_clip/__pycache__/loss.cpython-310.pyc ADDED
Binary file (7.97 kB). View file
 
models/CLAP/open_clip/__pycache__/model.cpython-310.pyc ADDED
Binary file (24.1 kB). View file
 
models/CLAP/open_clip/__pycache__/openai.cpython-310.pyc ADDED
Binary file (4.52 kB). View file
 
models/CLAP/open_clip/__pycache__/pann_model.cpython-310.pyc ADDED
Binary file (13.1 kB). View file
 
models/CLAP/open_clip/__pycache__/pretrained.cpython-310.pyc ADDED
Binary file (5.04 kB). View file
 
models/CLAP/open_clip/__pycache__/timm_model.cpython-310.pyc ADDED
Binary file (3.44 kB). View file
 
models/CLAP/open_clip/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (7.36 kB). View file
 
models/CLAP/open_clip/__pycache__/transform.cpython-310.pyc ADDED
Binary file (982 Bytes). View file
 
models/CLAP/open_clip/__pycache__/utils.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
models/CLAP/open_clip/bert.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertModel
2
+
3
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
4
+ model = BertModel.from_pretrained("bert-base-uncased")
5
+ text = "Replace me by any text you'd like."
6
+
7
+
8
+ def bert_embeddings(text):
9
+ # text = "Replace me by any text you'd like."
10
+ encoded_input = tokenizer(text, return_tensors="pt")
11
+ output = model(**encoded_input)
12
+ return output
13
+
14
+
15
+ from transformers import RobertaTokenizer, RobertaModel
16
+
17
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
18
+ model = RobertaModel.from_pretrained("roberta-base")
19
+ text = "Replace me by any text you'd like."
20
+
21
+
22
+ def Roberta_embeddings(text):
23
+ # text = "Replace me by any text you'd like."
24
+ encoded_input = tokenizer(text, return_tensors="pt")
25
+ output = model(**encoded_input)
26
+ return output
27
+
28
+
29
+ from transformers import BartTokenizer, BartModel
30
+
31
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
32
+ model = BartModel.from_pretrained("facebook/bart-base")
33
+ text = "Replace me by any text you'd like."
34
+
35
+
36
+ def bart_embeddings(text):
37
+ # text = "Replace me by any text you'd like."
38
+ encoded_input = tokenizer(text, return_tensors="pt")
39
+ output = model(**encoded_input)
40
+ return output
models/CLAP/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
models/CLAP/open_clip/factory.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+
9
+ import torch
10
+
11
+ from .model import CLAP, convert_weights_to_fp16
12
+ from .openai import load_openai_model
13
+ from .pretrained import get_pretrained_url, download_pretrained
14
+ from .transform import image_transform
15
+
16
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
17
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
18
+
19
+
20
+ def _natural_key(string_):
21
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
22
+
23
+
24
+ def _rescan_model_configs():
25
+ global _MODEL_CONFIGS
26
+
27
+ config_ext = (".json",)
28
+ config_files = []
29
+ for config_path in _MODEL_CONFIG_PATHS:
30
+ if config_path.is_file() and config_path.suffix in config_ext:
31
+ config_files.append(config_path)
32
+ elif config_path.is_dir():
33
+ for ext in config_ext:
34
+ config_files.extend(config_path.glob(f"*{ext}"))
35
+
36
+ for cf in config_files:
37
+ if os.path.basename(cf)[0] == ".":
38
+ continue # Ignore hidden files
39
+
40
+ with open(cf, "r") as f:
41
+ model_cfg = json.load(f)
42
+ if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
43
+ _MODEL_CONFIGS[cf.stem] = model_cfg
44
+
45
+ _MODEL_CONFIGS = {
46
+ k: v
47
+ for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
48
+ }
49
+
50
+
51
+ _rescan_model_configs() # initial populate of model config registry
52
+
53
+
54
+ def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
55
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
56
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
57
+ state_dict = checkpoint["state_dict"]
58
+ else:
59
+ state_dict = checkpoint
60
+ if skip_params:
61
+ if next(iter(state_dict.items()))[0].startswith("module"):
62
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
63
+ # for k in state_dict:
64
+ # if k.startswith('transformer'):
65
+ # v = state_dict.pop(k)
66
+ # state_dict['text_branch.' + k[12:]] = v
67
+ return state_dict
68
+
69
+
70
+ def create_model(
71
+ amodel_name: str,
72
+ tmodel_name: str,
73
+ pretrained: str = "",
74
+ precision: str = "fp32",
75
+ device: torch.device = torch.device("cpu"),
76
+ jit: bool = False,
77
+ force_quick_gelu: bool = False,
78
+ openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
79
+ skip_params=True,
80
+ pretrained_audio: str = "",
81
+ pretrained_text: str = "",
82
+ enable_fusion: bool = False,
83
+ fusion_type: str = "None"
84
+ # pretrained_image: bool = False,
85
+ ):
86
+ amodel_name = amodel_name.replace(
87
+ "/", "-"
88
+ ) # for callers using old naming with / in ViT names
89
+ pretrained_orig = pretrained
90
+ pretrained = pretrained.lower()
91
+ if pretrained == "openai":
92
+ if amodel_name in _MODEL_CONFIGS:
93
+ logging.info(f"Loading {amodel_name} model config.")
94
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
95
+ else:
96
+ logging.error(
97
+ f"Model config for {amodel_name} not found; available models {list_models()}."
98
+ )
99
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
100
+
101
+ logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
102
+ # Hard Code in model name
103
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
104
+ model = load_openai_model(
105
+ "ViT-B-16",
106
+ model_cfg,
107
+ device=device,
108
+ jit=jit,
109
+ cache_dir=openai_model_cache_dir,
110
+ enable_fusion=enable_fusion,
111
+ fusion_type=fusion_type,
112
+ )
113
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
114
+ if precision == "amp" or precision == "fp32":
115
+ model = model.float()
116
+ else:
117
+ if amodel_name in _MODEL_CONFIGS:
118
+ logging.info(f"Loading {amodel_name} model config.")
119
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
120
+ else:
121
+ logging.error(
122
+ f"Model config for {amodel_name} not found; available models {list_models()}."
123
+ )
124
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
125
+
126
+ if force_quick_gelu:
127
+ # override for use of QuickGELU on non-OpenAI transformer models
128
+ model_cfg["quick_gelu"] = True
129
+
130
+ # if pretrained_image:
131
+ # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
132
+ # # pretrained weight loading for timm models set via vision_cfg
133
+ # model_cfg['vision_cfg']['timm_model_pretrained'] = True
134
+ # else:
135
+ # assert False, 'pretrained image towers currently only supported for timm models'
136
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
137
+ model_cfg["enable_fusion"] = enable_fusion
138
+ model_cfg["fusion_type"] = fusion_type
139
+ model = CLAP(**model_cfg)
140
+
141
+ if pretrained:
142
+ checkpoint_path = ""
143
+ url = get_pretrained_url(amodel_name, pretrained)
144
+ if url:
145
+ checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
146
+ elif os.path.exists(pretrained_orig):
147
+ checkpoint_path = pretrained_orig
148
+ if checkpoint_path:
149
+ logging.info(
150
+ f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
151
+ )
152
+ ckpt = load_state_dict(checkpoint_path, skip_params=True)
153
+ model.load_state_dict(ckpt)
154
+ param_names = [n for n, p in model.named_parameters()]
155
+ # for n in param_names:
156
+ # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
157
+ else:
158
+ logging.warning(
159
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
160
+ )
161
+ raise RuntimeError(
162
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
163
+ )
164
+
165
+ if pretrained_audio:
166
+ if amodel_name.startswith("PANN"):
167
+ if "Cnn14_mAP" in pretrained_audio: # official checkpoint
168
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
169
+ audio_ckpt = audio_ckpt["model"]
170
+ keys = list(audio_ckpt.keys())
171
+ for key in keys:
172
+ if (
173
+ "spectrogram_extractor" not in key
174
+ and "logmel_extractor" not in key
175
+ ):
176
+ v = audio_ckpt.pop(key)
177
+ audio_ckpt["audio_branch." + key] = v
178
+ elif os.path.basename(pretrained_audio).startswith(
179
+ "PANN"
180
+ ): # checkpoint trained via HTSAT codebase
181
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
182
+ audio_ckpt = audio_ckpt["state_dict"]
183
+ keys = list(audio_ckpt.keys())
184
+ for key in keys:
185
+ if key.startswith("sed_model"):
186
+ v = audio_ckpt.pop(key)
187
+ audio_ckpt["audio_branch." + key[10:]] = v
188
+ elif os.path.basename(pretrained_audio).startswith(
189
+ "finetuned"
190
+ ): # checkpoint trained via linear probe codebase
191
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
192
+ else:
193
+ raise ValueError("Unknown audio checkpoint")
194
+ elif amodel_name.startswith("HTSAT"):
195
+ if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
196
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
197
+ audio_ckpt = audio_ckpt["state_dict"]
198
+ keys = list(audio_ckpt.keys())
199
+ for key in keys:
200
+ if key.startswith("sed_model") and (
201
+ "spectrogram_extractor" not in key
202
+ and "logmel_extractor" not in key
203
+ ):
204
+ v = audio_ckpt.pop(key)
205
+ audio_ckpt["audio_branch." + key[10:]] = v
206
+ elif os.path.basename(pretrained_audio).startswith(
207
+ "HTSAT"
208
+ ): # checkpoint trained via HTSAT codebase
209
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
210
+ audio_ckpt = audio_ckpt["state_dict"]
211
+ keys = list(audio_ckpt.keys())
212
+ for key in keys:
213
+ if key.startswith("sed_model"):
214
+ v = audio_ckpt.pop(key)
215
+ audio_ckpt["audio_branch." + key[10:]] = v
216
+ elif os.path.basename(pretrained_audio).startswith(
217
+ "finetuned"
218
+ ): # checkpoint trained via linear probe codebase
219
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
220
+ else:
221
+ raise ValueError("Unknown audio checkpoint")
222
+ else:
223
+ raise f"this audio encoder pretrained checkpoint is not support"
224
+
225
+ model.load_state_dict(audio_ckpt, strict=False)
226
+ logging.info(
227
+ f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
228
+ )
229
+ param_names = [n for n, p in model.named_parameters()]
230
+ for n in param_names:
231
+ print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
232
+
233
+ model.to(device=device)
234
+ if precision == "fp16":
235
+ assert device.type != "cpu"
236
+ convert_weights_to_fp16(model)
237
+
238
+ if jit:
239
+ model = torch.jit.script(model)
240
+
241
+ return model, model_cfg
242
+
243
+
244
+ def create_model_and_transforms(
245
+ model_name: str,
246
+ pretrained: str = "",
247
+ precision: str = "fp32",
248
+ device: torch.device = torch.device("cpu"),
249
+ jit: bool = False,
250
+ force_quick_gelu: bool = False,
251
+ # pretrained_image: bool = False,
252
+ ):
253
+ model = create_model(
254
+ model_name,
255
+ pretrained,
256
+ precision,
257
+ device,
258
+ jit,
259
+ force_quick_gelu=force_quick_gelu,
260
+ # pretrained_image=pretrained_image
261
+ )
262
+ preprocess_train = image_transform(model.visual.image_size, is_train=True)
263
+ preprocess_val = image_transform(model.visual.image_size, is_train=False)
264
+ return model, preprocess_train, preprocess_val
265
+
266
+
267
+ def list_models():
268
+ """enumerate available model architectures based on config files"""
269
+ return list(_MODEL_CONFIGS.keys())
270
+
271
+
272
+ def add_model_config(path):
273
+ """add model config path or file and update registry"""
274
+ if not isinstance(path, Path):
275
+ path = Path(path)
276
+ _MODEL_CONFIG_PATHS.append(path)
277
+ _rescan_model_configs()
models/CLAP/open_clip/feature_fusion.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feature Fusion for Variable-Length Data Processing
3
+ AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4
+ According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class DAF(nn.Module):
12
+ """
13
+ 直接相加 DirectAddFuse
14
+ """
15
+
16
+ def __init__(self):
17
+ super(DAF, self).__init__()
18
+
19
+ def forward(self, x, residual):
20
+ return x + residual
21
+
22
+
23
+ class iAFF(nn.Module):
24
+ """
25
+ 多特征融合 iAFF
26
+ """
27
+
28
+ def __init__(self, channels=64, r=4, type="2D"):
29
+ super(iAFF, self).__init__()
30
+ inter_channels = int(channels // r)
31
+
32
+ if type == "1D":
33
+ # 本地注意力
34
+ self.local_att = nn.Sequential(
35
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36
+ nn.BatchNorm1d(inter_channels),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39
+ nn.BatchNorm1d(channels),
40
+ )
41
+
42
+ # 全局注意力
43
+ self.global_att = nn.Sequential(
44
+ nn.AdaptiveAvgPool1d(1),
45
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46
+ nn.BatchNorm1d(inter_channels),
47
+ nn.ReLU(inplace=True),
48
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49
+ nn.BatchNorm1d(channels),
50
+ )
51
+
52
+ # 第二次本地注意力
53
+ self.local_att2 = nn.Sequential(
54
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55
+ nn.BatchNorm1d(inter_channels),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58
+ nn.BatchNorm1d(channels),
59
+ )
60
+ # 第二次全局注意力
61
+ self.global_att2 = nn.Sequential(
62
+ nn.AdaptiveAvgPool1d(1),
63
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64
+ nn.BatchNorm1d(inter_channels),
65
+ nn.ReLU(inplace=True),
66
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67
+ nn.BatchNorm1d(channels),
68
+ )
69
+ elif type == "2D":
70
+ # 本地注意力
71
+ self.local_att = nn.Sequential(
72
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73
+ nn.BatchNorm2d(inter_channels),
74
+ nn.ReLU(inplace=True),
75
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76
+ nn.BatchNorm2d(channels),
77
+ )
78
+
79
+ # 全局注意力
80
+ self.global_att = nn.Sequential(
81
+ nn.AdaptiveAvgPool2d(1),
82
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83
+ nn.BatchNorm2d(inter_channels),
84
+ nn.ReLU(inplace=True),
85
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86
+ nn.BatchNorm2d(channels),
87
+ )
88
+
89
+ # 第二次本地注意力
90
+ self.local_att2 = nn.Sequential(
91
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
+ nn.BatchNorm2d(inter_channels),
93
+ nn.ReLU(inplace=True),
94
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
+ nn.BatchNorm2d(channels),
96
+ )
97
+ # 第二次全局注意力
98
+ self.global_att2 = nn.Sequential(
99
+ nn.AdaptiveAvgPool2d(1),
100
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
+ nn.BatchNorm2d(inter_channels),
102
+ nn.ReLU(inplace=True),
103
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
+ nn.BatchNorm2d(channels),
105
+ )
106
+ else:
107
+ raise f"the type is not supported"
108
+
109
+ self.sigmoid = nn.Sigmoid()
110
+
111
+ def forward(self, x, residual):
112
+ flag = False
113
+ xa = x + residual
114
+ if xa.size(0) == 1:
115
+ xa = torch.cat([xa, xa], dim=0)
116
+ flag = True
117
+ xl = self.local_att(xa)
118
+ xg = self.global_att(xa)
119
+ xlg = xl + xg
120
+ wei = self.sigmoid(xlg)
121
+ xi = x * wei + residual * (1 - wei)
122
+
123
+ xl2 = self.local_att2(xi)
124
+ xg2 = self.global_att(xi)
125
+ xlg2 = xl2 + xg2
126
+ wei2 = self.sigmoid(xlg2)
127
+ xo = x * wei2 + residual * (1 - wei2)
128
+ if flag:
129
+ xo = xo[0].unsqueeze(0)
130
+ return xo
131
+
132
+
133
+ class AFF(nn.Module):
134
+ """
135
+ 多特征融合 AFF
136
+ """
137
+
138
+ def __init__(self, channels=64, r=4, type="2D"):
139
+ super(AFF, self).__init__()
140
+ inter_channels = int(channels // r)
141
+
142
+ if type == "1D":
143
+ self.local_att = nn.Sequential(
144
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145
+ nn.BatchNorm1d(inter_channels),
146
+ nn.ReLU(inplace=True),
147
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148
+ nn.BatchNorm1d(channels),
149
+ )
150
+ self.global_att = nn.Sequential(
151
+ nn.AdaptiveAvgPool1d(1),
152
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153
+ nn.BatchNorm1d(inter_channels),
154
+ nn.ReLU(inplace=True),
155
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156
+ nn.BatchNorm1d(channels),
157
+ )
158
+ elif type == "2D":
159
+ self.local_att = nn.Sequential(
160
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161
+ nn.BatchNorm2d(inter_channels),
162
+ nn.ReLU(inplace=True),
163
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164
+ nn.BatchNorm2d(channels),
165
+ )
166
+ self.global_att = nn.Sequential(
167
+ nn.AdaptiveAvgPool2d(1),
168
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169
+ nn.BatchNorm2d(inter_channels),
170
+ nn.ReLU(inplace=True),
171
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172
+ nn.BatchNorm2d(channels),
173
+ )
174
+ else:
175
+ raise f"the type is not supported."
176
+
177
+ self.sigmoid = nn.Sigmoid()
178
+
179
+ def forward(self, x, residual):
180
+ flag = False
181
+ xa = x + residual
182
+ if xa.size(0) == 1:
183
+ xa = torch.cat([xa, xa], dim=0)
184
+ flag = True
185
+ xl = self.local_att(xa)
186
+ xg = self.global_att(xa)
187
+ xlg = xl + xg
188
+ wei = self.sigmoid(xlg)
189
+ xo = 2 * x * wei + 2 * residual * (1 - wei)
190
+ if flag:
191
+ xo = xo[0].unsqueeze(0)
192
+ return xo