berkaygkv54 commited on
Commit
19759e2
1 Parent(s): 4b3ccc3

first push

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +5 -0
  2. app.py +52 -0
  3. data/.DS_Store +0 -0
  4. data/audio/.gitkeep +0 -0
  5. data/json/saved_tracks.json +0 -0
  6. data/vectors/audio_representations.npy +3 -0
  7. model_checkpoints/.gitkeep +0 -0
  8. model_checkpoints/music_audioset_epoch_15_esc_90.14.pt +3 -0
  9. notebooks/notebook.ipynb +788 -0
  10. orchestrate_audio_data.py +8 -0
  11. recommender.py +11 -0
  12. requirements.txt +89 -0
  13. src/config/__init__.py +0 -0
  14. src/config/configs.py +16 -0
  15. src/data/__init__.py +0 -0
  16. src/data/get_yt_links.py +52 -0
  17. src/data/pytuber.py +35 -0
  18. src/data/spotify.py +24 -0
  19. src/laion_clap/__init__.py +5 -0
  20. src/laion_clap/clap_module/__init__.py +8 -0
  21. src/laion_clap/clap_module/bert.py +32 -0
  22. src/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz +3 -0
  23. src/laion_clap/clap_module/factory.py +263 -0
  24. src/laion_clap/clap_module/feature_fusion.py +193 -0
  25. src/laion_clap/clap_module/htsat.py +1031 -0
  26. src/laion_clap/clap_module/linear_probe.py +63 -0
  27. src/laion_clap/clap_module/loss.py +307 -0
  28. src/laion_clap/clap_module/model.py +892 -0
  29. src/laion_clap/clap_module/model_configs/HTSAT-base.json +23 -0
  30. src/laion_clap/clap_module/model_configs/HTSAT-large.json +23 -0
  31. src/laion_clap/clap_module/model_configs/HTSAT-tiny-win-1536.json +23 -0
  32. src/laion_clap/clap_module/model_configs/HTSAT-tiny.json +23 -0
  33. src/laion_clap/clap_module/model_configs/PANN-10.json +23 -0
  34. src/laion_clap/clap_module/model_configs/PANN-14-fmax-18k.json +23 -0
  35. src/laion_clap/clap_module/model_configs/PANN-14-fmax-8k-20s.json +23 -0
  36. src/laion_clap/clap_module/model_configs/PANN-14-tiny-transformer.json +23 -0
  37. src/laion_clap/clap_module/model_configs/PANN-14-win-1536.json +23 -0
  38. src/laion_clap/clap_module/model_configs/PANN-14.json +23 -0
  39. src/laion_clap/clap_module/model_configs/PANN-6.json +23 -0
  40. src/laion_clap/clap_module/model_configs/RN101-quickgelu.json +22 -0
  41. src/laion_clap/clap_module/model_configs/RN101.json +21 -0
  42. src/laion_clap/clap_module/model_configs/RN50-quickgelu.json +22 -0
  43. src/laion_clap/clap_module/model_configs/RN50.json +21 -0
  44. src/laion_clap/clap_module/model_configs/RN50x16.json +21 -0
  45. src/laion_clap/clap_module/model_configs/RN50x4.json +21 -0
  46. src/laion_clap/clap_module/model_configs/ViT-B-16.json +16 -0
  47. src/laion_clap/clap_module/model_configs/ViT-B-32-quickgelu.json +17 -0
  48. src/laion_clap/clap_module/model_configs/ViT-B-32.json +16 -0
  49. src/laion_clap/clap_module/model_configs/ViT-L-14.json +16 -0
  50. src/laion_clap/clap_module/openai.py +129 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .venv
2
+ .env
3
+ .cache
4
+ __pycache__
5
+ data/audio/*.wav
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit import session_state as session
3
+ from src.config.configs import ProjectPaths
4
+ import numpy as np
5
+ from src.laion_clap.inference import AudioEncoder
6
+
7
+
8
+ @st.cache(persist=True, show_spinner=False, suppress_st_warning=True)
9
+ def load_data():
10
+ vectors = np.load(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy"))
11
+ return vectors
12
+
13
+
14
+ recommender = AudioEncoder()
15
+ audio_vectors = load_data()
16
+
17
+ dataframe = None
18
+
19
+ st.title("""
20
+ Curate me a Playlist.
21
+ """)
22
+
23
+ st.text("")
24
+ st.text("")
25
+ st.text("")
26
+ st.text("")
27
+
28
+ session.text_input = st.text(label="Describe a playlist")
29
+
30
+ st.text("")
31
+ st.text("")
32
+
33
+ session.slider_count = st.slider(label="movie_count", min_value=5, max_value=50)
34
+
35
+ st.text("")
36
+ st.text("")
37
+
38
+ buffer1, col1, buffer2 = st.columns([1.45, 1, 1])
39
+
40
+ is_clicked = col1.button(label="Curate")
41
+
42
+ if is_clicked:
43
+ text_embed = recommender.get_text_embedding(session.text_input)
44
+
45
+
46
+ st.text("")
47
+ st.text("")
48
+ st.text("")
49
+ st.text("")
50
+
51
+ if dataframe is not None:
52
+ st.table(dataframe)
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/audio/.gitkeep ADDED
File without changes
data/json/saved_tracks.json ADDED
File without changes
data/vectors/audio_representations.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe4a3ff8cfd2a6b13407352868f3f74fb290ebc11e8473e7132dd4bf947108da
3
+ size 1290368
model_checkpoints/.gitkeep ADDED
File without changes
model_checkpoints/music_audioset_epoch_15_esc_90.14.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fae3e9c087f2909c28a09dc31c8dfcdacbc42ba44c70e972b58c1bd1caf6dedd
3
+ size 2352471003
notebooks/notebook.ipynb ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 8,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "The autoreload extension is already loaded. To reload it, use:\n",
13
+ " %reload_ext autoreload\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "%load_ext autoreload\n",
19
+ "%autoreload 2"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 9,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import numpy as np\n",
29
+ "import librosa\n",
30
+ "import torch\n",
31
+ "from src import laion_clap\n",
32
+ "from glob import glob\n",
33
+ "import pandas as pd\n"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 10,
39
+ "metadata": {},
40
+ "outputs": [
41
+ {
42
+ "name": "stderr",
43
+ "output_type": "stream",
44
+ "text": [
45
+ "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight']\n",
46
+ "- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
47
+ "- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
48
+ "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
49
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
50
+ ]
51
+ },
52
+ {
53
+ "name": "stdout",
54
+ "output_type": "stream",
55
+ "text": [
56
+ "Load the specified checkpoint music_audioset_epoch_15_esc_90.14.pt from users.\n",
57
+ "Load Checkpoint...\n",
58
+ "logit_scale_a \t Loaded\n",
59
+ "logit_scale_t \t Loaded\n",
60
+ "audio_branch.spectrogram_extractor.stft.conv_real.weight \t Loaded\n",
61
+ "audio_branch.spectrogram_extractor.stft.conv_imag.weight \t Loaded\n",
62
+ "audio_branch.logmel_extractor.melW \t Loaded\n",
63
+ "audio_branch.bn0.weight \t Loaded\n",
64
+ "audio_branch.bn0.bias \t Loaded\n",
65
+ "audio_branch.patch_embed.proj.weight \t Loaded\n",
66
+ "audio_branch.patch_embed.proj.bias \t Loaded\n",
67
+ "audio_branch.patch_embed.norm.weight \t Loaded\n",
68
+ "audio_branch.patch_embed.norm.bias \t Loaded\n",
69
+ "audio_branch.layers.0.blocks.0.norm1.weight \t Loaded\n",
70
+ "audio_branch.layers.0.blocks.0.norm1.bias \t Loaded\n",
71
+ "audio_branch.layers.0.blocks.0.attn.relative_position_bias_table \t Loaded\n",
72
+ "audio_branch.layers.0.blocks.0.attn.qkv.weight \t Loaded\n",
73
+ "audio_branch.layers.0.blocks.0.attn.qkv.bias \t Loaded\n",
74
+ "audio_branch.layers.0.blocks.0.attn.proj.weight \t Loaded\n",
75
+ "audio_branch.layers.0.blocks.0.attn.proj.bias \t Loaded\n",
76
+ "audio_branch.layers.0.blocks.0.norm2.weight \t Loaded\n",
77
+ "audio_branch.layers.0.blocks.0.norm2.bias \t Loaded\n",
78
+ "audio_branch.layers.0.blocks.0.mlp.fc1.weight \t Loaded\n",
79
+ "audio_branch.layers.0.blocks.0.mlp.fc1.bias \t Loaded\n",
80
+ "audio_branch.layers.0.blocks.0.mlp.fc2.weight \t Loaded\n",
81
+ "audio_branch.layers.0.blocks.0.mlp.fc2.bias \t Loaded\n",
82
+ "audio_branch.layers.0.blocks.1.norm1.weight \t Loaded\n",
83
+ "audio_branch.layers.0.blocks.1.norm1.bias \t Loaded\n",
84
+ "audio_branch.layers.0.blocks.1.attn.relative_position_bias_table \t Loaded\n",
85
+ "audio_branch.layers.0.blocks.1.attn.qkv.weight \t Loaded\n",
86
+ "audio_branch.layers.0.blocks.1.attn.qkv.bias \t Loaded\n",
87
+ "audio_branch.layers.0.blocks.1.attn.proj.weight \t Loaded\n",
88
+ "audio_branch.layers.0.blocks.1.attn.proj.bias \t Loaded\n",
89
+ "audio_branch.layers.0.blocks.1.norm2.weight \t Loaded\n",
90
+ "audio_branch.layers.0.blocks.1.norm2.bias \t Loaded\n",
91
+ "audio_branch.layers.0.blocks.1.mlp.fc1.weight \t Loaded\n",
92
+ "audio_branch.layers.0.blocks.1.mlp.fc1.bias \t Loaded\n",
93
+ "audio_branch.layers.0.blocks.1.mlp.fc2.weight \t Loaded\n",
94
+ "audio_branch.layers.0.blocks.1.mlp.fc2.bias \t Loaded\n",
95
+ "audio_branch.layers.0.downsample.reduction.weight \t Loaded\n",
96
+ "audio_branch.layers.0.downsample.norm.weight \t Loaded\n",
97
+ "audio_branch.layers.0.downsample.norm.bias \t Loaded\n",
98
+ "audio_branch.layers.1.blocks.0.norm1.weight \t Loaded\n",
99
+ "audio_branch.layers.1.blocks.0.norm1.bias \t Loaded\n",
100
+ "audio_branch.layers.1.blocks.0.attn.relative_position_bias_table \t Loaded\n",
101
+ "audio_branch.layers.1.blocks.0.attn.qkv.weight \t Loaded\n",
102
+ "audio_branch.layers.1.blocks.0.attn.qkv.bias \t Loaded\n",
103
+ "audio_branch.layers.1.blocks.0.attn.proj.weight \t Loaded\n",
104
+ "audio_branch.layers.1.blocks.0.attn.proj.bias \t Loaded\n",
105
+ "audio_branch.layers.1.blocks.0.norm2.weight \t Loaded\n",
106
+ "audio_branch.layers.1.blocks.0.norm2.bias \t Loaded\n",
107
+ "audio_branch.layers.1.blocks.0.mlp.fc1.weight \t Loaded\n",
108
+ "audio_branch.layers.1.blocks.0.mlp.fc1.bias \t Loaded\n",
109
+ "audio_branch.layers.1.blocks.0.mlp.fc2.weight \t Loaded\n",
110
+ "audio_branch.layers.1.blocks.0.mlp.fc2.bias \t Loaded\n",
111
+ "audio_branch.layers.1.blocks.1.norm1.weight \t Loaded\n",
112
+ "audio_branch.layers.1.blocks.1.norm1.bias \t Loaded\n",
113
+ "audio_branch.layers.1.blocks.1.attn.relative_position_bias_table \t Loaded\n",
114
+ "audio_branch.layers.1.blocks.1.attn.qkv.weight \t Loaded\n",
115
+ "audio_branch.layers.1.blocks.1.attn.qkv.bias \t Loaded\n",
116
+ "audio_branch.layers.1.blocks.1.attn.proj.weight \t Loaded\n",
117
+ "audio_branch.layers.1.blocks.1.attn.proj.bias \t Loaded\n",
118
+ "audio_branch.layers.1.blocks.1.norm2.weight \t Loaded\n",
119
+ "audio_branch.layers.1.blocks.1.norm2.bias \t Loaded\n",
120
+ "audio_branch.layers.1.blocks.1.mlp.fc1.weight \t Loaded\n",
121
+ "audio_branch.layers.1.blocks.1.mlp.fc1.bias \t Loaded\n",
122
+ "audio_branch.layers.1.blocks.1.mlp.fc2.weight \t Loaded\n",
123
+ "audio_branch.layers.1.blocks.1.mlp.fc2.bias \t Loaded\n",
124
+ "audio_branch.layers.1.downsample.reduction.weight \t Loaded\n",
125
+ "audio_branch.layers.1.downsample.norm.weight \t Loaded\n",
126
+ "audio_branch.layers.1.downsample.norm.bias \t Loaded\n",
127
+ "audio_branch.layers.2.blocks.0.norm1.weight \t Loaded\n",
128
+ "audio_branch.layers.2.blocks.0.norm1.bias \t Loaded\n",
129
+ "audio_branch.layers.2.blocks.0.attn.relative_position_bias_table \t Loaded\n",
130
+ "audio_branch.layers.2.blocks.0.attn.qkv.weight \t Loaded\n",
131
+ "audio_branch.layers.2.blocks.0.attn.qkv.bias \t Loaded\n",
132
+ "audio_branch.layers.2.blocks.0.attn.proj.weight \t Loaded\n",
133
+ "audio_branch.layers.2.blocks.0.attn.proj.bias \t Loaded\n",
134
+ "audio_branch.layers.2.blocks.0.norm2.weight \t Loaded\n",
135
+ "audio_branch.layers.2.blocks.0.norm2.bias \t Loaded\n",
136
+ "audio_branch.layers.2.blocks.0.mlp.fc1.weight \t Loaded\n",
137
+ "audio_branch.layers.2.blocks.0.mlp.fc1.bias \t Loaded\n",
138
+ "audio_branch.layers.2.blocks.0.mlp.fc2.weight \t Loaded\n",
139
+ "audio_branch.layers.2.blocks.0.mlp.fc2.bias \t Loaded\n",
140
+ "audio_branch.layers.2.blocks.1.norm1.weight \t Loaded\n",
141
+ "audio_branch.layers.2.blocks.1.norm1.bias \t Loaded\n",
142
+ "audio_branch.layers.2.blocks.1.attn.relative_position_bias_table \t Loaded\n",
143
+ "audio_branch.layers.2.blocks.1.attn.qkv.weight \t Loaded\n",
144
+ "audio_branch.layers.2.blocks.1.attn.qkv.bias \t Loaded\n",
145
+ "audio_branch.layers.2.blocks.1.attn.proj.weight \t Loaded\n",
146
+ "audio_branch.layers.2.blocks.1.attn.proj.bias \t Loaded\n",
147
+ "audio_branch.layers.2.blocks.1.norm2.weight \t Loaded\n",
148
+ "audio_branch.layers.2.blocks.1.norm2.bias \t Loaded\n",
149
+ "audio_branch.layers.2.blocks.1.mlp.fc1.weight \t Loaded\n",
150
+ "audio_branch.layers.2.blocks.1.mlp.fc1.bias \t Loaded\n",
151
+ "audio_branch.layers.2.blocks.1.mlp.fc2.weight \t Loaded\n",
152
+ "audio_branch.layers.2.blocks.1.mlp.fc2.bias \t Loaded\n",
153
+ "audio_branch.layers.2.blocks.2.norm1.weight \t Loaded\n",
154
+ "audio_branch.layers.2.blocks.2.norm1.bias \t Loaded\n",
155
+ "audio_branch.layers.2.blocks.2.attn.relative_position_bias_table \t Loaded\n",
156
+ "audio_branch.layers.2.blocks.2.attn.qkv.weight \t Loaded\n",
157
+ "audio_branch.layers.2.blocks.2.attn.qkv.bias \t Loaded\n",
158
+ "audio_branch.layers.2.blocks.2.attn.proj.weight \t Loaded\n",
159
+ "audio_branch.layers.2.blocks.2.attn.proj.bias \t Loaded\n",
160
+ "audio_branch.layers.2.blocks.2.norm2.weight \t Loaded\n",
161
+ "audio_branch.layers.2.blocks.2.norm2.bias \t Loaded\n",
162
+ "audio_branch.layers.2.blocks.2.mlp.fc1.weight \t Loaded\n",
163
+ "audio_branch.layers.2.blocks.2.mlp.fc1.bias \t Loaded\n",
164
+ "audio_branch.layers.2.blocks.2.mlp.fc2.weight \t Loaded\n",
165
+ "audio_branch.layers.2.blocks.2.mlp.fc2.bias \t Loaded\n",
166
+ "audio_branch.layers.2.blocks.3.norm1.weight \t Loaded\n",
167
+ "audio_branch.layers.2.blocks.3.norm1.bias \t Loaded\n",
168
+ "audio_branch.layers.2.blocks.3.attn.relative_position_bias_table \t Loaded\n",
169
+ "audio_branch.layers.2.blocks.3.attn.qkv.weight \t Loaded\n",
170
+ "audio_branch.layers.2.blocks.3.attn.qkv.bias \t Loaded\n",
171
+ "audio_branch.layers.2.blocks.3.attn.proj.weight \t Loaded\n",
172
+ "audio_branch.layers.2.blocks.3.attn.proj.bias \t Loaded\n",
173
+ "audio_branch.layers.2.blocks.3.norm2.weight \t Loaded\n",
174
+ "audio_branch.layers.2.blocks.3.norm2.bias \t Loaded\n",
175
+ "audio_branch.layers.2.blocks.3.mlp.fc1.weight \t Loaded\n",
176
+ "audio_branch.layers.2.blocks.3.mlp.fc1.bias \t Loaded\n",
177
+ "audio_branch.layers.2.blocks.3.mlp.fc2.weight \t Loaded\n",
178
+ "audio_branch.layers.2.blocks.3.mlp.fc2.bias \t Loaded\n",
179
+ "audio_branch.layers.2.blocks.4.norm1.weight \t Loaded\n",
180
+ "audio_branch.layers.2.blocks.4.norm1.bias \t Loaded\n",
181
+ "audio_branch.layers.2.blocks.4.attn.relative_position_bias_table \t Loaded\n",
182
+ "audio_branch.layers.2.blocks.4.attn.qkv.weight \t Loaded\n",
183
+ "audio_branch.layers.2.blocks.4.attn.qkv.bias \t Loaded\n",
184
+ "audio_branch.layers.2.blocks.4.attn.proj.weight \t Loaded\n",
185
+ "audio_branch.layers.2.blocks.4.attn.proj.bias \t Loaded\n",
186
+ "audio_branch.layers.2.blocks.4.norm2.weight \t Loaded\n",
187
+ "audio_branch.layers.2.blocks.4.norm2.bias \t Loaded\n",
188
+ "audio_branch.layers.2.blocks.4.mlp.fc1.weight \t Loaded\n",
189
+ "audio_branch.layers.2.blocks.4.mlp.fc1.bias \t Loaded\n",
190
+ "audio_branch.layers.2.blocks.4.mlp.fc2.weight \t Loaded\n",
191
+ "audio_branch.layers.2.blocks.4.mlp.fc2.bias \t Loaded\n",
192
+ "audio_branch.layers.2.blocks.5.norm1.weight \t Loaded\n",
193
+ "audio_branch.layers.2.blocks.5.norm1.bias \t Loaded\n",
194
+ "audio_branch.layers.2.blocks.5.attn.relative_position_bias_table \t Loaded\n",
195
+ "audio_branch.layers.2.blocks.5.attn.qkv.weight \t Loaded\n",
196
+ "audio_branch.layers.2.blocks.5.attn.qkv.bias \t Loaded\n",
197
+ "audio_branch.layers.2.blocks.5.attn.proj.weight \t Loaded\n",
198
+ "audio_branch.layers.2.blocks.5.attn.proj.bias \t Loaded\n",
199
+ "audio_branch.layers.2.blocks.5.norm2.weight \t Loaded\n",
200
+ "audio_branch.layers.2.blocks.5.norm2.bias \t Loaded\n",
201
+ "audio_branch.layers.2.blocks.5.mlp.fc1.weight \t Loaded\n",
202
+ "audio_branch.layers.2.blocks.5.mlp.fc1.bias \t Loaded\n",
203
+ "audio_branch.layers.2.blocks.5.mlp.fc2.weight \t Loaded\n",
204
+ "audio_branch.layers.2.blocks.5.mlp.fc2.bias \t Loaded\n",
205
+ "audio_branch.layers.2.blocks.6.norm1.weight \t Loaded\n",
206
+ "audio_branch.layers.2.blocks.6.norm1.bias \t Loaded\n",
207
+ "audio_branch.layers.2.blocks.6.attn.relative_position_bias_table \t Loaded\n",
208
+ "audio_branch.layers.2.blocks.6.attn.qkv.weight \t Loaded\n",
209
+ "audio_branch.layers.2.blocks.6.attn.qkv.bias \t Loaded\n",
210
+ "audio_branch.layers.2.blocks.6.attn.proj.weight \t Loaded\n",
211
+ "audio_branch.layers.2.blocks.6.attn.proj.bias \t Loaded\n",
212
+ "audio_branch.layers.2.blocks.6.norm2.weight \t Loaded\n",
213
+ "audio_branch.layers.2.blocks.6.norm2.bias \t Loaded\n",
214
+ "audio_branch.layers.2.blocks.6.mlp.fc1.weight \t Loaded\n",
215
+ "audio_branch.layers.2.blocks.6.mlp.fc1.bias \t Loaded\n",
216
+ "audio_branch.layers.2.blocks.6.mlp.fc2.weight \t Loaded\n",
217
+ "audio_branch.layers.2.blocks.6.mlp.fc2.bias \t Loaded\n",
218
+ "audio_branch.layers.2.blocks.7.norm1.weight \t Loaded\n",
219
+ "audio_branch.layers.2.blocks.7.norm1.bias \t Loaded\n",
220
+ "audio_branch.layers.2.blocks.7.attn.relative_position_bias_table \t Loaded\n",
221
+ "audio_branch.layers.2.blocks.7.attn.qkv.weight \t Loaded\n",
222
+ "audio_branch.layers.2.blocks.7.attn.qkv.bias \t Loaded\n",
223
+ "audio_branch.layers.2.blocks.7.attn.proj.weight \t Loaded\n",
224
+ "audio_branch.layers.2.blocks.7.attn.proj.bias \t Loaded\n",
225
+ "audio_branch.layers.2.blocks.7.norm2.weight \t Loaded\n",
226
+ "audio_branch.layers.2.blocks.7.norm2.bias \t Loaded\n",
227
+ "audio_branch.layers.2.blocks.7.mlp.fc1.weight \t Loaded\n",
228
+ "audio_branch.layers.2.blocks.7.mlp.fc1.bias \t Loaded\n",
229
+ "audio_branch.layers.2.blocks.7.mlp.fc2.weight \t Loaded\n",
230
+ "audio_branch.layers.2.blocks.7.mlp.fc2.bias \t Loaded\n",
231
+ "audio_branch.layers.2.blocks.8.norm1.weight \t Loaded\n",
232
+ "audio_branch.layers.2.blocks.8.norm1.bias \t Loaded\n",
233
+ "audio_branch.layers.2.blocks.8.attn.relative_position_bias_table \t Loaded\n",
234
+ "audio_branch.layers.2.blocks.8.attn.qkv.weight \t Loaded\n",
235
+ "audio_branch.layers.2.blocks.8.attn.qkv.bias \t Loaded\n",
236
+ "audio_branch.layers.2.blocks.8.attn.proj.weight \t Loaded\n",
237
+ "audio_branch.layers.2.blocks.8.attn.proj.bias \t Loaded\n",
238
+ "audio_branch.layers.2.blocks.8.norm2.weight \t Loaded\n",
239
+ "audio_branch.layers.2.blocks.8.norm2.bias \t Loaded\n",
240
+ "audio_branch.layers.2.blocks.8.mlp.fc1.weight \t Loaded\n",
241
+ "audio_branch.layers.2.blocks.8.mlp.fc1.bias \t Loaded\n",
242
+ "audio_branch.layers.2.blocks.8.mlp.fc2.weight \t Loaded\n",
243
+ "audio_branch.layers.2.blocks.8.mlp.fc2.bias \t Loaded\n",
244
+ "audio_branch.layers.2.blocks.9.norm1.weight \t Loaded\n",
245
+ "audio_branch.layers.2.blocks.9.norm1.bias \t Loaded\n",
246
+ "audio_branch.layers.2.blocks.9.attn.relative_position_bias_table \t Loaded\n",
247
+ "audio_branch.layers.2.blocks.9.attn.qkv.weight \t Loaded\n",
248
+ "audio_branch.layers.2.blocks.9.attn.qkv.bias \t Loaded\n",
249
+ "audio_branch.layers.2.blocks.9.attn.proj.weight \t Loaded\n",
250
+ "audio_branch.layers.2.blocks.9.attn.proj.bias \t Loaded\n",
251
+ "audio_branch.layers.2.blocks.9.norm2.weight \t Loaded\n",
252
+ "audio_branch.layers.2.blocks.9.norm2.bias \t Loaded\n",
253
+ "audio_branch.layers.2.blocks.9.mlp.fc1.weight \t Loaded\n",
254
+ "audio_branch.layers.2.blocks.9.mlp.fc1.bias \t Loaded\n",
255
+ "audio_branch.layers.2.blocks.9.mlp.fc2.weight \t Loaded\n",
256
+ "audio_branch.layers.2.blocks.9.mlp.fc2.bias \t Loaded\n",
257
+ "audio_branch.layers.2.blocks.10.norm1.weight \t Loaded\n",
258
+ "audio_branch.layers.2.blocks.10.norm1.bias \t Loaded\n",
259
+ "audio_branch.layers.2.blocks.10.attn.relative_position_bias_table \t Loaded\n",
260
+ "audio_branch.layers.2.blocks.10.attn.qkv.weight \t Loaded\n",
261
+ "audio_branch.layers.2.blocks.10.attn.qkv.bias \t Loaded\n",
262
+ "audio_branch.layers.2.blocks.10.attn.proj.weight \t Loaded\n",
263
+ "audio_branch.layers.2.blocks.10.attn.proj.bias \t Loaded\n",
264
+ "audio_branch.layers.2.blocks.10.norm2.weight \t Loaded\n",
265
+ "audio_branch.layers.2.blocks.10.norm2.bias \t Loaded\n",
266
+ "audio_branch.layers.2.blocks.10.mlp.fc1.weight \t Loaded\n",
267
+ "audio_branch.layers.2.blocks.10.mlp.fc1.bias \t Loaded\n",
268
+ "audio_branch.layers.2.blocks.10.mlp.fc2.weight \t Loaded\n",
269
+ "audio_branch.layers.2.blocks.10.mlp.fc2.bias \t Loaded\n",
270
+ "audio_branch.layers.2.blocks.11.norm1.weight \t Loaded\n",
271
+ "audio_branch.layers.2.blocks.11.norm1.bias \t Loaded\n",
272
+ "audio_branch.layers.2.blocks.11.attn.relative_position_bias_table \t Loaded\n",
273
+ "audio_branch.layers.2.blocks.11.attn.qkv.weight \t Loaded\n",
274
+ "audio_branch.layers.2.blocks.11.attn.qkv.bias \t Loaded\n",
275
+ "audio_branch.layers.2.blocks.11.attn.proj.weight \t Loaded\n",
276
+ "audio_branch.layers.2.blocks.11.attn.proj.bias \t Loaded\n",
277
+ "audio_branch.layers.2.blocks.11.norm2.weight \t Loaded\n",
278
+ "audio_branch.layers.2.blocks.11.norm2.bias \t Loaded\n",
279
+ "audio_branch.layers.2.blocks.11.mlp.fc1.weight \t Loaded\n",
280
+ "audio_branch.layers.2.blocks.11.mlp.fc1.bias \t Loaded\n",
281
+ "audio_branch.layers.2.blocks.11.mlp.fc2.weight \t Loaded\n",
282
+ "audio_branch.layers.2.blocks.11.mlp.fc2.bias \t Loaded\n",
283
+ "audio_branch.layers.2.downsample.reduction.weight \t Loaded\n",
284
+ "audio_branch.layers.2.downsample.norm.weight \t Loaded\n",
285
+ "audio_branch.layers.2.downsample.norm.bias \t Loaded\n",
286
+ "audio_branch.layers.3.blocks.0.norm1.weight \t Loaded\n",
287
+ "audio_branch.layers.3.blocks.0.norm1.bias \t Loaded\n",
288
+ "audio_branch.layers.3.blocks.0.attn.relative_position_bias_table \t Loaded\n",
289
+ "audio_branch.layers.3.blocks.0.attn.qkv.weight \t Loaded\n",
290
+ "audio_branch.layers.3.blocks.0.attn.qkv.bias \t Loaded\n",
291
+ "audio_branch.layers.3.blocks.0.attn.proj.weight \t Loaded\n",
292
+ "audio_branch.layers.3.blocks.0.attn.proj.bias \t Loaded\n",
293
+ "audio_branch.layers.3.blocks.0.norm2.weight \t Loaded\n",
294
+ "audio_branch.layers.3.blocks.0.norm2.bias \t Loaded\n",
295
+ "audio_branch.layers.3.blocks.0.mlp.fc1.weight \t Loaded\n",
296
+ "audio_branch.layers.3.blocks.0.mlp.fc1.bias \t Loaded\n",
297
+ "audio_branch.layers.3.blocks.0.mlp.fc2.weight \t Loaded\n",
298
+ "audio_branch.layers.3.blocks.0.mlp.fc2.bias \t Loaded\n",
299
+ "audio_branch.layers.3.blocks.1.norm1.weight \t Loaded\n",
300
+ "audio_branch.layers.3.blocks.1.norm1.bias \t Loaded\n",
301
+ "audio_branch.layers.3.blocks.1.attn.relative_position_bias_table \t Loaded\n",
302
+ "audio_branch.layers.3.blocks.1.attn.qkv.weight \t Loaded\n",
303
+ "audio_branch.layers.3.blocks.1.attn.qkv.bias \t Loaded\n",
304
+ "audio_branch.layers.3.blocks.1.attn.proj.weight \t Loaded\n",
305
+ "audio_branch.layers.3.blocks.1.attn.proj.bias \t Loaded\n",
306
+ "audio_branch.layers.3.blocks.1.norm2.weight \t Loaded\n",
307
+ "audio_branch.layers.3.blocks.1.norm2.bias \t Loaded\n",
308
+ "audio_branch.layers.3.blocks.1.mlp.fc1.weight \t Loaded\n",
309
+ "audio_branch.layers.3.blocks.1.mlp.fc1.bias \t Loaded\n",
310
+ "audio_branch.layers.3.blocks.1.mlp.fc2.weight \t Loaded\n",
311
+ "audio_branch.layers.3.blocks.1.mlp.fc2.bias \t Loaded\n",
312
+ "audio_branch.norm.weight \t Loaded\n",
313
+ "audio_branch.norm.bias \t Loaded\n",
314
+ "audio_branch.tscam_conv.weight \t Loaded\n",
315
+ "audio_branch.tscam_conv.bias \t Loaded\n",
316
+ "audio_branch.head.weight \t Loaded\n",
317
+ "audio_branch.head.bias \t Loaded\n",
318
+ "text_branch.embeddings.word_embeddings.weight \t Loaded\n",
319
+ "text_branch.embeddings.position_embeddings.weight \t Loaded\n",
320
+ "text_branch.embeddings.token_type_embeddings.weight \t Loaded\n",
321
+ "text_branch.embeddings.LayerNorm.weight \t Loaded\n",
322
+ "text_branch.embeddings.LayerNorm.bias \t Loaded\n",
323
+ "text_branch.encoder.layer.0.attention.self.query.weight \t Loaded\n",
324
+ "text_branch.encoder.layer.0.attention.self.query.bias \t Loaded\n",
325
+ "text_branch.encoder.layer.0.attention.self.key.weight \t Loaded\n",
326
+ "text_branch.encoder.layer.0.attention.self.key.bias \t Loaded\n",
327
+ "text_branch.encoder.layer.0.attention.self.value.weight \t Loaded\n",
328
+ "text_branch.encoder.layer.0.attention.self.value.bias \t Loaded\n",
329
+ "text_branch.encoder.layer.0.attention.output.dense.weight \t Loaded\n",
330
+ "text_branch.encoder.layer.0.attention.output.dense.bias \t Loaded\n",
331
+ "text_branch.encoder.layer.0.attention.output.LayerNorm.weight \t Loaded\n",
332
+ "text_branch.encoder.layer.0.attention.output.LayerNorm.bias \t Loaded\n",
333
+ "text_branch.encoder.layer.0.intermediate.dense.weight \t Loaded\n",
334
+ "text_branch.encoder.layer.0.intermediate.dense.bias \t Loaded\n",
335
+ "text_branch.encoder.layer.0.output.dense.weight \t Loaded\n",
336
+ "text_branch.encoder.layer.0.output.dense.bias \t Loaded\n",
337
+ "text_branch.encoder.layer.0.output.LayerNorm.weight \t Loaded\n",
338
+ "text_branch.encoder.layer.0.output.LayerNorm.bias \t Loaded\n",
339
+ "text_branch.encoder.layer.1.attention.self.query.weight \t Loaded\n",
340
+ "text_branch.encoder.layer.1.attention.self.query.bias \t Loaded\n",
341
+ "text_branch.encoder.layer.1.attention.self.key.weight \t Loaded\n",
342
+ "text_branch.encoder.layer.1.attention.self.key.bias \t Loaded\n",
343
+ "text_branch.encoder.layer.1.attention.self.value.weight \t Loaded\n",
344
+ "text_branch.encoder.layer.1.attention.self.value.bias \t Loaded\n",
345
+ "text_branch.encoder.layer.1.attention.output.dense.weight \t Loaded\n",
346
+ "text_branch.encoder.layer.1.attention.output.dense.bias \t Loaded\n",
347
+ "text_branch.encoder.layer.1.attention.output.LayerNorm.weight \t Loaded\n",
348
+ "text_branch.encoder.layer.1.attention.output.LayerNorm.bias \t Loaded\n",
349
+ "text_branch.encoder.layer.1.intermediate.dense.weight \t Loaded\n",
350
+ "text_branch.encoder.layer.1.intermediate.dense.bias \t Loaded\n",
351
+ "text_branch.encoder.layer.1.output.dense.weight \t Loaded\n",
352
+ "text_branch.encoder.layer.1.output.dense.bias \t Loaded\n",
353
+ "text_branch.encoder.layer.1.output.LayerNorm.weight \t Loaded\n",
354
+ "text_branch.encoder.layer.1.output.LayerNorm.bias \t Loaded\n",
355
+ "text_branch.encoder.layer.2.attention.self.query.weight \t Loaded\n",
356
+ "text_branch.encoder.layer.2.attention.self.query.bias \t Loaded\n",
357
+ "text_branch.encoder.layer.2.attention.self.key.weight \t Loaded\n",
358
+ "text_branch.encoder.layer.2.attention.self.key.bias \t Loaded\n",
359
+ "text_branch.encoder.layer.2.attention.self.value.weight \t Loaded\n",
360
+ "text_branch.encoder.layer.2.attention.self.value.bias \t Loaded\n",
361
+ "text_branch.encoder.layer.2.attention.output.dense.weight \t Loaded\n",
362
+ "text_branch.encoder.layer.2.attention.output.dense.bias \t Loaded\n",
363
+ "text_branch.encoder.layer.2.attention.output.LayerNorm.weight \t Loaded\n",
364
+ "text_branch.encoder.layer.2.attention.output.LayerNorm.bias \t Loaded\n",
365
+ "text_branch.encoder.layer.2.intermediate.dense.weight \t Loaded\n",
366
+ "text_branch.encoder.layer.2.intermediate.dense.bias \t Loaded\n",
367
+ "text_branch.encoder.layer.2.output.dense.weight \t Loaded\n",
368
+ "text_branch.encoder.layer.2.output.dense.bias \t Loaded\n",
369
+ "text_branch.encoder.layer.2.output.LayerNorm.weight \t Loaded\n",
370
+ "text_branch.encoder.layer.2.output.LayerNorm.bias \t Loaded\n",
371
+ "text_branch.encoder.layer.3.attention.self.query.weight \t Loaded\n",
372
+ "text_branch.encoder.layer.3.attention.self.query.bias \t Loaded\n",
373
+ "text_branch.encoder.layer.3.attention.self.key.weight \t Loaded\n",
374
+ "text_branch.encoder.layer.3.attention.self.key.bias \t Loaded\n",
375
+ "text_branch.encoder.layer.3.attention.self.value.weight \t Loaded\n",
376
+ "text_branch.encoder.layer.3.attention.self.value.bias \t Loaded\n",
377
+ "text_branch.encoder.layer.3.attention.output.dense.weight \t Loaded\n",
378
+ "text_branch.encoder.layer.3.attention.output.dense.bias \t Loaded\n",
379
+ "text_branch.encoder.layer.3.attention.output.LayerNorm.weight \t Loaded\n",
380
+ "text_branch.encoder.layer.3.attention.output.LayerNorm.bias \t Loaded\n",
381
+ "text_branch.encoder.layer.3.intermediate.dense.weight \t Loaded\n",
382
+ "text_branch.encoder.layer.3.intermediate.dense.bias \t Loaded\n",
383
+ "text_branch.encoder.layer.3.output.dense.weight \t Loaded\n",
384
+ "text_branch.encoder.layer.3.output.dense.bias \t Loaded\n",
385
+ "text_branch.encoder.layer.3.output.LayerNorm.weight \t Loaded\n",
386
+ "text_branch.encoder.layer.3.output.LayerNorm.bias \t Loaded\n",
387
+ "text_branch.encoder.layer.4.attention.self.query.weight \t Loaded\n",
388
+ "text_branch.encoder.layer.4.attention.self.query.bias \t Loaded\n",
389
+ "text_branch.encoder.layer.4.attention.self.key.weight \t Loaded\n",
390
+ "text_branch.encoder.layer.4.attention.self.key.bias \t Loaded\n",
391
+ "text_branch.encoder.layer.4.attention.self.value.weight \t Loaded\n",
392
+ "text_branch.encoder.layer.4.attention.self.value.bias \t Loaded\n",
393
+ "text_branch.encoder.layer.4.attention.output.dense.weight \t Loaded\n",
394
+ "text_branch.encoder.layer.4.attention.output.dense.bias \t Loaded\n",
395
+ "text_branch.encoder.layer.4.attention.output.LayerNorm.weight \t Loaded\n",
396
+ "text_branch.encoder.layer.4.attention.output.LayerNorm.bias \t Loaded\n",
397
+ "text_branch.encoder.layer.4.intermediate.dense.weight \t Loaded\n",
398
+ "text_branch.encoder.layer.4.intermediate.dense.bias \t Loaded\n",
399
+ "text_branch.encoder.layer.4.output.dense.weight \t Loaded\n",
400
+ "text_branch.encoder.layer.4.output.dense.bias \t Loaded\n",
401
+ "text_branch.encoder.layer.4.output.LayerNorm.weight \t Loaded\n",
402
+ "text_branch.encoder.layer.4.output.LayerNorm.bias \t Loaded\n",
403
+ "text_branch.encoder.layer.5.attention.self.query.weight \t Loaded\n",
404
+ "text_branch.encoder.layer.5.attention.self.query.bias \t Loaded\n",
405
+ "text_branch.encoder.layer.5.attention.self.key.weight \t Loaded\n",
406
+ "text_branch.encoder.layer.5.attention.self.key.bias \t Loaded\n",
407
+ "text_branch.encoder.layer.5.attention.self.value.weight \t Loaded\n",
408
+ "text_branch.encoder.layer.5.attention.self.value.bias \t Loaded\n",
409
+ "text_branch.encoder.layer.5.attention.output.dense.weight \t Loaded\n",
410
+ "text_branch.encoder.layer.5.attention.output.dense.bias \t Loaded\n",
411
+ "text_branch.encoder.layer.5.attention.output.LayerNorm.weight \t Loaded\n",
412
+ "text_branch.encoder.layer.5.attention.output.LayerNorm.bias \t Loaded\n",
413
+ "text_branch.encoder.layer.5.intermediate.dense.weight \t Loaded\n",
414
+ "text_branch.encoder.layer.5.intermediate.dense.bias \t Loaded\n",
415
+ "text_branch.encoder.layer.5.output.dense.weight \t Loaded\n",
416
+ "text_branch.encoder.layer.5.output.dense.bias \t Loaded\n",
417
+ "text_branch.encoder.layer.5.output.LayerNorm.weight \t Loaded\n",
418
+ "text_branch.encoder.layer.5.output.LayerNorm.bias \t Loaded\n",
419
+ "text_branch.encoder.layer.6.attention.self.query.weight \t Loaded\n",
420
+ "text_branch.encoder.layer.6.attention.self.query.bias \t Loaded\n",
421
+ "text_branch.encoder.layer.6.attention.self.key.weight \t Loaded\n",
422
+ "text_branch.encoder.layer.6.attention.self.key.bias \t Loaded\n",
423
+ "text_branch.encoder.layer.6.attention.self.value.weight \t Loaded\n",
424
+ "text_branch.encoder.layer.6.attention.self.value.bias \t Loaded\n",
425
+ "text_branch.encoder.layer.6.attention.output.dense.weight \t Loaded\n",
426
+ "text_branch.encoder.layer.6.attention.output.dense.bias \t Loaded\n",
427
+ "text_branch.encoder.layer.6.attention.output.LayerNorm.weight \t Loaded\n",
428
+ "text_branch.encoder.layer.6.attention.output.LayerNorm.bias \t Loaded\n",
429
+ "text_branch.encoder.layer.6.intermediate.dense.weight \t Loaded\n",
430
+ "text_branch.encoder.layer.6.intermediate.dense.bias \t Loaded\n",
431
+ "text_branch.encoder.layer.6.output.dense.weight \t Loaded\n",
432
+ "text_branch.encoder.layer.6.output.dense.bias \t Loaded\n",
433
+ "text_branch.encoder.layer.6.output.LayerNorm.weight \t Loaded\n",
434
+ "text_branch.encoder.layer.6.output.LayerNorm.bias \t Loaded\n",
435
+ "text_branch.encoder.layer.7.attention.self.query.weight \t Loaded\n",
436
+ "text_branch.encoder.layer.7.attention.self.query.bias \t Loaded\n",
437
+ "text_branch.encoder.layer.7.attention.self.key.weight \t Loaded\n",
438
+ "text_branch.encoder.layer.7.attention.self.key.bias \t Loaded\n",
439
+ "text_branch.encoder.layer.7.attention.self.value.weight \t Loaded\n",
440
+ "text_branch.encoder.layer.7.attention.self.value.bias \t Loaded\n",
441
+ "text_branch.encoder.layer.7.attention.output.dense.weight \t Loaded\n",
442
+ "text_branch.encoder.layer.7.attention.output.dense.bias \t Loaded\n",
443
+ "text_branch.encoder.layer.7.attention.output.LayerNorm.weight \t Loaded\n",
444
+ "text_branch.encoder.layer.7.attention.output.LayerNorm.bias \t Loaded\n",
445
+ "text_branch.encoder.layer.7.intermediate.dense.weight \t Loaded\n",
446
+ "text_branch.encoder.layer.7.intermediate.dense.bias \t Loaded\n",
447
+ "text_branch.encoder.layer.7.output.dense.weight \t Loaded\n",
448
+ "text_branch.encoder.layer.7.output.dense.bias \t Loaded\n",
449
+ "text_branch.encoder.layer.7.output.LayerNorm.weight \t Loaded\n",
450
+ "text_branch.encoder.layer.7.output.LayerNorm.bias \t Loaded\n",
451
+ "text_branch.encoder.layer.8.attention.self.query.weight \t Loaded\n",
452
+ "text_branch.encoder.layer.8.attention.self.query.bias \t Loaded\n",
453
+ "text_branch.encoder.layer.8.attention.self.key.weight \t Loaded\n",
454
+ "text_branch.encoder.layer.8.attention.self.key.bias \t Loaded\n",
455
+ "text_branch.encoder.layer.8.attention.self.value.weight \t Loaded\n",
456
+ "text_branch.encoder.layer.8.attention.self.value.bias \t Loaded\n",
457
+ "text_branch.encoder.layer.8.attention.output.dense.weight \t Loaded\n",
458
+ "text_branch.encoder.layer.8.attention.output.dense.bias \t Loaded\n",
459
+ "text_branch.encoder.layer.8.attention.output.LayerNorm.weight \t Loaded\n",
460
+ "text_branch.encoder.layer.8.attention.output.LayerNorm.bias \t Loaded\n",
461
+ "text_branch.encoder.layer.8.intermediate.dense.weight \t Loaded\n",
462
+ "text_branch.encoder.layer.8.intermediate.dense.bias \t Loaded\n",
463
+ "text_branch.encoder.layer.8.output.dense.weight \t Loaded\n",
464
+ "text_branch.encoder.layer.8.output.dense.bias \t Loaded\n",
465
+ "text_branch.encoder.layer.8.output.LayerNorm.weight \t Loaded\n",
466
+ "text_branch.encoder.layer.8.output.LayerNorm.bias \t Loaded\n",
467
+ "text_branch.encoder.layer.9.attention.self.query.weight \t Loaded\n",
468
+ "text_branch.encoder.layer.9.attention.self.query.bias \t Loaded\n",
469
+ "text_branch.encoder.layer.9.attention.self.key.weight \t Loaded\n",
470
+ "text_branch.encoder.layer.9.attention.self.key.bias \t Loaded\n",
471
+ "text_branch.encoder.layer.9.attention.self.value.weight \t Loaded\n",
472
+ "text_branch.encoder.layer.9.attention.self.value.bias \t Loaded\n",
473
+ "text_branch.encoder.layer.9.attention.output.dense.weight \t Loaded\n",
474
+ "text_branch.encoder.layer.9.attention.output.dense.bias \t Loaded\n",
475
+ "text_branch.encoder.layer.9.attention.output.LayerNorm.weight \t Loaded\n",
476
+ "text_branch.encoder.layer.9.attention.output.LayerNorm.bias \t Loaded\n",
477
+ "text_branch.encoder.layer.9.intermediate.dense.weight \t Loaded\n",
478
+ "text_branch.encoder.layer.9.intermediate.dense.bias \t Loaded\n",
479
+ "text_branch.encoder.layer.9.output.dense.weight \t Loaded\n",
480
+ "text_branch.encoder.layer.9.output.dense.bias \t Loaded\n",
481
+ "text_branch.encoder.layer.9.output.LayerNorm.weight \t Loaded\n",
482
+ "text_branch.encoder.layer.9.output.LayerNorm.bias \t Loaded\n",
483
+ "text_branch.encoder.layer.10.attention.self.query.weight \t Loaded\n",
484
+ "text_branch.encoder.layer.10.attention.self.query.bias \t Loaded\n",
485
+ "text_branch.encoder.layer.10.attention.self.key.weight \t Loaded\n",
486
+ "text_branch.encoder.layer.10.attention.self.key.bias \t Loaded\n",
487
+ "text_branch.encoder.layer.10.attention.self.value.weight \t Loaded\n",
488
+ "text_branch.encoder.layer.10.attention.self.value.bias \t Loaded\n",
489
+ "text_branch.encoder.layer.10.attention.output.dense.weight \t Loaded\n",
490
+ "text_branch.encoder.layer.10.attention.output.dense.bias \t Loaded\n",
491
+ "text_branch.encoder.layer.10.attention.output.LayerNorm.weight \t Loaded\n",
492
+ "text_branch.encoder.layer.10.attention.output.LayerNorm.bias \t Loaded\n",
493
+ "text_branch.encoder.layer.10.intermediate.dense.weight \t Loaded\n",
494
+ "text_branch.encoder.layer.10.intermediate.dense.bias \t Loaded\n",
495
+ "text_branch.encoder.layer.10.output.dense.weight \t Loaded\n",
496
+ "text_branch.encoder.layer.10.output.dense.bias \t Loaded\n",
497
+ "text_branch.encoder.layer.10.output.LayerNorm.weight \t Loaded\n",
498
+ "text_branch.encoder.layer.10.output.LayerNorm.bias \t Loaded\n",
499
+ "text_branch.encoder.layer.11.attention.self.query.weight \t Loaded\n",
500
+ "text_branch.encoder.layer.11.attention.self.query.bias \t Loaded\n",
501
+ "text_branch.encoder.layer.11.attention.self.key.weight \t Loaded\n",
502
+ "text_branch.encoder.layer.11.attention.self.key.bias \t Loaded\n",
503
+ "text_branch.encoder.layer.11.attention.self.value.weight \t Loaded\n",
504
+ "text_branch.encoder.layer.11.attention.self.value.bias \t Loaded\n",
505
+ "text_branch.encoder.layer.11.attention.output.dense.weight \t Loaded\n",
506
+ "text_branch.encoder.layer.11.attention.output.dense.bias \t Loaded\n",
507
+ "text_branch.encoder.layer.11.attention.output.LayerNorm.weight \t Loaded\n",
508
+ "text_branch.encoder.layer.11.attention.output.LayerNorm.bias \t Loaded\n",
509
+ "text_branch.encoder.layer.11.intermediate.dense.weight \t Loaded\n",
510
+ "text_branch.encoder.layer.11.intermediate.dense.bias \t Loaded\n",
511
+ "text_branch.encoder.layer.11.output.dense.weight \t Loaded\n",
512
+ "text_branch.encoder.layer.11.output.dense.bias \t Loaded\n",
513
+ "text_branch.encoder.layer.11.output.LayerNorm.weight \t Loaded\n",
514
+ "text_branch.encoder.layer.11.output.LayerNorm.bias \t Loaded\n",
515
+ "text_branch.pooler.dense.weight \t Loaded\n",
516
+ "text_branch.pooler.dense.bias \t Loaded\n",
517
+ "text_transform.sequential.0.weight \t Loaded\n",
518
+ "text_transform.sequential.0.bias \t Loaded\n",
519
+ "text_transform.sequential.3.weight \t Loaded\n",
520
+ "text_transform.sequential.3.bias \t Loaded\n",
521
+ "text_projection.0.weight \t Loaded\n",
522
+ "text_projection.0.bias \t Loaded\n",
523
+ "text_projection.2.weight \t Loaded\n",
524
+ "text_projection.2.bias \t Loaded\n",
525
+ "audio_transform.sequential.0.weight \t Loaded\n",
526
+ "audio_transform.sequential.0.bias \t Loaded\n",
527
+ "audio_transform.sequential.3.weight \t Loaded\n",
528
+ "audio_transform.sequential.3.bias \t Loaded\n",
529
+ "audio_projection.0.weight \t Loaded\n",
530
+ "audio_projection.0.bias \t Loaded\n",
531
+ "audio_projection.2.weight \t Loaded\n",
532
+ "audio_projection.2.bias \t Loaded\n"
533
+ ]
534
+ }
535
+ ],
536
+ "source": [
537
+ "model = laion_clap.CLAP_Module(enable_fusion=False, amodel= 'HTSAT-base')\n",
538
+ "model.load_ckpt(ckpt=\"music_audioset_epoch_15_esc_90.14.pt\")"
539
+ ]
540
+ },
541
+ {
542
+ "cell_type": "code",
543
+ "execution_count": 11,
544
+ "metadata": {},
545
+ "outputs": [],
546
+ "source": [
547
+ "def load_music_file(file_name):\n",
548
+ " audio_data, _ = librosa.load(file_name, sr=48000) # sample rate should be 48000\n",
549
+ " audio_data = audio_data.reshape(1, -1) # Make it (1,T) or (N,T)\n",
550
+ " # audio_data = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() # quantize before send it in to the model\n",
551
+ " with torch.no_grad():\n",
552
+ " audio_embed = model.get_audio_embedding_from_data(x = audio_data, use_tensor=False)\n",
553
+ " return audio_embed\n"
554
+ ]
555
+ },
556
+ {
557
+ "cell_type": "code",
558
+ "execution_count": 12,
559
+ "metadata": {},
560
+ "outputs": [],
561
+ "source": [
562
+ "music_files = glob(\"/Users/berkayg/Codes/music-project/AudioCLIP/data/downloaded_tracks/*.wav\")[:100]"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "code",
567
+ "execution_count": 13,
568
+ "metadata": {},
569
+ "outputs": [
570
+ {
571
+ "name": "stderr",
572
+ "output_type": "stream",
573
+ "text": [
574
+ "/var/folders/sr/r72219hj06x_1xvw7hhd517h0000gn/T/ipykernel_18860/3009710654.py:2: UserWarning: PySoundFile failed. Trying audioread instead.\n",
575
+ " audio_data, _ = librosa.load(file_name, sr=48000) # sample rate should be 48000\n",
576
+ "/Users/berkayg/miniforge3/envs/playlist-curator/lib/python3.10/site-packages/librosa/core/audio.py:183: FutureWarning: librosa.core.audio.__audioread_load\n",
577
+ "\tDeprecated as of librosa version 0.10.0.\n",
578
+ "\tIt will be removed in librosa version 1.0.\n",
579
+ " y, sr_native = __audioread_load(path, offset, duration, dtype)\n"
580
+ ]
581
+ }
582
+ ],
583
+ "source": [
584
+ "music_data = np.zeros((len(music_files), 512), dtype=np.float32)\n",
585
+ "for m in range(music_data.shape[0]):\n",
586
+ " music_data[m] = load_music_file(music_files[m])\n"
587
+ ]
588
+ },
589
+ {
590
+ "cell_type": "code",
591
+ "execution_count": 14,
592
+ "metadata": {},
593
+ "outputs": [
594
+ {
595
+ "name": "stdout",
596
+ "output_type": "stream",
597
+ "text": [
598
+ "(1, 512)\n"
599
+ ]
600
+ }
601
+ ],
602
+ "source": [
603
+ "text_data = [\"This audio is a romantic song\"] \n",
604
+ "text_embed = model.get_text_embedding(text_data)\n",
605
+ "print(text_embed.shape)"
606
+ ]
607
+ },
608
+ {
609
+ "cell_type": "code",
610
+ "execution_count": 15,
611
+ "metadata": {},
612
+ "outputs": [],
613
+ "source": [
614
+ "song_names = [k.split(\"/\")[-1] for k in music_files]"
615
+ ]
616
+ },
617
+ {
618
+ "cell_type": "code",
619
+ "execution_count": 16,
620
+ "metadata": {},
621
+ "outputs": [
622
+ {
623
+ "name": "stdout",
624
+ "output_type": "stream",
625
+ "text": [
626
+ "torch.Size([100, 1])\n"
627
+ ]
628
+ }
629
+ ],
630
+ "source": [
631
+ "with torch.no_grad():\n",
632
+ " ranking = torch.tensor(music_data) @ torch.tensor(text_embed).t()\n",
633
+ " ranking = ranking[:, 0].reshape(-1, 1)\n",
634
+ "print(ranking.shape)"
635
+ ]
636
+ },
637
+ {
638
+ "cell_type": "code",
639
+ "execution_count": 14,
640
+ "metadata": {},
641
+ "outputs": [
642
+ {
643
+ "data": {
644
+ "text/html": [
645
+ "<div>\n",
646
+ "<style scoped>\n",
647
+ " .dataframe tbody tr th:only-of-type {\n",
648
+ " vertical-align: middle;\n",
649
+ " }\n",
650
+ "\n",
651
+ " .dataframe tbody tr th {\n",
652
+ " vertical-align: top;\n",
653
+ " }\n",
654
+ "\n",
655
+ " .dataframe thead th {\n",
656
+ " text-align: right;\n",
657
+ " }\n",
658
+ "</style>\n",
659
+ "<table border=\"1\" class=\"dataframe\">\n",
660
+ " <thead>\n",
661
+ " <tr style=\"text-align: right;\">\n",
662
+ " <th></th>\n",
663
+ " <th>This audio is a romantic song</th>\n",
664
+ " </tr>\n",
665
+ " </thead>\n",
666
+ " <tbody>\n",
667
+ " <tr>\n",
668
+ " <th>Coldplay - Charlie Brown.wav</th>\n",
669
+ " <td>0.400684</td>\n",
670
+ " </tr>\n",
671
+ " <tr>\n",
672
+ " <th>Sam Smith - I'm Not The Only One.wav</th>\n",
673
+ " <td>0.373561</td>\n",
674
+ " </tr>\n",
675
+ " <tr>\n",
676
+ " <th>Pink Floyd - The Great Gig In The Sky - 2011 Remastered Version.wav</th>\n",
677
+ " <td>0.371584</td>\n",
678
+ " </tr>\n",
679
+ " <tr>\n",
680
+ " <th>Christina Aguilera - You Lost Me.wav</th>\n",
681
+ " <td>0.370390</td>\n",
682
+ " </tr>\n",
683
+ " <tr>\n",
684
+ " <th>Lana Del Rey - Yayo.wav</th>\n",
685
+ " <td>0.370379</td>\n",
686
+ " </tr>\n",
687
+ " <tr>\n",
688
+ " <th>Queen - It's A Hard Life - Remastered 2011.wav</th>\n",
689
+ " <td>0.348699</td>\n",
690
+ " </tr>\n",
691
+ " <tr>\n",
692
+ " <th>Teoman - Haziran.wav</th>\n",
693
+ " <td>0.331220</td>\n",
694
+ " </tr>\n",
695
+ " <tr>\n",
696
+ " <th>John Lennon - Imagine - Remastered 2010.wav</th>\n",
697
+ " <td>0.330397</td>\n",
698
+ " </tr>\n",
699
+ " <tr>\n",
700
+ " <th>Sleeping At Last - Mars.wav</th>\n",
701
+ " <td>0.328770</td>\n",
702
+ " </tr>\n",
703
+ " <tr>\n",
704
+ " <th>Adele - Someone Like You.wav</th>\n",
705
+ " <td>0.325650</td>\n",
706
+ " </tr>\n",
707
+ " <tr>\n",
708
+ " <th>Coldplay - What If.wav</th>\n",
709
+ " <td>0.315717</td>\n",
710
+ " </tr>\n",
711
+ " <tr>\n",
712
+ " <th>Adamlar - Orda Ortada.wav</th>\n",
713
+ " <td>0.306465</td>\n",
714
+ " </tr>\n",
715
+ " <tr>\n",
716
+ " <th>Eric Clapton - Autumn Leaves.wav</th>\n",
717
+ " <td>0.305451</td>\n",
718
+ " </tr>\n",
719
+ " <tr>\n",
720
+ " <th>Premiata Forneria Marconi - Impressioni di settembre.wav</th>\n",
721
+ " <td>0.295878</td>\n",
722
+ " </tr>\n",
723
+ " <tr>\n",
724
+ " <th>Guthrie Govan - Lost in Rio.wav</th>\n",
725
+ " <td>0.284883</td>\n",
726
+ " </tr>\n",
727
+ " </tbody>\n",
728
+ "</table>\n",
729
+ "</div>"
730
+ ],
731
+ "text/plain": [
732
+ " This audio is a romantic song\n",
733
+ "Coldplay - Charlie Brown.wav 0.400684\n",
734
+ "Sam Smith - I'm Not The Only One.wav 0.373561\n",
735
+ "Pink Floyd - The Great Gig In The Sky - 2011 Re... 0.371584\n",
736
+ "Christina Aguilera - You Lost Me.wav 0.370390\n",
737
+ "Lana Del Rey - Yayo.wav 0.370379\n",
738
+ "Queen - It's A Hard Life - Remastered 2011.wav 0.348699\n",
739
+ "Teoman - Haziran.wav 0.331220\n",
740
+ "John Lennon - Imagine - Remastered 2010.wav 0.330397\n",
741
+ "Sleeping At Last - Mars.wav 0.328770\n",
742
+ "Adele - Someone Like You.wav 0.325650\n",
743
+ "Coldplay - What If.wav 0.315717\n",
744
+ "Adamlar - Orda Ortada.wav 0.306465\n",
745
+ "Eric Clapton - Autumn Leaves.wav 0.305451\n",
746
+ "Premiata Forneria Marconi - Impressioni di sett... 0.295878\n",
747
+ "Guthrie Govan - Lost in Rio.wav 0.284883"
748
+ ]
749
+ },
750
+ "execution_count": 14,
751
+ "metadata": {},
752
+ "output_type": "execute_result"
753
+ }
754
+ ],
755
+ "source": [
756
+ "pd.DataFrame(ranking, columns=[text_data[0]], index=song_names).nlargest(15, text_data[0])"
757
+ ]
758
+ },
759
+ {
760
+ "cell_type": "code",
761
+ "execution_count": null,
762
+ "metadata": {},
763
+ "outputs": [],
764
+ "source": []
765
+ }
766
+ ],
767
+ "metadata": {
768
+ "kernelspec": {
769
+ "display_name": "playlist-curator",
770
+ "language": "python",
771
+ "name": "python3"
772
+ },
773
+ "language_info": {
774
+ "codemirror_mode": {
775
+ "name": "ipython",
776
+ "version": 3
777
+ },
778
+ "file_extension": ".py",
779
+ "mimetype": "text/x-python",
780
+ "name": "python",
781
+ "nbconvert_exporter": "python",
782
+ "pygments_lexer": "ipython3",
783
+ "version": "3.10.13"
784
+ }
785
+ },
786
+ "nbformat": 4,
787
+ "nbformat_minor": 2
788
+ }
orchestrate_audio_data.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from src.data.spotify import list_personal_saved_tracks
2
+ from src.data.get_yt_links import collect_youtube_links
3
+ from src.data.pytuber import start_download_process
4
+
5
+ if __name__ == "__main__":
6
+ list_personal_saved_tracks()
7
+ collect_youtube_links()
8
+ start_download_process()
recommender.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.laion_clap.inference import AudioEncoder
2
+ from src.config.configs import ProjectPaths
3
+ from glob import glob
4
+
5
+ recommender = AudioEncoder()
6
+ # audio = recommender.extract_bulk_audio_representaions(save=False)
7
+ result = recommender.get_text_embedding("This audio is a romantic song")
8
+ music_files = glob(str(ProjectPaths.DATA_DIR.joinpath("audio", "*.wav")))
9
+ song_names = [k.split("/")[-1] for k in music_files]
10
+ print(result)
11
+ pass
requirements.txt ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.1.2
2
+ anyio==4.0.0
3
+ appdirs==1.4.4
4
+ async-timeout==4.0.3
5
+ attrs==23.1.0
6
+ audioread==3.0.1
7
+ blinker==1.7.0
8
+ braceexpand==0.1.7
9
+ cachetools==5.3.2
10
+ certifi==2023.7.22
11
+ cffi==1.16.0
12
+ charset-normalizer==3.3.2
13
+ click==8.1.7
14
+ docker-pycreds==0.4.0
15
+ filelock==3.13.1
16
+ fsspec==2023.10.0
17
+ ftfy==6.1.1
18
+ gitdb==4.0.11
19
+ GitPython==3.1.40
20
+ google-api-python-client==2.105.0
21
+ google-auth-httplib2==0.1.1
22
+ h11==0.14.0
23
+ h5py==3.10.0
24
+ httpcore==1.0.2
25
+ httplib2==0.22.0
26
+ httpx==0.25.1
27
+ huggingface-hub==0.19.4
28
+ idna==3.4
29
+ Jinja2==3.1.2
30
+ joblib==1.3.2
31
+ jsonschema==4.20.0
32
+ jsonschema-specifications==2023.11.1
33
+ lazy_loader==0.3
34
+ librosa==0.10.1
35
+ llvmlite==0.41.1
36
+ markdown-it-py==3.0.0
37
+ MarkupSafe==2.1.3
38
+ mdurl==0.1.2
39
+ msgpack==1.0.7
40
+ numba==0.58.1
41
+ numpy==1.23.5
42
+ pandas==2.1.3
43
+ Pillow==10.1.0
44
+ pooch==1.8.0
45
+ progressbar==2.5
46
+ protobuf==3.20.1
47
+ pyarrow==14.0.1
48
+ pycparser==2.21
49
+ pydeck==0.8.1b0
50
+ pytube==15.0.0
51
+ pytz==2023.3.post1
52
+ PyYAML==6.0.1
53
+ redis==5.0.1
54
+ referencing==0.31.0
55
+ regex==2023.10.3
56
+ requests==2.31.0
57
+ rich==13.7.0
58
+ rpds-py==0.13.0
59
+ safetensors==0.4.0
60
+ scikit-learn==1.3.2
61
+ scipy==1.11.3
62
+ sentry-sdk==1.35.0
63
+ setproctitle==1.3.3
64
+ smmap==5.0.1
65
+ sniffio==1.3.0
66
+ soundfile==0.12.1
67
+ soxr==0.3.7
68
+ spotipy==2.23.0
69
+ streamlit==1.28.2
70
+ tenacity==8.2.3
71
+ threadpoolctl==3.2.0
72
+ tokenizers==0.13.3
73
+ toml==0.10.2
74
+ toolz==0.12.0
75
+ torch==1.11.0
76
+ torchaudio==0.11.0
77
+ torchlibrosa==0.1.0
78
+ torchvision==0.12.0
79
+ tqdm==4.66.1
80
+ transformers==4.30.2
81
+ tzdata==2023.3
82
+ tzlocal==5.2
83
+ uritemplate==4.1.1
84
+ urllib3==2.1.0
85
+ validators==0.22.0
86
+ wandb==0.16.0
87
+ webdataset==0.2.77
88
+ wget==3.2
89
+ youtube-search-python==1.6.6
src/config/__init__.py ADDED
File without changes
src/config/configs.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from dataclasses import dataclass
3
+ from os import getenv
4
+
5
+
6
+ @dataclass
7
+ class ProjectPaths:
8
+ ROOT: Path = Path(__file__).parents[2]
9
+ DATA_DIR: Path = ROOT.joinpath("data")
10
+ MODEL_PATH: Path = ROOT.joinpath("model_checkpoints", "music_audioset_epoch_15_esc_90.14.pt")
11
+
12
+
13
+ @dataclass
14
+ class Credentials:
15
+ SPOTIFY_CLIENT_ID: str = getenv("SPOTIFY_CLIENT_ID")
16
+ SPOTIFY_SECRET_ID: str = getenv("SPOTIFY_SECRET_ID")
src/data/__init__.py ADDED
File without changes
src/data/get_yt_links.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from youtubesearchpython import VideosSearch
2
+ import json
3
+ import time
4
+ from src.config.configs import ProjectPaths
5
+ from tqdm import tqdm
6
+
7
+
8
+ def read_json_data():
9
+ with open(ProjectPaths.DATA_DIR.joinpath("json", "saved_tracks.json"), "r") as rd:
10
+ data = json.load(rd)
11
+ return data
12
+
13
+
14
+ def get_track_link(artist_name, track_name):
15
+ search_result = VideosSearch(f'{artist_name} - {track_name}', limit=1)
16
+ result = search_result.result()["result"][0]
17
+ data = {
18
+ "artist_name": artist_name,
19
+ "track_name": track_name,
20
+ "duration": result.get("duration"),
21
+ "published_time": result.get("publishedTime"),
22
+ "title": result.get("title"),
23
+ "view_count": result.get("viewCount").get("text"),
24
+ "link": result.get("link")
25
+ }
26
+ return data
27
+
28
+
29
+ def save_youtube_data(data):
30
+ with open(ProjectPaths.DATA_DIR.joinpath("json", "youtube_data.json"), "w") as wr:
31
+ json.dump(data, wr, indent=4)
32
+
33
+
34
+ def collect_youtube_links():
35
+ data = read_json_data()
36
+ youtube_data = []
37
+ for track_data in tqdm(data):
38
+ yt_data = get_track_link(track_data["artist"], track_data["track"])
39
+ youtube_data.append(yt_data)
40
+ time.sleep(0.2)
41
+ save_youtube_data(youtube_data)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ data = read_json_data()
46
+ youtube_data = []
47
+ for track_data in tqdm(data):
48
+ yt_data = get_track_link(track_data["artist"], track_data["track"])
49
+ youtube_data.append(yt_data)
50
+ time.sleep(0.2)
51
+ pass
52
+ save_youtube_data(youtube_data)
src/data/pytuber.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from src.config.configs import ProjectPaths
3
+ import json
4
+ import pytube
5
+ from tqdm import tqdm
6
+ from pytube.exceptions import AgeRestrictedError
7
+
8
+
9
+ def read_youtube_data():
10
+ input_data = ProjectPaths.DATA_DIR.joinpath("json", "youtube_data.json")
11
+ with open(input_data, "r") as rd:
12
+ return json.load(rd)
13
+
14
+
15
+ def download_mp3(link, download_path, track_full_name):
16
+ data_dir = ProjectPaths.DATA_DIR.joinpath("audio")
17
+ try:
18
+ mp3 = pytube.YouTube(link, use_oauth=True, allow_oauth_cache=True).streams.filter(only_audio=True).first()
19
+ mp3.download(data_dir)
20
+
21
+ new_file = track_full_name + '.wav'
22
+ os.rename(download_path.joinpath(mp3.default_filename), data_dir.joinpath(new_file))
23
+ except AgeRestrictedError:
24
+ pass
25
+
26
+
27
+ def start_download_process():
28
+ input_data = read_youtube_data()
29
+ done_pieces = os.listdir(ProjectPaths.DATA_DIR.joinpath("audio"))
30
+ for i in tqdm(input_data):
31
+ link = i["link"]
32
+ full_name = f'{i["artist_name"]} - {i["track_name"]}'.replace("/", "_")
33
+ if full_name + ".wav" in done_pieces:
34
+ continue
35
+ download_mp3(link, full_name)
src/data/spotify.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spotipy
2
+ from spotipy.oauth2 import SpotifyOAuth
3
+ from ..config.configs import Credentials, ProjectPaths
4
+ import json
5
+
6
+
7
+ def list_personal_saved_tracks():
8
+ scope = "user-library-read"
9
+ auth = SpotifyOAuth(client_id=Credentials.SPOTIFY_CLIENT_ID, client_secret=Credentials.SPOTIFY_SECRET_ID, scope=scope, redirect_uri="https://localhost:5000")
10
+ sp = spotipy.Spotify(auth_manager=auth)
11
+
12
+ tracks = []
13
+ offset_count = 0
14
+ for _ in range(50):
15
+ results = sp.current_user_saved_tracks(limit=50, offset=offset_count)
16
+ for idx, item in enumerate(results['items']):
17
+ track = item['track']
18
+ data = {"artist": track['artists'][0]['name'], "track": track['name']}
19
+ tracks.append(data)
20
+ print(idx, track['artists'][0]['name'], " - ", track['name'])
21
+ offset_count += 50
22
+
23
+ with open(ProjectPaths.DATA_DIR.joinpath("json", "saved_tracks.json"), "w", encoding="UTF-8") as wr:
24
+ json.dump(tracks, wr, indent=4)
src/laion_clap/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ dir_path = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(dir_path)
5
+ from .hook import CLAP_Module
src/laion_clap/clap_module/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .factory import list_models, create_model, create_model_and_transforms, add_model_config
2
+ from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
3
+ from .model import CLAP, CLAPTextCfg, CLAPVisionCfg, CLAPAudioCfp, convert_weights_to_fp16, trace_model
4
+ from .openai import load_openai_model, list_openai_models
5
+ from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\
6
+ get_pretrained_url, download_pretrained
7
+ from .tokenizer import SimpleTokenizer, tokenize
8
+ from .transform import image_transform
src/laion_clap/clap_module/bert.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertModel
2
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
3
+ model = BertModel.from_pretrained("bert-base-uncased")
4
+ text = "Replace me by any text you'd like."
5
+
6
+ def bert_embeddings(text):
7
+ # text = "Replace me by any text you'd like."
8
+ encoded_input = tokenizer(text, return_tensors='pt')
9
+ output = model(**encoded_input)
10
+ return output
11
+
12
+ from transformers import RobertaTokenizer, RobertaModel
13
+
14
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
15
+ model = RobertaModel.from_pretrained('roberta-base')
16
+ text = "Replace me by any text you'd like."
17
+ def Roberta_embeddings(text):
18
+ # text = "Replace me by any text you'd like."
19
+ encoded_input = tokenizer(text, return_tensors='pt')
20
+ output = model(**encoded_input)
21
+ return output
22
+
23
+ from transformers import BartTokenizer, BartModel
24
+
25
+ tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
26
+ model = BartModel.from_pretrained('facebook/bart-base')
27
+ text = "Replace me by any text you'd like."
28
+ def bart_embeddings(text):
29
+ # text = "Replace me by any text you'd like."
30
+ encoded_input = tokenizer(text, return_tensors='pt')
31
+ output = model(**encoded_input)
32
+ return output
src/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
src/laion_clap/clap_module/factory.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from packaging import version
9
+
10
+ import torch
11
+ import transformers
12
+
13
+ from .model import CLAP, convert_weights_to_fp16
14
+ from .openai import load_openai_model
15
+ from .pretrained import get_pretrained_url, download_pretrained
16
+ from .transform import image_transform
17
+
18
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
19
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
20
+
21
+
22
+ def _natural_key(string_):
23
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
24
+
25
+
26
+ def _rescan_model_configs():
27
+ global _MODEL_CONFIGS
28
+
29
+ config_ext = (".json",)
30
+ config_files = []
31
+ for config_path in _MODEL_CONFIG_PATHS:
32
+ if config_path.is_file() and config_path.suffix in config_ext:
33
+ config_files.append(config_path)
34
+ elif config_path.is_dir():
35
+ for ext in config_ext:
36
+ config_files.extend(config_path.glob(f"*{ext}"))
37
+
38
+ for cf in config_files:
39
+ with open(cf, "r") as f:
40
+ model_cfg = json.load(f)
41
+ if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
42
+ _MODEL_CONFIGS[cf.stem] = model_cfg
43
+
44
+ _MODEL_CONFIGS = {
45
+ k: v
46
+ for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
47
+ }
48
+
49
+
50
+ _rescan_model_configs() # initial populate of model config registry
51
+
52
+
53
+ def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
54
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
55
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
56
+ state_dict = checkpoint["state_dict"]
57
+ else:
58
+ state_dict = checkpoint
59
+ if skip_params:
60
+ if next(iter(state_dict.items()))[0].startswith("module"):
61
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
62
+
63
+ # removing position_ids to maintain compatibility with latest transformers update
64
+ if version.parse(transformers.__version__) >= version.parse("4.31.0"):
65
+ del state_dict["text_branch.embeddings.position_ids"]
66
+ # for k in state_dict:
67
+ # if k.startswith('transformer'):
68
+ # v = state_dict.pop(k)
69
+ # state_dict['text_branch.' + k[12:]] = v
70
+ return state_dict
71
+
72
+
73
+ def create_model(
74
+ amodel_name: str,
75
+ tmodel_name: str,
76
+ pretrained: str = "",
77
+ precision: str = "fp32",
78
+ device: torch.device = torch.device("cpu"),
79
+ jit: bool = False,
80
+ force_quick_gelu: bool = False,
81
+ openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
82
+ skip_params=True,
83
+ pretrained_audio: str = "",
84
+ pretrained_text: str = "",
85
+ enable_fusion: bool = False,
86
+ fusion_type: str = 'None'
87
+ # pretrained_image: bool = False,
88
+ ):
89
+ amodel_name = amodel_name.replace(
90
+ "/", "-"
91
+ ) # for callers using old naming with / in ViT names
92
+ pretrained_orig = pretrained
93
+ pretrained = pretrained.lower()
94
+ if pretrained == "openai":
95
+ if amodel_name in _MODEL_CONFIGS:
96
+ logging.info(f"Loading {amodel_name} model config.")
97
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
98
+ else:
99
+ logging.error(
100
+ f"Model config for {amodel_name} not found; available models {list_models()}."
101
+ )
102
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
103
+
104
+ logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
105
+ # Hard Code in model name
106
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
107
+ model = load_openai_model(
108
+ "ViT-B-16",
109
+ model_cfg,
110
+ device=device,
111
+ jit=jit,
112
+ cache_dir=openai_model_cache_dir,
113
+ enable_fusion=enable_fusion,
114
+ fusion_type=fusion_type
115
+ )
116
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
117
+ if precision == "amp" or precision == "fp32":
118
+ model = model.float()
119
+ else:
120
+ if amodel_name in _MODEL_CONFIGS:
121
+ logging.info(f"Loading {amodel_name} model config.")
122
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
123
+ else:
124
+ logging.error(
125
+ f"Model config for {amodel_name} not found; available models {list_models()}."
126
+ )
127
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
128
+
129
+ if force_quick_gelu:
130
+ # override for use of QuickGELU on non-OpenAI transformer models
131
+ model_cfg["quick_gelu"] = True
132
+
133
+ # if pretrained_image:
134
+ # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
135
+ # # pretrained weight loading for timm models set via vision_cfg
136
+ # model_cfg['vision_cfg']['timm_model_pretrained'] = True
137
+ # else:
138
+ # assert False, 'pretrained image towers currently only supported for timm models'
139
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
140
+ model_cfg["enable_fusion"] = enable_fusion
141
+ model_cfg["fusion_type"] = fusion_type
142
+ model = CLAP(**model_cfg)
143
+
144
+ if pretrained:
145
+ checkpoint_path = ""
146
+ url = get_pretrained_url(amodel_name, pretrained)
147
+ if url:
148
+ checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
149
+ elif os.path.exists(pretrained_orig):
150
+ checkpoint_path = pretrained_orig
151
+ if checkpoint_path:
152
+ logging.info(f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained}).")
153
+ ckpt = load_state_dict(checkpoint_path, skip_params=True)
154
+ model.load_state_dict(ckpt)
155
+ param_names = [n for n, p in model.named_parameters()]
156
+ for n in param_names:
157
+ print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
158
+ else:
159
+ logging.warning(
160
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
161
+ )
162
+ raise RuntimeError(
163
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
164
+ )
165
+
166
+ if pretrained_audio:
167
+ if amodel_name.startswith('PANN'):
168
+ if 'Cnn14_mAP' in pretrained_audio: # official checkpoint
169
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
170
+ audio_ckpt = audio_ckpt['model']
171
+ keys = list(audio_ckpt.keys())
172
+ for key in keys:
173
+ if 'spectrogram_extractor' not in key and 'logmel_extractor' not in key:
174
+ v = audio_ckpt.pop(key)
175
+ audio_ckpt['audio_branch.' + key] = v
176
+ elif os.path.basename(pretrained_audio).startswith('PANN'): # checkpoint trained via HTSAT codebase
177
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
178
+ audio_ckpt = audio_ckpt['state_dict']
179
+ keys = list(audio_ckpt.keys())
180
+ for key in keys:
181
+ if key.startswith('sed_model'):
182
+ v = audio_ckpt.pop(key)
183
+ audio_ckpt['audio_branch.' + key[10:]] = v
184
+ elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase
185
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
186
+ else:
187
+ raise ValueError('Unknown audio checkpoint')
188
+ elif amodel_name.startswith('HTSAT'):
189
+ if 'HTSAT_AudioSet_Saved' in pretrained_audio: # official checkpoint
190
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
191
+ audio_ckpt = audio_ckpt['state_dict']
192
+ keys = list(audio_ckpt.keys())
193
+ for key in keys:
194
+ if key.startswith('sed_model') and ('spectrogram_extractor' not in key
195
+ and 'logmel_extractor' not in key):
196
+ v = audio_ckpt.pop(key)
197
+ audio_ckpt['audio_branch.' + key[10:]] = v
198
+ elif os.path.basename(pretrained_audio).startswith('HTSAT'): # checkpoint trained via HTSAT codebase
199
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
200
+ audio_ckpt = audio_ckpt['state_dict']
201
+ keys = list(audio_ckpt.keys())
202
+ for key in keys:
203
+ if key.startswith('sed_model'):
204
+ v = audio_ckpt.pop(key)
205
+ audio_ckpt['audio_branch.' + key[10:]] = v
206
+ elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase
207
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
208
+ else:
209
+ raise ValueError('Unknown audio checkpoint')
210
+ else:
211
+ raise f'this audio encoder pretrained checkpoint is not support'
212
+
213
+ model.load_state_dict(audio_ckpt, strict=False)
214
+ logging.info(f"Loading pretrained {amodel_name} weights ({pretrained_audio}).")
215
+ param_names = [n for n, p in model.named_parameters()]
216
+ for n in param_names:
217
+ print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
218
+
219
+ model.to(device=device)
220
+ if precision == "fp16":
221
+ assert device.type != "cpu"
222
+ convert_weights_to_fp16(model)
223
+
224
+ if jit:
225
+ model = torch.jit.script(model)
226
+
227
+ return model, model_cfg
228
+
229
+
230
+ def create_model_and_transforms(
231
+ model_name: str,
232
+ pretrained: str = "",
233
+ precision: str = "fp32",
234
+ device: torch.device = torch.device("cpu"),
235
+ jit: bool = False,
236
+ force_quick_gelu: bool = False,
237
+ # pretrained_image: bool = False,
238
+ ):
239
+ model = create_model(
240
+ model_name,
241
+ pretrained,
242
+ precision,
243
+ device,
244
+ jit,
245
+ force_quick_gelu=force_quick_gelu,
246
+ # pretrained_image=pretrained_image
247
+ )
248
+ preprocess_train = image_transform(model.visual.image_size, is_train=True)
249
+ preprocess_val = image_transform(model.visual.image_size, is_train=False)
250
+ return model, preprocess_train, preprocess_val
251
+
252
+
253
+ def list_models():
254
+ """enumerate available model architectures based on config files"""
255
+ return list(_MODEL_CONFIGS.keys())
256
+
257
+
258
+ def add_model_config(path):
259
+ """add model config path or file and update registry"""
260
+ if not isinstance(path, Path):
261
+ path = Path(path)
262
+ _MODEL_CONFIG_PATHS.append(path)
263
+ _rescan_model_configs()
src/laion_clap/clap_module/feature_fusion.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Feature Fusion for Varible-Length Data Processing
3
+ AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4
+ According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5
+ '''
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class DAF(nn.Module):
12
+ '''
13
+ 直接相加 DirectAddFuse
14
+ '''
15
+
16
+ def __init__(self):
17
+ super(DAF, self).__init__()
18
+
19
+ def forward(self, x, residual):
20
+ return x + residual
21
+
22
+
23
+ class iAFF(nn.Module):
24
+ '''
25
+ 多特征融合 iAFF
26
+ '''
27
+
28
+ def __init__(self, channels=64, r=4, type='2D'):
29
+ super(iAFF, self).__init__()
30
+ inter_channels = int(channels // r)
31
+
32
+ if type == '1D':
33
+ # 本地注意力
34
+ self.local_att = nn.Sequential(
35
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36
+ nn.BatchNorm1d(inter_channels),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39
+ nn.BatchNorm1d(channels),
40
+ )
41
+
42
+ # 全局注意力
43
+ self.global_att = nn.Sequential(
44
+ nn.AdaptiveAvgPool1d(1),
45
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46
+ nn.BatchNorm1d(inter_channels),
47
+ nn.ReLU(inplace=True),
48
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49
+ nn.BatchNorm1d(channels),
50
+ )
51
+
52
+ # 第二次本地注意力
53
+ self.local_att2 = nn.Sequential(
54
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55
+ nn.BatchNorm1d(inter_channels),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58
+ nn.BatchNorm1d(channels),
59
+ )
60
+ # 第二次全局注意力
61
+ self.global_att2 = nn.Sequential(
62
+ nn.AdaptiveAvgPool1d(1),
63
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64
+ nn.BatchNorm1d(inter_channels),
65
+ nn.ReLU(inplace=True),
66
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67
+ nn.BatchNorm1d(channels),
68
+ )
69
+ elif type == '2D':
70
+ # 本地注意力
71
+ self.local_att = nn.Sequential(
72
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73
+ nn.BatchNorm2d(inter_channels),
74
+ nn.ReLU(inplace=True),
75
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76
+ nn.BatchNorm2d(channels),
77
+ )
78
+
79
+ # 全局注意力
80
+ self.global_att = nn.Sequential(
81
+ nn.AdaptiveAvgPool2d(1),
82
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83
+ nn.BatchNorm2d(inter_channels),
84
+ nn.ReLU(inplace=True),
85
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86
+ nn.BatchNorm2d(channels),
87
+ )
88
+
89
+ # 第二次本地注意力
90
+ self.local_att2 = nn.Sequential(
91
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
+ nn.BatchNorm2d(inter_channels),
93
+ nn.ReLU(inplace=True),
94
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
+ nn.BatchNorm2d(channels),
96
+ )
97
+ # 第二次全局注意力
98
+ self.global_att2 = nn.Sequential(
99
+ nn.AdaptiveAvgPool2d(1),
100
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
+ nn.BatchNorm2d(inter_channels),
102
+ nn.ReLU(inplace=True),
103
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
+ nn.BatchNorm2d(channels),
105
+ )
106
+ else:
107
+ raise f'the type is not supported'
108
+
109
+ self.sigmoid = nn.Sigmoid()
110
+
111
+ def forward(self, x, residual):
112
+ flag = False
113
+ xa = x + residual
114
+ if xa.size(0) == 1:
115
+ xa = torch.cat([xa,xa],dim=0)
116
+ flag = True
117
+ xl = self.local_att(xa)
118
+ xg = self.global_att(xa)
119
+ xlg = xl + xg
120
+ wei = self.sigmoid(xlg)
121
+ xi = x * wei + residual * (1 - wei)
122
+
123
+ xl2 = self.local_att2(xi)
124
+ xg2 = self.global_att(xi)
125
+ xlg2 = xl2 + xg2
126
+ wei2 = self.sigmoid(xlg2)
127
+ xo = x * wei2 + residual * (1 - wei2)
128
+ if flag:
129
+ xo = xo[0].unsqueeze(0)
130
+ return xo
131
+
132
+
133
+ class AFF(nn.Module):
134
+ '''
135
+ 多特征融合 AFF
136
+ '''
137
+
138
+ def __init__(self, channels=64, r=4, type='2D'):
139
+ super(AFF, self).__init__()
140
+ inter_channels = int(channels // r)
141
+
142
+ if type == '1D':
143
+ self.local_att = nn.Sequential(
144
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145
+ nn.BatchNorm1d(inter_channels),
146
+ nn.ReLU(inplace=True),
147
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148
+ nn.BatchNorm1d(channels),
149
+ )
150
+ self.global_att = nn.Sequential(
151
+ nn.AdaptiveAvgPool1d(1),
152
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153
+ nn.BatchNorm1d(inter_channels),
154
+ nn.ReLU(inplace=True),
155
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156
+ nn.BatchNorm1d(channels),
157
+ )
158
+ elif type == '2D':
159
+ self.local_att = nn.Sequential(
160
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161
+ nn.BatchNorm2d(inter_channels),
162
+ nn.ReLU(inplace=True),
163
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164
+ nn.BatchNorm2d(channels),
165
+ )
166
+ self.global_att = nn.Sequential(
167
+ nn.AdaptiveAvgPool2d(1),
168
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169
+ nn.BatchNorm2d(inter_channels),
170
+ nn.ReLU(inplace=True),
171
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172
+ nn.BatchNorm2d(channels),
173
+ )
174
+ else:
175
+ raise f'the type is not supported.'
176
+
177
+ self.sigmoid = nn.Sigmoid()
178
+
179
+ def forward(self, x, residual):
180
+ flag = False
181
+ xa = x + residual
182
+ if xa.size(0) == 1:
183
+ xa = torch.cat([xa,xa],dim=0)
184
+ flag = True
185
+ xl = self.local_att(xa)
186
+ xg = self.global_att(xa)
187
+ xlg = xl + xg
188
+ wei = self.sigmoid(xlg)
189
+ xo = 2 * x * wei + 2 * residual * (1 - wei)
190
+ if flag:
191
+ xo = xo[0].unsqueeze(0)
192
+ return xo
193
+
src/laion_clap/clap_module/htsat.py ADDED
@@ -0,0 +1,1031 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
+ # knutchen@ucsd.edu
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # Some layers designed on the model
5
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from itertools import repeat
12
+ import collections.abc
13
+ import math
14
+ import warnings
15
+
16
+ from torch.nn.init import _calculate_fan_in_and_fan_out
17
+ import torch.utils.checkpoint as checkpoint
18
+
19
+ import random
20
+
21
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
22
+ from torchlibrosa.augmentation import SpecAugmentation
23
+
24
+ from itertools import repeat
25
+ from .utils import do_mixup, interpolate
26
+
27
+ from .feature_fusion import iAFF, AFF, DAF
28
+
29
+ # from PyTorch internals
30
+ def _ntuple(n):
31
+ def parse(x):
32
+ if isinstance(x, collections.abc.Iterable):
33
+ return x
34
+ return tuple(repeat(x, n))
35
+ return parse
36
+
37
+ to_1tuple = _ntuple(1)
38
+ to_2tuple = _ntuple(2)
39
+ to_3tuple = _ntuple(3)
40
+ to_4tuple = _ntuple(4)
41
+ to_ntuple = _ntuple
42
+
43
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
44
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
45
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
46
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
47
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
48
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
49
+ 'survival rate' as the argument.
50
+ """
51
+ if drop_prob == 0. or not training:
52
+ return x
53
+ keep_prob = 1 - drop_prob
54
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
55
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
56
+ random_tensor.floor_() # binarize
57
+ output = x.div(keep_prob) * random_tensor
58
+ return output
59
+
60
+
61
+ class DropPath(nn.Module):
62
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
63
+ """
64
+ def __init__(self, drop_prob=None):
65
+ super(DropPath, self).__init__()
66
+ self.drop_prob = drop_prob
67
+
68
+ def forward(self, x):
69
+ return drop_path(x, self.drop_prob, self.training)
70
+
71
+ class PatchEmbed(nn.Module):
72
+ """ 2D Image to Patch Embedding
73
+ """
74
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16,
75
+ enable_fusion=False, fusion_type='None'):
76
+ super().__init__()
77
+ img_size = to_2tuple(img_size)
78
+ patch_size = to_2tuple(patch_size)
79
+ patch_stride = to_2tuple(patch_stride)
80
+ self.img_size = img_size
81
+ self.patch_size = patch_size
82
+ self.patch_stride = patch_stride
83
+ self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
84
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
85
+ self.flatten = flatten
86
+ self.in_chans = in_chans
87
+ self.embed_dim = embed_dim
88
+
89
+ self.enable_fusion = enable_fusion
90
+ self.fusion_type = fusion_type
91
+
92
+ padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
93
+
94
+ if (self.enable_fusion) and (self.fusion_type == 'channel_map'):
95
+ self.proj = nn.Conv2d(in_chans*4, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
96
+ else:
97
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
98
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
99
+
100
+ if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']):
101
+ self.mel_conv2d = nn.Conv2d(in_chans, embed_dim, kernel_size=(patch_size[0], patch_size[1]*3), stride=(patch_stride[0], patch_stride[1] * 3), padding=padding)
102
+ if self.fusion_type == 'daf_2d':
103
+ self.fusion_model = DAF()
104
+ elif self.fusion_type == 'aff_2d':
105
+ self.fusion_model = AFF(channels=embed_dim, type='2D')
106
+ elif self.fusion_type == 'iaff_2d':
107
+ self.fusion_model = iAFF(channels=embed_dim, type='2D')
108
+ def forward(self, x, longer_idx = None):
109
+ if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']):
110
+ global_x = x[:,0:1,:,:]
111
+
112
+
113
+ # global processing
114
+ B, C, H, W = global_x.shape
115
+ assert H == self.img_size[0] and W == self.img_size[1], \
116
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
117
+ global_x = self.proj(global_x)
118
+ TW = global_x.size(-1)
119
+ if len(longer_idx) > 0:
120
+ # local processing
121
+ local_x = x[longer_idx,1:,:,:].contiguous()
122
+ B, C, H, W = local_x.shape
123
+ local_x = local_x.view(B*C,1,H,W)
124
+ local_x = self.mel_conv2d(local_x)
125
+ local_x = local_x.view(B,C,local_x.size(1),local_x.size(2),local_x.size(3))
126
+ local_x = local_x.permute((0,2,3,1,4)).contiguous().flatten(3)
127
+ TB,TC,TH,_ = local_x.size()
128
+ if local_x.size(-1) < TW:
129
+ local_x = torch.cat([local_x, torch.zeros((TB,TC,TH,TW-local_x.size(-1)), device=global_x.device)], dim=-1)
130
+ else:
131
+ local_x = local_x[:,:,:,:TW]
132
+
133
+ global_x[longer_idx] = self.fusion_model(global_x[longer_idx],local_x)
134
+ x = global_x
135
+ else:
136
+ B, C, H, W = x.shape
137
+ assert H == self.img_size[0] and W == self.img_size[1], \
138
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
139
+ x = self.proj(x)
140
+
141
+ if self.flatten:
142
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
143
+ x = self.norm(x)
144
+ return x
145
+
146
+ class Mlp(nn.Module):
147
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
148
+ """
149
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
150
+ super().__init__()
151
+ out_features = out_features or in_features
152
+ hidden_features = hidden_features or in_features
153
+ self.fc1 = nn.Linear(in_features, hidden_features)
154
+ self.act = act_layer()
155
+ self.fc2 = nn.Linear(hidden_features, out_features)
156
+ self.drop = nn.Dropout(drop)
157
+
158
+ def forward(self, x):
159
+ x = self.fc1(x)
160
+ x = self.act(x)
161
+ x = self.drop(x)
162
+ x = self.fc2(x)
163
+ x = self.drop(x)
164
+ return x
165
+
166
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
167
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
168
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
169
+ def norm_cdf(x):
170
+ # Computes standard normal cumulative distribution function
171
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
172
+
173
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
174
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
175
+ "The distribution of values may be incorrect.",
176
+ stacklevel=2)
177
+
178
+ with torch.no_grad():
179
+ # Values are generated by using a truncated uniform distribution and
180
+ # then using the inverse CDF for the normal distribution.
181
+ # Get upper and lower cdf values
182
+ l = norm_cdf((a - mean) / std)
183
+ u = norm_cdf((b - mean) / std)
184
+
185
+ # Uniformly fill tensor with values from [l, u], then translate to
186
+ # [2l-1, 2u-1].
187
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
188
+
189
+ # Use inverse cdf transform for normal distribution to get truncated
190
+ # standard normal
191
+ tensor.erfinv_()
192
+
193
+ # Transform to proper mean, std
194
+ tensor.mul_(std * math.sqrt(2.))
195
+ tensor.add_(mean)
196
+
197
+ # Clamp to ensure it's in the proper range
198
+ tensor.clamp_(min=a, max=b)
199
+ return tensor
200
+
201
+
202
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
203
+ # type: (Tensor, float, float, float, float) -> Tensor
204
+ r"""Fills the input Tensor with values drawn from a truncated
205
+ normal distribution. The values are effectively drawn from the
206
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
207
+ with values outside :math:`[a, b]` redrawn until they are within
208
+ the bounds. The method used for generating the random values works
209
+ best when :math:`a \leq \text{mean} \leq b`.
210
+ Args:
211
+ tensor: an n-dimensional `torch.Tensor`
212
+ mean: the mean of the normal distribution
213
+ std: the standard deviation of the normal distribution
214
+ a: the minimum cutoff value
215
+ b: the maximum cutoff value
216
+ Examples:
217
+ >>> w = torch.empty(3, 5)
218
+ >>> nn.init.trunc_normal_(w)
219
+ """
220
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
221
+
222
+
223
+ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
224
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
225
+ if mode == 'fan_in':
226
+ denom = fan_in
227
+ elif mode == 'fan_out':
228
+ denom = fan_out
229
+ elif mode == 'fan_avg':
230
+ denom = (fan_in + fan_out) / 2
231
+
232
+ variance = scale / denom
233
+
234
+ if distribution == "truncated_normal":
235
+ # constant is stddev of standard normal truncated to (-2, 2)
236
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
237
+ elif distribution == "normal":
238
+ tensor.normal_(std=math.sqrt(variance))
239
+ elif distribution == "uniform":
240
+ bound = math.sqrt(3 * variance)
241
+ tensor.uniform_(-bound, bound)
242
+ else:
243
+ raise ValueError(f"invalid distribution {distribution}")
244
+
245
+
246
+ def lecun_normal_(tensor):
247
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
248
+
249
+ def window_partition(x, window_size):
250
+ """
251
+ Args:
252
+ x: (B, H, W, C)
253
+ window_size (int): window size
254
+ Returns:
255
+ windows: (num_windows*B, window_size, window_size, C)
256
+ """
257
+ B, H, W, C = x.shape
258
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
259
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
260
+ return windows
261
+
262
+
263
+ def window_reverse(windows, window_size, H, W):
264
+ """
265
+ Args:
266
+ windows: (num_windows*B, window_size, window_size, C)
267
+ window_size (int): Window size
268
+ H (int): Height of image
269
+ W (int): Width of image
270
+ Returns:
271
+ x: (B, H, W, C)
272
+ """
273
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
274
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
275
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
276
+ return x
277
+
278
+
279
+ class WindowAttention(nn.Module):
280
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
281
+ It supports both of shifted and non-shifted window.
282
+ Args:
283
+ dim (int): Number of input channels.
284
+ window_size (tuple[int]): The height and width of the window.
285
+ num_heads (int): Number of attention heads.
286
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
287
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
288
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
289
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
290
+ """
291
+
292
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
293
+
294
+ super().__init__()
295
+ self.dim = dim
296
+ self.window_size = window_size # Wh, Ww
297
+ self.num_heads = num_heads
298
+ head_dim = dim // num_heads
299
+ self.scale = qk_scale or head_dim ** -0.5
300
+
301
+ # define a parameter table of relative position bias
302
+ self.relative_position_bias_table = nn.Parameter(
303
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
304
+
305
+ # get pair-wise relative position index for each token inside the window
306
+ coords_h = torch.arange(self.window_size[0])
307
+ coords_w = torch.arange(self.window_size[1])
308
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
309
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
310
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
311
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
312
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
313
+ relative_coords[:, :, 1] += self.window_size[1] - 1
314
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
315
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
316
+ self.register_buffer("relative_position_index", relative_position_index)
317
+
318
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
319
+ self.attn_drop = nn.Dropout(attn_drop)
320
+ self.proj = nn.Linear(dim, dim)
321
+ self.proj_drop = nn.Dropout(proj_drop)
322
+
323
+ trunc_normal_(self.relative_position_bias_table, std=.02)
324
+ self.softmax = nn.Softmax(dim=-1)
325
+
326
+ def forward(self, x, mask=None):
327
+ """
328
+ Args:
329
+ x: input features with shape of (num_windows*B, N, C)
330
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
331
+ """
332
+ B_, N, C = x.shape
333
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
334
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
335
+
336
+ q = q * self.scale
337
+ attn = (q @ k.transpose(-2, -1))
338
+
339
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
340
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
341
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
342
+ attn = attn + relative_position_bias.unsqueeze(0)
343
+
344
+ if mask is not None:
345
+ nW = mask.shape[0]
346
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
347
+ attn = attn.view(-1, self.num_heads, N, N)
348
+ attn = self.softmax(attn)
349
+ else:
350
+ attn = self.softmax(attn)
351
+
352
+ attn = self.attn_drop(attn)
353
+
354
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
355
+ x = self.proj(x)
356
+ x = self.proj_drop(x)
357
+ return x, attn
358
+
359
+ def extra_repr(self):
360
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
361
+
362
+
363
+ # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
364
+ class SwinTransformerBlock(nn.Module):
365
+ r""" Swin Transformer Block.
366
+ Args:
367
+ dim (int): Number of input channels.
368
+ input_resolution (tuple[int]): Input resulotion.
369
+ num_heads (int): Number of attention heads.
370
+ window_size (int): Window size.
371
+ shift_size (int): Shift size for SW-MSA.
372
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
373
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
374
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
375
+ drop (float, optional): Dropout rate. Default: 0.0
376
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
377
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
378
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
379
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
380
+ """
381
+
382
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
383
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
384
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
385
+ super().__init__()
386
+ self.dim = dim
387
+ self.input_resolution = input_resolution
388
+ self.num_heads = num_heads
389
+ self.window_size = window_size
390
+ self.shift_size = shift_size
391
+ self.mlp_ratio = mlp_ratio
392
+ self.norm_before_mlp = norm_before_mlp
393
+ if min(self.input_resolution) <= self.window_size:
394
+ # if window size is larger than input resolution, we don't partition windows
395
+ self.shift_size = 0
396
+ self.window_size = min(self.input_resolution)
397
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
398
+
399
+ self.norm1 = norm_layer(dim)
400
+ self.attn = WindowAttention(
401
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
402
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
403
+
404
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
405
+ if self.norm_before_mlp == 'ln':
406
+ self.norm2 = nn.LayerNorm(dim)
407
+ elif self.norm_before_mlp == 'bn':
408
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
409
+ else:
410
+ raise NotImplementedError
411
+ mlp_hidden_dim = int(dim * mlp_ratio)
412
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
413
+
414
+ if self.shift_size > 0:
415
+ # calculate attention mask for SW-MSA
416
+ H, W = self.input_resolution
417
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
418
+ h_slices = (slice(0, -self.window_size),
419
+ slice(-self.window_size, -self.shift_size),
420
+ slice(-self.shift_size, None))
421
+ w_slices = (slice(0, -self.window_size),
422
+ slice(-self.window_size, -self.shift_size),
423
+ slice(-self.shift_size, None))
424
+ cnt = 0
425
+ for h in h_slices:
426
+ for w in w_slices:
427
+ img_mask[:, h, w, :] = cnt
428
+ cnt += 1
429
+
430
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
431
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
432
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
433
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
434
+ else:
435
+ attn_mask = None
436
+
437
+ self.register_buffer("attn_mask", attn_mask)
438
+
439
+ def forward(self, x):
440
+ # pdb.set_trace()
441
+ H, W = self.input_resolution
442
+ # print("H: ", H)
443
+ # print("W: ", W)
444
+ # pdb.set_trace()
445
+ B, L, C = x.shape
446
+ # assert L == H * W, "input feature has wrong size"
447
+
448
+ shortcut = x
449
+ x = self.norm1(x)
450
+ x = x.view(B, H, W, C)
451
+
452
+ # cyclic shift
453
+ if self.shift_size > 0:
454
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
455
+ else:
456
+ shifted_x = x
457
+
458
+ # partition windows
459
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
460
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
461
+
462
+ # W-MSA/SW-MSA
463
+ attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
464
+
465
+ # merge windows
466
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
467
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
468
+
469
+ # reverse cyclic shift
470
+ if self.shift_size > 0:
471
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
472
+ else:
473
+ x = shifted_x
474
+ x = x.view(B, H * W, C)
475
+
476
+ # FFN
477
+ x = shortcut + self.drop_path(x)
478
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
479
+
480
+ return x, attn
481
+
482
+ def extra_repr(self):
483
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
484
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
485
+
486
+
487
+
488
+ class PatchMerging(nn.Module):
489
+ r""" Patch Merging Layer.
490
+ Args:
491
+ input_resolution (tuple[int]): Resolution of input feature.
492
+ dim (int): Number of input channels.
493
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
494
+ """
495
+
496
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
497
+ super().__init__()
498
+ self.input_resolution = input_resolution
499
+ self.dim = dim
500
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
501
+ self.norm = norm_layer(4 * dim)
502
+
503
+ def forward(self, x):
504
+ """
505
+ x: B, H*W, C
506
+ """
507
+ H, W = self.input_resolution
508
+ B, L, C = x.shape
509
+ assert L == H * W, "input feature has wrong size"
510
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
511
+
512
+ x = x.view(B, H, W, C)
513
+
514
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
515
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
516
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
517
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
518
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
519
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
520
+
521
+ x = self.norm(x)
522
+ x = self.reduction(x)
523
+
524
+ return x
525
+
526
+ def extra_repr(self):
527
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
528
+
529
+
530
+ class BasicLayer(nn.Module):
531
+ """ A basic Swin Transformer layer for one stage.
532
+ Args:
533
+ dim (int): Number of input channels.
534
+ input_resolution (tuple[int]): Input resolution.
535
+ depth (int): Number of blocks.
536
+ num_heads (int): Number of attention heads.
537
+ window_size (int): Local window size.
538
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
539
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
540
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
541
+ drop (float, optional): Dropout rate. Default: 0.0
542
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
543
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
544
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
545
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
546
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
547
+ """
548
+
549
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
550
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
551
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
552
+ norm_before_mlp='ln'):
553
+
554
+ super().__init__()
555
+ self.dim = dim
556
+ self.input_resolution = input_resolution
557
+ self.depth = depth
558
+ self.use_checkpoint = use_checkpoint
559
+
560
+ # build blocks
561
+ self.blocks = nn.ModuleList([
562
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
563
+ num_heads=num_heads, window_size=window_size,
564
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
565
+ mlp_ratio=mlp_ratio,
566
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
567
+ drop=drop, attn_drop=attn_drop,
568
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
569
+ norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
570
+ for i in range(depth)])
571
+
572
+ # patch merging layer
573
+ if downsample is not None:
574
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
575
+ else:
576
+ self.downsample = None
577
+
578
+ def forward(self, x):
579
+ attns = []
580
+ for blk in self.blocks:
581
+ if self.use_checkpoint:
582
+ x = checkpoint.checkpoint(blk, x)
583
+ else:
584
+ x, attn = blk(x)
585
+ if not self.training:
586
+ attns.append(attn.unsqueeze(0))
587
+ if self.downsample is not None:
588
+ x = self.downsample(x)
589
+ if not self.training:
590
+ attn = torch.cat(attns, dim = 0)
591
+ attn = torch.mean(attn, dim = 0)
592
+ return x, attn
593
+
594
+ def extra_repr(self):
595
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
596
+
597
+
598
+ # The Core of HTSAT
599
+ class HTSAT_Swin_Transformer(nn.Module):
600
+ r"""HTSAT based on the Swin Transformer
601
+ Args:
602
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
603
+ patch_size (int | tuple(int)): Patch size. Default: 4
604
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
605
+ in_chans (int): Number of input image channels. Default: 1 (mono)
606
+ num_classes (int): Number of classes for classification head. Default: 527
607
+ embed_dim (int): Patch embedding dimension. Default: 96
608
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
609
+ num_heads (tuple(int)): Number of attention heads in different layers.
610
+ window_size (int): Window size. Default: 8
611
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
612
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
613
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
614
+ drop_rate (float): Dropout rate. Default: 0
615
+ attn_drop_rate (float): Attention dropout rate. Default: 0
616
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
617
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
618
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
619
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
620
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
621
+ config (module): The configuration Module from config.py
622
+ """
623
+
624
+ def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
625
+ in_chans=1, num_classes=527,
626
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
627
+ window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
628
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
629
+ norm_layer=nn.LayerNorm,
630
+ ape=False, patch_norm=True,
631
+ use_checkpoint=False, norm_before_mlp='ln', config = None,
632
+ enable_fusion = False, fusion_type = 'None', **kwargs):
633
+ super(HTSAT_Swin_Transformer, self).__init__()
634
+
635
+ self.config = config
636
+ self.spec_size = spec_size
637
+ self.patch_stride = patch_stride
638
+ self.patch_size = patch_size
639
+ self.window_size = window_size
640
+ self.embed_dim = embed_dim
641
+ self.depths = depths
642
+ self.ape = ape
643
+ self.in_chans = in_chans
644
+ self.num_classes = num_classes
645
+ self.num_heads = num_heads
646
+ self.num_layers = len(self.depths)
647
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
648
+
649
+ self.drop_rate = drop_rate
650
+ self.attn_drop_rate = attn_drop_rate
651
+ self.drop_path_rate = drop_path_rate
652
+
653
+ self.qkv_bias = qkv_bias
654
+ self.qk_scale = None
655
+
656
+ self.patch_norm = patch_norm
657
+ self.norm_layer = norm_layer if self.patch_norm else None
658
+ self.norm_before_mlp = norm_before_mlp
659
+ self.mlp_ratio = mlp_ratio
660
+
661
+ self.use_checkpoint = use_checkpoint
662
+
663
+ self.enable_fusion = enable_fusion
664
+ self.fusion_type = fusion_type
665
+
666
+ # process mel-spec ; used only once
667
+ self.freq_ratio = self.spec_size // self.config.mel_bins
668
+ window = 'hann'
669
+ center = True
670
+ pad_mode = 'reflect'
671
+ ref = 1.0
672
+ amin = 1e-10
673
+ top_db = None
674
+ self.interpolate_ratio = 32 # Downsampled ratio
675
+ # Spectrogram extractor
676
+ self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
677
+ win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
678
+ freeze_parameters=True)
679
+ # Logmel feature extractor
680
+ self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
681
+ n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
682
+ freeze_parameters=True)
683
+ # Spec augmenter
684
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
685
+ freq_drop_width=8, freq_stripes_num=2) # 2 2
686
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
687
+
688
+
689
+ # split spctrogram into non-overlapping patches
690
+ self.patch_embed = PatchEmbed(
691
+ img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
692
+ embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride,
693
+ enable_fusion=self.enable_fusion, fusion_type=self.fusion_type
694
+ )
695
+
696
+ num_patches = self.patch_embed.num_patches
697
+ patches_resolution = self.patch_embed.grid_size
698
+ self.patches_resolution = patches_resolution
699
+
700
+ # absolute position embedding
701
+ if self.ape:
702
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
703
+ trunc_normal_(self.absolute_pos_embed, std=.02)
704
+
705
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
706
+
707
+ # stochastic depth
708
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
709
+
710
+ # build layers
711
+ self.layers = nn.ModuleList()
712
+ for i_layer in range(self.num_layers):
713
+ layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
714
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
715
+ patches_resolution[1] // (2 ** i_layer)),
716
+ depth=self.depths[i_layer],
717
+ num_heads=self.num_heads[i_layer],
718
+ window_size=self.window_size,
719
+ mlp_ratio=self.mlp_ratio,
720
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
721
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate,
722
+ drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
723
+ norm_layer=self.norm_layer,
724
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
725
+ use_checkpoint=use_checkpoint,
726
+ norm_before_mlp=self.norm_before_mlp)
727
+ self.layers.append(layer)
728
+
729
+ self.norm = self.norm_layer(self.num_features)
730
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
731
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
732
+
733
+ SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
734
+ self.tscam_conv = nn.Conv2d(
735
+ in_channels = self.num_features,
736
+ out_channels = self.num_classes,
737
+ kernel_size = (SF,3),
738
+ padding = (0,1)
739
+ )
740
+ self.head = nn.Linear(num_classes, num_classes)
741
+
742
+ if (self.enable_fusion) and (self.fusion_type in ['daf_1d','aff_1d','iaff_1d']):
743
+ self.mel_conv1d = nn.Sequential(
744
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
745
+ nn.BatchNorm1d(64)
746
+ )
747
+ if self.fusion_type == 'daf_1d':
748
+ self.fusion_model = DAF()
749
+ elif self.fusion_type == 'aff_1d':
750
+ self.fusion_model = AFF(channels=64, type='1D')
751
+ elif self.fusion_type == 'iaff_1d':
752
+ self.fusion_model = iAFF(channels=64, type='1D')
753
+
754
+ self.apply(self._init_weights)
755
+
756
+ def _init_weights(self, m):
757
+ if isinstance(m, nn.Linear):
758
+ trunc_normal_(m.weight, std=.02)
759
+ if isinstance(m, nn.Linear) and m.bias is not None:
760
+ nn.init.constant_(m.bias, 0)
761
+ elif isinstance(m, nn.LayerNorm):
762
+ nn.init.constant_(m.bias, 0)
763
+ nn.init.constant_(m.weight, 1.0)
764
+
765
+ @torch.jit.ignore
766
+ def no_weight_decay(self):
767
+ return {'absolute_pos_embed'}
768
+
769
+ @torch.jit.ignore
770
+ def no_weight_decay_keywords(self):
771
+ return {'relative_position_bias_table'}
772
+
773
+
774
+ def forward_features(self, x, longer_idx = None):
775
+ # A deprecated optimization for using a hierarchical output from different blocks
776
+
777
+ frames_num = x.shape[2]
778
+ x = self.patch_embed(x, longer_idx = longer_idx)
779
+ if self.ape:
780
+ x = x + self.absolute_pos_embed
781
+ x = self.pos_drop(x)
782
+ for i, layer in enumerate(self.layers):
783
+ x, attn = layer(x)
784
+ # for x
785
+ x = self.norm(x)
786
+ B, N, C = x.shape
787
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
788
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
789
+ x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
790
+ B, C, F, T = x.shape
791
+ # group 2D CNN
792
+ c_freq_bin = F // self.freq_ratio
793
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
794
+ x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
795
+ # get latent_output
796
+ fine_grained_latent_output = torch.mean(x, dim = 2)
797
+ fine_grained_latent_output = interpolate(fine_grained_latent_output.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
798
+
799
+ latent_output = self.avgpool(torch.flatten(x,2))
800
+ latent_output = torch.flatten(latent_output, 1)
801
+
802
+ # display the attention map, if needed
803
+
804
+ x = self.tscam_conv(x)
805
+ x = torch.flatten(x, 2) # B, C, T
806
+
807
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
808
+
809
+ x = self.avgpool(x)
810
+ x = torch.flatten(x, 1)
811
+
812
+ output_dict = {
813
+ 'framewise_output': fpx, # already sigmoided
814
+ 'clipwise_output': torch.sigmoid(x),
815
+ 'fine_grained_embedding': fine_grained_latent_output,
816
+ 'embedding': latent_output
817
+ }
818
+
819
+ return output_dict
820
+
821
+ def crop_wav(self, x, crop_size, spe_pos = None):
822
+ time_steps = x.shape[2]
823
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
824
+ for i in range(len(x)):
825
+ if spe_pos is None:
826
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
827
+ else:
828
+ crop_pos = spe_pos
829
+ tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
830
+ return tx
831
+
832
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
833
+ def reshape_wav2img(self, x):
834
+ B, C, T, F = x.shape
835
+ target_T = int(self.spec_size * self.freq_ratio)
836
+ target_F = self.spec_size // self.freq_ratio
837
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
838
+ # to avoid bicubic zero error
839
+ if T < target_T:
840
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
841
+ if F < target_F:
842
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
843
+ x = x.permute(0,1,3,2).contiguous()
844
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
845
+ # print(x.shape)
846
+ x = x.permute(0,1,3,2,4).contiguous()
847
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
848
+ return x
849
+
850
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
851
+ def repeat_wat2img(self, x, cur_pos):
852
+ B, C, T, F = x.shape
853
+ target_T = int(self.spec_size * self.freq_ratio)
854
+ target_F = self.spec_size // self.freq_ratio
855
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
856
+ # to avoid bicubic zero error
857
+ if T < target_T:
858
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
859
+ if F < target_F:
860
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
861
+ x = x.permute(0,1,3,2).contiguous() # B C F T
862
+ x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
863
+ x = x.repeat(repeats = (1,1,4,1))
864
+ return x
865
+
866
+ def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False, device=None):# out_feat_keys: List[str] = None):
867
+
868
+ if self.enable_fusion and x["longer"].sum() == 0:
869
+ # if no audio is longer than 10s, then randomly select one audio to be longer
870
+ if self.training:
871
+ x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
872
+ else:
873
+ x = x["mel_fusion"].to(device=device, non_blocking=True)
874
+ x = x.transpose(1, 3)
875
+ x = self.bn0(x)
876
+ x = x.transpose(1, 3)
877
+ x = self.reshape_wav2img(x)
878
+ output_dict = self.forward_features(x, longer_idx=[])
879
+ return output_dict
880
+
881
+ if not self.enable_fusion:
882
+ x = x["waveform"].to(device=device, non_blocking=True)
883
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
884
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
885
+ x = x.transpose(1, 3)
886
+ x = self.bn0(x)
887
+ x = x.transpose(1, 3)
888
+ if self.training:
889
+ x = self.spec_augmenter(x)
890
+
891
+ if self.training and mixup_lambda is not None:
892
+ x = do_mixup(x, mixup_lambda)
893
+
894
+ x = self.reshape_wav2img(x)
895
+ output_dict = self.forward_features(x)
896
+ else:
897
+ longer_list = x["longer"].to(device=device, non_blocking=True)
898
+ x = x["mel_fusion"].to(device=device, non_blocking=True)
899
+ x = x.transpose(1, 3)
900
+ x = self.bn0(x)
901
+ x = x.transpose(1, 3)
902
+ longer_list_idx = torch.where(longer_list)[0]
903
+ if self.fusion_type in ['daf_1d','aff_1d','iaff_1d']:
904
+ new_x = x[:,0:1,:,:].clone().contiguous()
905
+ if len(longer_list_idx) > 0:
906
+ # local processing
907
+ fusion_x_local = x[longer_list_idx,1:,:,:].clone().contiguous()
908
+ FB,FC,FT,FF = fusion_x_local.size()
909
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
910
+ fusion_x_local = torch.permute(fusion_x_local, (0,2,1)).contiguous()
911
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
912
+ fusion_x_local = fusion_x_local.view(FB,FC,FF,fusion_x_local.size(-1))
913
+ fusion_x_local = torch.permute(fusion_x_local, (0,2,1,3)).contiguous().flatten(2)
914
+ if fusion_x_local.size(-1) < FT:
915
+ fusion_x_local = torch.cat([fusion_x_local, torch.zeros((FB,FF,FT- fusion_x_local.size(-1)), device=device)], dim=-1)
916
+ else:
917
+ fusion_x_local = fusion_x_local[:,:,:FT]
918
+ # 1D fusion
919
+ new_x = new_x.squeeze(1).permute((0,2,1)).contiguous()
920
+ new_x[longer_list_idx] = self.fusion_model(new_x[longer_list_idx], fusion_x_local)
921
+ x = new_x.permute((0,2,1)).contiguous()[:,None,:,:]
922
+ else:
923
+ x = new_x
924
+
925
+ elif self.fusion_type in ['daf_2d','aff_2d','iaff_2d','channel_map']:
926
+ x = x # no change
927
+
928
+ if self.training:
929
+ x = self.spec_augmenter(x)
930
+ if self.training and mixup_lambda is not None:
931
+ x = do_mixup(x, mixup_lambda)
932
+
933
+ x = self.reshape_wav2img(x)
934
+ output_dict = self.forward_features(x, longer_idx = longer_list_idx)
935
+
936
+ # if infer_mode:
937
+ # # in infer mode. we need to handle different length audio input
938
+ # frame_num = x.shape[2]
939
+ # target_T = int(self.spec_size * self.freq_ratio)
940
+ # repeat_ratio = math.floor(target_T / frame_num)
941
+ # x = x.repeat(repeats=(1,1,repeat_ratio,1))
942
+ # x = self.reshape_wav2img(x)
943
+ # output_dict = self.forward_features(x)
944
+ # else:
945
+ # if x.shape[2] > self.freq_ratio * self.spec_size:
946
+ # if self.training:
947
+ # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
948
+ # x = self.reshape_wav2img(x)
949
+ # output_dict = self.forward_features(x)
950
+ # else:
951
+ # # Change: Hard code here
952
+ # overlap_size = (x.shape[2] - 1) // 4
953
+ # output_dicts = []
954
+ # crop_size = (x.shape[2] - 1) // 2
955
+ # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
956
+ # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
957
+ # tx = self.reshape_wav2img(tx)
958
+ # output_dicts.append(self.forward_features(tx))
959
+ # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
960
+ # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
961
+ # for d in output_dicts:
962
+ # clipwise_output += d["clipwise_output"]
963
+ # framewise_output += d["framewise_output"]
964
+ # clipwise_output = clipwise_output / len(output_dicts)
965
+ # framewise_output = framewise_output / len(output_dicts)
966
+ # output_dict = {
967
+ # 'framewise_output': framewise_output,
968
+ # 'clipwise_output': clipwise_output
969
+ # }
970
+ # else: # this part is typically used, and most easy one
971
+ # x = self.reshape_wav2img(x)
972
+ # output_dict = self.forward_features(x)
973
+ # x = self.head(x)
974
+
975
+ # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
976
+
977
+
978
+
979
+ return output_dict
980
+
981
+ def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type='None'):
982
+ try:
983
+
984
+ assert audio_cfg.model_name in ["tiny", "base", "large"], "model name for HTS-AT is wrong!"
985
+ if audio_cfg.model_name == "tiny":
986
+ model = HTSAT_Swin_Transformer(
987
+ spec_size=256,
988
+ patch_size=4,
989
+ patch_stride=(4,4),
990
+ num_classes=audio_cfg.class_num,
991
+ embed_dim=96,
992
+ depths=[2,2,6,2],
993
+ num_heads=[4,8,16,32],
994
+ window_size=8,
995
+ config = audio_cfg,
996
+ enable_fusion = enable_fusion,
997
+ fusion_type = fusion_type
998
+ )
999
+ elif audio_cfg.model_name == "base":
1000
+ model = HTSAT_Swin_Transformer(
1001
+ spec_size=256,
1002
+ patch_size=4,
1003
+ patch_stride=(4,4),
1004
+ num_classes=audio_cfg.class_num,
1005
+ embed_dim=128,
1006
+ depths=[2,2,12,2],
1007
+ num_heads=[4,8,16,32],
1008
+ window_size=8,
1009
+ config = audio_cfg,
1010
+ enable_fusion = enable_fusion,
1011
+ fusion_type = fusion_type
1012
+ )
1013
+ elif audio_cfg.model_name == "large":
1014
+ model = HTSAT_Swin_Transformer(
1015
+ spec_size=256,
1016
+ patch_size=4,
1017
+ patch_stride=(4,4),
1018
+ num_classes=audio_cfg.class_num,
1019
+ embed_dim=256,
1020
+ depths=[2,2,12,2],
1021
+ num_heads=[4,8,16,32],
1022
+ window_size=8,
1023
+ config = audio_cfg,
1024
+ enable_fusion = enable_fusion,
1025
+ fusion_type = fusion_type
1026
+ )
1027
+
1028
+ return model
1029
+ except:
1030
+ raise RuntimeError(f'Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough.')
1031
+
src/laion_clap/clap_module/linear_probe.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from .model import MLPLayers
5
+
6
+
7
+ class LinearProbe(nn.Module):
8
+ def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None):
9
+ """
10
+ Args:
11
+ model: nn.Module
12
+ mlp: bool, if True, then use the MLP layer as the linear probe module
13
+ freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe
14
+ in_ch: int, the output channel from CLAP model
15
+ out_ch: int, the output channel from linear probe (class_num)
16
+ act: torch.nn.functional, the activation function before the loss function
17
+ """
18
+ super().__init__()
19
+ in_ch = 512
20
+ self.clap_model = model
21
+ self.clap_model.text_branch = None # to save memory
22
+ self.freeze = freeze
23
+ if mlp:
24
+ self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch])
25
+ else:
26
+ self.lp_layer = nn.Linear(in_ch, out_ch)
27
+
28
+ if self.freeze:
29
+ for param in self.clap_model.parameters():
30
+ param.requires_grad = False
31
+
32
+ if act == 'None':
33
+ self.act = None
34
+ elif act == 'relu':
35
+ self.act = nn.ReLU()
36
+ elif act == 'elu':
37
+ self.act = nn.ELU()
38
+ elif act == 'prelu':
39
+ self.act = nn.PReLU(num_parameters=in_ch)
40
+ elif act == 'softmax':
41
+ self.act = nn.Softmax(dim=-1)
42
+ elif act == 'sigmoid':
43
+ self.act = nn.Sigmoid()
44
+
45
+ def forward(self, x, mix_lambda=None, device=None):
46
+ """
47
+ Args:
48
+ x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list
49
+ mix_lambda: torch.tensor [batch], the mixup lambda
50
+ Returns:
51
+ class_prob: torch.tensor [batch, class_num]
52
+
53
+ """
54
+ # batchnorm cancel grandient
55
+ if self.freeze:
56
+ self.clap_model.eval()
57
+
58
+ x = self.clap_model.audio_projection(
59
+ self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)["embedding"])
60
+ out = self.lp_layer(x)
61
+ if self.act is not None:
62
+ out = self.act(out)
63
+ return out
src/laion_clap/clap_module/loss.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing.sharedctypes import Value
2
+ import torch
3
+ import torch.distributed.nn
4
+ from torch import distributed as dist, nn as nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
8
+
9
+ try:
10
+ import horovod.torch as hvd
11
+ except ImportError:
12
+ hvd = None
13
+
14
+
15
+ def gather_features(
16
+ audio_features,
17
+ text_features,
18
+ audio_features_mlp=None,
19
+ text_features_mlp=None,
20
+ local_loss=False,
21
+ gather_with_grad=False,
22
+ rank=0,
23
+ world_size=1,
24
+ use_horovod=False,
25
+ mlp_loss=False
26
+ ):
27
+ if use_horovod:
28
+ assert hvd is not None, 'Please install horovod'
29
+ if gather_with_grad:
30
+ all_audio_features = hvd.allgather(audio_features)
31
+ all_text_features = hvd.allgather(text_features)
32
+ if mlp_loss:
33
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
34
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
35
+ else:
36
+ with torch.no_grad():
37
+ all_audio_features = hvd.allgather(audio_features)
38
+ all_text_features = hvd.allgather(text_features)
39
+ if mlp_loss:
40
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
41
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
42
+ if not local_loss:
43
+ # ensure grads for local rank when all_* features don't have a gradient
44
+ gathered_audio_features = list(all_audio_features.chunk(world_size, dim=0))
45
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
46
+ gathered_audio_features[rank] = audio_features
47
+ gathered_text_features[rank] = text_features
48
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
49
+ all_text_features = torch.cat(gathered_text_features, dim=0)
50
+ if mlp_loss:
51
+ gathered_audio_features_mlp = list(all_audio_features_mlp.chunk(world_size, dim=0))
52
+ gathered_text_features_mlp = list(all_text_features_mlp.chunk(world_size, dim=0))
53
+ gathered_audio_features_mlp[rank] = audio_features_mlp
54
+ gathered_text_features_mlp[rank] = text_features_mlp
55
+ all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
56
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
57
+ else:
58
+ # We gather tensors from all gpus
59
+ if gather_with_grad:
60
+ all_audio_features = torch.cat(torch.distributed.nn.all_gather(audio_features), dim=0)
61
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
62
+ if mlp_loss:
63
+ all_audio_features_mlp = torch.cat(torch.distributed.nn.all_gather(audio_features_mlp), dim=0)
64
+ all_text_features_mlp = torch.cat(torch.distributed.nn.all_gather(text_features_mlp), dim=0)
65
+ else:
66
+ gathered_audio_features = [torch.zeros_like(audio_features) for _ in range(world_size)]
67
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
68
+ dist.all_gather(gathered_audio_features, audio_features)
69
+ dist.all_gather(gathered_text_features, text_features)
70
+ if mlp_loss:
71
+ gathered_audio_features_mlp = [torch.zeros_like(audio_features_mlp) for _ in range(world_size)]
72
+ gathered_text_features_mlp = [torch.zeros_like(text_features_mlp) for _ in range(world_size)]
73
+ dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
74
+ dist.all_gather(gathered_text_features_mlp, text_features_mlp)
75
+ if not local_loss:
76
+ # ensure grads for local rank when all_* features don't have a gradient
77
+ gathered_audio_features[rank] = audio_features
78
+ gathered_text_features[rank] = text_features
79
+ if mlp_loss:
80
+ gathered_audio_features_mlp[rank] = audio_features_mlp
81
+ gathered_text_features_mlp[rank] = text_features_mlp
82
+
83
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
84
+ all_text_features = torch.cat(gathered_text_features, dim=0)
85
+ if mlp_loss:
86
+ all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
87
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
88
+ if mlp_loss:
89
+ return all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp
90
+ else:
91
+ return all_audio_features, all_text_features
92
+
93
+ class ClipLoss(nn.Module):
94
+
95
+ def __init__(
96
+ self,
97
+ local_loss=False,
98
+ gather_with_grad=False,
99
+ cache_labels=False,
100
+ rank=0,
101
+ world_size=1,
102
+ use_horovod=False,
103
+ mlp_loss=False,
104
+ weight_loss_kappa=0,
105
+ ):
106
+ super().__init__()
107
+ self.local_loss = local_loss
108
+ self.gather_with_grad = gather_with_grad
109
+ self.cache_labels = cache_labels
110
+ self.rank = rank
111
+ self.world_size = world_size
112
+ self.use_horovod = use_horovod
113
+ self.mlp_loss = mlp_loss
114
+ self.weighted_loss = bool(weight_loss_kappa!=0)
115
+ self.weight_loss_kappa = weight_loss_kappa
116
+ # cache state
117
+ self.prev_num_logits = 0
118
+ self.labels = {}
119
+
120
+ def forward(self, audio_features, text_features, logit_scale_a, logit_scale_t=None, audio_features_mlp=None, text_features_mlp=None):
121
+ device = audio_features.device
122
+ if self.mlp_loss:
123
+ if self.world_size > 1:
124
+ all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = gather_features(
125
+ audio_features=audio_features,text_features=text_features,
126
+ audio_features_mlp=audio_features_mlp,text_features_mlp=text_features_mlp,
127
+ local_loss=self.local_loss,gather_with_grad=self.gather_with_grad,
128
+ rank=self.rank,world_size=self.world_size,use_horovod=self.use_horovod,
129
+ mlp_loss=self.mlp_loss
130
+ )
131
+ if self.local_loss:
132
+ a_logits_per_audio = logit_scale_a * audio_features @ all_text_features_mlp.T
133
+ a_logits_per_text = logit_scale_a * text_features_mlp @ all_audio_features.T
134
+ t_logits_per_audio = logit_scale_t * audio_features_mlp @ all_text_features.T
135
+ t_logits_per_text = logit_scale_t * text_features @ all_audio_features_mlp.T
136
+ else:
137
+ a_logits_per_audio = logit_scale_a * all_audio_features @ all_text_features_mlp.T
138
+ a_logits_per_text = a_logits_per_audio.T
139
+ t_logits_per_audio = logit_scale_t * all_audio_features_mlp @ all_text_features.T
140
+ t_logits_per_text = t_logits_per_audio.T
141
+ else:
142
+ a_logits_per_audio = logit_scale_a * audio_features @ text_features_mlp.T
143
+ a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
144
+ t_logits_per_audio = logit_scale_t * audio_features_mlp @ text_features.T
145
+ t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
146
+
147
+ # calculated ground-truth and cache if enabled
148
+ num_logits = a_logits_per_audio.shape[0]
149
+ if self.prev_num_logits != num_logits or device not in self.labels:
150
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
151
+ if self.world_size > 1 and self.local_loss:
152
+ labels = labels + num_logits * self.rank
153
+ if self.cache_labels:
154
+ self.labels[device] = labels
155
+ self.prev_num_logits = num_logits
156
+ else:
157
+ labels = self.labels[device]
158
+
159
+ if not self.weighted_loss:
160
+ total_loss = (
161
+ F.cross_entropy(a_logits_per_audio, labels) +
162
+ F.cross_entropy(a_logits_per_text, labels) +
163
+ F.cross_entropy(t_logits_per_audio, labels) +
164
+ F.cross_entropy(t_logits_per_text, labels)
165
+ ) / 4
166
+ else:
167
+ audio_weight = (audio_features@audio_features.T).detach()
168
+ audio_weight = (torch.exp(torch.sum(audio_weight, axis=1)/(self.weight_loss_kappa*len(audio_weight)))).detach()
169
+ text_weight = (text_features@text_features.T).detach()
170
+ text_weight = (torch.exp(torch.sum(text_weight, axis=1)/(self.weight_loss_kappa*len(text_features)))).detach()
171
+ total_loss = (
172
+ F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) +
173
+ F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) +
174
+ F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) +
175
+ F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
176
+ ) / 4
177
+ else:
178
+ if self.world_size > 1:
179
+ all_audio_features, all_text_features = gather_features(
180
+ audio_features=audio_features,text_features=text_features,
181
+ local_loss=self.local_loss,gather_with_grad=self.gather_with_grad,
182
+ rank=self.rank,world_size=self.world_size,use_horovod=self.use_horovod,
183
+ mlp_loss=self.mlp_loss
184
+ )
185
+
186
+ if self.local_loss:
187
+ logits_per_audio = logit_scale_a * audio_features @ all_text_features.T
188
+ logits_per_text = logit_scale_a * text_features @ all_audio_features.T
189
+ else:
190
+ logits_per_audio = logit_scale_a * all_audio_features @ all_text_features.T
191
+ logits_per_text = logits_per_audio.T
192
+ else:
193
+ logits_per_audio = logit_scale_a * audio_features @ text_features.T
194
+ logits_per_text = logit_scale_a * text_features @ audio_features.T
195
+
196
+ # calculated ground-truth and cache if enabled
197
+ num_logits = logits_per_audio.shape[0]
198
+ if self.prev_num_logits != num_logits or device not in self.labels:
199
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
200
+ if self.world_size > 1 and self.local_loss:
201
+ labels = labels + num_logits * self.rank
202
+ if self.cache_labels:
203
+ self.labels[device] = labels
204
+ self.prev_num_logits = num_logits
205
+ else:
206
+ labels = self.labels[device]
207
+ if not self.weighted_loss:
208
+ total_loss = (
209
+ F.cross_entropy(logits_per_audio, labels) +
210
+ F.cross_entropy(logits_per_text, labels)
211
+ ) / 2
212
+ else:
213
+ audio_weight = (all_audio_features@all_audio_features.T).detach()
214
+ audio_weight = (torch.exp(torch.sum(audio_weight, axis=1)/(self.weight_loss_kappa*len(all_audio_features)))).detach()
215
+ text_weight = (all_text_features@all_text_features.T).detach()
216
+ text_weight = (torch.exp(torch.sum(text_weight, axis=1)/(self.weight_loss_kappa*len(all_text_features)))).detach()
217
+ total_loss = (
218
+ F.cross_entropy(logits_per_audio, labels, weight=text_weight) +
219
+ F.cross_entropy(logits_per_text, labels, weight=audio_weight)
220
+ ) / 2
221
+ return total_loss
222
+
223
+ def lp_gather_features(
224
+ pred,
225
+ target,
226
+ world_size=1,
227
+ use_horovod=False
228
+ ):
229
+ if use_horovod:
230
+ assert hvd is not None, 'Please install horovod'
231
+ with torch.no_grad():
232
+ all_preds = hvd.allgather(pred)
233
+ all_targets = hvd.allgath(target)
234
+ else:
235
+ gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
236
+ gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
237
+
238
+ dist.all_gather(gathered_preds, pred)
239
+ dist.all_gather(gathered_targets, target)
240
+ all_preds = torch.cat(gathered_preds, dim=0)
241
+ all_targets = torch.cat(gathered_targets, dim=0)
242
+
243
+ return all_preds, all_targets
244
+
245
+
246
+ def get_map(pred, target):
247
+ pred = torch.sigmoid(pred).numpy()
248
+ target = target.numpy()
249
+ return np.mean(average_precision_score(target, pred, average=None))
250
+
251
+ def get_acc(pred, target):
252
+ pred = torch.argmax(pred,1).numpy()
253
+ target = torch.argmax(target,1).numpy()
254
+ return accuracy_score(target, pred)
255
+
256
+ def get_mauc(pred, target):
257
+ pred = torch.sigmoid(pred).numpy()
258
+ target = target.numpy()
259
+ return np.mean(roc_auc_score(target, pred, average=None))
260
+
261
+
262
+ class LPMetrics(object):
263
+ def __init__(self, metric_names = ['map','acc','mauc']):
264
+ self.metrics = []
265
+ for name in metric_names:
266
+ self.metrics.append(self.get_metric(name))
267
+ self.metric_names = metric_names
268
+
269
+ def get_metric(self,name):
270
+ if name == 'map':
271
+ return get_map
272
+ elif name == 'acc':
273
+ return get_acc
274
+ elif name == 'mauc':
275
+ return get_mauc
276
+ else:
277
+ raise ValueError(f'the metric should be at least one of [map, acc, mauc]')
278
+
279
+ def evaluate_mertics(self, pred, target):
280
+ metric_dict = {}
281
+ for i in range(len(self.metric_names)):
282
+ metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
283
+ return metric_dict
284
+
285
+
286
+ def calc_celoss(pred, target):
287
+ target = torch.argmax(target, 1).long()
288
+ return nn.CrossEntropyLoss()(pred, target)
289
+
290
+
291
+ class LPLoss(nn.Module):
292
+
293
+ def __init__(self, loss_name):
294
+ super().__init__()
295
+ if loss_name == 'bce':
296
+ self.loss_func = nn.BCEWithLogitsLoss()
297
+ elif loss_name == 'ce':
298
+ self.loss_func = calc_celoss
299
+ elif loss_name == 'mse':
300
+ self.loss_func = nn.MSELoss()
301
+ else:
302
+ raise ValueError(f'the loss func should be at least one of [bce, ce, mse]')
303
+
304
+ def forward(self, pred, target):
305
+ loss = self.loss_func(pred, target)
306
+ return loss
307
+
src/laion_clap/clap_module/model.py ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLAP Model
2
+
3
+ Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ Adapted to the Audio Task.
5
+ """
6
+
7
+ from collections import OrderedDict
8
+ from dataclasses import dataclass
9
+ from email.mime import audio
10
+ from typing import Tuple, Union, Callable, Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import nn
16
+
17
+ from .timm_model import TimmModel
18
+ import logging
19
+ from .utils import freeze_batch_norm_2d
20
+
21
+ from .pann_model import create_pann_model
22
+ from .htsat import create_htsat_model
23
+ from transformers import BertModel, RobertaModel, BartModel
24
+ from transformers.tokenization_utils_base import BatchEncoding
25
+
26
+
27
+ class MLPLayers(nn.Module):
28
+ def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
29
+ super(MLPLayers, self).__init__()
30
+ self.nonlin = nonlin
31
+ self.dropout = dropout
32
+
33
+ sequence = []
34
+ for u0, u1 in zip(units[:-1], units[1:]):
35
+ sequence.append(nn.Linear(u0, u1))
36
+ sequence.append(self.nonlin)
37
+ sequence.append(nn.Dropout(self.dropout))
38
+ sequence = sequence[:-2]
39
+
40
+ self.sequential = nn.Sequential(*sequence)
41
+
42
+ def forward(self, X):
43
+ X = self.sequential(X)
44
+ return X
45
+
46
+
47
+ class Bottleneck(nn.Module):
48
+ expansion = 4
49
+
50
+ def __init__(self, inplanes, planes, stride=1):
51
+ super().__init__()
52
+
53
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
54
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
55
+ self.bn1 = nn.BatchNorm2d(planes)
56
+
57
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
58
+ self.bn2 = nn.BatchNorm2d(planes)
59
+
60
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
61
+
62
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
63
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
64
+
65
+ self.relu = nn.ReLU(inplace=True)
66
+ self.downsample = None
67
+ self.stride = stride
68
+
69
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
70
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
71
+ self.downsample = nn.Sequential(
72
+ OrderedDict(
73
+ [
74
+ ("-1", nn.AvgPool2d(stride)),
75
+ (
76
+ "0",
77
+ nn.Conv2d(
78
+ inplanes,
79
+ planes * self.expansion,
80
+ 1,
81
+ stride=1,
82
+ bias=False,
83
+ ),
84
+ ),
85
+ ("1", nn.BatchNorm2d(planes * self.expansion)),
86
+ ]
87
+ )
88
+ )
89
+
90
+ def forward(self, x: torch.Tensor):
91
+ identity = x
92
+
93
+ out = self.relu(self.bn1(self.conv1(x)))
94
+ out = self.relu(self.bn2(self.conv2(out)))
95
+ out = self.avgpool(out)
96
+ out = self.bn3(self.conv3(out))
97
+
98
+ if self.downsample is not None:
99
+ identity = self.downsample(x)
100
+
101
+ out += identity
102
+ out = self.relu(out)
103
+ return out
104
+
105
+
106
+ class AttentionPool2d(nn.Module):
107
+ def __init__(
108
+ self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
109
+ ):
110
+ super().__init__()
111
+ self.positional_embedding = nn.Parameter(
112
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
113
+ )
114
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
115
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
116
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
117
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
118
+ self.num_heads = num_heads
119
+
120
+ def forward(self, x):
121
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
122
+ 2, 0, 1
123
+ ) # NCHW -> (HW)NC
124
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
125
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
126
+ x, _ = F.multi_head_attention_forward(
127
+ query=x,
128
+ key=x,
129
+ value=x,
130
+ embed_dim_to_check=x.shape[-1],
131
+ num_heads=self.num_heads,
132
+ q_proj_weight=self.q_proj.weight,
133
+ k_proj_weight=self.k_proj.weight,
134
+ v_proj_weight=self.v_proj.weight,
135
+ in_proj_weight=None,
136
+ in_proj_bias=torch.cat(
137
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
138
+ ),
139
+ bias_k=None,
140
+ bias_v=None,
141
+ add_zero_attn=False,
142
+ dropout_p=0,
143
+ out_proj_weight=self.c_proj.weight,
144
+ out_proj_bias=self.c_proj.bias,
145
+ use_separate_proj_weight=True,
146
+ training=self.training,
147
+ need_weights=False,
148
+ )
149
+
150
+ return x[0]
151
+
152
+
153
+ class ModifiedResNet(nn.Module):
154
+ """
155
+ A ResNet class that is similar to torchvision's but contains the following changes:
156
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
157
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
158
+ - The final pooling layer is a QKV attention instead of an average pool
159
+ """
160
+
161
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
162
+ super().__init__()
163
+ self.output_dim = output_dim
164
+ self.image_size = image_size
165
+
166
+ # the 3-layer stem
167
+ self.conv1 = nn.Conv2d(
168
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
169
+ )
170
+ self.bn1 = nn.BatchNorm2d(width // 2)
171
+ self.conv2 = nn.Conv2d(
172
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False
173
+ )
174
+ self.bn2 = nn.BatchNorm2d(width // 2)
175
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
176
+ self.bn3 = nn.BatchNorm2d(width)
177
+ self.avgpool = nn.AvgPool2d(2)
178
+ self.relu = nn.ReLU(inplace=True)
179
+
180
+ # residual layers
181
+ self._inplanes = width # this is a *mutable* variable used during construction
182
+ self.layer1 = self._make_layer(width, layers[0])
183
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
184
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
185
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
186
+
187
+ embed_dim = width * 32 # the ResNet feature dimension
188
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
189
+
190
+ self.init_parameters()
191
+
192
+ def _make_layer(self, planes, blocks, stride=1):
193
+ layers = [Bottleneck(self._inplanes, planes, stride)]
194
+
195
+ self._inplanes = planes * Bottleneck.expansion
196
+ for _ in range(1, blocks):
197
+ layers.append(Bottleneck(self._inplanes, planes))
198
+
199
+ return nn.Sequential(*layers)
200
+
201
+ def init_parameters(self):
202
+ if self.attnpool is not None:
203
+ std = self.attnpool.c_proj.in_features**-0.5
204
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
205
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
206
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
207
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
208
+
209
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
210
+ for name, param in resnet_block.named_parameters():
211
+ if name.endswith("bn3.weight"):
212
+ nn.init.zeros_(param)
213
+
214
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
215
+ assert (
216
+ unlocked_groups == 0
217
+ ), "partial locking not currently supported for this model"
218
+ for param in self.parameters():
219
+ param.requires_grad = False
220
+ if freeze_bn_stats:
221
+ freeze_batch_norm_2d(self)
222
+
223
+ def stem(self, x):
224
+ for conv, bn in [
225
+ (self.conv1, self.bn1),
226
+ (self.conv2, self.bn2),
227
+ (self.conv3, self.bn3),
228
+ ]:
229
+ x = self.relu(bn(conv(x)))
230
+ x = self.avgpool(x)
231
+ return x
232
+
233
+ def forward(self, x):
234
+ x = self.stem(x)
235
+ x = self.layer1(x)
236
+ x = self.layer2(x)
237
+ x = self.layer3(x)
238
+ x = self.layer4(x)
239
+ x = self.attnpool(x)
240
+
241
+ return x
242
+
243
+
244
+ class LayerNorm(nn.LayerNorm):
245
+ """Subclass torch's LayerNorm to handle fp16."""
246
+
247
+ def forward(self, x: torch.Tensor):
248
+ orig_type = x.dtype
249
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
250
+ return x.to(orig_type)
251
+
252
+
253
+ class QuickGELU(nn.Module):
254
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
255
+ def forward(self, x: torch.Tensor):
256
+ return x * torch.sigmoid(1.702 * x)
257
+
258
+
259
+ class ResidualAttentionBlock(nn.Module):
260
+ def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
261
+ super().__init__()
262
+
263
+ self.attn = nn.MultiheadAttention(d_model, n_head)
264
+ self.ln_1 = LayerNorm(d_model)
265
+ self.mlp = nn.Sequential(
266
+ OrderedDict(
267
+ [
268
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
269
+ ("gelu", act_layer()),
270
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
271
+ ]
272
+ )
273
+ )
274
+ self.ln_2 = LayerNorm(d_model)
275
+
276
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
277
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
278
+
279
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
280
+ x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
281
+ x = x + self.mlp(self.ln_2(x))
282
+ return x
283
+
284
+
285
+ class Transformer(nn.Module):
286
+ def __init__(
287
+ self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
288
+ ):
289
+ super().__init__()
290
+ self.width = width
291
+ self.layers = layers
292
+ self.resblocks = nn.ModuleList(
293
+ [
294
+ ResidualAttentionBlock(width, heads, act_layer=act_layer)
295
+ for _ in range(layers)
296
+ ]
297
+ )
298
+
299
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
300
+ for r in self.resblocks:
301
+ x = r(x, attn_mask=attn_mask)
302
+ return x
303
+
304
+
305
+ class VisualTransformer(nn.Module):
306
+ def __init__(
307
+ self,
308
+ image_size: int,
309
+ patch_size: int,
310
+ width: int,
311
+ layers: int,
312
+ heads: int,
313
+ output_dim: int,
314
+ act_layer: Callable = nn.GELU,
315
+ ):
316
+ super().__init__()
317
+ self.image_size = image_size
318
+ self.output_dim = output_dim
319
+ self.conv1 = nn.Conv2d(
320
+ in_channels=3,
321
+ out_channels=width,
322
+ kernel_size=patch_size,
323
+ stride=patch_size,
324
+ bias=False,
325
+ )
326
+
327
+ scale = width**-0.5
328
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
329
+ self.positional_embedding = nn.Parameter(
330
+ scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
331
+ )
332
+ self.ln_pre = LayerNorm(width)
333
+
334
+ self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
335
+
336
+ self.ln_post = LayerNorm(width)
337
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
338
+
339
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
340
+ assert (
341
+ unlocked_groups == 0
342
+ ), "partial locking not currently supported for this model"
343
+ for param in self.parameters():
344
+ param.requires_grad = False
345
+
346
+ def forward(self, x: torch.Tensor):
347
+ x = self.conv1(x) # shape = [*, width, grid, grid]
348
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
349
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
350
+ x = torch.cat(
351
+ [
352
+ self.class_embedding.to(x.dtype)
353
+ + torch.zeros(
354
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
355
+ ),
356
+ x,
357
+ ],
358
+ dim=1,
359
+ ) # shape = [*, grid ** 2 + 1, width]
360
+ x = x + self.positional_embedding.to(x.dtype)
361
+ x = self.ln_pre(x)
362
+
363
+ x = x.permute(1, 0, 2) # NLD -> LND
364
+ x = self.text_branch(x)
365
+ x = x.permute(1, 0, 2) # LND -> NLD
366
+
367
+ x = self.ln_post(x[:, 0, :])
368
+
369
+ if self.proj is not None:
370
+ x = x @ self.proj
371
+
372
+ return x
373
+
374
+
375
+ @dataclass
376
+ class CLAPVisionCfg:
377
+ layers: Union[Tuple[int, int, int, int], int] = 12
378
+ width: int = 768
379
+ patch_size: int = 16
380
+ image_size: Union[Tuple[int, int], int] = 224
381
+ timm_model_name: str = (
382
+ None # a valid model name overrides layers, width, patch_size
383
+ )
384
+ timm_model_pretrained: bool = (
385
+ False # use (imagenet) pretrained weights for named model
386
+ )
387
+ timm_pool: str = (
388
+ "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
389
+ )
390
+ timm_proj: str = (
391
+ "linear" # linear projection for timm model output ('linear', 'mlp', '')
392
+ )
393
+
394
+
395
+ # Audio Config Class
396
+ @dataclass
397
+ class CLAPAudioCfp:
398
+ model_type: str = "PANN"
399
+ model_name: str = "Cnn14"
400
+ sample_rate: int = 48000
401
+ # Param
402
+ audio_length: int = 1024
403
+ window_size: int = 1024
404
+ hop_size: int = 1024
405
+ fmin: int = 50
406
+ fmax: int = 14000
407
+ class_num: int = 527
408
+ mel_bins: int = 64
409
+ clip_samples: int = 480000
410
+
411
+
412
+ @dataclass
413
+ class CLAPTextCfg:
414
+ context_length: int
415
+ vocab_size: int
416
+ width: int
417
+ heads: int
418
+ layers: int
419
+ model_type: str
420
+
421
+
422
+ class CLAP(nn.Module):
423
+ def __init__(
424
+ self,
425
+ embed_dim: int,
426
+ audio_cfg: CLAPAudioCfp,
427
+ text_cfg: CLAPTextCfg,
428
+ quick_gelu: bool = False,
429
+ enable_fusion: bool = False,
430
+ fusion_type: str = 'None',
431
+ joint_embed_shape: int = 512,
432
+ mlp_act: str = 'relu',
433
+ ):
434
+ super().__init__()
435
+ if isinstance(audio_cfg, dict):
436
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
437
+ if isinstance(text_cfg, dict):
438
+ text_cfg = CLAPTextCfg(**text_cfg)
439
+
440
+ self.audio_cfg = audio_cfg
441
+ self.text_cfg = text_cfg
442
+ self.enable_fusion = enable_fusion
443
+ self.fusion_type = fusion_type
444
+ self.joint_embed_shape = joint_embed_shape
445
+ self.mlp_act = mlp_act
446
+
447
+
448
+ self.context_length = text_cfg.context_length
449
+
450
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
451
+ # memory efficient in recent PyTorch releases (>= 1.10).
452
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
453
+ act_layer = QuickGELU if quick_gelu else nn.GELU
454
+
455
+ if mlp_act == 'relu':
456
+ mlp_act_layer = nn.ReLU()
457
+ elif mlp_act == 'gelu':
458
+ mlp_act_layer = nn.GELU()
459
+ else:
460
+ raise NotImplementedError
461
+
462
+ # audio branch
463
+ # audio branch parameters
464
+ if audio_cfg.model_type == "PANN":
465
+ self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
466
+ elif audio_cfg.model_type == "HTSAT":
467
+ self.audio_branch = create_htsat_model(audio_cfg, enable_fusion, fusion_type)
468
+ else:
469
+ logging.error(f"Model config for {audio_cfg.model_type} not found")
470
+ raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
471
+
472
+ # text branch
473
+ # text branch parameters
474
+ if text_cfg.model_type == "transformer":
475
+ self.text_branch = Transformer(
476
+ width=text_cfg.width,
477
+ layers=text_cfg.layers,
478
+ heads=text_cfg.heads,
479
+ act_layer=act_layer,
480
+ )
481
+ self.vocab_size = text_cfg.vocab_size
482
+ self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
483
+ self.positional_embedding = nn.Parameter(
484
+ torch.empty(self.context_length, text_cfg.width)
485
+ )
486
+ self.ln_final = LayerNorm(text_cfg.width)
487
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
488
+ self.joint_embed_shape,
489
+ self.joint_embed_shape], dropout=0.1)
490
+ self.text_projection = nn.Sequential(
491
+ nn.Linear(text_cfg.width, self.joint_embed_shape),
492
+ mlp_act_layer,
493
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
494
+ )
495
+ elif text_cfg.model_type == "bert":
496
+ self.text_branch = BertModel.from_pretrained("bert-base-uncased")
497
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
498
+ self.joint_embed_shape,
499
+ self.joint_embed_shape], dropout=0.1)
500
+ self.text_projection = nn.Sequential(
501
+ nn.Linear(768, self.joint_embed_shape),
502
+ mlp_act_layer,
503
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
504
+ )
505
+ elif text_cfg.model_type == "roberta":
506
+ self.text_branch = RobertaModel.from_pretrained('roberta-base')
507
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
508
+ self.joint_embed_shape,
509
+ self.joint_embed_shape], dropout=0.1)
510
+ self.text_projection = nn.Sequential(
511
+ nn.Linear(768, self.joint_embed_shape),
512
+ mlp_act_layer,
513
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
514
+ )
515
+ elif text_cfg.model_type == "bart":
516
+ self.text_branch = BartModel.from_pretrained('facebook/bart-base')
517
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
518
+ self.joint_embed_shape,
519
+ self.joint_embed_shape], dropout=0.1)
520
+ self.text_projection = nn.Sequential(
521
+ nn.Linear(768, self.joint_embed_shape),
522
+ mlp_act_layer,
523
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
524
+ )
525
+ else:
526
+ logging.error(f"Model config for {text_cfg.model_type} not found")
527
+ raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
528
+ self.text_branch_type = text_cfg.model_type
529
+ # text branch parameters
530
+
531
+ # audio branch parameters
532
+ self.audio_transform = MLPLayers(units=[self.joint_embed_shape,
533
+ self.joint_embed_shape,
534
+ self.joint_embed_shape], dropout=0.1)
535
+
536
+ # below here is text branch parameters
537
+
538
+ # ============================================================================================================
539
+ self.audio_projection = nn.Sequential(
540
+ nn.Linear(embed_dim, self.joint_embed_shape),
541
+ mlp_act_layer,
542
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
543
+ )
544
+
545
+ self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
546
+ self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
547
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
548
+
549
+ self.init_text_branch_parameters()
550
+
551
+ def init_text_branch_parameters(self):
552
+ if self.text_branch_type == "transformer":
553
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
554
+ nn.init.normal_(self.positional_embedding, std=0.01)
555
+ proj_std = (self.text_branch.width**-0.5) * (
556
+ (2 * self.text_branch.layers) ** -0.5
557
+ )
558
+ attn_std = self.text_branch.width**-0.5
559
+ fc_std = (2 * self.text_branch.width) ** -0.5
560
+ for block in self.text_branch.resblocks:
561
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
562
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
563
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
564
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
565
+ if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
566
+ width = self.text_branch.embeddings.word_embeddings.weight.shape[-1]
567
+ elif self.text_branch_type == "bart":
568
+ width = self.text_branch.shared.weight.shape[-1]
569
+ else:
570
+ width = self.text_branch.width
571
+ nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
572
+ nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
573
+
574
+ # deprecated
575
+ # if hasattr(self.visual, 'init_parameters'):
576
+ # self.visual.init_parameters()
577
+
578
+ # if self.text_projection is not None:
579
+ # nn.init.normal_(self.text_projection, std=width**-0.5)
580
+
581
+ def build_attention_mask(self):
582
+ # lazily create causal attention mask, with full attention between the vision tokens
583
+ # pytorch uses additive attention mask; fill with -inf
584
+ mask = torch.empty(self.context_length, self.context_length)
585
+ mask.fill_(float("-inf"))
586
+ mask.triu_(1) # zero out the lower diagonal
587
+ return mask
588
+
589
+ def encode_audio(self, audio, device):
590
+ return self.audio_branch(audio, mixup_lambda=None, device=device) # mix lambda needs to add
591
+
592
+ # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
593
+ # tmp = {}
594
+ # for k in x[0].keys():
595
+ # tmp[k] = []
596
+ # for i in range(len(x)):
597
+ # tmp[k].append(x[i][k][:77])
598
+ # for k in x[0].keys():
599
+ # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
600
+ # return tmp
601
+
602
+ def encode_text(self, text, device):
603
+ if self.text_branch_type == "transformer":
604
+ text = text.to(device=device, non_blocking=True)
605
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
606
+
607
+ x = x + self.positional_embedding
608
+ x = x.permute(1, 0, 2) # NLD -> LND
609
+ x = self.text_branch(x, attn_mask=self.attn_mask)
610
+ x = x.permute(1, 0, 2) # LND -> NLD
611
+ x = self.ln_final(x)
612
+
613
+ # x.shape = [batch_size, n_ctx, transformer.width]
614
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
615
+ x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
616
+ elif self.text_branch_type == "bert":
617
+ # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
618
+ # text = BatchEncoding(text)
619
+ x = self.text_branch(
620
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
621
+ attention_mask=text["attention_mask"].to(
622
+ device=device, non_blocking=True
623
+ ),
624
+ token_type_ids=text["token_type_ids"].to(
625
+ device=device, non_blocking=True
626
+ ),
627
+ )["pooler_output"]
628
+ x = self.text_projection(x)
629
+ elif self.text_branch_type == "roberta":
630
+ x = self.text_branch(
631
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
632
+ attention_mask=text["attention_mask"].to(
633
+ device=device, non_blocking=True
634
+ ),
635
+ )["pooler_output"]
636
+ x = self.text_projection(x)
637
+ elif self.text_branch_type == "bart":
638
+ x = torch.mean(self.text_branch(
639
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
640
+ attention_mask=text["attention_mask"].to(
641
+ device=device, non_blocking=True
642
+ ),
643
+ )["encoder_last_hidden_state"],axis=1)
644
+ x = self.text_projection(x)
645
+ else:
646
+ logging.error(f"Model type {self.text_branch_type} not found")
647
+ raise RuntimeError(f"Model type {self.text_branch_type} not found.")
648
+ return x
649
+
650
+ def forward(self, audio, text, device=None):
651
+ """Forward audio and text into the CLAP
652
+
653
+ Parameters
654
+ ----------
655
+ audio: torch.Tensor (batch_size, audio_length)
656
+ the time-domain audio input / the batch of mel_spec and longer list.
657
+ text: torch.Tensor () // need to add
658
+ the text token input
659
+ """
660
+ if device is None:
661
+ if audio is not None:
662
+ device = audio.device
663
+ elif text is not None:
664
+ device = text.device
665
+ if audio is None and text is None:
666
+ # a hack to get the logit scale
667
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
668
+ elif audio is None:
669
+ return self.encode_text(text, device=device)
670
+ elif text is None:
671
+ return self.audio_projection(self.encode_audio(audio, device=device)["embedding"])
672
+ audio_features = self.audio_projection(self.encode_audio(audio, device=device)["embedding"])
673
+ audio_features = F.normalize(audio_features, dim=-1)
674
+
675
+ text_features = self.encode_text(
676
+ text, device=device
677
+ )
678
+ # print("text_features", text_features)
679
+ # print("text_features.shape", text_features.shape)
680
+ # print("text_features.type", type(text_features))
681
+ text_features = F.normalize(text_features, dim=-1)
682
+
683
+ audio_features_mlp = self.audio_transform(audio_features)
684
+ text_features_mlp = self.text_transform(text_features)
685
+ # Four outputs: audio features (basic & MLP), text features (basic & MLP)
686
+ return (
687
+ audio_features,
688
+ text_features,
689
+ audio_features_mlp,
690
+ text_features_mlp,
691
+ self.logit_scale_a.exp(),
692
+ self.logit_scale_t.exp(),
693
+ )
694
+
695
+ def get_logit_scale(self):
696
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
697
+
698
+ def get_text_embedding(self, data):
699
+ """Get the text embedding from the model
700
+
701
+ Parameters
702
+ ----------
703
+ data: torch.Tensor
704
+ a tensor of text embedding
705
+
706
+ Returns
707
+ ----------
708
+ text_embed: torch.Tensor
709
+ a tensor of text_embeds (N, D)
710
+
711
+ """
712
+ device = next(self.parameters()).device
713
+ for k in data:
714
+ data[k] = data[k].to(device)
715
+ text_embeds = self.encode_text(data, device=device)
716
+ text_embeds = F.normalize(text_embeds, dim=-1)
717
+
718
+ return text_embeds
719
+
720
+ def get_audio_embedding(self, data):
721
+ """Get the audio embedding from the model
722
+
723
+ Parameters
724
+ ----------
725
+ data: a list of dict
726
+ the audio input dict list from 'get_audio_feature' method
727
+
728
+ Returns
729
+ ----------
730
+ audio_embed: torch.Tensor
731
+ a tensor of audio_embeds (N, D)
732
+
733
+ """
734
+ device = next(self.parameters()).device
735
+ input_dict = {}
736
+ keys = data[0].keys()
737
+ for k in keys:
738
+ input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(device)
739
+ audio_embeds = self.encode_audio(input_dict, device=device)["embedding"]
740
+ audio_embeds = self.audio_projection(audio_embeds)
741
+ audio_embeds = F.normalize(audio_embeds, dim=-1)
742
+ return audio_embeds
743
+
744
+
745
+
746
+ def audio_infer(self, audio, hopsize=None, device=None):
747
+ """Forward one audio and produce the audio embedding
748
+
749
+ Parameters
750
+ ----------
751
+ audio: (audio_length)
752
+ the time-domain audio input, notice that it must be only one input
753
+ hopsize: int
754
+ the overlap hopsize as the sliding window
755
+
756
+ Returns
757
+ ----------
758
+ output_dict: {
759
+ key: [n, (embedding_shape)] if "HTS-AT"
760
+ or
761
+ key: [(embedding_shape)] if "PANN"
762
+ }
763
+ the list of key values of the audio branch
764
+
765
+ """
766
+
767
+ assert not self.training, "the inference mode must be run at eval stage"
768
+ output_dict = {}
769
+ # PANN
770
+ if self.audio_cfg.model_type == "PANN":
771
+ audio_input = audio.unsqueeze(dim=0)
772
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key].squeeze(dim=0)
773
+ elif self.audio_cfg.model_type == "HTSAT":
774
+ # repeat
775
+ audio_len = len(audio)
776
+ k = self.audio_cfg.clip_samples // audio_len
777
+ if k > 1:
778
+ audio = audio.repeat(k)
779
+ audio_len = len(audio)
780
+
781
+ if hopsize is None:
782
+ hopsize = min(hopsize, audio_len)
783
+
784
+ if audio_len > self.audio_cfg.clip_samples:
785
+ audio_input = [
786
+ audio[pos : pos + self.audio_cfg.clip_samples].clone()
787
+ for pos in range(
788
+ 0, audio_len - self.audio_cfg.clip_samples, hopsize
789
+ )
790
+ ]
791
+ audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
792
+ audio_input = torch.stack(audio_input)
793
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key]
794
+ else:
795
+ audio_input = audio.unsqueeze(dim=0)
796
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key].squeeze(dim=0)
797
+
798
+ return output_dict
799
+
800
+
801
+ def convert_weights_to_fp16(model: nn.Module):
802
+ """Convert applicable model parameters to fp16"""
803
+
804
+ def _convert_weights_to_fp16(l):
805
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
806
+ l.weight.data = l.weight.data.half()
807
+ if l.bias is not None:
808
+ l.bias.data = l.bias.data.half()
809
+
810
+ if isinstance(l, nn.MultiheadAttention):
811
+ for attr in [
812
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
813
+ "in_proj_bias",
814
+ "bias_k",
815
+ "bias_v",
816
+ ]:
817
+ tensor = getattr(l, attr)
818
+ if tensor is not None:
819
+ tensor.data = tensor.data.half()
820
+
821
+ for name in ["text_projection", "proj"]:
822
+ if hasattr(l, name):
823
+ attr = getattr(l, name)
824
+ if attr is not None:
825
+ attr.data = attr.data.half()
826
+
827
+ model.apply(_convert_weights_to_fp16)
828
+
829
+
830
+ # Ignore the state dict of the vision part
831
+ def build_model_from_openai_state_dict(state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = 'None'):
832
+
833
+ embed_dim = model_cfg["embed_dim"]
834
+ audio_cfg = model_cfg["audio_cfg"]
835
+ text_cfg = model_cfg["text_cfg"]
836
+ context_length = state_dict["positional_embedding"].shape[0]
837
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
838
+ transformer_width = state_dict["ln_final.weight"].shape[0]
839
+ transformer_heads = transformer_width // 64
840
+ transformer_layers = len(
841
+ set(
842
+ k.split(".")[2]
843
+ for k in state_dict
844
+ if k.startswith(f"transformer.resblocks")
845
+ )
846
+ )
847
+
848
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
849
+ text_cfg = CLAPTextCfg(**text_cfg)
850
+
851
+ model = CLAP(
852
+ embed_dim,
853
+ audio_cfg=audio_cfg,
854
+ text_cfg=text_cfg,
855
+ quick_gelu=True, # OpenAI models were trained with QuickGELU
856
+ enable_fusion=enable_fusion,
857
+ fusion_type=fusion_type
858
+ )
859
+ state_dict["logit_scale_a"] = state_dict["logit_scale"]
860
+ state_dict["logit_scale_t"] = state_dict["logit_scale"]
861
+ pop_keys = list(state_dict.keys())[::]
862
+ # pop the visual branch saved weights
863
+ for key in pop_keys:
864
+ if key.startswith("visual."):
865
+ state_dict.pop(key, None)
866
+
867
+ for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
868
+ state_dict.pop(key, None)
869
+
870
+ # not use fp16
871
+ # convert_weights_to_fp16(model)
872
+ model.load_state_dict(state_dict, strict=False)
873
+ return model.eval()
874
+
875
+
876
+ def trace_model(model, batch_size=256, device=torch.device("cpu")):
877
+ model.eval()
878
+ audio_length = model.audio_cfg.audio_length
879
+ example_audio = torch.ones((batch_size, audio_length), device=device)
880
+ example_text = torch.zeros(
881
+ (batch_size, model.context_length), dtype=torch.int, device=device
882
+ )
883
+ model = torch.jit.trace_module(
884
+ model,
885
+ inputs=dict(
886
+ forward=(example_audio, example_text),
887
+ encode_text=(example_text,),
888
+ encode_image=(example_audio,),
889
+ ),
890
+ )
891
+ model.audio_cfg.audio_length = audio_length # Question: what does this do?
892
+ return model
src/laion_clap/clap_module/model_configs/HTSAT-base.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "base"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
src/laion_clap/clap_module/model_configs/HTSAT-large.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "large"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
src/laion_clap/clap_module/model_configs/HTSAT-tiny-win-1536.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1536,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "tiny"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
src/laion_clap/clap_module/model_configs/HTSAT-tiny.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "tiny"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
src/laion_clap/clap_module/model_configs/PANN-10.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn10"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
src/laion_clap/clap_module/model_configs/PANN-14-fmax-18k.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 18000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
src/laion_clap/clap_module/model_configs/PANN-14-fmax-8k-20s.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 960000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 360,
10
+ "fmin": 50,
11
+ "fmax": 8000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
src/laion_clap/clap_module/model_configs/PANN-14-tiny-transformer.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 4
22
+ }
23
+ }
src/laion_clap/clap_module/model_configs/PANN-14-win-1536.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1536,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
src/laion_clap/clap_module/model_configs/PANN-14.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
src/laion_clap/clap_module/model_configs/PANN-6.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn6"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
src/laion_clap/clap_module/model_configs/RN101-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 23,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
src/laion_clap/clap_module/model_configs/RN101.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 23,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
src/laion_clap/clap_module/model_configs/RN50-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 6,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
src/laion_clap/clap_module/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
src/laion_clap/clap_module/model_configs/RN50x16.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 384,
5
+ "layers": [
6
+ 6,
7
+ 8,
8
+ 18,
9
+ 8
10
+ ],
11
+ "width": 96,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 768,
18
+ "heads": 12,
19
+ "layers": 12
20
+ }
21
+ }
src/laion_clap/clap_module/model_configs/RN50x4.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 288,
5
+ "layers": [
6
+ 4,
7
+ 6,
8
+ 10,
9
+ 6
10
+ ],
11
+ "width": 80,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 640,
18
+ "heads": 10,
19
+ "layers": 12
20
+ }
21
+ }
src/laion_clap/clap_module/model_configs/ViT-B-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
src/laion_clap/clap_module/model_configs/ViT-B-32-quickgelu.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": 12,
7
+ "width": 768,
8
+ "patch_size": 32
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ }
17
+ }
src/laion_clap/clap_module/model_configs/ViT-B-32.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
src/laion_clap/clap_module/model_configs/ViT-L-14.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
src/laion_clap/clap_module/openai.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ OpenAI pretrained model functions
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ from typing import Union, List
9
+
10
+ import torch
11
+
12
+ from .model import build_model_from_openai_state_dict
13
+ from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained
14
+
15
+ __all__ = ["list_openai_models", "load_openai_model"]
16
+
17
+
18
+ def list_openai_models() -> List[str]:
19
+ """Returns the names of available CLIP models"""
20
+ return list_pretrained_tag_models('openai')
21
+
22
+
23
+ def load_openai_model(
24
+ name: str,
25
+ model_cfg,
26
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
27
+ jit=True,
28
+ cache_dir=os.path.expanduser("~/.cache/clip"),
29
+ enable_fusion: bool = False,
30
+ fusion_type: str = 'None'
31
+ ):
32
+ """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
33
+
34
+ Parameters
35
+ ----------
36
+ name : str
37
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
38
+ device : Union[str, torch.device]
39
+ The device to put the loaded model
40
+ jit : bool
41
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42
+
43
+ Returns
44
+ -------
45
+ model : torch.nn.Module
46
+ The CLAP model
47
+ preprocess : Callable[[PIL.Image], torch.Tensor]
48
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
49
+ """
50
+ if get_pretrained_url(name, 'openai'):
51
+ model_path = download_pretrained(get_pretrained_url(name, 'openai'), root=cache_dir)
52
+ elif os.path.isfile(name):
53
+ model_path = name
54
+ else:
55
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
56
+
57
+ try:
58
+ # loading JIT archive
59
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
60
+ state_dict = None
61
+ except RuntimeError:
62
+ # loading saved state dict
63
+ if jit:
64
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
65
+ jit = False
66
+ state_dict = torch.load(model_path, map_location="cpu")
67
+
68
+ if not jit:
69
+ try:
70
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type).to(device)
71
+ except KeyError:
72
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
73
+ model = build_model_from_openai_state_dict(sd, model_cfg, enable_fusion, fusion_type).to(device)
74
+
75
+ if str(device) == "cpu":
76
+ model.float()
77
+ return model
78
+
79
+ # patch the device names
80
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
81
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
82
+
83
+ def patch_device(module):
84
+ try:
85
+ graphs = [module.graph] if hasattr(module, "graph") else []
86
+ except RuntimeError:
87
+ graphs = []
88
+
89
+ if hasattr(module, "forward1"):
90
+ graphs.append(module.forward1.graph)
91
+
92
+ for graph in graphs:
93
+ for node in graph.findAllNodes("prim::Constant"):
94
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
95
+ node.copyAttributes(device_node)
96
+
97
+ model.apply(patch_device)
98
+ patch_device(model.encode_audio)
99
+ patch_device(model.encode_text)
100
+
101
+ # patch dtype to float32 on CPU
102
+ if str(device) == "cpu":
103
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
104
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
105
+ float_node = float_input.node()
106
+
107
+ def patch_float(module):
108
+ try:
109
+ graphs = [module.graph] if hasattr(module, "graph") else []
110
+ except RuntimeError:
111
+ graphs = []
112
+
113
+ if hasattr(module, "forward1"):
114
+ graphs.append(module.forward1.graph)
115
+
116
+ for graph in graphs:
117
+ for node in graph.findAllNodes("aten::to"):
118
+ inputs = list(node.inputs())
119
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
120
+ if inputs[i].node()["value"] == 5:
121
+ inputs[i].node().copyAttributes(float_node)
122
+
123
+ model.apply(patch_float)
124
+ patch_float(model.encode_audio)
125
+ patch_float(model.encode_text)
126
+ model.float()
127
+
128
+ model.audio_branch.audio_length = model.audio_cfg.audio_length
129
+ return model