fpaissan commited on
Commit
5e02fce
1 Parent(s): 70d8c7d

tinyCLAP Space

Browse files
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ results/
2
+
3
+ *.pth
4
+ drop*
5
+ confs.py
6
+ # CLAP/
7
+ build.sh
8
+ run.sh
9
+ Dockerfile
10
+ debug
11
+
12
+ *.swp
13
+ ckp/
14
+
15
+ # Byte-compiled / optimized / DLL files
16
+ __pycache__/
17
+ *.py[cod]
18
+ *$py.class
19
+
20
+ # C extensions
21
+ *.so
22
+
23
+ # Distribution / packaging
24
+ .Python
25
+ build/
26
+ develop-eggs/
27
+ dist/
28
+ downloads/
29
+ eggs/
30
+ .eggs/
31
+ lib/
32
+ lib64/
33
+ parts/
34
+ sdist/
35
+ var/
36
+ wheels/
37
+ share/python-wheels/
38
+ *.egg-info/
39
+ .installed.cfg
40
+ *.egg
41
+ MANIFEST
42
+
43
+ # PyInstaller
44
+ # Usually these files are written by a python script from a template
45
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
46
+ *.manifest
47
+ *.spec
48
+
49
+ # Installer logs
50
+ pip-log.txt
51
+ pip-delete-this-directory.txt
52
+
53
+ # Unit test / coverage reports
54
+ htmlcov/
55
+ .tox/
56
+ .nox/
57
+ .coverage
58
+ .coverage.*
59
+ .cache
60
+ nosetests.xml
61
+ coverage.xml
62
+ *.cover
63
+ *.py,cover
64
+ .hypothesis/
65
+ .pytest_cache/
66
+ cover/
67
+
68
+ # Translations
69
+ *.mo
70
+ *.pot
71
+
72
+ # Django stuff:
73
+ *.log
74
+ local_settings.py
75
+ db.sqlite3
76
+ db.sqlite3-journal
77
+
78
+ # Flask stuff:
79
+ instance/
80
+ .webassets-cache
81
+
82
+ # Scrapy stuff:
83
+ .scrapy
84
+
85
+ # Sphinx documentation
86
+ docs/_build/
87
+
88
+ # PyBuilder
89
+ .pybuilder/
90
+ target/
91
+
92
+ # Jupyter Notebook
93
+ .ipynb_checkpoints
94
+
95
+ # IPython
96
+ profile_default/
97
+ ipython_config.py
98
+
99
+ # pyenv
100
+ # For a library or package, you might want to ignore these files since the code is
101
+ # intended to run in multiple environments; otherwise, check them in:
102
+ # .python-version
103
+
104
+ # pipenv
105
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
106
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
107
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
108
+ # install all needed dependencies.
109
+ #Pipfile.lock
110
+
111
+ # poetry
112
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
113
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
114
+ # commonly ignored for libraries.
115
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
116
+ #poetry.lock
117
+
118
+ # pdm
119
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
120
+ #pdm.lock
121
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
122
+ # in version control.
123
+ # https://pdm.fming.dev/#use-with-ide
124
+ .pdm.toml
125
+
126
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
127
+ __pypackages__/
128
+
129
+ # Celery stuff
130
+ celerybeat-schedule
131
+ celerybeat.pid
132
+
133
+ # SageMath parsed files
134
+ *.sage.py
135
+
136
+ # Environments
137
+ .env
138
+ .venv
139
+ env/
140
+ venv/
141
+ ENV/
142
+ env.bak/
143
+ venv.bak/
144
+
145
+ # Spyder project settings
146
+ .spyderproject
147
+ .spyproject
148
+
149
+ # Rope project settings
150
+ .ropeproject
151
+
152
+ # mkdocs documentation
153
+ /site
154
+
155
+ # mypy
156
+ .mypy_cache/
157
+ .dmypy.json
158
+ dmypy.json
159
+
160
+ # Pyre type checker
161
+ .pyre/
162
+
163
+ # pytype static type analyzer
164
+ .pytype/
165
+
166
+ # Cython debug symbols
167
+ cython_debug/
168
+
169
+ # PyCharm
170
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
171
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
172
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
173
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
174
+ #.idea/
1-20133-A-39.wav ADDED
Binary file (441 kB). View file
 
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This recipe to train CLAP.
2
+ It supports distillation using tinyCLAP (https://arxiv.org/abs/2311.14517).
3
+
4
+ Authors
5
+ * Francesco Paissan 2024
6
+ """
7
+
8
+ import sys
9
+
10
+ import gradio as gr
11
+ import speechbrain as sb
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torchaudio
16
+ import torchaudio.transforms as T
17
+ from hyperpyyaml import load_hyperpyyaml
18
+ from speechbrain.utils.distributed import run_on_main
19
+ from speechbrain.utils.metric_stats import MetricStats
20
+
21
+ torch.backends.cudnn.enabled = False
22
+
23
+ eps = 1e-10
24
+
25
+
26
+ class CLAPBrain(sb.Brain):
27
+ def preprocess(self, wavs):
28
+ """Pre-process wavs."""
29
+ x = self.hparams.spectrogram_extractor(wavs)
30
+ x = self.hparams.logmel_extractor(x)
31
+
32
+ return x
33
+
34
+ def prepare_txt_features(self, text):
35
+ """Prepares text features to input in CLAP text encoder."""
36
+ txt_inp = self.hparams.txt_tokenizer(
37
+ text,
38
+ max_length=self.hparams.text_max_length,
39
+ padding="max_length",
40
+ truncation=True,
41
+ return_tensors="pt",
42
+ ).to(self.device)
43
+
44
+ return txt_inp
45
+
46
+ def compute_sim(self, audio_embed, caption_embed):
47
+ """Computes CLAP similarity metric."""
48
+ similarity = audio_embed @ caption_embed.t()
49
+
50
+ return similarity
51
+
52
+ def compute_forward(self, batch, stage):
53
+ if len(batch) == 2:
54
+ wavs, caption = batch
55
+ else:
56
+ wavs, caption, _, _ = batch
57
+
58
+ wavs = wavs.to(self.device).squeeze(1)
59
+
60
+ x_sb = self.preprocess(wavs)
61
+
62
+ text_inp = self.prepare_txt_features(caption)
63
+
64
+ txt_shared, aud_shared = self.hparams.clap(
65
+ x_sb,
66
+ text_inp.input_ids.data,
67
+ text_inp.token_type_ids.data,
68
+ text_inp.attention_mask.data,
69
+ )
70
+
71
+ if not hasattr(self.modules, "clap"):
72
+ aud_shared_student, _, _ = self.modules.clap_student(x_sb)
73
+ aud_shared_student = aud_shared_student / aud_shared_student.norm(
74
+ dim=1, keepdim=True
75
+ )
76
+
77
+ return txt_shared, aud_shared, aud_shared_student
78
+
79
+
80
+ def audio_preprocess(x, sample_rate):
81
+ tmp, sr = torchaudio.load(x)
82
+ resample = T.Resample(sr, sample_rate)
83
+
84
+ tmp = resample(tmp)
85
+ tmp = tmp.sum(0, keepdims=True)
86
+
87
+ return tmp
88
+
89
+
90
+ @torch.no_grad()
91
+ def inference_wrapper(clap_brain):
92
+ def f(wav_path, prompt):
93
+ clap_brain.modules.eval()
94
+ tmp = audio_preprocess(wav_path, clap_brain.hparams.sample_rate)
95
+
96
+ ret = clap_brain.compute_forward([tmp, prompt], stage=sb.Stage.TEST)
97
+ sim = clap_brain.compute_sim(ret[2], ret[0])
98
+
99
+ return f"tinyCLAP similarity is: {round(sim.item(), 2)}"
100
+
101
+ return f
102
+
103
+
104
+ if __name__ == "__main__":
105
+
106
+ # CLI:
107
+ # hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
108
+ hparams_file = "hparams/inference.yaml"
109
+
110
+ # Load hyperparameters file with command-line overrides
111
+ with open(hparams_file) as fin:
112
+ hparams = load_hyperpyyaml(fin, {})
113
+
114
+ # Tensorboard logging
115
+ if hparams["use_tensorboard"]:
116
+ from speechbrain.utils.train_logger import TensorboardLogger
117
+
118
+ hparams["tensorboard_train_logger"] = TensorboardLogger(
119
+ hparams["tensorboard_logs_folder"]
120
+ )
121
+
122
+ hparams["clap"].to(hparams["device"])
123
+ hparams["clap"].requires_grad_(False)
124
+ hparams["clap"].eval()
125
+
126
+ if hparams["zs_eval"]:
127
+ hparams["class_list"] = datasets["train"].dataset.classes
128
+
129
+ if hparams["audioenc_name_student"] is not None:
130
+ if hparams["projection_only"]:
131
+ print("Freezing Base AudioEncoder. Updating only the projection layers.")
132
+ hparams["student_model"].base.requires_grad_(False)
133
+
134
+ hparams["spectrogram_extractor"].to(hparams["device"])
135
+ hparams["logmel_extractor"].to(hparams["device"])
136
+
137
+ clap_brain = CLAPBrain(
138
+ modules=hparams["modules"],
139
+ hparams=hparams,
140
+ )
141
+
142
+ if hparams["pretrained_CLAP"] is not None:
143
+ print("Loading CLAP model...")
144
+ run_on_main(hparams["load_CLAP"].collect_files)
145
+ hparams["load_CLAP"].load_collected()
146
+
147
+ inference_api = inference_wrapper(clap_brain)
148
+
149
+ examples_list = [
150
+ ["./tunztunz_music.wav", "this is the sound of house music"],
151
+ ["./siren.wav", "this is the sound of sirens wailing"],
152
+ [
153
+ "./whistling_and_chirping.wav",
154
+ "someone is whistling while birds are chirping",
155
+ ],
156
+ ]
157
+
158
+ demo = gr.Interface(
159
+ fn=inference_api,
160
+ inputs=[gr.Audio(type="filepath"), gr.Textbox()],
161
+ outputs=["text"],
162
+ examples=examples_list,
163
+ )
164
+ demo.launch()
hparams/inference.yaml ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # #################################
2
+ # The recipe for distilling the CLAP baseline.
3
+ #
4
+ # Author:
5
+ # * Francesco Paissan 2024
6
+ # #################################
7
+
8
+ # Seed needs to be set at top of yaml, before objects with parameters are made
9
+ seed: 1234
10
+ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
11
+
12
+ # Set up folders for reading from and writing to -- if null dataset is ignored
13
+ esc_folder: null
14
+ us8k_folder: null
15
+ tut17_folder: null
16
+
17
+ audiocaps_folder: null
18
+ macs_folder: null
19
+ clotho_folder: null
20
+ fsd50k_folder: null
21
+
22
+ device: "cpu"
23
+
24
+ projection_only: False
25
+
26
+ # Audio Enc Student type
27
+ audioenc_name_student: phinet_alpha_1.50_beta_0.75_t0_6_N_7
28
+ aud_emb_dim_student: 2048
29
+
30
+ zs_eval: False
31
+
32
+ clap_ckpt: "https://zenodo.org/records/7312125/files/CLAP_weights_2022.pth"
33
+
34
+ experiment_name: tinyCLAP
35
+ output_folder: !ref ./results/<experiment_name>/<seed>
36
+ save_folder: !ref <output_folder>/save
37
+ train_log: !ref <output_folder>/train_log.txt
38
+
39
+ # Tensorboard logs
40
+ use_tensorboard: False
41
+ tensorboard_logs_folder: !ref <output_folder>/tb_logs/
42
+
43
+ ckpt_interval_minutes: 15 # save checkpoint every N min
44
+
45
+ # Training parameters
46
+ number_of_epochs: 100
47
+ batch_size: 64
48
+
49
+ lr: 0.012
50
+
51
+ sample_rate: 44100
52
+ signal_length_s: 5
53
+
54
+ # Feature parameters
55
+ n_mels: 64
56
+ spec_mag_power: 1
57
+
58
+ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
59
+ limit: !ref <number_of_epochs>
60
+
61
+ opt_class: !name:torch.optim.Adam
62
+ lr: !ref <lr>
63
+
64
+ lr_annealing: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau
65
+ factor: 0.1
66
+ patience: 10
67
+
68
+ # Logging + checkpoints
69
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
70
+ save_file: !ref <train_log>
71
+
72
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
73
+ checkpoints_dir: !ref <save_folder>
74
+ recoverables:
75
+ student_model: !ref <student_model>
76
+ counter: !ref <epoch_counter>
77
+
78
+ pretrained_CLAP: !ref fpaissan/tinyCLAP/<audioenc_name_student>.ckpt
79
+ load_CLAP: !new:speechbrain.utils.parameter_transfer.Pretrainer
80
+ collect_in: !ref <save_folder>
81
+ loadables:
82
+ student_model: !ref <student_model>
83
+ paths:
84
+ student_model: !ref <pretrained_CLAP>
85
+
86
+ fmin: 50
87
+ fmax: 14000
88
+ aud_emb_classes_num: 527
89
+
90
+ emb_norm_type: bn
91
+ aud_emb_dim: 2048
92
+ txt_emb_dim: 768
93
+ shared_emb_dim: 1024
94
+ text_max_length: 100
95
+
96
+ use_pretrained: True
97
+ clap: !new:modules.CLAP
98
+ audioenc_name: Cnn14
99
+ classes_num: !ref <aud_emb_classes_num>
100
+ out_emb: !ref <aud_emb_dim>
101
+ text_model: bert-base-uncased
102
+ transformer_embed_dim: !ref <txt_emb_dim>
103
+ d_proj: !ref <shared_emb_dim>
104
+ pretrained_weights: !ref <use_pretrained>
105
+ CLAP_weights: !ref <clap_ckpt>
106
+ audioenc_name_student: !ref <audioenc_name_student>
107
+ out_emb_student: !ref <aud_emb_dim_student>
108
+
109
+ txt_tokenizer: !apply:transformers.AutoTokenizer.from_pretrained
110
+ pretrained_model_name_or_path: bert-base-uncased
111
+
112
+ # Interpretation hyperparams
113
+ K: 1024
114
+
115
+ # pre-processing
116
+ n_fft: 1024
117
+ hop_length: 320
118
+ win_length: 1024
119
+ use_melspectra_log1p: False
120
+ use_melspectra: True
121
+ use_stft2mel: True
122
+
123
+ # Spectrogram extractor
124
+ spectrogram_extractor: !new:torchlibrosa.stft.Spectrogram
125
+ n_fft: !ref <n_fft>
126
+ hop_length: !ref <hop_length>
127
+ win_length: !ref <win_length>
128
+ window: "hann"
129
+ center: True
130
+ pad_mode: "reflect"
131
+ freeze_parameters: True
132
+
133
+ # Logmel feature extractor
134
+ logmel_extractor: !new:torchlibrosa.stft.LogmelFilterBank
135
+ sr: !ref <sample_rate>
136
+ n_fft: !ref <win_length>
137
+ n_mels: !ref <n_mels>
138
+ fmin: !ref <fmin>
139
+ fmax: !ref <fmax>
140
+ ref: 1.0
141
+ amin: 0.0000000001
142
+ top_db: null
143
+ freeze_parameters: True
144
+
145
+
146
+ student_model: !new:modules.AudioEncoder
147
+ audioenc_name: !ref <audioenc_name_student>
148
+ d_in: !ref <aud_emb_dim_student>
149
+ d_out: !ref <shared_emb_dim>
150
+ classes_num: !ref <aud_emb_classes_num>
151
+
152
+ modules:
153
+ clap_student: !ref <student_model>
modules.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code to define CLAP-related networks.
3
+ Some code inspired from here https://github.com/zhepeiw/clap_curation
4
+
5
+ Credits:
6
+ * Francesco Paissan 2024
7
+ """
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from micromind.networks import PhiNet
13
+ from speechbrain.utils.fetching import fetch
14
+ from torch import nn
15
+ from torchinfo import summary
16
+ from transformers import AutoModel, BatchEncoding
17
+
18
+
19
+ def get_model_from_str(s, vs=("alpha", "beta", "t0", "N")):
20
+ def get_var(s, key):
21
+ tmp = s.split("_")
22
+ return tmp[tmp.index(key) + 1]
23
+
24
+ verb = "PhiNet initialized with "
25
+ ret = {}
26
+ for k in vs:
27
+ verb += f"{k}={get_var(s, k)} "
28
+ ret[k] = float(get_var(s, k))
29
+
30
+ ret["t_zero"] = ret["t0"]
31
+ ret["num_layers"] = ret["N"]
32
+ del ret["t0"]
33
+ del ret["N"]
34
+
35
+ return ret
36
+
37
+
38
+ def get_audio_encoder(name: str):
39
+ if name == "Cnn14":
40
+ return Cnn14
41
+ elif "phinet" in name:
42
+ phinet_conf = get_model_from_str(name)
43
+ return PhiNet(input_shape=(1, 640, 64), compatibility=True, **phinet_conf)
44
+ else:
45
+ raise Exception(
46
+ "The audio encoder name {} is incorrect or not supported".format(name)
47
+ )
48
+
49
+
50
+ class Projection(nn.Module):
51
+ def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None:
52
+ super().__init__()
53
+ self.linear1 = nn.Linear(d_in, d_out, bias=False)
54
+ self.linear2 = nn.Linear(d_out, d_out, bias=False)
55
+ self.layer_norm = nn.LayerNorm(d_out)
56
+ self.drop = nn.Dropout(p)
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ embed1 = self.linear1(x)
60
+ embed2 = self.drop(self.linear2(F.gelu(embed1)))
61
+ embeds = self.layer_norm(embed1 + embed2)
62
+ return embeds
63
+
64
+
65
+ class PhiNet(PhiNet):
66
+ def __init__(self, embedding_dim=2048, *args, **kwargs):
67
+ super().__init__(*args, **kwargs)
68
+
69
+ self.bn0 = nn.BatchNorm2d(64)
70
+
71
+ if embedding_dim is not None:
72
+ in_channels_next = self._layers[-1]._layers[-2].weight.shape[0]
73
+ self.pn_block = nn.Conv2d(
74
+ in_channels_next,
75
+ embedding_dim,
76
+ kernel_size=1,
77
+ stride=2,
78
+ )
79
+
80
+ def forward(self, x):
81
+ if x.dim() == 3:
82
+ x = x[:, None]
83
+
84
+ x = x.transpose(1, 3)
85
+ x = self.bn0(x)
86
+ x = x.transpose(1, 3)
87
+
88
+ x = super().forward(x)
89
+ embedding = x
90
+
91
+ x = self.pn_block(x)
92
+ x = x.mean((-1, -2))
93
+
94
+ return {"embedding": (x, embedding), "clipwise_output": x}
95
+
96
+
97
+ class ConvBlock(nn.Module):
98
+ def __init__(self, in_channels, out_channels):
99
+
100
+ super(ConvBlock, self).__init__()
101
+
102
+ self.conv1 = nn.Conv2d(
103
+ in_channels=in_channels,
104
+ out_channels=out_channels,
105
+ kernel_size=(3, 3),
106
+ stride=(1, 1),
107
+ padding=(1, 1),
108
+ bias=False,
109
+ )
110
+
111
+ self.conv2 = nn.Conv2d(
112
+ in_channels=out_channels,
113
+ out_channels=out_channels,
114
+ kernel_size=(3, 3),
115
+ stride=(1, 1),
116
+ padding=(1, 1),
117
+ bias=False,
118
+ )
119
+
120
+ self.bn1 = nn.BatchNorm2d(out_channels)
121
+ self.bn2 = nn.BatchNorm2d(out_channels)
122
+
123
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
124
+
125
+ x = input
126
+ x = F.relu_(self.bn1(self.conv1(x)))
127
+ x = F.relu_(self.bn2(self.conv2(x)))
128
+ if pool_type == "max":
129
+ x = F.max_pool2d(x, kernel_size=pool_size)
130
+ elif pool_type == "avg":
131
+ x = F.avg_pool2d(x, kernel_size=pool_size)
132
+ elif pool_type == "avg+max":
133
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
134
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
135
+ x = x1 + x2
136
+ else:
137
+ raise Exception("Incorrect argument!")
138
+
139
+ return x
140
+
141
+
142
+ class ConvBlock5x5(nn.Module):
143
+ def __init__(self, in_channels, out_channels):
144
+
145
+ super(ConvBlock5x5, self).__init__()
146
+
147
+ self.conv1 = nn.Conv2d(
148
+ in_channels=in_channels,
149
+ out_channels=out_channels,
150
+ kernel_size=(5, 5),
151
+ stride=(1, 1),
152
+ padding=(2, 2),
153
+ bias=False,
154
+ )
155
+
156
+ self.bn1 = nn.BatchNorm2d(out_channels)
157
+
158
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
159
+
160
+ x = input
161
+ x = F.relu_(self.bn1(self.conv1(x)))
162
+ if pool_type == "max":
163
+ x = F.max_pool2d(x, kernel_size=pool_size)
164
+ elif pool_type == "avg":
165
+ x = F.avg_pool2d(x, kernel_size=pool_size)
166
+ elif pool_type == "avg+max":
167
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
168
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
169
+ x = x1 + x2
170
+ else:
171
+ raise Exception("Incorrect argument!")
172
+
173
+ return x
174
+
175
+
176
+ class AttBlock(nn.Module):
177
+ def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
178
+ super(AttBlock, self).__init__()
179
+
180
+ self.activation = activation
181
+ self.temperature = temperature
182
+ self.att = nn.Conv1d(
183
+ in_channels=n_in,
184
+ out_channels=n_out,
185
+ kernel_size=1,
186
+ stride=1,
187
+ padding=0,
188
+ bias=True,
189
+ )
190
+ self.cla = nn.Conv1d(
191
+ in_channels=n_in,
192
+ out_channels=n_out,
193
+ kernel_size=1,
194
+ stride=1,
195
+ padding=0,
196
+ bias=True,
197
+ )
198
+
199
+ self.bn_att = nn.BatchNorm1d(n_out)
200
+
201
+ def forward(self, x):
202
+ # x: (n_samples, n_in, n_time)
203
+ norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
204
+ cla = self.nonlinear_transform(self.cla(x))
205
+ x = torch.sum(norm_att * cla, dim=2)
206
+ return x, norm_att, cla
207
+
208
+ def nonlinear_transform(self, x):
209
+ if self.activation == "linear":
210
+ return x
211
+ elif self.activation == "sigmoid":
212
+ return torch.sigmoid(x)
213
+
214
+
215
+ class Cnn14(nn.Module):
216
+ def __init__(
217
+ self,
218
+ classes_num,
219
+ out_emb,
220
+ ):
221
+
222
+ super(Cnn14, self).__init__()
223
+
224
+ self.bn0 = nn.BatchNorm2d(64)
225
+
226
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
227
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
228
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
229
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
230
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
231
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
232
+
233
+ # out_emb is 2048 for best Cnn14
234
+ self.fc1 = nn.Linear(2048, out_emb, bias=True)
235
+ self.fc_audioset = nn.Linear(out_emb, classes_num, bias=True)
236
+
237
+ def forward(self, x, mixup_lambda=None):
238
+ """
239
+ Input: (batch_size, data_length)
240
+ """
241
+ # (batch_size, 1, time_steps, mel_bins)
242
+
243
+ if x.dim() == 3:
244
+ x = x.unsqueeze(1)
245
+
246
+ x = x.transpose(1, 3)
247
+ x = self.bn0(x)
248
+ x = x.transpose(1, 3)
249
+
250
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
251
+ x = F.dropout(x, p=0.2, training=self.training)
252
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
253
+ x = F.dropout(x, p=0.2, training=self.training)
254
+ x4_out = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
255
+ x = F.dropout(x4_out, p=0.2, training=self.training)
256
+ x3_out = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
257
+ x = F.dropout(x3_out, p=0.2, training=self.training)
258
+ x2_out = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
259
+ x = F.dropout(x2_out, p=0.2, training=self.training)
260
+ x1_out = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
261
+ x = F.dropout(x1_out, p=0.2, training=self.training)
262
+ x = torch.mean(x, dim=3)
263
+
264
+ (x1, _) = torch.max(x, dim=2)
265
+ x2 = torch.mean(x, dim=2)
266
+ x = x1 + x2
267
+ x = F.dropout(x, p=0.5, training=self.training)
268
+ x = F.relu_(self.fc1(x))
269
+ embedding = F.dropout(x, p=0.5, training=self.training)
270
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
271
+
272
+ output_dict = {
273
+ "clipwise_output": clipwise_output,
274
+ "embedding": (embedding, x1_out, x2_out, x3_out, x4_out),
275
+ }
276
+
277
+ return output_dict
278
+
279
+
280
+ class AudioEncoder(nn.Module):
281
+ def __init__(
282
+ self,
283
+ audioenc_name: str,
284
+ d_in: int,
285
+ d_out: int,
286
+ classes_num: int,
287
+ ) -> None:
288
+ super().__init__()
289
+
290
+ audio_encoder = get_audio_encoder(audioenc_name)
291
+
292
+ if not "phinet" in audioenc_name:
293
+ self.base = audio_encoder(
294
+ classes_num,
295
+ d_in,
296
+ )
297
+ else:
298
+ self.base = audio_encoder
299
+
300
+ self.projection = Projection(d_in, d_out)
301
+
302
+ def forward(self, x):
303
+ out_dict = self.base(x)
304
+ audio_features, audio_classification_output = (
305
+ out_dict["embedding"][0],
306
+ out_dict["clipwise_output"],
307
+ )
308
+ projected_vec = self.projection(audio_features)
309
+
310
+ return (
311
+ projected_vec,
312
+ out_dict["embedding"][1:],
313
+ audio_classification_output,
314
+ )
315
+
316
+
317
+ class TextEncoder(nn.Module):
318
+ def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
319
+ super().__init__()
320
+ self.base = AutoModel.from_pretrained(text_model)
321
+
322
+ self.projection = Projection(transformer_embed_dim, d_out)
323
+
324
+ def forward(self, x):
325
+ out = self.base(**x)[0]
326
+ hidden_state = out
327
+ out = out[:, 0, :] # get CLS token output
328
+ projected_vec = self.projection(out)
329
+ self.hidden_state = hidden_state.detach()
330
+ return projected_vec
331
+
332
+
333
+ class CLAP(nn.Module):
334
+ def __init__(
335
+ self,
336
+ # audio
337
+ audioenc_name: str,
338
+ classes_num: int,
339
+ out_emb: int,
340
+ # text
341
+ text_model: str,
342
+ transformer_embed_dim: int,
343
+ # common
344
+ d_proj: int,
345
+ pretrained_weights: bool = True,
346
+ CLAP_weights: str = None,
347
+ # audio student
348
+ audioenc_name_student=None,
349
+ out_emb_student=None,
350
+ ):
351
+ super().__init__()
352
+ ckpt_path = None
353
+ if pretrained_weights and CLAP_weights is not None:
354
+ weights_path = "CLAP_weights.pth"
355
+ tmp = CLAP_weights.split("/")
356
+ print(
357
+ " ".join(
358
+ """Fetching CLAP weights.
359
+ The checkpoint is a ~2GB, so be patient.
360
+ The process will start right after.
361
+ """.split()
362
+ )
363
+ )
364
+ fetch(
365
+ tmp[-1],
366
+ "/".join(tmp[:-1]),
367
+ savedir=".",
368
+ save_filename=weights_path,
369
+ )
370
+
371
+ ckpt_path = weights_path
372
+
373
+ self.audio_encoder = AudioEncoder(
374
+ audioenc_name,
375
+ out_emb,
376
+ d_proj,
377
+ classes_num,
378
+ )
379
+
380
+ self.caption_encoder = TextEncoder(d_proj, text_model, transformer_embed_dim)
381
+
382
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
383
+
384
+ state_dict = torch.load(ckpt_path)["model"]
385
+ self.load_state_dict(self.clean_state_dict(state_dict))
386
+ print("Loaded pretrained CLAP checkpoint.")
387
+
388
+ @staticmethod
389
+ def clean_state_dict(state_dict):
390
+ """Removes pre-processing keys from the state-dict."""
391
+ keys_to_remove = []
392
+ for k in state_dict:
393
+ if "spectrogram" in k or "mel" in k:
394
+ keys_to_remove.append(k)
395
+
396
+ for k in keys_to_remove:
397
+ state_dict.pop(
398
+ k,
399
+ None,
400
+ )
401
+
402
+ return state_dict
403
+
404
+ def forward(self, audio, input_ids, token_type_ids, attention_mask, single=None):
405
+ audio_embed = None
406
+ caption_embed = None
407
+
408
+ if not single == "txt":
409
+ audio_embed, _, _ = self.audio_encoder(audio)
410
+ audio_embed = audio_embed / audio_embed.norm(dim=1, keepdim=True)
411
+
412
+ if not single == "aud":
413
+ text = BatchEncoding(
414
+ {
415
+ "input_ids": input_ids,
416
+ "token_type_ids": token_type_ids,
417
+ "attention_mask": attention_mask,
418
+ }
419
+ )
420
+ caption_embed = self.caption_encoder(text)
421
+ caption_embed = caption_embed / caption_embed.norm(dim=1, keepdim=True)
422
+
423
+ return caption_embed, audio_embed
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ speechbrain
2
+ pandas
3
+ transformers==4.28.1
4
+ torchlibrosa
5
+ micromind
6
+ torchinfo
7
+ gradio
siren.wav ADDED
Binary file (640 kB). View file
 
tunztunz_music.wav ADDED
Binary file (963 kB). View file
 
whistling_and_chirping.wav ADDED
Binary file (328 kB). View file