badayvedat commited on
Commit
ae29df4
1 Parent(s): ea1edf1

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +160 -0
  2. LICENSE +21 -0
  3. README.md +1 -3
  4. app.py +115 -0
  5. assets/results.png +0 -0
  6. callbacks/base.py +35 -0
  7. config/audiosep_base.yaml +41 -0
  8. data/audiotext_dataset.py +91 -0
  9. data/datamodules.py +122 -0
  10. data/waveform_mixers.py +127 -0
  11. datafiles/template.json +8 -0
  12. environment.yml +326 -0
  13. losses.py +17 -0
  14. models/CLAP/__init__.py +0 -0
  15. models/CLAP/open_clip/__init__.py +25 -0
  16. models/CLAP/open_clip/bert.py +40 -0
  17. models/CLAP/open_clip/factory.py +277 -0
  18. models/CLAP/open_clip/feature_fusion.py +192 -0
  19. models/CLAP/open_clip/htsat.py +1308 -0
  20. models/CLAP/open_clip/linear_probe.py +66 -0
  21. models/CLAP/open_clip/loss.py +398 -0
  22. models/CLAP/open_clip/model.py +935 -0
  23. models/CLAP/open_clip/model_configs/HTSAT-base.json +23 -0
  24. models/CLAP/open_clip/model_configs/HTSAT-large.json +23 -0
  25. models/CLAP/open_clip/model_configs/HTSAT-tiny-win-1536.json +23 -0
  26. models/CLAP/open_clip/model_configs/HTSAT-tiny.json +23 -0
  27. models/CLAP/open_clip/model_configs/PANN-10.json +23 -0
  28. models/CLAP/open_clip/model_configs/PANN-14-fmax-18k.json +23 -0
  29. models/CLAP/open_clip/model_configs/PANN-14-fmax-8k-20s.json +23 -0
  30. models/CLAP/open_clip/model_configs/PANN-14-tiny-transformer.json +23 -0
  31. models/CLAP/open_clip/model_configs/PANN-14-win-1536.json +23 -0
  32. models/CLAP/open_clip/model_configs/PANN-14.json +23 -0
  33. models/CLAP/open_clip/model_configs/PANN-6.json +23 -0
  34. models/CLAP/open_clip/model_configs/RN101-quickgelu.json +22 -0
  35. models/CLAP/open_clip/model_configs/RN101.json +21 -0
  36. models/CLAP/open_clip/model_configs/RN50-quickgelu.json +22 -0
  37. models/CLAP/open_clip/model_configs/RN50.json +21 -0
  38. models/CLAP/open_clip/model_configs/RN50x16.json +21 -0
  39. models/CLAP/open_clip/model_configs/RN50x4.json +21 -0
  40. models/CLAP/open_clip/model_configs/ViT-B-16.json +16 -0
  41. models/CLAP/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
  42. models/CLAP/open_clip/model_configs/ViT-B-32.json +16 -0
  43. models/CLAP/open_clip/model_configs/ViT-L-14.json +16 -0
  44. models/CLAP/open_clip/openai.py +156 -0
  45. models/CLAP/open_clip/pann_model.py +704 -0
  46. models/CLAP/open_clip/pretrained.py +167 -0
  47. models/CLAP/open_clip/timm_model.py +112 -0
  48. models/CLAP/open_clip/tokenizer.py +197 -0
  49. models/CLAP/open_clip/transform.py +45 -0
  50. models/CLAP/open_clip/utils.py +361 -0
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Xubo Liu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
README.md CHANGED
@@ -8,6 +8,4 @@ sdk_version: 3.47.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ ---
 
 
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from threading import Thread
3
+
4
+ import gdown
5
+ import gradio as gr
6
+ import librosa
7
+ import numpy as np
8
+ import torch
9
+
10
+ from pipeline import build_audiosep
11
+
12
+ CHECKPOINTS_DIR = Path("checkpoint")
13
+
14
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # The model will be loaded in the future
17
+ MODEL_NAME = CHECKPOINTS_DIR / "audiosep_base_4M_steps.ckpt"
18
+ MODEL = None
19
+
20
+
21
+ description = """
22
+ # AudioSep: Separate Anything You Describe
23
+ [[Project Page]](https://audio-agi.github.io/Separate-Anything-You-Describe) [[Paper]](https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf) [[Code]](https://github.com/Audio-AGI/AudioSep)
24
+
25
+ We introduce AudioSep, a foundation model for open-domain sound separation with natural language queries.
26
+ AudioSep demonstrates strong separation performance and impressivezero-shot generalization ability on
27
+ numerous tasks such as audio event separation, musical instrument separation, and speech enhancement.
28
+ """
29
+
30
+
31
+ def get_model():
32
+ model = build_audiosep(
33
+ config_yaml="config/audiosep_base.yaml",
34
+ checkpoint_path=MODEL_NAME,
35
+ device=DEVICE,
36
+ )
37
+ return model
38
+
39
+
40
+ def inference(audio_file_path: str, text: str):
41
+ print(f"Separate audio from [{audio_file_path}] with textual query [{text}]")
42
+ mixture, _ = librosa.load(audio_file_path, sr=32000, mono=True)
43
+
44
+ with torch.no_grad():
45
+ text = [text]
46
+
47
+ conditions = MODEL.query_encoder.get_query_embed(
48
+ modality="text", text=text, device=DEVICE
49
+ )
50
+
51
+ input_dict = {
52
+ "mixture": torch.Tensor(mixture)[None, None, :].to(DEVICE),
53
+ "condition": conditions,
54
+ }
55
+
56
+ sep_segment = MODEL.ss_model(input_dict)["waveform"]
57
+
58
+ sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
59
+
60
+ return 32000, np.round(sep_segment * 32767).astype(np.int16)
61
+
62
+
63
+ def download_models():
64
+ CHECKPOINTS_DIR.mkdir(exist_ok=True)
65
+ success_file = CHECKPOINTS_DIR / "_SUCCESS"
66
+
67
+ models = (
68
+ (
69
+ "https://drive.google.com/file/d/1wQuXThdATXrkmkPM2sRGaNapJ4mTqmlY/view?usp=sharing",
70
+ MODEL_NAME,
71
+ ),
72
+ (
73
+ "https://drive.google.com/file/d/11oj8_tPG6SXgw5fIEsZ5HiWZnJOrvdhw/view?usp=sharing",
74
+ CHECKPOINTS_DIR / "music_speech_audioset_epoch_15_esc_89.98.pt",
75
+ ),
76
+ )
77
+
78
+ def download(models):
79
+ for model_url, model_path in models:
80
+ gdown.download(model_url, str(model_path), quiet=False, fuzzy=True)
81
+
82
+ success_file.touch()
83
+
84
+ global MODEL
85
+ MODEL = get_model()
86
+ button.update(value="Separate", interactive=True)
87
+
88
+ if not success_file.exists():
89
+ thread = Thread(target=download, args=[models])
90
+ thread.start()
91
+
92
+
93
+ with gr.Blocks(title="AudioSep") as demo:
94
+ gr.Markdown(description)
95
+ with gr.Row():
96
+ with gr.Column():
97
+ input_audio = gr.Audio()
98
+ text = gr.Textbox()
99
+ with gr.Column():
100
+ with gr.Column():
101
+ output_audio = gr.Audio(scale=10)
102
+ button = gr.Button(
103
+ "Downloading the models...",
104
+ variant="primary",
105
+ scale=2,
106
+ size="lg",
107
+ interactive=False,
108
+ )
109
+ button.click(
110
+ fn=inference, inputs=[input_audio, text], outputs=[output_audio]
111
+ )
112
+
113
+ download_models()
114
+
115
+ demo.queue().launch(share=True)
assets/results.png ADDED
callbacks/base.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import lightning.pytorch as pl
3
+ from lightning.pytorch.utilities import rank_zero_only
4
+
5
+
6
+ class CheckpointEveryNSteps(pl.Callback):
7
+ def __init__(
8
+ self,
9
+ checkpoints_dir,
10
+ save_step_frequency,
11
+ ) -> None:
12
+ r"""Save a checkpoint every N steps.
13
+
14
+ Args:
15
+ checkpoints_dir (str): directory to save checkpoints
16
+ save_step_frequency (int): save checkpoint every N step
17
+ """
18
+
19
+ self.checkpoints_dir = checkpoints_dir
20
+ self.save_step_frequency = save_step_frequency
21
+
22
+ @rank_zero_only
23
+ def on_train_batch_end(self, *args, **kwargs) -> None:
24
+ r"""Save a checkpoint every N steps."""
25
+
26
+ trainer = args[0]
27
+ global_step = trainer.global_step
28
+
29
+ if global_step == 1 or global_step % self.save_step_frequency == 0:
30
+
31
+ ckpt_path = os.path.join(
32
+ self.checkpoints_dir,
33
+ "step={}.ckpt".format(global_step))
34
+ trainer.save_checkpoint(ckpt_path)
35
+ print("Save checkpoint to {}".format(ckpt_path))
config/audiosep_base.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ task_name: AudioSep
3
+
4
+ data:
5
+ datafiles:
6
+ - 'datafiles/template.json'
7
+
8
+ sampling_rate: 32000
9
+ segment_seconds: 5
10
+ loudness_norm:
11
+ lower_db: -10
12
+ higher_db: 10
13
+ max_mix_num: 2
14
+
15
+ model:
16
+ query_net: CLAP
17
+ condition_size: 512
18
+ model_type: ResUNet30
19
+ input_channels: 1
20
+ output_channels: 1
21
+ resume_checkpoint: ""
22
+ use_text_ratio: 1.0
23
+
24
+ train:
25
+ optimizer:
26
+ optimizer_type: AdamW
27
+ learning_rate: 1e-3
28
+ warm_up_steps: 10000
29
+ reduce_lr_steps: 1000000
30
+ lr_lambda_type: constant_warm_up
31
+ num_nodes: 1
32
+ num_workers: 6
33
+ loss_type: l1_wav
34
+ sync_batchnorm: True
35
+ batch_size_per_device: 12
36
+ steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`.
37
+ evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps.
38
+ save_step_frequency: 20000 # Save every #save_step_frequency steps.
39
+ early_stop_steps: 10000001
40
+ random_seed: 1234
41
+
data/audiotext_dataset.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import torch
4
+ import torchaudio
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class AudioTextDataset(Dataset):
9
+ """Can sample data from audio-text databases
10
+ Params:
11
+ sampling_rate: audio sampling rate
12
+ max_clip_len: max length (seconds) of audio clip to be sampled
13
+ """
14
+ def __init__(
15
+ self,
16
+ datafiles=[''],
17
+ sampling_rate=32000,
18
+ max_clip_len=5,
19
+ ):
20
+ all_data_json = []
21
+ for datafile in datafiles:
22
+ with open(datafile, 'r') as fp:
23
+ data_json = json.load(fp)['data']
24
+ all_data_json.extend(data_json)
25
+ self.all_data_json = all_data_json
26
+
27
+ self.sampling_rate = sampling_rate
28
+ self.max_length = max_clip_len * sampling_rate
29
+
30
+ def __len__(self):
31
+ return len(self.all_data_json)
32
+
33
+ def _cut_or_randomcrop(self, waveform):
34
+ # waveform: [1, samples]
35
+ # random crop
36
+ if waveform.size(1) > self.max_length:
37
+ random_idx = random.randint(0, waveform.size(1)-self.max_length)
38
+ waveform = waveform[:, random_idx:random_idx+self.max_length]
39
+ else:
40
+ temp_wav = torch.zeros(1, self.max_length)
41
+ temp_wav[:, 0:waveform.size(1)] = waveform
42
+ waveform = temp_wav
43
+
44
+ assert waveform.size(1) == self.max_length, \
45
+ f"number of audio samples is {waveform.size(1)}"
46
+
47
+ return waveform
48
+
49
+ def _read_audio(self, index):
50
+ try:
51
+ audio_path = self.all_data_json[index]['wav']
52
+ audio_data, audio_rate = torchaudio.load(audio_path, channels_first=True)
53
+ text = self.all_data_json[index]['caption']
54
+
55
+ # drop short utterance
56
+ if audio_data.size(1) < self.sampling_rate * 1:
57
+ raise Exception(f'{audio_path} is too short, drop it ...')
58
+
59
+ return text, audio_data, audio_rate
60
+
61
+ except Exception as e:
62
+ print(f'error: {e} occurs, when loading {audio_path}')
63
+ random_index = random.randint(0, len(self.all_data_json)-1)
64
+ return self._read_audio(index=random_index)
65
+
66
+ def __getitem__(self, index):
67
+ # create a audio tensor
68
+ text, audio_data, audio_rate = self._read_audio(index)
69
+ audio_len = audio_data.shape[1] / audio_rate
70
+ # convert stero to single channel
71
+ if audio_data.shape[0] > 1:
72
+ # audio_data: [samples]
73
+ audio_data = (audio_data[0] + audio_data[1]) / 2
74
+ else:
75
+ audio_data = audio_data.squeeze(0)
76
+
77
+ # resample audio clip
78
+ if audio_rate != self.sampling_rate:
79
+ audio_data = torchaudio.functional.resample(audio_data, orig_freq=audio_rate, new_freq=self.sampling_rate)
80
+
81
+ audio_data = audio_data.unsqueeze(0)
82
+
83
+ audio_data = self._cut_or_randomcrop(audio_data)
84
+
85
+ data_dict = {
86
+ 'text': text,
87
+ 'waveform': audio_data,
88
+ 'modality': 'audio_text'
89
+ }
90
+
91
+ return data_dict
data/datamodules.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, NoReturn
2
+ import torch
3
+ import lightning.pytorch as pl
4
+ from torch.utils.data import DataLoader
5
+ from data.audiotext_dataset import AudioTextDataset
6
+
7
+
8
+ class DataModule(pl.LightningDataModule):
9
+ def __init__(
10
+ self,
11
+ train_dataset: object,
12
+ batch_size: int,
13
+ num_workers: int
14
+ ):
15
+ r"""Data module. To get one batch of data:
16
+
17
+ code-block:: python
18
+
19
+ data_module.setup()
20
+
21
+ for batch_data_dict in data_module.train_dataloader():
22
+ print(batch_data_dict.keys())
23
+ break
24
+
25
+ Args:
26
+ train_sampler: Sampler object
27
+ train_dataset: Dataset object
28
+ num_workers: int
29
+ distributed: bool
30
+ """
31
+ super().__init__()
32
+ self._train_dataset = train_dataset
33
+ self.num_workers = num_workers
34
+ self.batch_size = batch_size
35
+ self.collate_fn = collate_fn
36
+
37
+
38
+ def prepare_data(self):
39
+ # download, split, etc...
40
+ # only called on 1 GPU/TPU in distributed
41
+ pass
42
+
43
+ def setup(self, stage: Optional[str] = None) -> NoReturn:
44
+ r"""called on every device."""
45
+
46
+ # make assignments here (val/train/test split)
47
+ # called on every process in DDP
48
+
49
+ # SegmentSampler is used for selecting segments for training.
50
+ # On multiple devices, each SegmentSampler samples a part of mini-batch
51
+ # data.
52
+ self.train_dataset = self._train_dataset
53
+
54
+
55
+ def train_dataloader(self) -> torch.utils.data.DataLoader:
56
+ r"""Get train loader."""
57
+ train_loader = DataLoader(
58
+ dataset=self.train_dataset,
59
+ batch_size=self.batch_size,
60
+ collate_fn=self.collate_fn,
61
+ num_workers=self.num_workers,
62
+ pin_memory=True,
63
+ persistent_workers=False,
64
+ shuffle=True
65
+ )
66
+
67
+ return train_loader
68
+
69
+ def val_dataloader(self):
70
+ # val_split = Dataset(...)
71
+ # return DataLoader(val_split)
72
+ pass
73
+
74
+ def test_dataloader(self):
75
+ # test_split = Dataset(...)
76
+ # return DataLoader(test_split)
77
+ pass
78
+
79
+ def teardown(self):
80
+ # clean up after fit or test
81
+ # called on every process in DDP
82
+ pass
83
+
84
+
85
+ def collate_fn(list_data_dict):
86
+ r"""Collate mini-batch data to inputs and targets for training.
87
+
88
+ Args:
89
+ list_data_dict: e.g., [
90
+ {
91
+ 'text': 'a sound of dog',
92
+ 'waveform': (1, samples),
93
+ 'modality': 'audio_text'
94
+ }
95
+ ...
96
+ ]
97
+ Returns:
98
+ data_dict: e.g.
99
+ 'audio_text': {
100
+ 'text': ['a sound of dog', ...]
101
+ 'waveform': (batch_size, 1, samples)
102
+ }
103
+ """
104
+
105
+ at_list_data_dict = [data_dict for data_dict in list_data_dict if data_dict['modality']=='audio_text']
106
+
107
+ at_data_dict = {}
108
+
109
+ if len(at_list_data_dict) > 0:
110
+ for key in at_list_data_dict[0].keys():
111
+ at_data_dict[key] = [at_data_dict[key] for at_data_dict in at_list_data_dict]
112
+ if key == 'waveform':
113
+ at_data_dict[key] = torch.stack(at_data_dict[key])
114
+ elif key == 'text':
115
+ at_data_dict[key] = [text for text in at_data_dict[key]]
116
+
117
+
118
+ data_dict = {
119
+ 'audio_text': at_data_dict
120
+ }
121
+
122
+ return data_dict
data/waveform_mixers.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import sre_compile
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import pyloudnorm as pyln
7
+
8
+
9
+ class SegmentMixer(nn.Module):
10
+ def __init__(self, max_mix_num, lower_db, higher_db):
11
+ super(SegmentMixer, self).__init__()
12
+
13
+ self.max_mix_num = max_mix_num
14
+ self.loudness_param = {
15
+ 'lower_db': lower_db,
16
+ 'higher_db': higher_db,
17
+ }
18
+
19
+ def __call__(self, waveforms):
20
+
21
+ batch_size = waveforms.shape[0]
22
+
23
+ data_dict = {
24
+ 'segment': [],
25
+ 'mixture': [],
26
+ }
27
+
28
+ for n in range(0, batch_size):
29
+
30
+ segment = waveforms[n].clone()
31
+
32
+ # create zero tensors as the background template
33
+ noise = torch.zeros_like(segment)
34
+
35
+ mix_num = random.randint(2, self.max_mix_num)
36
+ assert mix_num >= 2
37
+
38
+ for i in range(1, mix_num):
39
+ next_segment = waveforms[(n + i) % batch_size]
40
+ rescaled_next_segment = dynamic_loudnorm(audio=next_segment, reference=segment, **self.loudness_param)
41
+ noise += rescaled_next_segment
42
+
43
+ # randomly normalize background noise
44
+ noise = dynamic_loudnorm(audio=noise, reference=segment, **self.loudness_param)
45
+
46
+ # create audio mixyure
47
+ mixture = segment + noise
48
+
49
+ # declipping if need be
50
+ max_value = torch.max(torch.abs(mixture))
51
+ if max_value > 1:
52
+ segment *= 0.9 / max_value
53
+ mixture *= 0.9 / max_value
54
+
55
+ data_dict['segment'].append(segment)
56
+ data_dict['mixture'].append(mixture)
57
+
58
+ for key in data_dict.keys():
59
+ data_dict[key] = torch.stack(data_dict[key], dim=0)
60
+
61
+ # return data_dict
62
+ return data_dict['mixture'], data_dict['segment']
63
+
64
+
65
+ def rescale_to_match_energy(segment1, segment2):
66
+
67
+ ratio = get_energy_ratio(segment1, segment2)
68
+ rescaled_segment1 = segment1 / ratio
69
+ return rescaled_segment1
70
+
71
+
72
+ def get_energy(x):
73
+ return torch.mean(x ** 2)
74
+
75
+
76
+ def get_energy_ratio(segment1, segment2):
77
+
78
+ energy1 = get_energy(segment1)
79
+ energy2 = max(get_energy(segment2), 1e-10)
80
+ ratio = (energy1 / energy2) ** 0.5
81
+ ratio = torch.clamp(ratio, 0.02, 50)
82
+ return ratio
83
+
84
+
85
+ def dynamic_loudnorm(audio, reference, lower_db=-10, higher_db=10):
86
+ rescaled_audio = rescale_to_match_energy(audio, reference)
87
+
88
+ delta_loudness = random.randint(lower_db, higher_db)
89
+
90
+ gain = np.power(10.0, delta_loudness / 20.0)
91
+
92
+ return gain * rescaled_audio
93
+
94
+
95
+ def torch_to_numpy(tensor):
96
+ """Convert a PyTorch tensor to a NumPy array."""
97
+ if isinstance(tensor, torch.Tensor):
98
+ return tensor.detach().cpu().numpy()
99
+ else:
100
+ raise ValueError("Input must be a PyTorch tensor.")
101
+
102
+
103
+ def numpy_to_torch(array):
104
+ """Convert a NumPy array to a PyTorch tensor."""
105
+ if isinstance(array, np.ndarray):
106
+ return torch.from_numpy(array)
107
+ else:
108
+ raise ValueError("Input must be a NumPy array.")
109
+
110
+
111
+ # decayed
112
+ def random_loudness_norm(audio, lower_db=-35, higher_db=-15, sr=32000):
113
+ device = audio.device
114
+ audio = torch_to_numpy(audio.squeeze(0))
115
+ # randomly select a norm volume
116
+ norm_vol = random.randint(lower_db, higher_db)
117
+
118
+ # measure the loudness first
119
+ meter = pyln.Meter(sr) # create BS.1770 meter
120
+ loudness = meter.integrated_loudness(audio)
121
+ # loudness normalize audio
122
+ normalized_audio = pyln.normalize.loudness(audio, loudness, norm_vol)
123
+
124
+ normalized_audio = numpy_to_torch(normalized_audio).unsqueeze(0)
125
+
126
+ return normalized_audio.to(device)
127
+
datafiles/template.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data": [
3
+ {
4
+ "wav": "path_to_audio_file",
5
+ "caption": "textual_desciptions"
6
+ }
7
+ ]
8
+ }
environment.yml ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: AudioSep
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - _openmp_mutex=5.1=1_gnu
9
+ - backcall=0.2.0=pyhd3eb1b0_0
10
+ - blas=1.0=mkl
11
+ - boltons=23.0.0=py310h06a4308_0
12
+ - brotlipy=0.7.0=py310h7f8727e_1002
13
+ - bzip2=1.0.8=h7b6447c_0
14
+ - ca-certificates=2023.01.10=h06a4308_0
15
+ - certifi=2022.12.7=py310h06a4308_0
16
+ - cffi=1.15.1=py310h5eee18b_3
17
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
18
+ - comm=0.1.2=py310h06a4308_0
19
+ - conda=23.3.1=py310h06a4308_0
20
+ - conda-content-trust=0.1.3=py310h06a4308_0
21
+ - conda-package-handling=2.0.2=py310h06a4308_0
22
+ - conda-package-streaming=0.7.0=py310h06a4308_0
23
+ - cryptography=38.0.4=py310h9ce1e76_0
24
+ - cuda=11.6.1=0
25
+ - cuda-cccl=11.6.55=hf6102b2_0
26
+ - cuda-command-line-tools=11.6.2=0
27
+ - cuda-compiler=11.6.2=0
28
+ - cuda-cudart=11.6.55=he381448_0
29
+ - cuda-cudart-dev=11.6.55=h42ad0f4_0
30
+ - cuda-cuobjdump=11.6.124=h2eeebcb_0
31
+ - cuda-cupti=11.6.124=h86345e5_0
32
+ - cuda-cuxxfilt=11.6.124=hecbf4f6_0
33
+ - cuda-driver-dev=11.6.55=0
34
+ - cuda-gdb=12.1.55=0
35
+ - cuda-libraries=11.6.1=0
36
+ - cuda-libraries-dev=11.6.1=0
37
+ - cuda-memcheck=11.8.86=0
38
+ - cuda-nsight=12.1.55=0
39
+ - cuda-nsight-compute=12.1.0=0
40
+ - cuda-nvcc=11.6.124=hbba6d2d_0
41
+ - cuda-nvdisasm=12.1.55=0
42
+ - cuda-nvml-dev=11.6.55=haa9ef22_0
43
+ - cuda-nvprof=12.1.55=0
44
+ - cuda-nvprune=11.6.124=he22ec0a_0
45
+ - cuda-nvrtc=11.6.124=h020bade_0
46
+ - cuda-nvrtc-dev=11.6.124=h249d397_0
47
+ - cuda-nvtx=11.6.124=h0630a44_0
48
+ - cuda-nvvp=12.1.55=0
49
+ - cuda-runtime=11.6.1=0
50
+ - cuda-samples=11.6.101=h8efea70_0
51
+ - cuda-sanitizer-api=12.1.55=0
52
+ - cuda-toolkit=11.6.1=0
53
+ - cuda-tools=11.6.1=0
54
+ - cuda-visual-tools=11.6.1=0
55
+ - debugpy=1.5.1=py310h295c915_0
56
+ - decorator=5.1.1=pyhd3eb1b0_0
57
+ - flit-core=3.8.0=py310h06a4308_0
58
+ - freetype=2.12.1=h4a9f257_0
59
+ - gds-tools=1.6.0.25=0
60
+ - giflib=5.2.1=h5eee18b_3
61
+ - gmp=6.2.1=h295c915_3
62
+ - gnutls=3.6.15=he1e5248_0
63
+ - idna=3.4=py310h06a4308_0
64
+ - intel-openmp=2021.4.0=h06a4308_3561
65
+ - ipykernel=6.19.2=py310h2f386ee_0
66
+ - ipython=8.12.0=py310h06a4308_0
67
+ - jpeg=9e=h5eee18b_1
68
+ - jsonpatch=1.32=pyhd3eb1b0_0
69
+ - jsonpointer=2.1=pyhd3eb1b0_0
70
+ - jupyter_client=8.1.0=py310h06a4308_0
71
+ - jupyter_core=5.3.0=py310h06a4308_0
72
+ - lame=3.100=h7b6447c_0
73
+ - lcms2=2.12=h3be6417_0
74
+ - ld_impl_linux-64=2.38=h1181459_1
75
+ - lerc=3.0=h295c915_0
76
+ - libcublas=11.9.2.110=h5e84587_0
77
+ - libcublas-dev=11.9.2.110=h5c901ab_0
78
+ - libcufft=10.7.1.112=hf425ae0_0
79
+ - libcufft-dev=10.7.1.112=ha5ce4c0_0
80
+ - libcufile=1.6.0.25=0
81
+ - libcufile-dev=1.6.0.25=0
82
+ - libcurand=10.3.2.56=0
83
+ - libcurand-dev=10.3.2.56=0
84
+ - libcusolver=11.3.4.124=h33c3c4e_0
85
+ - libcusparse=11.7.2.124=h7538f96_0
86
+ - libcusparse-dev=11.7.2.124=hbbe9722_0
87
+ - libdeflate=1.17=h5eee18b_0
88
+ - libffi=3.4.2=h6a678d5_6
89
+ - libgcc-ng=11.2.0=h1234567_1
90
+ - libgomp=11.2.0=h1234567_1
91
+ - libiconv=1.16=h7f8727e_2
92
+ - libidn2=2.3.2=h7f8727e_0
93
+ - libnpp=11.6.3.124=hd2722f0_0
94
+ - libnpp-dev=11.6.3.124=h3c42840_0
95
+ - libnvjpeg=11.6.2.124=hd473ad6_0
96
+ - libnvjpeg-dev=11.6.2.124=hb5906b9_0
97
+ - libpng=1.6.39=h5eee18b_0
98
+ - libsodium=1.0.18=h7b6447c_0
99
+ - libstdcxx-ng=11.2.0=h1234567_1
100
+ - libtasn1=4.19.0=h5eee18b_0
101
+ - libtiff=4.5.0=h6a678d5_2
102
+ - libunistring=0.9.10=h27cfd23_0
103
+ - libuuid=1.41.5=h5eee18b_0
104
+ - libwebp=1.2.4=h11a3e52_1
105
+ - libwebp-base=1.2.4=h5eee18b_1
106
+ - lz4-c=1.9.4=h6a678d5_0
107
+ - matplotlib-inline=0.1.6=py310h06a4308_0
108
+ - mkl=2021.4.0=h06a4308_640
109
+ - mkl-service=2.4.0=py310h7f8727e_0
110
+ - mkl_fft=1.3.1=py310hd6ae3a3_0
111
+ - mkl_random=1.2.2=py310h00e6091_0
112
+ - ncurses=6.4=h6a678d5_0
113
+ - nest-asyncio=1.5.6=py310h06a4308_0
114
+ - nettle=3.7.3=hbbd107a_1
115
+ - nsight-compute=2023.1.0.15=0
116
+ - numpy=1.23.5=py310hd5efca6_0
117
+ - numpy-base=1.23.5=py310h8e6c178_0
118
+ - openh264=2.1.1=h4ff587b_0
119
+ - openssl=1.1.1t=h7f8727e_0
120
+ - packaging=23.0=py310h06a4308_0
121
+ - parso=0.8.3=pyhd3eb1b0_0
122
+ - pexpect=4.8.0=pyhd3eb1b0_3
123
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
124
+ - pip=22.3.1=py310h06a4308_0
125
+ - platformdirs=2.5.2=py310h06a4308_0
126
+ - pluggy=1.0.0=py310h06a4308_1
127
+ - psutil=5.9.0=py310h5eee18b_0
128
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
129
+ - pure_eval=0.2.2=pyhd3eb1b0_0
130
+ - pycosat=0.6.4=py310h5eee18b_0
131
+ - pycparser=2.21=pyhd3eb1b0_0
132
+ - pyopenssl=22.0.0=pyhd3eb1b0_0
133
+ - pysocks=1.7.1=py310h06a4308_0
134
+ - python=3.10.9=h7a1cb2a_0
135
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
136
+ - pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0
137
+ - pytorch-cuda=11.6=h867d48c_1
138
+ - pytorch-mutex=1.0=cuda
139
+ - pyzmq=23.2.0=py310h6a678d5_0
140
+ - readline=8.2=h5eee18b_0
141
+ - requests=2.28.1=py310h06a4308_0
142
+ - ruamel.yaml=0.17.21=py310h5eee18b_0
143
+ - ruamel.yaml.clib=0.2.6=py310h5eee18b_1
144
+ - setuptools=65.6.3=py310h06a4308_0
145
+ - six=1.16.0=pyhd3eb1b0_1
146
+ - sqlite=3.40.1=h5082296_0
147
+ - stack_data=0.2.0=pyhd3eb1b0_0
148
+ - tk=8.6.12=h1ccaba5_0
149
+ - toolz=0.12.0=py310h06a4308_0
150
+ - torchaudio=0.13.1=py310_cu116
151
+ - torchvision=0.14.1=py310_cu116
152
+ - tornado=6.2=py310h5eee18b_0
153
+ - tqdm=4.64.1=py310h06a4308_0
154
+ - typing_extensions=4.4.0=py310h06a4308_0
155
+ - tzdata=2022g=h04d1e81_0
156
+ - urllib3=1.26.14=py310h06a4308_0
157
+ - wheel=0.37.1=pyhd3eb1b0_0
158
+ - xz=5.2.10=h5eee18b_1
159
+ - zeromq=4.3.4=h2531618_0
160
+ - zlib=1.2.13=h5eee18b_0
161
+ - zstandard=0.18.0=py310h5eee18b_0
162
+ - zstd=1.5.4=hc292b87_0
163
+ - pip:
164
+ - absl-py==1.4.0
165
+ - aiohttp==3.8.4
166
+ - aiosignal==1.3.1
167
+ - anyio==3.6.2
168
+ - appdirs==1.4.4
169
+ - arrow==1.2.3
170
+ - asttokens==2.2.1
171
+ - async-generator==1.10
172
+ - async-timeout==4.0.2
173
+ - attrs==22.2.0
174
+ - audioread==3.0.0
175
+ - av==10.0.0
176
+ - beartype==0.12.0
177
+ - beautifulsoup4==4.12.2
178
+ - blessed==1.20.0
179
+ - braceexpand==0.1.7
180
+ - cachetools==5.3.0
181
+ - click==8.1.3
182
+ - contourpy==1.0.7
183
+ - croniter==1.3.10
184
+ - cycler==0.11.0
185
+ - dataclasses-json==0.5.8
186
+ - dateutils==0.6.12
187
+ - decord==0.6.0
188
+ - deepdiff==6.3.0
189
+ - dtk==0.2
190
+ - exceptiongroup==1.1.1
191
+ - executing==1.2.0
192
+ - fastapi==0.88.0
193
+ - ffmpeg==1.4
194
+ - ffmpeg-python==0.2.0
195
+ - filelock==3.12.0
196
+ - fonttools==4.39.3
197
+ - frozenlist==1.3.3
198
+ - fsspec==2023.4.0
199
+ - ftfy==6.1.1
200
+ - future==0.18.3
201
+ - gammatone==1.0
202
+ - google-auth==2.17.3
203
+ - google-auth-oauthlib==1.0.0
204
+ - greenlet==2.0.2
205
+ - grpcio==1.54.0
206
+ - h11==0.14.0
207
+ - h5py==3.8.0
208
+ - hickle==5.0.2
209
+ - huggingface-hub==0.14.1
210
+ - humanize==4.6.0
211
+ - imageio==2.27.0
212
+ - inquirer==3.1.3
213
+ - ipdb==0.13.13
214
+ - itsdangerous==2.1.2
215
+ - jedi==0.18.2
216
+ - jinja2==3.1.2
217
+ - joblib==1.2.0
218
+ - kiwisolver==1.4.4
219
+ - langchain==0.0.216
220
+ - langchainplus-sdk==0.0.17
221
+ - lazy-loader==0.2
222
+ - librosa==0.10.0.post2
223
+ - lightning==2.0.0
224
+ - lightning-cloud==0.5.33
225
+ - lightning-utilities==0.8.0
226
+ - llvmlite==0.39.1
227
+ - markdown==3.4.3
228
+ - markdown-it-py==2.2.0
229
+ - markupsafe==2.1.2
230
+ - marshmallow==3.19.0
231
+ - marshmallow-enum==1.5.1
232
+ - matplotlib==3.7.1
233
+ - mdurl==0.1.2
234
+ - mergedeep==1.3.4
235
+ - mock==5.0.2
236
+ - msgpack==1.0.5
237
+ - msgpack-numpy==0.4.8
238
+ - multidict==6.0.4
239
+ - musdb==0.4.0
240
+ - mypy-extensions==1.0.0
241
+ - networkx==3.1
242
+ - nose==1.3.7
243
+ - numba==0.56.4
244
+ - numexpr==2.8.4
245
+ - oauthlib==3.2.2
246
+ - openai==0.27.8
247
+ - openapi-schema-pydantic==1.2.4
248
+ - opencv-python==4.7.0.72
249
+ - ordered-set==4.1.0
250
+ - outcome==1.2.0
251
+ - pandas==1.5.3
252
+ - panns-inference==0.1.0
253
+ - pesq==0.0.4
254
+ - pillow==9.5.0
255
+ - pooch==1.6.0
256
+ - prompt-toolkit==3.0.38
257
+ - protobuf==4.22.3
258
+ - pyaml==23.5.9
259
+ - pyasn1==0.5.0
260
+ - pyasn1-modules==0.3.0
261
+ - pydantic==1.10.7
262
+ - pygments==2.14.0
263
+ - pyjwt==2.6.0
264
+ - pyloudnorm==0.1.1
265
+ - pyparsing==3.0.9
266
+ - pystoi==0.3.3
267
+ - python-editor==1.0.4
268
+ - python-multipart==0.0.6
269
+ - pytorch-ignite==0.3.0
270
+ - pytorch-lightning==2.0.1.post0
271
+ - pytz==2023.3
272
+ - pywavelets==1.4.1
273
+ - pyyaml==6.0
274
+ - readchar==4.0.5
275
+ - regex==2023.3.23
276
+ - requests-oauthlib==1.3.1
277
+ - resampy==0.4.2
278
+ - rich==13.3.3
279
+ - rsa==4.9
280
+ - scikit-image==0.20.0
281
+ - scikit-learn==1.2.2
282
+ - scipy==1.10.1
283
+ - selenium==4.8.3
284
+ - simplejpeg==1.6.6
285
+ - sniffio==1.3.0
286
+ - sortedcontainers==2.4.0
287
+ - soundfile==0.12.1
288
+ - soupsieve==2.4
289
+ - soxr==0.3.5
290
+ - sqlalchemy==2.0.17
291
+ - stack-data==0.6.2
292
+ - starlette==0.22.0
293
+ - starsessions==1.3.0
294
+ - stempeg==0.2.3
295
+ - tenacity==8.2.2
296
+ - tensorboard==2.12.2
297
+ - tensorboard-data-server==0.7.0
298
+ - tensorboard-plugin-wit==1.8.1
299
+ - termcolor==1.1.0
300
+ - threadpoolctl==3.1.0
301
+ - tifffile==2023.3.21
302
+ - timm==0.3.2
303
+ - tokenizers==0.13.3
304
+ - tomli==2.0.1
305
+ - torchfile==0.1.0
306
+ - torchlibrosa==0.1.0
307
+ - torchmetrics==0.11.4
308
+ - traitlets==5.9.0
309
+ - transformers==4.28.1
310
+ - trio==0.22.0
311
+ - trio-websocket==0.10.2
312
+ - typeguard==3.0.2
313
+ - typing-extensions==4.5.0
314
+ - typing-inspect==0.9.0
315
+ - uvicorn==0.21.1
316
+ - visdom==0.1.8.9
317
+ - wcwidth==0.2.6
318
+ - webdataset==0.2.48
319
+ - websocket-client==1.5.1
320
+ - websockets==11.0.1
321
+ - werkzeug==2.2.3
322
+ - wget==3.2
323
+ - wsproto==1.2.0
324
+ - yarl==1.8.2
325
+ - zenodo-get==1.3.4
326
+ - zsvision==0.7.8
losses.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def l1(output, target):
5
+ return torch.mean(torch.abs(output - target))
6
+
7
+
8
+ def l1_wav(output_dict, target_dict):
9
+ return l1(output_dict['segment'], target_dict['segment'])
10
+
11
+
12
+ def get_loss_function(loss_type):
13
+ if loss_type == "l1_wav":
14
+ return l1_wav
15
+
16
+ else:
17
+ raise NotImplementedError("Error!")
models/CLAP/__init__.py ADDED
File without changes
models/CLAP/open_clip/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .factory import (
2
+ list_models,
3
+ create_model,
4
+ create_model_and_transforms,
5
+ add_model_config,
6
+ )
7
+ from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8
+ from .model import (
9
+ CLAP,
10
+ CLAPTextCfg,
11
+ CLAPVisionCfg,
12
+ CLAPAudioCfp,
13
+ convert_weights_to_fp16,
14
+ trace_model,
15
+ )
16
+ from .openai import load_openai_model, list_openai_models
17
+ from .pretrained import (
18
+ list_pretrained,
19
+ list_pretrained_tag_models,
20
+ list_pretrained_model_tags,
21
+ get_pretrained_url,
22
+ download_pretrained,
23
+ )
24
+ from .tokenizer import SimpleTokenizer, tokenize
25
+ from .transform import image_transform
models/CLAP/open_clip/bert.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertModel
2
+
3
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
4
+ model = BertModel.from_pretrained("bert-base-uncased")
5
+ text = "Replace me by any text you'd like."
6
+
7
+
8
+ def bert_embeddings(text):
9
+ # text = "Replace me by any text you'd like."
10
+ encoded_input = tokenizer(text, return_tensors="pt")
11
+ output = model(**encoded_input)
12
+ return output
13
+
14
+
15
+ from transformers import RobertaTokenizer, RobertaModel
16
+
17
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
18
+ model = RobertaModel.from_pretrained("roberta-base")
19
+ text = "Replace me by any text you'd like."
20
+
21
+
22
+ def Roberta_embeddings(text):
23
+ # text = "Replace me by any text you'd like."
24
+ encoded_input = tokenizer(text, return_tensors="pt")
25
+ output = model(**encoded_input)
26
+ return output
27
+
28
+
29
+ from transformers import BartTokenizer, BartModel
30
+
31
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
32
+ model = BartModel.from_pretrained("facebook/bart-base")
33
+ text = "Replace me by any text you'd like."
34
+
35
+
36
+ def bart_embeddings(text):
37
+ # text = "Replace me by any text you'd like."
38
+ encoded_input = tokenizer(text, return_tensors="pt")
39
+ output = model(**encoded_input)
40
+ return output
models/CLAP/open_clip/factory.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+
9
+ import torch
10
+
11
+ from .model import CLAP, convert_weights_to_fp16
12
+ from .openai import load_openai_model
13
+ from .pretrained import get_pretrained_url, download_pretrained
14
+ from .transform import image_transform
15
+
16
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
17
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
18
+
19
+
20
+ def _natural_key(string_):
21
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
22
+
23
+
24
+ def _rescan_model_configs():
25
+ global _MODEL_CONFIGS
26
+
27
+ config_ext = (".json",)
28
+ config_files = []
29
+ for config_path in _MODEL_CONFIG_PATHS:
30
+ if config_path.is_file() and config_path.suffix in config_ext:
31
+ config_files.append(config_path)
32
+ elif config_path.is_dir():
33
+ for ext in config_ext:
34
+ config_files.extend(config_path.glob(f"*{ext}"))
35
+
36
+ for cf in config_files:
37
+ if os.path.basename(cf)[0] == ".":
38
+ continue # Ignore hidden files
39
+
40
+ with open(cf, "r") as f:
41
+ model_cfg = json.load(f)
42
+ if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
43
+ _MODEL_CONFIGS[cf.stem] = model_cfg
44
+
45
+ _MODEL_CONFIGS = {
46
+ k: v
47
+ for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
48
+ }
49
+
50
+
51
+ _rescan_model_configs() # initial populate of model config registry
52
+
53
+
54
+ def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
55
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
56
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
57
+ state_dict = checkpoint["state_dict"]
58
+ else:
59
+ state_dict = checkpoint
60
+ if skip_params:
61
+ if next(iter(state_dict.items()))[0].startswith("module"):
62
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
63
+ # for k in state_dict:
64
+ # if k.startswith('transformer'):
65
+ # v = state_dict.pop(k)
66
+ # state_dict['text_branch.' + k[12:]] = v
67
+ return state_dict
68
+
69
+
70
+ def create_model(
71
+ amodel_name: str,
72
+ tmodel_name: str,
73
+ pretrained: str = "",
74
+ precision: str = "fp32",
75
+ device: torch.device = torch.device("cpu"),
76
+ jit: bool = False,
77
+ force_quick_gelu: bool = False,
78
+ openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
79
+ skip_params=True,
80
+ pretrained_audio: str = "",
81
+ pretrained_text: str = "",
82
+ enable_fusion: bool = False,
83
+ fusion_type: str = "None"
84
+ # pretrained_image: bool = False,
85
+ ):
86
+ amodel_name = amodel_name.replace(
87
+ "/", "-"
88
+ ) # for callers using old naming with / in ViT names
89
+ pretrained_orig = pretrained
90
+ pretrained = pretrained.lower()
91
+ if pretrained == "openai":
92
+ if amodel_name in _MODEL_CONFIGS:
93
+ logging.info(f"Loading {amodel_name} model config.")
94
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
95
+ else:
96
+ logging.error(
97
+ f"Model config for {amodel_name} not found; available models {list_models()}."
98
+ )
99
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
100
+
101
+ logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
102
+ # Hard Code in model name
103
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
104
+ model = load_openai_model(
105
+ "ViT-B-16",
106
+ model_cfg,
107
+ device=device,
108
+ jit=jit,
109
+ cache_dir=openai_model_cache_dir,
110
+ enable_fusion=enable_fusion,
111
+ fusion_type=fusion_type,
112
+ )
113
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
114
+ if precision == "amp" or precision == "fp32":
115
+ model = model.float()
116
+ else:
117
+ if amodel_name in _MODEL_CONFIGS:
118
+ logging.info(f"Loading {amodel_name} model config.")
119
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
120
+ else:
121
+ logging.error(
122
+ f"Model config for {amodel_name} not found; available models {list_models()}."
123
+ )
124
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
125
+
126
+ if force_quick_gelu:
127
+ # override for use of QuickGELU on non-OpenAI transformer models
128
+ model_cfg["quick_gelu"] = True
129
+
130
+ # if pretrained_image:
131
+ # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
132
+ # # pretrained weight loading for timm models set via vision_cfg
133
+ # model_cfg['vision_cfg']['timm_model_pretrained'] = True
134
+ # else:
135
+ # assert False, 'pretrained image towers currently only supported for timm models'
136
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
137
+ model_cfg["enable_fusion"] = enable_fusion
138
+ model_cfg["fusion_type"] = fusion_type
139
+ model = CLAP(**model_cfg)
140
+
141
+ if pretrained:
142
+ checkpoint_path = ""
143
+ url = get_pretrained_url(amodel_name, pretrained)
144
+ if url:
145
+ checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
146
+ elif os.path.exists(pretrained_orig):
147
+ checkpoint_path = pretrained_orig
148
+ if checkpoint_path:
149
+