ASLP-lab commited on
Commit
70d8fcf
·
1 Parent(s): ef1d5fa
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +186 -0
  2. app.py +636 -0
  3. requirements.txt +86 -0
  4. src/SongFormer/ckpts/md5sum.txt +4 -0
  5. src/SongFormer/configs/SongFormer.yaml +186 -0
  6. src/SongFormer/dataset/DatasetAdaper.py +33 -0
  7. src/SongFormer/dataset/GeminiOnlyLabelAdapter.py +332 -0
  8. src/SongFormer/dataset/HookTheoryAdapter.py +448 -0
  9. src/SongFormer/dataset/custom_types.py +14 -0
  10. src/SongFormer/dataset/label2id.py +163 -0
  11. src/SongFormer/dataset/msa_info_utils.py +47 -0
  12. src/SongFormer/eval.sh +22 -0
  13. src/SongFormer/evaluation/eval_infer_results.py +198 -0
  14. src/SongFormer/infer.sh +21 -0
  15. src/SongFormer/infer/infer.py +439 -0
  16. src/SongFormer/models/SongFormer.py +521 -0
  17. src/SongFormer/postprocessing/calc_acc.py +82 -0
  18. src/SongFormer/postprocessing/calc_iou.py +89 -0
  19. src/SongFormer/postprocessing/functional.py +71 -0
  20. src/SongFormer/postprocessing/helpers.py +101 -0
  21. src/SongFormer/train/accelerate_config/single_gpu.yaml +17 -0
  22. src/SongFormer/utils/average_checkpoints.py +152 -0
  23. src/SongFormer/utils/convert_res2msa_txt.py +79 -0
  24. src/SongFormer/utils/fetch_pretrained.py +40 -0
  25. src/third_party/MuQ/.gitattributes +2 -0
  26. src/third_party/MuQ/.gitignore +46 -0
  27. src/third_party/MuQ/.gitmodules +3 -0
  28. src/third_party/MuQ/LICENSE +21 -0
  29. src/third_party/MuQ/LICENSE_weights +399 -0
  30. src/third_party/MuQ/README.md +129 -0
  31. src/third_party/MuQ/images/muq-logo.jpeg +0 -0
  32. src/third_party/MuQ/images/radar.jpg +3 -0
  33. src/third_party/MuQ/images/tab-marble.jpg +3 -0
  34. src/third_party/MuQ/images/tab-mulan.png +3 -0
  35. src/third_party/MuQ/images/tagging.jpg +3 -0
  36. src/third_party/MuQ/requirements.txt +11 -0
  37. src/third_party/MuQ/setup.py +34 -0
  38. src/third_party/MuQ/src/muq/__init__.py +2 -0
  39. src/third_party/MuQ/src/muq/muq/__init__.py +1 -0
  40. src/third_party/MuQ/src/muq/muq/models/__init__.py +0 -0
  41. src/third_party/MuQ/src/muq/muq/models/muq_model.py +366 -0
  42. src/third_party/MuQ/src/muq/muq/modules/__init__.py +2 -0
  43. src/third_party/MuQ/src/muq/muq/modules/conv.py +77 -0
  44. src/third_party/MuQ/src/muq/muq/modules/features.py +37 -0
  45. src/third_party/MuQ/src/muq/muq/modules/flash_conformer.py +2114 -0
  46. src/third_party/MuQ/src/muq/muq/modules/random_quantizer.py +68 -0
  47. src/third_party/MuQ/src/muq/muq/modules/rvq.py +314 -0
  48. src/third_party/MuQ/src/muq/muq/muq.py +90 -0
  49. src/third_party/MuQ/src/muq/muq_mulan/__init__.py +1 -0
  50. src/third_party/MuQ/src/muq/muq_mulan/models/__init__.py +0 -0
README.md ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SongFormer
3
+ emoji: 🎵
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ python_version: "3.10"
8
+ app_file: app.py
9
+ tags:
10
+ - music-structure-annotation
11
+ - transformer
12
+ short_description: State-of-the-art music analysis with multi-scale datasets
13
+ fullWidth: true
14
+ ---
15
+
16
+ <p align="center">
17
+ <img src="figs/logo.png" width="50%" />
18
+ </p>
19
+
20
+
21
+ # SONGFORMER: SCALING MUSIC STRUCTURE ANALYSIS WITH HETEROGENEOUS SUPERVISION
22
+
23
+ ![Python](https://img.shields.io/badge/Python-3.10-brightgreen)
24
+ ![License](https://img.shields.io/badge/License-CC%20BY%204.0-lightblue)
25
+ [![arXiv](https://img.shields.io/badge/arXiv-com.svg?logo=arXiv)]()
26
+ [![GitHub](https://img.shields.io/badge/GitHub-SongFormer-black)](https://github.com/ASLP-lab/SongFormer)
27
+ [![HuggingFace Space](https://img.shields.io/badge/HuggingFace-space-yellow)](https://huggingface.co/spaces/ASLP-lab/SongFormer)
28
+ [![HuggingFace Model](https://img.shields.io/badge/HuggingFace-model-blue)](https://huggingface.co/ASLP-lab/SongFormer)
29
+ [![Dataset SongFormDB](https://img.shields.io/badge/HF%20Dataset-SongFormDB-green)](https://huggingface.co/datasets/ASLP-lab/SongFormDB)
30
+ [![Dataset SongFormBench](https://img.shields.io/badge/HF%20Dataset-SongFormBench-orange)](https://huggingface.co/datasets/ASLP-lab/SongFormBench)
31
+ [![Discord](https://img.shields.io/badge/Discord-join%20us-purple?logo=discord&logoColor=white)](https://discord.gg/rwcqh7Em)
32
+ [![lab](https://img.shields.io/badge/🏫-ASLP-grey?labelColor=lightgrey)](http://www.npu-aslp.org/)
33
+
34
+ Chunbo Hao<sup>&ast;</sup>, Ruibin Yuan<sup>&ast;</sup>, Jixun Yao, Qixin Deng, Xinyi Bai, Wei Xue, Lei Xie<sup>&dagger;</sup>
35
+
36
+
37
+ ----
38
+
39
+
40
+ SongFormer is a music structure analysis framework that leverages multi-resolution self-supervised representations and heterogeneous supervision, accompanied by the large-scale multilingual dataset SongFormDB and the high-quality benchmark SongFormBench to foster fair and reproducible research.
41
+
42
+ ![](figs/songformer.png)
43
+
44
+ ## News and Updates
45
+
46
+ ## 📋 To-Do List
47
+
48
+ - [x] Complete and push inference code to GitHub
49
+ - [x] Upload model checkpoint(s) to Hugging Face Hub
50
+ - [ ] Upload the paper to arXiv
51
+ - [x] Fix readme
52
+ - [ ] Deploy an out-of-the-box inference version on Hugging Face (via Inference API or Spaces)
53
+ - [ ] Publish the package to PyPI for easy installation via `pip`
54
+ - [ ] Open-source evaluation code
55
+ - [ ] Open-source training code
56
+
57
+ ## Installation
58
+
59
+ ### Setting up Python Environment
60
+
61
+ ```bash
62
+ git clone https://github.com/ASLP-lab/SongFormer.git
63
+
64
+ # Get MuQ and MusicFM source code
65
+ git submodule update --init --recursive
66
+
67
+ conda create -n songformer python=3.10 -y
68
+ conda activate songformer
69
+ ```
70
+
71
+ For users in mainland China, you may need to set up pip mirror source:
72
+
73
+ ```bash
74
+ pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple
75
+ ```
76
+
77
+ Install dependencies:
78
+
79
+ ```bash
80
+ pip install -r requirements.txt
81
+ ```
82
+
83
+ We tested this on Ubuntu 22.04.1 LTS and it works normally. If you cannot install, you may need to remove version constraints in `requirements.txt`
84
+
85
+ ### Download Pre-trained Models
86
+
87
+ ```bash
88
+ cd src/SongFormer
89
+ # For users in mainland China, you can modify according to the py file instructions to use hf-mirror.com for downloading
90
+ python utils/fetch_pretrained.py
91
+ ```
92
+
93
+ After downloading, you can verify the md5sum values in `src/SongFormer/ckpts/MusicFM/md5sum.txt` match the downloaded files:
94
+
95
+ ```bash
96
+ md5sum ckpts/MusicFM/msd_stats.json
97
+ md5sum ckpts/MusicFM/pretrained_msd.pt
98
+ md5sum ckpts/SongFormer.safetensors
99
+ # md5sum ckpts/SongFormer.pt
100
+ ```
101
+
102
+ ## Inference
103
+
104
+ ## Inference
105
+
106
+ ### 1. One-Click Inference with HuggingFace Space (coming soon)
107
+
108
+ Available at: [https://huggingface.co/spaces/ASLP-lab/SongFormer](https://huggingface.co/spaces/ASLP-lab/SongFormer)
109
+
110
+ ### 2. Gradio App
111
+
112
+ First, cd to the project root directory and activate the environment:
113
+
114
+ ```bash
115
+ conda activate songformer
116
+ ```
117
+
118
+ You can modify the server port and listening address in the last line of `app.py` according to your preference.
119
+
120
+ > If you're using an HTTP proxy, please ensure you include:
121
+ >
122
+ > ```bash
123
+ > export no_proxy="localhost, 127.0.0.1, ::1"
124
+ > export NO_PROXY="localhost, 127.0.0.1, ::1"
125
+ > ```
126
+ >
127
+ > Otherwise, Gradio may incorrectly assume the service hasn't started, causing startup to exit directly.
128
+
129
+ When first running `app.py`, it will connect to Hugging Face to download MuQ-related weights. We recommend creating an empty folder in an appropriate location and using `export HF_HOME=XXX` to point to this folder, so cache will be stored there for easy cleanup and transfer.
130
+
131
+ And for users in mainland China, you may need `export HF_ENDPOINT=https://hf-mirror.com`. For details, refer to https://hf-mirror.com/
132
+
133
+ ```bash
134
+ python app.py
135
+ ```
136
+
137
+ ### 3. Python Code
138
+
139
+ You can refer to the file `src/SongFormer/infer/infer.py`. The corresponding execution script is located at `src/SongFormer/infer.sh`. This is a ready-to-use, single-machine, multi-process annotation script.
140
+
141
+ Below are some configurable parameters from the `src/SongFormer/infer.sh` script. You can set `CUDA_VISIBLE_DEVICES` to specify which GPUs to use:
142
+
143
+ ```bash
144
+ -i # Input SCP folder path, each line containing the absolute path to one audio file
145
+ -o # Output directory for annotation results
146
+ --model # Annotation model; the default is 'SongFormer', change if using a fine-tuned model
147
+ --checkpoint # Path to the model checkpoint file
148
+ --config_pat # Path to the configuration file
149
+ -gn # Total number of GPUs to use — should match the number specified in CUDA_VISIBLE_DEVICES
150
+ -tn # Number of processes to run per GPU
151
+ ```
152
+
153
+ You can control which GPUs are used by setting the `CUDA_VISIBLE_DEVICES` environment variable.
154
+
155
+ ### 4. CLI Inference
156
+
157
+ Coming soon
158
+
159
+ ### 4. Pitfall
160
+
161
+ - You may need to modify line 121 in `src/third_party/musicfm/model/musicfm_25hz.py` to:
162
+ `S = torch.load(model_path, weights_only=False)["state_dict"]`
163
+
164
+ ## Training
165
+
166
+ ## Citation
167
+
168
+ If our work and codebase is useful for you, please cite as:
169
+
170
+ ````
171
+ comming soon
172
+ ````
173
+ ## License
174
+
175
+ Our code is released under CC-BY-4.0 License.
176
+
177
+ ## Contact Us
178
+
179
+
180
+ <p align="center">
181
+ <a href="http://www.nwpu-aslp.org/">
182
+ <img src="figs/aslp.png" width="400"/>
183
+ </a>
184
+ </p>
185
+
186
+
app.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # import sys
3
+
4
+ # os.chdir(os.path.join("src", "SongFormer"))
5
+ # sys.path.append(os.path.join("..", "third_party"))
6
+ # sys.path.append(".")
7
+
8
+ import os
9
+ import sys
10
+ # 获取当前文件的绝对路径和脚本名称
11
+ current_file = os.path.abspath(__file__)
12
+ current_dir = os.path.dirname(current_file)
13
+ script_name = os.path.basename(__file__)
14
+ print(f"[INFO] 正在运行脚本:{script_name}")
15
+ print(f"[INFO] 当前文件所在目录为:{current_dir}")
16
+ # 设置工作目录为 `src/SongFormer`(如果该路径存在)
17
+ songformer_path = os.path.join(current_dir, "src", "SongFormer")
18
+ if os.path.exists(songformer_path):
19
+ os.chdir(songformer_path)
20
+ print(f"[INFO] 工作目录已修改为:{songformer_path}")
21
+ else:
22
+ print(f"[WARNING] 目标工作目录不存在:{songformer_path}")
23
+ # 获取当前工作目录,即运行 os.chdir 后的路径
24
+ working_dir = os.getcwd()
25
+ print(f"[INFO] 当前工作目录为:{working_dir}")
26
+ # 添加第三方库路径到 sys.path(third_party)
27
+ third_party_path = os.path.join(current_dir, "third_party")
28
+ if os.path.exists(third_party_path):
29
+ sys.path.insert(0, third_party_path)
30
+ print(f"[INFO] 已添加第三方库路径到 sys.path:{third_party_path}")
31
+ else:
32
+ print(f"[WARNING] third_party 路径不存在:{third_party_path}")
33
+ # 添加当前工作目录到 sys.path(通常是 src/SongFormer)
34
+ sys.path.insert(0, working_dir)
35
+ print(f"[INFO] 已添加当前工作目录到 sys.path:{working_dir}")
36
+ # 尝试添加多个可能用于 musicfm 导入的路径
37
+ musicfm_paths = [
38
+ os.path.join(current_dir, "src"),
39
+ os.path.join(current_dir, "third_party"),
40
+ os.path.join(current_dir, "src", "SongFormer"),
41
+ ]
42
+ for path in musicfm_paths:
43
+ if os.path.exists(path):
44
+ sys.path.insert(0, path)
45
+ print(f"[INFO] 已添加路径到 sys.path:{path}")
46
+ else:
47
+ print(f"[DEBUG] 路径不存在,跳过添加:{path}")
48
+ # 可选:打印 sys.path 的当前状态
49
+ print("\n[DEBUG] 当前 sys.path 设置如下:")
50
+ for idx, p in enumerate(sys.path):
51
+ print(f" {idx}: {p}")
52
+
53
+ # monkey patch to fix issues in msaf
54
+ import scipy
55
+ import numpy as np
56
+
57
+ scipy.inf = np.inf
58
+
59
+ import gradio as gr
60
+ import torch
61
+ import librosa
62
+ import json
63
+ import math
64
+ import importlib
65
+ import matplotlib.pyplot as plt
66
+ import matplotlib.ticker as ticker
67
+ from pathlib import Path
68
+ from argparse import Namespace
69
+ from omegaconf import OmegaConf
70
+ from ema_pytorch import EMA
71
+ from muq import MuQ
72
+ from musicfm.model.musicfm_25hz import MusicFM25Hz
73
+ from postprocessing.functional import postprocess_functional_structure
74
+ from dataset.label2id import DATASET_ID_ALLOWED_LABEL_IDS, DATASET_LABEL_TO_DATASET_ID
75
+ from utils.fetch_pretrained import download_all
76
+
77
+ # Constants
78
+ MUSICFM_HOME_PATH = os.path.join("ckpts", "MusicFM")
79
+ BEFORE_DOWNSAMPLING_FRAME_RATES = 25
80
+ AFTER_DOWNSAMPLING_FRAME_RATES = 8.333
81
+ DATASET_LABEL = "SongForm-HX-8Class"
82
+ DATASET_IDS = [5]
83
+ TIME_DUR = 420
84
+ INPUT_SAMPLING_RATE = 24000
85
+
86
+ # Global model variables
87
+ muq_model = None
88
+ musicfm_model = None
89
+ msa_model = None
90
+ device = None
91
+
92
+
93
+ def load_checkpoint(checkpoint_path, device=None):
94
+ """Load checkpoint from path"""
95
+ if device is None:
96
+ device = "cpu"
97
+
98
+ if checkpoint_path.endswith(".pt"):
99
+ checkpoint = torch.load(checkpoint_path, map_location=device)
100
+ elif checkpoint_path.endswith(".safetensors"):
101
+ from safetensors.torch import load_file
102
+
103
+ checkpoint = {"model_ema": load_file(checkpoint_path, device=device)}
104
+ else:
105
+ raise ValueError("Unsupported checkpoint format. Use .pt or .safetensors")
106
+ return checkpoint
107
+
108
+
109
+ def initialize_models(model_name: str, checkpoint: str, config_path: str):
110
+ """Initialize all models"""
111
+ global muq_model, musicfm_model, msa_model, device
112
+
113
+ # Set device
114
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
115
+
116
+ # Load MuQ
117
+ muq_model = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
118
+ muq_model = muq_model.to(device).eval()
119
+
120
+ # Load MusicFM
121
+ musicfm_model = MusicFM25Hz(
122
+ is_flash=False,
123
+ stat_path=os.path.join(MUSICFM_HOME_PATH, "msd_stats.json"),
124
+ model_path=os.path.join(MUSICFM_HOME_PATH, "pretrained_msd.pt"),
125
+ )
126
+ musicfm_model = musicfm_model.to(device).eval()
127
+
128
+ # Load MSA model
129
+ module = importlib.import_module("models." + str(model_name))
130
+ Model = getattr(module, "Model")
131
+ hp = OmegaConf.load(os.path.join("configs", config_path))
132
+ msa_model = Model(hp)
133
+
134
+ ckpt = load_checkpoint(checkpoint_path=os.path.join("ckpts", checkpoint))
135
+ if ckpt.get("model_ema", None) is not None:
136
+ model_ema = EMA(msa_model, include_online_model=False)
137
+ model_ema.load_state_dict(ckpt["model_ema"])
138
+ msa_model.load_state_dict(model_ema.ema_model.state_dict())
139
+ else:
140
+ msa_model.load_state_dict(ckpt["model"])
141
+
142
+ msa_model.to(device).eval()
143
+
144
+ return hp
145
+
146
+
147
+ def process_audio(audio_path, win_size=420, hop_size=420, num_classes=128):
148
+ """Process audio file and return structure analysis results"""
149
+ global muq_model, musicfm_model, msa_model, device
150
+
151
+ if muq_model is None:
152
+ hp = initialize_models()
153
+ else:
154
+ hp = OmegaConf.load(os.path.join("configs", "SongFormer.yaml"))
155
+
156
+ # Load audio
157
+ wav, sr = librosa.load(audio_path, sr=INPUT_SAMPLING_RATE)
158
+ audio = torch.tensor(wav).to(device)
159
+
160
+ # Prepare output
161
+ total_len = (
162
+ (audio.shape[0] // INPUT_SAMPLING_RATE) // TIME_DUR * TIME_DUR
163
+ ) + TIME_DUR
164
+ total_frames = math.ceil(total_len * AFTER_DOWNSAMPLING_FRAME_RATES)
165
+
166
+ logits = {
167
+ "function_logits": np.zeros([total_frames, num_classes]),
168
+ "boundary_logits": np.zeros([total_frames]),
169
+ }
170
+ logits_num = {
171
+ "function_logits": np.zeros([total_frames, num_classes]),
172
+ "boundary_logits": np.zeros([total_frames]),
173
+ }
174
+
175
+ # Prepare label masks
176
+ dataset_id2label_mask = {}
177
+ for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
178
+ dataset_id2label_mask[key] = np.ones(num_classes, dtype=bool)
179
+ dataset_id2label_mask[key][allowed_ids] = False
180
+
181
+ lens = 0
182
+ i = 0
183
+
184
+ with torch.no_grad():
185
+ while True:
186
+ start_idx = i * INPUT_SAMPLING_RATE
187
+ end_idx = min((i + win_size) * INPUT_SAMPLING_RATE, audio.shape[-1])
188
+ if start_idx >= audio.shape[-1]:
189
+ break
190
+ if end_idx - start_idx <= 1024:
191
+ continue
192
+
193
+ audio_seg = audio[start_idx:end_idx]
194
+
195
+ # Get embeddings
196
+ muq_output = muq_model(audio_seg.unsqueeze(0), output_hidden_states=True)
197
+ muq_embd_420s = muq_output["hidden_states"][10]
198
+ del muq_output
199
+ torch.cuda.empty_cache()
200
+
201
+ _, musicfm_hidden_states = musicfm_model.get_predictions(
202
+ audio_seg.unsqueeze(0)
203
+ )
204
+ musicfm_embd_420s = musicfm_hidden_states[10]
205
+ del musicfm_hidden_states
206
+ torch.cuda.empty_cache()
207
+
208
+ # Process 30-second segments
209
+ wraped_muq_embd_30s = []
210
+ wraped_musicfm_embd_30s = []
211
+
212
+ for idx_30s in range(i, i + hop_size, 30):
213
+ start_idx_30s = idx_30s * INPUT_SAMPLING_RATE
214
+ end_idx_30s = min(
215
+ (idx_30s + 30) * INPUT_SAMPLING_RATE,
216
+ audio.shape[-1],
217
+ (i + hop_size) * INPUT_SAMPLING_RATE,
218
+ )
219
+ if start_idx_30s >= audio.shape[-1]:
220
+ break
221
+ if end_idx_30s - start_idx_30s <= 1024:
222
+ continue
223
+
224
+ wraped_muq_embd_30s.append(
225
+ muq_model(
226
+ audio[start_idx_30s:end_idx_30s].unsqueeze(0),
227
+ output_hidden_states=True,
228
+ )["hidden_states"][10]
229
+ )
230
+ torch.cuda.empty_cache()
231
+
232
+ wraped_musicfm_embd_30s.append(
233
+ musicfm_model.get_predictions(
234
+ audio[start_idx_30s:end_idx_30s].unsqueeze(0)
235
+ )[1][10]
236
+ )
237
+ torch.cuda.empty_cache()
238
+
239
+ if wraped_muq_embd_30s:
240
+ wraped_muq_embd_30s = torch.concatenate(wraped_muq_embd_30s, dim=1)
241
+ wraped_musicfm_embd_30s = torch.concatenate(
242
+ wraped_musicfm_embd_30s, dim=1
243
+ )
244
+
245
+ all_embds = [
246
+ wraped_musicfm_embd_30s,
247
+ wraped_muq_embd_30s,
248
+ musicfm_embd_420s,
249
+ muq_embd_420s,
250
+ ]
251
+
252
+ # Align embedding lengths
253
+ if len(all_embds) > 1:
254
+ embd_lens = [x.shape[1] for x in all_embds]
255
+ min_embd_len = min(embd_lens)
256
+ for idx in range(len(all_embds)):
257
+ all_embds[idx] = all_embds[idx][:, :min_embd_len, :]
258
+
259
+ embd = torch.concatenate(all_embds, axis=-1)
260
+
261
+ # Inference
262
+ dataset_ids = torch.Tensor(DATASET_IDS).to(device, dtype=torch.long)
263
+ msa_info, chunk_logits = msa_model.infer(
264
+ input_embeddings=embd,
265
+ dataset_ids=dataset_ids,
266
+ label_id_masks=torch.Tensor(
267
+ dataset_id2label_mask[
268
+ DATASET_LABEL_TO_DATASET_ID[DATASET_LABEL]
269
+ ]
270
+ )
271
+ .to(device, dtype=bool)
272
+ .unsqueeze(0)
273
+ .unsqueeze(0),
274
+ with_logits=True,
275
+ )
276
+
277
+ # Accumulate logits
278
+ start_frame = int(i * AFTER_DOWNSAMPLING_FRAME_RATES)
279
+ end_frame = start_frame + min(
280
+ math.ceil(hop_size * AFTER_DOWNSAMPLING_FRAME_RATES),
281
+ chunk_logits["boundary_logits"][0].shape[0],
282
+ )
283
+
284
+ logits["function_logits"][start_frame:end_frame, :] += (
285
+ chunk_logits["function_logits"][0].detach().cpu().numpy()
286
+ )
287
+ logits["boundary_logits"][start_frame:end_frame] = (
288
+ chunk_logits["boundary_logits"][0].detach().cpu().numpy()
289
+ )
290
+ logits_num["function_logits"][start_frame:end_frame, :] += 1
291
+ logits_num["boundary_logits"][start_frame:end_frame] += 1
292
+ lens += end_frame - start_frame
293
+
294
+ i += hop_size
295
+
296
+ # Average logits
297
+ logits["function_logits"] /= np.maximum(logits_num["function_logits"], 1)
298
+ logits["boundary_logits"] /= np.maximum(logits_num["boundary_logits"], 1)
299
+
300
+ logits["function_logits"] = torch.from_numpy(
301
+ logits["function_logits"][:lens]
302
+ ).unsqueeze(0)
303
+ logits["boundary_logits"] = torch.from_numpy(
304
+ logits["boundary_logits"][:lens]
305
+ ).unsqueeze(0)
306
+
307
+ # Post-process
308
+ msa_infer_output = postprocess_functional_structure(logits, hp)
309
+
310
+ return logits, msa_infer_output
311
+
312
+
313
+ def format_as_segments(msa_output):
314
+ """Format as list of segments"""
315
+ segments = []
316
+ for idx in range(len(msa_output) - 1):
317
+ segments.append(
318
+ {
319
+ "start": str(round(msa_output[idx][0], 2)),
320
+ "end": str(round(msa_output[idx + 1][0], 2)),
321
+ "label": msa_output[idx][1],
322
+ }
323
+ )
324
+ return segments
325
+
326
+
327
+ def format_as_msa(msa_output):
328
+ """Format as MSA format"""
329
+ lines = []
330
+ for time, label in msa_output:
331
+ lines.append(f"{time:.2f} {label}")
332
+ return "\n".join(lines)
333
+
334
+
335
+ def format_as_json(segments):
336
+ """Format as JSON"""
337
+ return json.dumps(segments, indent=2, ensure_ascii=False)
338
+
339
+
340
+ def create_visualization(
341
+ logits, msa_output, label_num=8, frame_rates=AFTER_DOWNSAMPLING_FRAME_RATES
342
+ ):
343
+ """Create visualization plot"""
344
+ # Assume ID_TO_LABEL mapping exists
345
+ try:
346
+ from dataset.label2id import ID_TO_LABEL
347
+ except:
348
+ ID_TO_LABEL = {i: f"Class_{i}" for i in range(128)}
349
+
350
+ function_vals = logits["function_logits"].squeeze().cpu().numpy()
351
+ boundary_vals = logits["boundary_logits"].squeeze().cpu().numpy()
352
+
353
+ top_classes = np.argsort(function_vals.mean(axis=0))[-label_num:]
354
+ T = function_vals.shape[0]
355
+ time_axis = np.arange(T) / frame_rates
356
+
357
+ fig, ax = plt.subplots(2, 1, figsize=(15, 8), sharex=True)
358
+
359
+ # Plot function logits
360
+ for cls in top_classes:
361
+ ax[1].plot(
362
+ time_axis,
363
+ function_vals[:, cls],
364
+ label=f"{ID_TO_LABEL.get(cls, f'Class_{cls}')}",
365
+ )
366
+
367
+ ax[1].set_title("Top 8 Function Logits by Mean Activation")
368
+ ax[1].set_xlabel("Time (seconds)")
369
+ ax[1].set_ylabel("Logit")
370
+ ax[1].xaxis.set_major_locator(ticker.MultipleLocator(20))
371
+ ax[1].xaxis.set_minor_locator(ticker.MultipleLocator(5))
372
+ ax[1].xaxis.set_major_formatter(ticker.FormatStrFormatter("%.1f"))
373
+ ax[1].legend()
374
+ ax[1].grid(True)
375
+
376
+ # Plot boundary logits
377
+ ax[0].plot(time_axis, boundary_vals, label="Boundary Logit", color="orange")
378
+ ax[0].set_title("Boundary Logits")
379
+ ax[0].set_ylabel("Logit")
380
+ ax[0].legend()
381
+ ax[0].grid(True)
382
+
383
+ # Add vertical lines for markers
384
+ for t_sec, label in msa_output:
385
+ for a in ax:
386
+ a.axvline(x=t_sec, color="red", linestyle="--", linewidth=0.8, alpha=0.7)
387
+ if label != "end":
388
+ ax[1].text(
389
+ t_sec + 0.3,
390
+ ax[1].get_ylim()[1] * 0.85,
391
+ label,
392
+ rotation=90,
393
+ fontsize=8,
394
+ color="red",
395
+ )
396
+
397
+ plt.suptitle("Music Structure Analysis - Logits Overview", fontsize=16)
398
+ plt.tight_layout()
399
+
400
+ return fig
401
+
402
+
403
+ def rule_post_processing(msa_list):
404
+ if len(msa_list) <= 2:
405
+ return msa_list
406
+
407
+ result = msa_list.copy()
408
+
409
+ while len(result) > 2:
410
+ first_duration = result[1][0] - result[0][0]
411
+ if first_duration < 1.0 and len(result) > 2:
412
+ result[0] = (result[0][0], result[1][1])
413
+ result = [result[0]] + result[2:]
414
+ else:
415
+ break
416
+
417
+ while len(result) > 2:
418
+ last_label_duration = result[-1][0] - result[-2][0]
419
+ if last_label_duration < 1.0:
420
+ result = result[:-2] + [result[-1]]
421
+ else:
422
+ break
423
+
424
+ while len(result) > 2:
425
+ if result[0][1] == result[1][1] and result[1][0] <= 10.0:
426
+ result = [(result[0][0], result[0][1])] + result[2:]
427
+ else:
428
+ break
429
+
430
+ while len(result) > 2:
431
+ last_duration = result[-1][0] - result[-2][0]
432
+ if result[-2][1] == result[-3][1] and last_duration <= 10.0:
433
+ result = result[:-2] + [result[-1]]
434
+ else:
435
+ break
436
+
437
+ return result
438
+
439
+
440
+ def process_and_analyze(audio_file):
441
+ """Main processing function"""
442
+
443
+ def format_time(t: float) -> str:
444
+ minutes = int(t // 60)
445
+ seconds = t % 60
446
+ return f"{minutes:02d}:{seconds:06.3f}" # 这个格式是正确的
447
+
448
+ if audio_file is None:
449
+ return None, "", "", None
450
+
451
+ try:
452
+ # Process audio
453
+ logits, msa_output = process_audio(audio_file)
454
+ # Apply rule-based post-processing, if not needed, use in cli infer
455
+ msa_output = rule_post_processing(msa_output)
456
+ # Format outputs
457
+ segments = format_as_segments(msa_output)
458
+ msa_format = format_as_msa(msa_output)
459
+ json_format = format_as_json(segments)
460
+
461
+ # Create table data
462
+ table_data = [
463
+ [
464
+ f"{float(seg['start']):.2f} ({format_time(float(seg['start']))})",
465
+ f"{float(seg['end']):.2f} ({format_time(float(seg['end']))})",
466
+ seg["label"],
467
+ ]
468
+ for seg in segments
469
+ ]
470
+
471
+ # Create visualization
472
+ fig = create_visualization(logits, msa_output)
473
+
474
+ return table_data, json_format, msa_format, fig
475
+
476
+ except Exception as e:
477
+ import traceback
478
+
479
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
480
+ print(error_msg) # 在命令行输出完整错误
481
+ return None, "", error_msg, None
482
+
483
+
484
+ # Create Gradio interface
485
+ with gr.Blocks(
486
+ title="Music Structure Analysis",
487
+ css="""
488
+ .logo-container {
489
+ text-align: center;
490
+ margin-bottom: 20px;
491
+ }
492
+ .links-container {
493
+ display: flex;
494
+ justify-content: center;
495
+ column-gap: 10px;
496
+ margin-bottom: 10px;
497
+ }
498
+ .model-title {
499
+ text-align: center;
500
+ font-size: 24px;
501
+ font-weight: bold;
502
+ margin-bottom: 30px;
503
+ }
504
+ """,
505
+ ) as demo:
506
+ # Top Logo
507
+ gr.HTML("""
508
+ <div style="display: flex; justify-content: center; align-items: center;">
509
+ <img src="https://raw.githubusercontent.com/ASLP-lab/SongFormer/refs/heads/main/figs/logo.png" style="max-width: 300px; height: auto;" />
510
+ </div>
511
+ """)
512
+
513
+ # Model title
514
+ gr.HTML("""
515
+ <div class="model-title">
516
+ SongFormer: Scaling Music Structure Analysis with Heterogeneous Supervision
517
+ </div>
518
+ """)
519
+
520
+ # Links
521
+ gr.HTML("""
522
+ <div class="links-container">
523
+ <img src="https://img.shields.io/badge/Python-3.10-brightgreen" alt="Python">
524
+ <img src="https://img.shields.io/badge/License-CC%20BY%204.0-lightblue" alt="License">
525
+ <a href="https://arxiv.org/abs/">
526
+ <img src="https://img.shields.io/badge/arXiv-com.svg?logo=arXiv" alt="arXiv">
527
+ </a>
528
+ <a href="https://github.com/ASLP-lab/SongFormer">
529
+ <img src="https://img.shields.io/badge/GitHub-SongFormer-black" alt="GitHub">
530
+ </a>
531
+ <a href="https://huggingface.co/spaces/ASLP-lab/SongFormer">
532
+ <img src="https://img.shields.io/badge/HuggingFace-space-yellow" alt="HuggingFace Space">
533
+ </a>
534
+ <a href="https://huggingface.co/ASLP-lab/SongFormer">
535
+ <img src="https://img.shields.io/badge/HuggingFace-model-blue" alt="HuggingFace Model">
536
+ </a>
537
+ <a href="https://huggingface.co/datasets/ASLP-lab/SongFormDB">
538
+ <img src="https://img.shields.io/badge/HF%20Dataset-SongFormDB-green" alt="Dataset SongFormDB">
539
+ </a>
540
+ <a href="https://huggingface.co/datasets/ASLP-lab/SongFormBench">
541
+ <img src="https://img.shields.io/badge/HF%20Dataset-SongFormBench-orange" alt="Dataset SongFormBench">
542
+ </a>
543
+ <a href="https://discord.gg/rwcqh7Em">
544
+ <img src="https://img.shields.io/badge/Discord-join%20us-purple?logo=discord&logoColor=white" alt="Discord">
545
+ </a>
546
+ <a href="http://www.npu-aslp.org/">
547
+ <img src="https://img.shields.io/badge/🏫-ASLP-grey?labelColor=lightgrey" alt="ASLP">
548
+ </a>
549
+ </div>
550
+ """)
551
+
552
+ # Main input area
553
+ with gr.Row():
554
+ with gr.Column(scale=3):
555
+ audio_input = gr.Audio(
556
+ label="Upload Audio File", type="filepath", elem_id="audio-input"
557
+ )
558
+
559
+ with gr.Column(scale=1):
560
+ gr.Markdown("### 📌 Examples")
561
+ gr.Examples(
562
+ examples=[
563
+ # Add your example audio file paths
564
+ # ["example1.mp3"],
565
+ # ["example2.mp3"],
566
+ ],
567
+ inputs=[audio_input],
568
+ label="Click to load example",
569
+ )
570
+
571
+ # Analyze button
572
+ with gr.Row():
573
+ analyze_btn = gr.Button(
574
+ "🚀 Analyze Music Structure", variant="primary", scale=1
575
+ )
576
+
577
+ # Results display area
578
+ with gr.Row():
579
+ with gr.Column(scale=13):
580
+ segments_table = gr.Dataframe(
581
+ headers=["Start / s (m:s.ms)", "End / s (m:s.ms)", "Label"],
582
+ label="Detected Music Segments",
583
+ interactive=False,
584
+ elem_id="result-table",
585
+ )
586
+ with gr.Column(scale=8):
587
+ with gr.Row():
588
+ with gr.Accordion("📄 JSON Output", open=False):
589
+ json_output = gr.Textbox(
590
+ label="JSON Format",
591
+ lines=15,
592
+ max_lines=20,
593
+ interactive=False,
594
+ show_copy_button=True,
595
+ )
596
+ with gr.Row():
597
+ with gr.Accordion("📋 MSA Text Output", open=False):
598
+ msa_output = gr.Textbox(
599
+ label="MSA Format",
600
+ lines=15,
601
+ max_lines=20,
602
+ interactive=False,
603
+ show_copy_button=True,
604
+ )
605
+
606
+ # Visualization plot
607
+ with gr.Row():
608
+ plot_output = gr.Plot(label="Activation Curves Visualization")
609
+
610
+ gr.HTML("""
611
+ <div style="display: flex; justify-content: center; align-items: center;">
612
+ <img src="https://raw.githubusercontent.com/ASLP-lab/SongFormer/refs/heads/main/figs/aslp.png" style="max-width: 300px; height: auto;" />
613
+ </div>
614
+ """)
615
+
616
+ # Set event handlers
617
+ analyze_btn.click(
618
+ fn=process_and_analyze,
619
+ inputs=[audio_input],
620
+ outputs=[segments_table, json_output, msa_output, plot_output],
621
+ )
622
+
623
+ if __name__ == "__main__":
624
+ # Download pretrained models if not exist
625
+ download_all(use_mirror=False)
626
+ # Initialize models
627
+ print("Initializing models...")
628
+ initialize_models(
629
+ model_name="SongFormer",
630
+ checkpoint="SongFormer.safetensors",
631
+ config_path="SongFormer.yaml",
632
+ )
633
+ print("Models loaded successfully!")
634
+
635
+ # Launch interface
636
+ demo.launch(server_name="127.0.0.1", server_port=7891, debug=True)
requirements.txt ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Deep Learning Framework
2
+ torch==2.4.0
3
+ torchaudio==2.4.0
4
+ lightning==2.5.1.post0
5
+
6
+ # ML/DL Libraries
7
+ transformers==4.51.1
8
+ accelerate==1.5.2
9
+ datasets==3.6.0
10
+ tokenizers==0.21.1
11
+ huggingface-hub==0.30.1
12
+ safetensors==0.5.3
13
+
14
+ # Scientific Computing
15
+ numpy==1.25.0
16
+ scipy==1.15.2
17
+ scikit-learn==1.6.1
18
+ pandas==2.2.3
19
+
20
+ # Audio Processing
21
+ librosa==0.11.0
22
+ audioread==3.0.1
23
+ soundfile==0.13.1
24
+ pesq==0.0.4
25
+ auraloss==0.4.0
26
+ nnAudio==0.3.3
27
+ julius==0.2.7
28
+ soxr==0.5.0.post1
29
+ mir_eval==0.8.2
30
+ jams==0.3.4
31
+ msaf==0.1.80
32
+
33
+ # Visualization & Monitoring
34
+ matplotlib==3.10.1
35
+ seaborn==0.13.2
36
+ tensorboard==2.19.0
37
+ wandb==0.19.8
38
+ gpustat==1.1.1
39
+
40
+ # Configuration & CLI
41
+ hydra-core==1.3.2
42
+ omegaconf==2.3.0
43
+ fire==0.7.1
44
+ click==8.1.8
45
+
46
+ # Deep Learning Utils
47
+ einops==0.8.1
48
+ einx==0.3.0
49
+ x-transformers==2.4.14
50
+ x-clip==0.14.4
51
+ ema-pytorch==0.7.7
52
+ schedulefree==1.4.1
53
+ torchmetrics==1.7.1
54
+
55
+ # Data Processing
56
+ h5py==3.13.0
57
+ pyarrow==19.0.1
58
+ pillow==11.1.0
59
+
60
+ # Text Processing
61
+ ftfy==6.3.1
62
+ regex==2024.11.6
63
+ pypinyin==0.54.0
64
+ textgrid==1.6.1
65
+ pylrc==0.1.2
66
+
67
+ # Model Management
68
+ modelscope==1.27.1
69
+
70
+ # Utilities
71
+ tqdm==4.67.1
72
+ loguru==0.7.3
73
+ joblib==1.4.2
74
+ easydict==1.13
75
+ addict==2.4.0
76
+ beartype==0.21.0
77
+
78
+ # Others
79
+ triton==3.0.0
80
+ muq==0.1.0
81
+ vmo==0.30.5
82
+
83
+ # others
84
+ gradio
85
+ einops
86
+ beartype
src/SongFormer/ckpts/md5sum.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ df930aceac8209818556c4a656a0714c MusicFM/pretrained_msd.pt
2
+ 75ab2e47b093e07378f7f703bdb82c14 MusicFM/msd_stats.json
3
+ 5a24800e12ab357744f8b47e523ba3e6 SongFormer.safetensors
4
+ 2c66c0bb91364e318e90dbc2d9a79ee2 _SongFormer.pt
src/SongFormer/configs/SongFormer.yaml ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================
2
+ # Model Configuration
3
+ # ============================
4
+
5
+ input_dim_raw: 4096 # Downsampled Fused SSL Representation Dimension
6
+ input_dim: 2048 # Input Dimension after Linear Layer
7
+
8
+ # Downsampling Module
9
+ down_sample_conv_kernel_size: 3
10
+ down_sample_conv_stride: 3
11
+ down_sample_conv_dropout: 0.1
12
+ down_sample_conv_padding: 0
13
+
14
+ # Transformer Module
15
+ transformer_encoder_input_dim: 1024
16
+ transformer_input_dim: 512
17
+ num_transformer_layers: 4
18
+ transformer_nhead: 8
19
+ transformer_dropout: 0.1
20
+
21
+ # task-specific heads
22
+ boundary_head_hidden_dims: [128, 64, 8]
23
+ function_head_hidden_dims: []
24
+
25
+ num_classes: 128
26
+ num_dataset_classes: 64
27
+
28
+ # scheduler
29
+ warmup_steps: 300
30
+ total_steps: 12010
31
+ warmup_max_lr: 0.0001
32
+
33
+ # frame rates after downsampling
34
+ output_logits_frame_rates: 8.333
35
+ # it means output_logits_frame_rates = input_embd_frame_rates // downsample_rates, because the padding is 0.
36
+ downsample_rates: 3
37
+ # frame rates after downsampling, used by model and post process
38
+ frame_rates: 8.333
39
+
40
+ # ema config
41
+ ema_kwargs:
42
+ {update_after_step: 200}
43
+
44
+ # ============================
45
+ # Loss Functions configuration
46
+ # ============================
47
+
48
+ # Focal loss
49
+ label_focal_loss_weight: 0.2
50
+
51
+ label_focal_loss_alpha: 0.25
52
+ label_focal_loss_gamma: 2.0
53
+
54
+ # Boundary TV loss
55
+ boundary_tvloss_weight: 0.05
56
+
57
+ boundary_tv_loss_beta: 0.6
58
+ boundary_tv_loss_lambda: 0.4
59
+ boundary_tv_loss_boundary_threshold: 0.01
60
+ boundary_tv_loss_reduction_weight: 0.1
61
+
62
+ loss_weight_section: 0.2
63
+ loss_weight_function: 0.8
64
+
65
+ # ============================
66
+ # Training config
67
+ # ============================
68
+
69
+ # Number of neighbors used to augment boundaries in the dataset.
70
+ # Example: 1/25*3 * 10s = 1.2s (both sides total 4.2s)
71
+ num_neighbors: 10
72
+ learn_label: true
73
+ learn_segment: true
74
+ accumulation_steps: 2
75
+ slice_dur: 420
76
+ early_stopping_step: 3
77
+ local_maxima_filter_size: 3
78
+
79
+ # ============================
80
+ # Dataset config
81
+ # ============================
82
+
83
+ train_dataset:
84
+ _target_: dataset.SongFormerDataset.Dataset
85
+ dataset_abstracts:
86
+ [
87
+ {
88
+ "internal_tmp_id": "SongForm-HX-8Class",
89
+ "dataset_type": "SongForm-HX-8Class",
90
+ "input_embedding_dir": "your_data_dir/30s_420s/harmonix/musicfm_hop420/layer_10 your_data_dir/30s_420s/harmonix/muq_hop420/layer_10 your_data_dir/420s/harmonix/musicfm_hop420/layer_10 your_data_dir/420s/harmonix/muq_hop420/layer_10",
91
+ "label_path": "your_data_dir/labels/harmonixset_8class_rule_revision.jsonl",
92
+ "split_ids_path": "your_data_dir/separated_ids/harmonixset_separated_ids_with_val_set/train.txt",
93
+ "multiplier": 4,
94
+ },
95
+ {
96
+ "internal_tmp_id": "SongForm-Private",
97
+ "dataset_type": "SongForm-Private",
98
+ "input_embedding_dir": "your_data_dir/30s_420s/Internal_data/musicfm_hop420/layer_10 your_data_dir/30s_420s/Internal_data/muq_hop420/layer_10 your_data_dir/420s/Internal_data/musicfm_hop420/layer_10 your_data_dir/420s/Internal_data/muq_hop420/layer_10",
99
+ "label_path": "your_data_dir/labels/0006_single_layer_transformer_musicfm_muq_along_time_00_5k_v1.jsonl",
100
+ "split_ids_path": "your_data_dir/separated_ids/internal_data_sofa_clean/train.txt",
101
+ "multiplier": 1,
102
+ },
103
+ {
104
+ adapter: HookTheoryAdapter,
105
+ internal_tmp_id: "SongForm-Hook",
106
+ structure_jsonl_paths: [
107
+ "your_data_dir/HookTheoryStructure.train.jsonl"
108
+ ],
109
+ dataset_type: "SongForm-Hook",
110
+ input_embedding_dir: "your_data_dir/30s_420s/HookTheory/musicfm_hop420/layer_10 your_data_dir/30s_420s/HookTheory/muq_hop420/layer_10 your_data_dir/420s/HookTheory/musicfm_hop420/layer_10 your_data_dir/420s/HookTheory/muq_hop420/layer_10",
111
+ split_ids_path: "your_data_dir/separated_ids/hooktheory_separated_ids/train.txt",
112
+ multiplier: 1,
113
+ },
114
+ ]
115
+ hparams:
116
+ output_logits_frame_rates: ${output_logits_frame_rates}
117
+ downsample_rates: ${downsample_rates}
118
+ num_neighbors: ${num_neighbors}
119
+ input_dim: ${input_dim_raw}
120
+ slice_dur: ${slice_dur}
121
+ num_classes: ${num_classes}
122
+ frame_rates: ${frame_rates}
123
+
124
+ eval_dataset:
125
+ _target_: dataset.SongFormerDataset.Dataset
126
+ dataset_abstracts:
127
+ [
128
+ {
129
+ "internal_tmp_id": "SongForm-HX-8Classs_val",
130
+ "dataset_type": "SongForm-HX-8Class",
131
+ "input_embedding_dir": "your_data_dir/30s_420s/harmonix/musicfm_hop420/layer_10 your_data_dir/30s_420s/harmonix/muq_hop420/layer_10 your_data_dir/420s/harmonix/musicfm_hop420/layer_10 your_data_dir/420s/harmonix/muq_hop420/layer_10",
132
+ "label_path": "your_data_dir/processed_data/labels/harmonixset_8class_rule_revision.jsonl",
133
+ "split_ids_path": "your_data_dir/separated_ids/harmonixset_separated_ids_with_val_set/val.txt",
134
+ "multiplier": 1,
135
+ },
136
+ ]
137
+ hparams:
138
+ output_logits_frame_rates: ${output_logits_frame_rates}
139
+ downsample_rates: ${downsample_rates}
140
+ num_neighbors: ${num_neighbors}
141
+ input_dim: ${input_dim_raw}
142
+ slice_dur: ${slice_dur}
143
+ num_classes: ${num_classes}
144
+ frame_rates: ${frame_rates}
145
+
146
+ # ============================
147
+ # DataLoader configuration
148
+ # ============================
149
+
150
+ train_dataloader:
151
+ num_workers: 4
152
+ batch_size: 4
153
+ pin_memory: True
154
+ prefetch_factor: 4
155
+ drop_last: True
156
+ persistent_workers: True
157
+ shuffle: true
158
+
159
+ eval_dataloader:
160
+ num_workers: 0
161
+ batch_size: 1
162
+ shuffle: false
163
+
164
+ # ============================
165
+ # Optimizer configuration
166
+ # ============================
167
+
168
+ optimizer:
169
+ lr: ${warmup_max_lr}
170
+ betas: [0.8, 0.999]
171
+ eps: 1e-08
172
+ weight_decay: 3e-7
173
+
174
+ # ============================
175
+ # Training Run configuration
176
+ # ============================
177
+
178
+ args:
179
+ run_name: SongFormer
180
+ model_name: SongFormer
181
+ save_interval: 800
182
+ eval_interval: 800
183
+ checkpoint_dir: output/SongFormer
184
+ max_epochs: 1000
185
+ max_steps: 12010
186
+ tags: null
src/SongFormer/dataset/DatasetAdaper.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class DatasetAdapter(ABC):
5
+ """
6
+ Abstract base class for dataset adapters.
7
+ """
8
+
9
+ @abstractmethod
10
+ def __init__(self, *args, **kwargs):
11
+ """
12
+ Initialize the dataset adapter with necessary parameters.
13
+ """
14
+ raise NotImplementedError("Subclasses must implement the __init__ method.")
15
+
16
+ @abstractmethod
17
+ def get_ids(self):
18
+ """
19
+ Get the IDs of the dataset.
20
+ This method should be implemented by subclasses.
21
+
22
+ Returns:
23
+ A list or set of IDs representing the dataset. In format: ID + start_time
24
+ must cosider the split of dataset, e.g. train, val, test.
25
+ """
26
+ raise NotImplementedError("Subclasses must implement this method.")
27
+
28
+ @abstractmethod
29
+ def get_item_json(self, *args, **kwargs):
30
+ """
31
+ Get the item JSON representation from the dataset.
32
+ """
33
+ raise NotImplementedError("Subclasses must implement this method.")
src/SongFormer/dataset/GeminiOnlyLabelAdapter.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1. It was found that the annotations generated by Gemini are discontinuous between segments
2
+ # (possibly differing by more than 1.7 seconds, accounting for approximately 1/4 to 1/3 of the cases).
3
+ # 2. Gemini's labels can compete with our SOTA model, but Gemini's boundary metrics are very poor.
4
+ # With a tolerance of 3 seconds, they are similar to the metrics of our best model.
5
+ import pdb
6
+ import random
7
+ import os
8
+ from collections import defaultdict
9
+ from pathlib import Path
10
+ import json
11
+ from venv import logger
12
+ import numpy as np
13
+ import math
14
+ from .label2id import (
15
+ DATASET_ID_ALLOWED_LABEL_IDS,
16
+ DATASET_LABEL_TO_DATASET_ID,
17
+ ID_TO_LABEL,
18
+ LABEL_TO_ID,
19
+ )
20
+ from argparse import Namespace
21
+ from scipy.ndimage import gaussian_filter1d
22
+ from .DatasetAdaper import DatasetAdapter
23
+ from omegaconf import ListConfig
24
+ import copy
25
+
26
+
27
+ # Adapter for datasets labeled only by Gemini
28
+ class GeminiOnlyLabelAdapter(DatasetAdapter):
29
+ def __init__(self, **kwargs):
30
+ (
31
+ label_paths,
32
+ hparams,
33
+ internal_tmp_id,
34
+ dataset_type,
35
+ input_embedding_dir,
36
+ split_ids_path,
37
+ ) = (
38
+ kwargs["label_paths"],
39
+ kwargs["hparams"],
40
+ kwargs["internal_tmp_id"],
41
+ kwargs["dataset_type"],
42
+ kwargs["input_embedding_dir"],
43
+ kwargs["split_ids_path"],
44
+ )
45
+ self.frame_rates = hparams.frame_rates
46
+ self.hparams = hparams
47
+ self.label_to_id = LABEL_TO_ID
48
+ self.dataset_id_to_dataset_id = DATASET_LABEL_TO_DATASET_ID
49
+ self.id_to_label = ID_TO_LABEL
50
+ self.internal_tmp_id = internal_tmp_id
51
+ self.dataset_type = dataset_type
52
+ self.EPS = 1e-6
53
+ self.dataset_id2label_mask = {}
54
+ for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
55
+ self.dataset_id2label_mask[key] = np.ones(
56
+ self.hparams.num_classes, dtype=bool
57
+ )
58
+ self.dataset_id2label_mask[key][allowed_ids] = False
59
+
60
+ self.id2segments = {}
61
+ data = self.load_jsonl(label_paths)
62
+
63
+ self.input_embedding_dir = input_embedding_dir
64
+ all_input_embedding_dirs = input_embedding_dir.split()
65
+
66
+ valid_data_ids = self.get_ids_from_dir(all_input_embedding_dirs[0])
67
+
68
+ for x in all_input_embedding_dirs:
69
+ valid_data_ids = valid_data_ids.intersection(self.get_ids_from_dir(x))
70
+ split_ids = []
71
+ with open(split_ids_path) as f:
72
+ for line in f:
73
+ if not line.strip():
74
+ continue
75
+ split_ids.append(line.strip())
76
+ split_ids = set(split_ids)
77
+
78
+ valid_data_ids = [
79
+ x for x in valid_data_ids if "_".join(x.split("_")[:-1]) in split_ids
80
+ ]
81
+ valid_data_ids = [
82
+ (internal_tmp_id, dataset_type, x, "HookTheoryAdapter")
83
+ for x in valid_data_ids
84
+ ]
85
+ self.valid_data_ids = valid_data_ids
86
+ rng = random.Random(42)
87
+ rng.shuffle(self.valid_data_ids)
88
+ for item in data:
89
+ self.id2segments[item["data_id"]] = item["msa_info"]
90
+
91
+ def get_ids_from_dir(self, dir_path: str):
92
+ ids = os.listdir(dir_path)
93
+ ids = [Path(x).stem for x in ids if x.endswith(".npy")]
94
+ return set(ids)
95
+
96
+ def time2frame(self, this_time):
97
+ return int(this_time * self.frame_rates)
98
+
99
+ def load_jsonl(self, paths):
100
+ data = []
101
+ for path in paths:
102
+ with open(path, "r", encoding="utf-8") as f:
103
+ for line in f:
104
+ line = line.strip()
105
+ if not line:
106
+ continue
107
+ obj = json.loads(line)
108
+ data.append(obj)
109
+ return data
110
+
111
+ def get_ids(self):
112
+ return list(self.valid_data_ids)
113
+
114
+ def widen_temporal_events(self, events, num_neighbors):
115
+ def theoretical_gaussian_max(sigma):
116
+ return 1 / (np.sqrt(2 * np.pi) * sigma)
117
+
118
+ widen_events = events
119
+ sigma = num_neighbors / 3.0
120
+ smoothed = gaussian_filter1d(widen_events.astype(float), sigma=sigma)
121
+ smoothed /= theoretical_gaussian_max(sigma)
122
+ smoothed = np.clip(smoothed, 0, 1)
123
+
124
+ return smoothed
125
+
126
+ def get_item_json(self, utt, start_time, end_time):
127
+ embd_list = []
128
+ embd_dirs = self.input_embedding_dir.split()
129
+ for embd_dir in embd_dirs:
130
+ if not Path(embd_dir).exists():
131
+ raise FileNotFoundError(
132
+ f"Embedding directory {embd_dir} does not exist"
133
+ )
134
+ tmp = np.load(Path(embd_dir) / f"{utt}.npy").squeeze(axis=0)
135
+ embd_list.append(tmp)
136
+
137
+ # Check that max and min lengths of all representations differ by at most 2
138
+ if len(embd_list) > 1:
139
+ embd_shapes = [x.shape for x in embd_list]
140
+ max_shape = max(embd_shapes, key=lambda x: x[0])
141
+ min_shape = min(embd_shapes, key=lambda x: x[0])
142
+ if abs(max_shape[0] - min_shape[0]) > 2:
143
+ raise ValueError(
144
+ f"Embedding shapes differ too much: {max_shape} vs {min_shape}"
145
+ )
146
+
147
+ for idx in range(len(embd_list)):
148
+ embd_list[idx] = embd_list[idx][: min_shape[0], :]
149
+
150
+ input_embedding = np.concatenate(embd_list, axis=-1)
151
+
152
+ return_json = self._get_item_json_without_embedding(
153
+ "_".join(utt.split("_")[:-1]), start_time, end_time
154
+ )
155
+
156
+ if return_json is None:
157
+ logger.warning(
158
+ f"Skip {utt} because no valid segments found in {start_time} to {end_time}."
159
+ )
160
+ return None
161
+ else:
162
+ return_json["input_embedding"] = input_embedding
163
+ return return_json
164
+
165
+ def get_local_times_labels(self, utt):
166
+ assert utt in self.id2segments, f"utt {utt} not found in id2segments"
167
+ time_datas = [x[0] for x in self.id2segments[utt]]
168
+ time_datas = list(map(float, time_datas))
169
+ label_datas = [
170
+ -1 if x[1] == "end" else self.label_to_id[x[1]]
171
+ for x in self.id2segments[utt]
172
+ ]
173
+ return np.array(time_datas), label_datas
174
+
175
+ def _get_item_json_without_embedding(self, utt, start_time, end_time):
176
+ SLICE_DUR = int(math.ceil(end_time - start_time))
177
+
178
+ local_times, local_labels = self.get_local_times_labels(utt)
179
+
180
+ local_times, local_labels = (
181
+ copy.deepcopy(local_times),
182
+ copy.deepcopy(local_labels),
183
+ )
184
+
185
+ assert np.all(local_times[:-1] < local_times[1:]), (
186
+ f"time must be sorted, but {utt} is {local_times}"
187
+ )
188
+
189
+ local_times = local_times - start_time
190
+
191
+ time_L = max(0.0, float(local_times.min()))
192
+ time_R = min(float(SLICE_DUR), float(local_times.max()))
193
+ # Note whether boundary labels are reachable
194
+ keep_boundarys = (time_L + self.EPS < local_times) & (
195
+ local_times < time_R - self.EPS
196
+ )
197
+
198
+ # If no valid boundaries, return None
199
+ if keep_boundarys.sum() <= 0:
200
+ return None
201
+
202
+ mask = np.ones([int(SLICE_DUR * self.frame_rates)], dtype=bool)
203
+ mask[self.time2frame(time_L) : self.time2frame(time_R)] = False
204
+
205
+ true_boundary = np.zeros([int(SLICE_DUR * self.frame_rates)], dtype=float)
206
+ for idx in np.flatnonzero(keep_boundarys):
207
+ true_boundary[self.time2frame(local_times[idx])] = 1
208
+
209
+ true_function = np.zeros(
210
+ [int(SLICE_DUR * self.frame_rates), self.hparams.num_classes],
211
+ dtype=float,
212
+ )
213
+ true_function_list = []
214
+ msa_info = []
215
+ last_pos = self.time2frame(time_L)
216
+ for idx in np.flatnonzero(keep_boundarys):
217
+
218
+ true_function[
219
+ last_pos : self.time2frame(local_times[idx]),
220
+ int(local_labels[idx - 1]),
221
+ ] = 1
222
+ true_function_list.append(
223
+ [int(x) for x in local_labels[idx - 1]]
224
+ if isinstance(local_labels[idx - 1], list)
225
+ else int(local_labels[idx - 1])
226
+ )
227
+ msa_info.append(
228
+ (
229
+ float(max(local_times[idx - 1], time_L)),
230
+ [str(self.id_to_label[int(x)]) for x in local_labels[idx - 1]]
231
+ if isinstance(local_labels[idx - 1], list)
232
+ else str(self.id_to_label[int(local_labels[idx - 1])]),
233
+ )
234
+ )
235
+ last_pos = self.time2frame(local_times[idx])
236
+
237
+ # Check last label correctness
238
+ true_function[
239
+ last_pos : self.time2frame(time_R),
240
+ local_labels[int(np.flatnonzero(keep_boundarys)[-1])],
241
+ ] = 1
242
+ true_function_list.append(
243
+ [int(x) for x in local_labels[int(np.flatnonzero(keep_boundarys)[-1])]]
244
+ if isinstance(local_labels[int(np.flatnonzero(keep_boundarys)[-1])], list)
245
+ else int(local_labels[int(np.flatnonzero(keep_boundarys)[-1])])
246
+ )
247
+
248
+ msa_info.append(
249
+ (
250
+ float(local_times[int(np.flatnonzero(keep_boundarys)[-1])]),
251
+ [
252
+ str(self.id_to_label[int(x)])
253
+ for x in local_labels[int(np.flatnonzero(keep_boundarys)[-1])]
254
+ ]
255
+ if isinstance(
256
+ local_labels[int(np.flatnonzero(keep_boundarys)[-1])], list
257
+ )
258
+ else str(
259
+ self.id_to_label[
260
+ int(local_labels[int(np.flatnonzero(keep_boundarys)[-1])])
261
+ ]
262
+ ),
263
+ )
264
+ )
265
+ # Append final label at end; decide if it's necessary
266
+ msa_info.append((float(time_R), "end"))
267
+
268
+ # Add boundary_mask & function_mask
269
+ frame_len = int(SLICE_DUR * self.frame_rates)
270
+ # During loss computation, boundaries are fully masked
271
+ boundary_mask = np.ones([frame_len], dtype=bool)
272
+ function_mask = np.zeros([frame_len], dtype=bool)
273
+
274
+ # Set masks according to msa_info
275
+ for i in range(len(msa_info) - 1):
276
+ seg_start, seg_label = msa_info[i]
277
+ seg_end, _ = msa_info[i + 1]
278
+ start_frame = self.time2frame(seg_start)
279
+ end_frame = self.time2frame(seg_end)
280
+
281
+ # Handle case where label may be string or list
282
+ is_no_label = (
283
+ seg_label == "NO_LABEL"
284
+ if isinstance(seg_label, str)
285
+ else "NO_LABEL" in seg_label
286
+ )
287
+
288
+ if is_no_label:
289
+ # function_mask set True
290
+ function_mask[start_frame:end_frame] = True
291
+
292
+ # ------~~------------
293
+ # During loss computation, boundaries are fully masked
294
+ boundary_mask = np.ones([frame_len], dtype=bool)
295
+ function_mask = np.zeros([frame_len], dtype=bool)
296
+
297
+ # Set masks according to msa_info
298
+ for i in range(len(msa_info) - 1):
299
+ seg_start, seg_label = msa_info[i]
300
+ seg_end, _ = msa_info[i + 1]
301
+ start_frame = self.time2frame(seg_start)
302
+ end_frame = self.time2frame(seg_end)
303
+
304
+ # Handle case where label may be string or list
305
+ is_no_label = (
306
+ seg_label == "NO_LABEL"
307
+ if isinstance(seg_label, str)
308
+ else "NO_LABEL" in seg_label
309
+ )
310
+
311
+ if is_no_label:
312
+ # function_mask set True
313
+ function_mask[start_frame:end_frame] = True
314
+
315
+ # return all things except for input_embedding
316
+ return {
317
+ "data_id": self.internal_tmp_id + "_" + f"{utt}_{start_time}",
318
+ "mask": mask,
319
+ "true_boundary": true_boundary,
320
+ "widen_true_boundary": self.widen_temporal_events(
321
+ true_boundary, num_neighbors=self.hparams.num_neighbors
322
+ ),
323
+ "true_function": true_function,
324
+ "true_function_list": true_function_list,
325
+ "msa_info": msa_info,
326
+ "dataset_id": self.dataset_id_to_dataset_id[self.dataset_type],
327
+ "label_id_mask": self.dataset_id2label_mask[
328
+ self.dataset_id_to_dataset_id[self.dataset_type]
329
+ ],
330
+ "boundary_mask": boundary_mask, # Only effective during loss calculation
331
+ "function_mask": function_mask, # Only effective during loss calculation
332
+ }
src/SongFormer/dataset/HookTheoryAdapter.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import os
3
+ from collections import defaultdict
4
+ from pathlib import Path
5
+ import json
6
+ import numpy as np
7
+ import math
8
+ from .label2id import (
9
+ DATASET_ID_ALLOWED_LABEL_IDS,
10
+ DATASET_LABEL_TO_DATASET_ID,
11
+ ID_TO_LABEL,
12
+ LABEL_TO_ID,
13
+ )
14
+ from argparse import Namespace
15
+ from scipy.ndimage import gaussian_filter1d
16
+ from .DatasetAdaper import DatasetAdapter
17
+ from omegaconf import ListConfig
18
+
19
+
20
+ class HookTheoryAdapter(DatasetAdapter):
21
+ def __init__(self, **kwargs):
22
+ (
23
+ structure_jsonl_paths,
24
+ hparams,
25
+ internal_tmp_id,
26
+ dataset_type,
27
+ input_embedding_dir,
28
+ split_ids_path,
29
+ ) = (
30
+ kwargs["structure_jsonl_paths"],
31
+ kwargs["hparams"],
32
+ kwargs["internal_tmp_id"],
33
+ kwargs["dataset_type"],
34
+ kwargs.get("input_embedding_dir", None),
35
+ kwargs.get("split_ids_path", None),
36
+ )
37
+
38
+ # basic attrs
39
+ self.frame_rates = hparams.frame_rates
40
+ self.hparams = hparams
41
+ self.label_to_id = LABEL_TO_ID
42
+ self.dataset_id_to_dataset_id = DATASET_LABEL_TO_DATASET_ID
43
+ self.id_to_label = ID_TO_LABEL
44
+ self.internal_tmp_id = internal_tmp_id
45
+ self.dataset_type = dataset_type
46
+ self.EPS = 1e-6
47
+
48
+ # build dataset-specific label mask
49
+ self.dataset_id2label_mask = {}
50
+ for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
51
+ self.dataset_id2label_mask[key] = np.ones(
52
+ self.hparams.num_classes, dtype=bool
53
+ )
54
+ self.dataset_id2label_mask[key][allowed_ids] = False
55
+
56
+ assert isinstance(structure_jsonl_paths, (ListConfig, tuple, list))
57
+
58
+ # load segments per audio id
59
+ self.id2segments = defaultdict(list)
60
+ data = self.load_jsonl(structure_jsonl_paths)
61
+
62
+ # input embedding dirs (space-separated)
63
+ self.input_embedding_dir = input_embedding_dir
64
+ all_input_embedding_dirs = input_embedding_dir.split()
65
+
66
+ # get valid ids that exist in all embedding dirs
67
+ valid_data_ids = self.get_ids_from_dir(all_input_embedding_dirs[0])
68
+ for x in all_input_embedding_dirs:
69
+ valid_data_ids = valid_data_ids.intersection(self.get_ids_from_dir(x))
70
+
71
+ # read split ids
72
+ split_ids = []
73
+ with open(split_ids_path) as f:
74
+ for line in f:
75
+ if not line.strip():
76
+ continue
77
+ split_ids.append(line.strip())
78
+ split_ids = set(split_ids)
79
+
80
+ # filter valid ids by split
81
+ valid_data_ids = [
82
+ x for x in valid_data_ids if "_".join(x.split("_")[:-1]) in split_ids
83
+ ]
84
+ valid_data_ids = [
85
+ (internal_tmp_id, dataset_type, x, "HookTheoryAdapter")
86
+ for x in valid_data_ids
87
+ ]
88
+ self.valid_data_ids = valid_data_ids
89
+
90
+ rng = random.Random(42)
91
+ rng.shuffle(self.valid_data_ids)
92
+
93
+ for item in data:
94
+ self.id2segments[Path(item["ori_audio_path"]).stem].append(item)
95
+ # logger.info(f"load {len(self.id2segments)} songs from {structure_jsonl_paths}")
96
+
97
+ def get_ids_from_dir(self, dir_path: str):
98
+ ids = os.listdir(dir_path)
99
+ ids = [Path(x).stem for x in ids if x.endswith(".npy")]
100
+ return set(ids)
101
+
102
+ def time2frame(self, this_time):
103
+ # convert time (s) to frame index
104
+ return int(this_time * self.frame_rates)
105
+
106
+ def load_jsonl(self, paths):
107
+ # load list of jsonl files
108
+ data = []
109
+ for path in paths:
110
+ with open(path, "r", encoding="utf-8") as f:
111
+ for line in f:
112
+ line = line.strip()
113
+ if not line:
114
+ continue
115
+ obj = json.loads(line)
116
+ data.append(obj)
117
+ return data
118
+
119
+ def split_and_label(self, query_start, query_end, segments):
120
+ """
121
+ segments: List of dicts, each with keys: "segment_start", "segment_end", 'label'
122
+ """
123
+ # Step 1: collect all boundary points (only within query interval)
124
+ points = set([query_start, query_end])
125
+ for seg in segments:
126
+ if query_start <= seg["segment_start"] <= query_end:
127
+ points.add(seg["segment_start"])
128
+ if query_start <= seg["segment_end"] <= query_end:
129
+ points.add(seg["segment_end"])
130
+ sorted_points = sorted(points)
131
+
132
+ result = []
133
+ # Step 2: for each small interval, check which segments cover it
134
+ for i in range(len(sorted_points) - 1):
135
+ part_start = sorted_points[i]
136
+ part_end = sorted_points[i + 1]
137
+ labels = []
138
+ for seg in segments:
139
+ if (
140
+ seg["segment_start"] <= part_start
141
+ and seg["segment_end"] >= part_end
142
+ ):
143
+ labels.extend(seg["label"])
144
+ if not labels:
145
+ labels = ["NO_LABEL"]
146
+ result.append(
147
+ {"segment_start": part_start, "segment_end": part_end, "labels": labels}
148
+ )
149
+
150
+ # deduplicate labels per interval
151
+ for idx in range(len(result)):
152
+ result[idx]["labels"] = list(set(result[idx]["labels"]))
153
+ return result
154
+
155
+ def merge_small_intervals(self, parts, min_duration=2.0):
156
+ """
157
+ parts: list of dicts with "segment_start", "segment_end", 'labels'
158
+ Merge intervals shorter than min_duration into neighbor intervals.
159
+ """
160
+ new_parts = []
161
+ i = 0
162
+ while i < len(parts):
163
+ part = parts[i]
164
+ duration = part["segment_end"] - part["segment_start"]
165
+ if duration < min_duration:
166
+ # decide where to merge
167
+ if len(new_parts) > 0 and (i + 1) < len(parts):
168
+ # randomly choose previous or next
169
+ if random.choice([True, False]):
170
+ prev = new_parts[-1]
171
+ prev["segment_end"] = part["segment_end"]
172
+ else:
173
+ next_part = parts[i + 1]
174
+ next_part["segment_start"] = part["segment_start"]
175
+ # skip adding this part
176
+ elif len(new_parts) > 0:
177
+ # only previous exists - merge into previous
178
+ prev = new_parts[-1]
179
+ prev["segment_end"] = part["segment_end"]
180
+ elif (i + 1) < len(parts):
181
+ # only next exists - merge into next
182
+ next_part = parts[i + 1]
183
+ next_part["segment_start"] = part["segment_start"]
184
+ # else: nothing to merge, drop
185
+ i += 1
186
+ else:
187
+ new_parts.append(part)
188
+ i += 1
189
+ return new_parts
190
+
191
+ def rounding_time(self, segments, num_decimals=3):
192
+ # round segment boundaries to given decimals
193
+ for idx in range(len(segments)):
194
+ segments[idx]["segment_start"] = round(
195
+ segments[idx]["segment_start"], num_decimals
196
+ )
197
+ segments[idx]["segment_end"] = round(
198
+ segments[idx]["segment_end"], num_decimals
199
+ )
200
+ return segments
201
+
202
+ def get_ids(self):
203
+ return list(self.valid_data_ids)
204
+
205
+ def convert_label(self, label: str):
206
+ # map various labels to canonical labels
207
+ mapping = {
208
+ "chorus": "chorus",
209
+ "intro": "intro",
210
+ "bridge": "bridge",
211
+ "verse": "verse",
212
+ "pre-chorus": "pre-chorus",
213
+ "solo": "inst",
214
+ "instrumental": "inst",
215
+ "outro": "outro",
216
+ "NO_LABEL": "NO_LABEL",
217
+ }
218
+ assert label in mapping, f"Unknown label: {label}"
219
+ return mapping[label]
220
+
221
+ def parts_to_label_and_times(self, parts, use_random_tag=True):
222
+ """
223
+ parts: list of dicts with 'segment_start', 'segment_end', 'labels'
224
+
225
+ if use_random_tag: label will be random from valid labels
226
+ else: label will be all valid labels (labels list)
227
+
228
+ return:
229
+ local_times: np.array of right boundary time points (excluding query_end)
230
+ local_labels: list of label indices corresponding to self.label_to_id
231
+ """
232
+ local_times = []
233
+ local_labels = []
234
+
235
+ for part in parts:
236
+ local_times.append(part["segment_start"])
237
+ label = random.choice(part["labels"]) if use_random_tag else part["labels"]
238
+ local_labels.append(self.label_to_id[self.convert_label(label)])
239
+ return np.array(local_times), local_labels
240
+
241
+ def get_parts(self, utt, query_start, query_end):
242
+ key = "_".join(utt.split("_")[:-1])
243
+ assert key in self.id2segments
244
+ segments = self.id2segments[key]
245
+ segments = self.rounding_time(segments)
246
+ parts = self.split_and_label(query_start, query_end, segments)
247
+
248
+ # Apply merging twice to remove very short intervals
249
+ new_parts = self.merge_small_intervals(parts, min_duration=2.0)
250
+ new_parts = self.merge_small_intervals(new_parts, min_duration=2.0)
251
+
252
+ return new_parts
253
+
254
+ def widen_temporal_events(self, events, num_neighbors):
255
+ # smooth binary events with a normalized gaussian
256
+ def theoretical_gaussian_max(sigma):
257
+ return 1 / (np.sqrt(2 * np.pi) * sigma)
258
+
259
+ widen_events = events
260
+ sigma = num_neighbors / 3.0
261
+ smoothed = gaussian_filter1d(widen_events.astype(float), sigma=sigma)
262
+ smoothed /= theoretical_gaussian_max(sigma)
263
+ smoothed = np.clip(smoothed, 0, 1)
264
+
265
+ return smoothed
266
+
267
+ def get_item_json(self, utt, start_time, end_time):
268
+ # load embeddings from all embedding dirs
269
+ embd_list = []
270
+ embd_dirs = self.input_embedding_dir.split()
271
+ for embd_dir in embd_dirs:
272
+ if not Path(embd_dir).exists():
273
+ raise FileNotFoundError(
274
+ f"Embedding directory {embd_dir} does not exist"
275
+ )
276
+ tmp = np.load(Path(embd_dir) / f"{utt}.npy").squeeze(axis=0)
277
+ embd_list.append(tmp)
278
+
279
+ # Check that max/min length difference across embeddings <= 2
280
+ if len(embd_list) > 1:
281
+ embd_shapes = [x.shape for x in embd_list]
282
+ max_shape = max(embd_shapes, key=lambda x: x[0])
283
+ min_shape = min(embd_shapes, key=lambda x: x[0])
284
+ if abs(max_shape[0] - min_shape[0]) > 2:
285
+ raise ValueError(
286
+ f"Embedding shapes differ too much: {max_shape} vs {min_shape}"
287
+ )
288
+
289
+ for idx in range(len(embd_list)):
290
+ embd_list[idx] = embd_list[idx][: min_shape[0], :]
291
+
292
+ input_embedding = np.concatenate(embd_list, axis=-1)
293
+
294
+ return_json = self.get_item_json_without_embedding(utt, start_time, end_time)
295
+ if return_json is None:
296
+ return None
297
+ else:
298
+ return_json["input_embedding"] = input_embedding
299
+ return return_json
300
+
301
+ def get_item_json_without_embedding(self, utt, start_time, end_time):
302
+ SLICE_DUR = int(math.ceil(end_time - start_time))
303
+
304
+ local_times, local_labels = self.parts_to_label_and_times(
305
+ self.get_parts(utt, start_time, end_time)
306
+ )
307
+
308
+ assert np.all(local_times[:-1] < local_times[1:]), (
309
+ f"time must be sorted, but {utt} is {local_times}"
310
+ )
311
+
312
+ # normalize local times relative to slice start
313
+ local_times = local_times - start_time
314
+ time_L = 0.0
315
+ # here time_R is full slice duration because NO_LABEL may appear
316
+ time_R = float(SLICE_DUR)
317
+
318
+ # determine which boundaries are within (time_L, time_R)
319
+ keep_boundarys = (time_L + self.EPS < local_times) & (
320
+ local_times < time_R - self.EPS
321
+ )
322
+
323
+ # if no valid boundary, return None
324
+ if keep_boundarys.sum() <= 0:
325
+ return None
326
+
327
+ mask = np.ones([int(SLICE_DUR * self.frame_rates)], dtype=bool)
328
+ mask[self.time2frame(time_L) : self.time2frame(time_R)] = False
329
+
330
+ true_boundary = np.zeros([int(SLICE_DUR * self.frame_rates)], dtype=float)
331
+ for idx in np.flatnonzero(keep_boundarys):
332
+ true_boundary[self.time2frame(local_times[idx])] = 1
333
+
334
+ true_function = np.zeros(
335
+ [int(SLICE_DUR * self.frame_rates), self.hparams.num_classes],
336
+ dtype=float,
337
+ )
338
+ true_function_list = []
339
+ msa_info = []
340
+ last_pos = self.time2frame(time_L)
341
+ for idx in np.flatnonzero(keep_boundarys):
342
+ # local_labels[idx] might be int or list(int)
343
+ true_function[
344
+ last_pos : self.time2frame(local_times[idx]),
345
+ local_labels[idx - 1],
346
+ ] = 1
347
+ true_function_list.append(
348
+ [int(x) for x in local_labels[idx - 1]]
349
+ if isinstance(local_labels[idx - 1], list)
350
+ else int(local_labels[idx - 1])
351
+ )
352
+ msa_info.append(
353
+ (
354
+ float(max(local_times[idx - 1], time_L)),
355
+ [str(self.id_to_label[int(x)]) for x in local_labels[idx - 1]]
356
+ if isinstance(local_labels[idx - 1], list)
357
+ else str(self.id_to_label[int(local_labels[idx - 1])]),
358
+ )
359
+ )
360
+ last_pos = self.time2frame(local_times[idx])
361
+
362
+ # check last label correctness
363
+ true_function[
364
+ last_pos : self.time2frame(time_R),
365
+ local_labels[int(np.flatnonzero(keep_boundarys)[-1])],
366
+ ] = 1
367
+ true_function_list.append(
368
+ [int(x) for x in local_labels[int(np.flatnonzero(keep_boundarys)[-1])]]
369
+ if isinstance(local_labels[int(np.flatnonzero(keep_boundarys)[-1])], list)
370
+ else int(local_labels[int(np.flatnonzero(keep_boundarys)[-1])])
371
+ )
372
+ msa_info.append(
373
+ (
374
+ float(local_times[int(np.flatnonzero(keep_boundarys)[-1])]),
375
+ [
376
+ str(self.id_to_label[int(x)])
377
+ for x in local_labels[int(np.flatnonzero(keep_boundarys)[-1])]
378
+ ]
379
+ if isinstance(
380
+ local_labels[int(np.flatnonzero(keep_boundarys)[-1])], list
381
+ )
382
+ else str(
383
+ self.id_to_label[
384
+ int(local_labels[int(np.flatnonzero(keep_boundarys)[-1])])
385
+ ]
386
+ ),
387
+ )
388
+ )
389
+ # append final "end" marker
390
+ msa_info.append((float(time_R), "end"))
391
+
392
+ # -------------------------
393
+ # boundary_mask & function_mask
394
+ # -------------------------
395
+ frame_len = int(SLICE_DUR * self.frame_rates)
396
+ boundary_mask = np.zeros([frame_len], dtype=bool)
397
+ function_mask = np.zeros([frame_len], dtype=bool)
398
+
399
+ # set masks according to msa_info
400
+ for i in range(len(msa_info) - 1):
401
+ seg_start, seg_label = msa_info[i]
402
+ seg_end, _ = msa_info[i + 1]
403
+ start_frame = self.time2frame(seg_start)
404
+ end_frame = self.time2frame(seg_end)
405
+
406
+ # handle label being string or list
407
+ is_no_label = (
408
+ seg_label == "NO_LABEL"
409
+ if isinstance(seg_label, str)
410
+ else "NO_LABEL" in seg_label
411
+ )
412
+
413
+ if is_no_label:
414
+ # set function_mask True for NO_LABEL regions
415
+ function_mask[start_frame:end_frame] = True
416
+
417
+ # set boundary_mask True for regions >4s away from ends
418
+ left_offset = self.time2frame(seg_start + 4)
419
+ right_offset = self.time2frame(seg_end - 4)
420
+ if i == 0:
421
+ if right_offset > 0:
422
+ boundary_mask[0 : min(right_offset, frame_len)] = True
423
+ elif i == len(msa_info) - 2:
424
+ if left_offset < frame_len:
425
+ boundary_mask[left_offset:frame_len] = True
426
+ elif right_offset > left_offset:
427
+ boundary_mask[left_offset:right_offset] = True
428
+
429
+ # -------------------------
430
+ # return all things except input_embedding
431
+ # -------------------------
432
+ return {
433
+ "data_id": self.internal_tmp_id + "_" + f"{utt}_{start_time}",
434
+ "mask": mask,
435
+ "true_boundary": true_boundary,
436
+ "widen_true_boundary": self.widen_temporal_events(
437
+ true_boundary, num_neighbors=self.hparams.num_neighbors
438
+ ),
439
+ "true_function": true_function,
440
+ "true_function_list": true_function_list,
441
+ "msa_info": msa_info,
442
+ "dataset_id": self.dataset_id_to_dataset_id[self.dataset_type],
443
+ "label_id_mask": self.dataset_id2label_mask[
444
+ self.dataset_id_to_dataset_id[self.dataset_type]
445
+ ],
446
+ "boundary_mask": boundary_mask, # only effective during loss computation
447
+ "function_mask": function_mask, # only effective during loss computation
448
+ }
src/SongFormer/dataset/custom_types.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MsaInfo
3
+ A list of (timestamp, label) tuples used to represent music structure
4
+ analysis (MSA). The first element of the tuple is a float timestamp
5
+ (in seconds) and the second is a string label
6
+
7
+ Example
8
+ -------
9
+ >>> msa: MsaInfo = [(0.0, "intro"), (12.5, "verse"), (34.0, "chorus")]
10
+ """
11
+
12
+ from typing import List, Tuple
13
+
14
+ MsaInfo = List[Tuple[float, str]]
src/SongFormer/dataset/label2id.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LABEL_TO_ID = {
2
+ "intro": 0,
3
+ "verse": 1,
4
+ "chorus": 2,
5
+ "bridge": 3,
6
+ "inst": 4,
7
+ "outro": 5,
8
+ "silence": 6,
9
+ "intchorus": 7,
10
+ "prechorus": 8,
11
+ "gtrbreak": 9,
12
+ "solo": 10,
13
+ "quietchorus": 11,
14
+ "bre": 12,
15
+ "break": 13,
16
+ "introverse": 14,
17
+ "mainriff": 15,
18
+ "chorushalf": 16,
19
+ "instintro": 17,
20
+ "gtr": 18,
21
+ "vocaloutro": 19,
22
+ "verse_slow": 20,
23
+ "fadein": 21,
24
+ "saxobeat": 22,
25
+ "transition": 23,
26
+ "verse1a": 24,
27
+ "build": 25,
28
+ "pre-chorus": 26,
29
+ "outroa": 27,
30
+ "bigoutro": 28,
31
+ "fast": 29,
32
+ "instrumentalverse": 30,
33
+ "section": 31,
34
+ "choruspart": 32,
35
+ "instbridge": 33,
36
+ "guitar": 34,
37
+ "instrumental": 35,
38
+ "breakdown": 36,
39
+ "rhythmlessintro": 37,
40
+ "intropt": 38,
41
+ "interlude": 39,
42
+ "postchorus": 40,
43
+ "postverse": 41,
44
+ "opening": 42,
45
+ "altchorus": 43,
46
+ "stutter": 44,
47
+ "oddriff": 45,
48
+ "synth": 46,
49
+ "preverse": 47,
50
+ "quiet": 48,
51
+ "raps": 49,
52
+ "verseinst": 50,
53
+ "instchorus": 51,
54
+ "chorus_instrumental": 52,
55
+ "slowverse": 53,
56
+ "slow": 54,
57
+ "worstthingever": 55,
58
+ "transition2a": 56,
59
+ "miniverse": 57,
60
+ "refrain": 58,
61
+ "introchorus": 59,
62
+ "drumroll": 60,
63
+ "guitarsolo": 61,
64
+ "versepart": 62,
65
+ "chorusinst": 63,
66
+ "ending": 64,
67
+ "no-vocal-intro": 65,
68
+ "no-vocal-interlude": 66,
69
+ "no-vocal-outro": 67,
70
+ "NO_LABEL": 68, # Only referring to cases without labels, this portion of labels will be ignored during the loss calculation process.
71
+ }
72
+
73
+ ID_TO_LABEL = {v: k for k, v in LABEL_TO_ID.items()}
74
+
75
+ # Reserve 64 embedding positions for dataset identifiers in the model.
76
+ DATASET_LABEL_TO_DATASET_ID = {
77
+ "SongForm-HX-7Class": 0, # Categories after rule mapping for HarmonixSet
78
+ "SongForm-HX-Widen": 1, # Original HarmonixSet
79
+ "SongForm-Private-Raw": 2,
80
+ "SongForm-Private": 3,
81
+ "SongForm-HX-Gemini-Relabeled": 4, # Rule-mapped HarmonixSet corrected by Gemini
82
+ "SongForm-HX-8Class": 5, # Rule-mapped (pre-chorus retained)
83
+ "SongForm-Hook": 6,
84
+ "SongForm-Gem": 7,
85
+ "SongForm-Gem-Only-Label": 8, # Use only segments with labels in SongForm-Gem
86
+ }
87
+
88
+ DATASET_ID_TO_DATASET_LABEL = {v: k for k, v in DATASET_LABEL_TO_DATASET_ID.items()}
89
+
90
+ DATASET_ID_ALLOWED_LABEL_IDS = {
91
+ 0: [0, 1, 2, 3, 4, 5, 6],
92
+ 1: [
93
+ 0,
94
+ 1,
95
+ 2,
96
+ 3,
97
+ 4,
98
+ 5,
99
+ 6,
100
+ 7,
101
+ 8,
102
+ 9,
103
+ 10,
104
+ 11,
105
+ 12,
106
+ 13,
107
+ 14,
108
+ 15,
109
+ 16,
110
+ 17,
111
+ 18,
112
+ 19,
113
+ 20,
114
+ 21,
115
+ 22,
116
+ 23,
117
+ 24,
118
+ 25,
119
+ 27,
120
+ 28,
121
+ 29,
122
+ 30,
123
+ 31,
124
+ 32,
125
+ 33,
126
+ 34,
127
+ 35,
128
+ 36,
129
+ 37,
130
+ 38,
131
+ 40,
132
+ 41,
133
+ 42,
134
+ 43,
135
+ 44,
136
+ 45,
137
+ 46,
138
+ 47,
139
+ 48,
140
+ 49,
141
+ 50,
142
+ 51,
143
+ 52,
144
+ 53,
145
+ 54,
146
+ 55,
147
+ 56,
148
+ 57,
149
+ 58,
150
+ 59,
151
+ 60,
152
+ 61,
153
+ 62,
154
+ 63,
155
+ ],
156
+ 2: [0, 1, 2, 3, 26, 39, 64, 65, 66, 67],
157
+ 3: [0, 1, 2, 3, 4, 5, 6, 26, 39, 64, 65, 66, 67],
158
+ 4: [0, 1, 2, 3, 4, 5, 6, 26],
159
+ 5: [0, 1, 2, 3, 4, 5, 6, 26],
160
+ 6: [0, 1, 2, 3, 4, 5, 6, 26],
161
+ 7: [0, 1, 2, 3, 4, 5, 6, 26],
162
+ 8: [0, 1, 2, 3, 4, 5, 6, 26],
163
+ }
src/SongFormer/dataset/msa_info_utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataset.custom_types import MsaInfo
2
+ from dataset.label2id import LABEL_TO_ID
3
+
4
+
5
+ def load_msa_info(msa_info_path):
6
+ msa_info: MsaInfo = []
7
+ with open(msa_info_path) as f:
8
+ for line in f:
9
+ line = line.strip()
10
+ if not line:
11
+ continue
12
+ time_, label = line.split()
13
+ time_ = float(time_)
14
+ label = str(label)
15
+ assert label in LABEL_TO_ID or label == "end", f"{label} not in LABEL_TO_ID"
16
+ msa_info.append((time_, label))
17
+ assert msa_info[-1][1] == "end", f"last {msa_info[-1][1]} != end"
18
+ return msa_info
19
+
20
+
21
+ def load_msa_infos(msa_str):
22
+ msa_info: MsaInfo = []
23
+ for line in msa_str:
24
+ line = line.strip()
25
+ if not line:
26
+ continue
27
+ time_, label = line.split()
28
+ time_ = float(time_)
29
+ label = str(label)
30
+ assert label in LABEL_TO_ID or label == "end", f"{label} not in LABEL_TO_ID"
31
+ msa_info.append((time_, label))
32
+ assert msa_info[-1][1] == "end", f"last {msa_info[-1][1]} != end"
33
+ return msa_info
34
+
35
+
36
+ def dump_msa_info(msa_info_path, msa_info: MsaInfo):
37
+ with open(msa_info_path, "w") as f:
38
+ for time_, label in msa_info:
39
+ f.write(f"{time_} {label}\n")
40
+
41
+
42
+ def dump_msa_infos(msa_info: MsaInfo):
43
+ mas_strs = []
44
+ for time_, label in msa_info:
45
+ mas_strs.append(f"{round(time_, 2)} {label}")
46
+
47
+ return "\n".join(mas_strs)
src/SongFormer/eval.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export CUDA_VISIBLE_DEVICES=-1
2
+ export PYTHONPATH=${PWD}:$PYTHONPATH
3
+
4
+ export HYDRA_FULL_ERROR=1
5
+ export OMP_NUM_THREADS=1
6
+ export MPI_NUM_THREADS=1
7
+ export NCCL_P2P_DISABLE=1
8
+ export NCCL_IB_DISABLE=1
9
+
10
+
11
+ EST_DIR=
12
+ ANN_DIR=
13
+ OUTPUT_DIR=
14
+ echo "$EST_DIR --> $OUTPUT_DIR"
15
+ mkdir -p "$OUTPUT_DIR"
16
+
17
+ python evaluation/eval_infer_results.py \
18
+ --ann_dir $ANN_DIR \
19
+ --est_dir $EST_DIR \
20
+ --output_dir $OUTPUT_DIR \
21
+ --prechorus2what verse
22
+ # --armerge_continuous_segments
src/SongFormer/evaluation/eval_infer_results.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from collections import defaultdict
4
+ from pathlib import Path
5
+ import mir_eval
6
+ import numpy as np
7
+ import pandas as pd
8
+ from dataset.custom_types import MsaInfo
9
+ from dataset.label2id import LABEL_TO_ID
10
+ from dataset.msa_info_utils import load_msa_info
11
+ from msaf.eval import compute_results
12
+ from postprocessing.calc_acc import cal_acc
13
+ from postprocessing.calc_iou import cal_iou
14
+ from tqdm import tqdm
15
+ from loguru import logger
16
+
17
+ LEGAL_LABELS = {
18
+ "end",
19
+ "intro",
20
+ "verse",
21
+ "chorus",
22
+ "bridge",
23
+ "inst",
24
+ "outro",
25
+ "silence",
26
+ "pre-chorus",
27
+ }
28
+
29
+
30
+ def to_inters_labels(msa_info: MsaInfo):
31
+ label_ids = np.array([LABEL_TO_ID[x[1]] for x in msa_info[:-1]])
32
+ times = [x[0] for x in msa_info]
33
+ start_times = np.column_stack([np.array(times[:-1]), np.array(times[1:])])
34
+ return start_times, label_ids
35
+
36
+
37
+ def merge_continuous_segments(segments):
38
+ """
39
+ Merge continuous segments with the same label.
40
+
41
+ Parameters:
42
+ segments: List of tuples [(start_time, label), ...], where the last element is (end_time, 'end')
43
+
44
+ Returns:
45
+ Merged segment list in the same format [(start_time, label), ...], with the last element being (end_time, 'end')
46
+ """
47
+ if not segments or len(segments) < 2:
48
+ return segments
49
+
50
+ merged = []
51
+ current_start = segments[0][0]
52
+ current_label = segments[0][1]
53
+
54
+ for i in range(1, len(segments)):
55
+ time, label = segments[i]
56
+
57
+ if label == "end":
58
+ if current_label != "end":
59
+ merged.append((current_start, current_label))
60
+ merged.append((time, "end"))
61
+ break
62
+
63
+ if label != current_label:
64
+ merged.append((current_start, current_label))
65
+ current_start = time
66
+ current_label = label
67
+
68
+ return merged
69
+
70
+
71
+ def main():
72
+ argparser = argparse.ArgumentParser()
73
+ argparser.add_argument("--ann_dir", type=str, required=True)
74
+ argparser.add_argument("--est_dir", type=str, required=True)
75
+ argparser.add_argument("--output_dir", type=str, default="./eval_infer_results")
76
+ argparser.add_argument("--prechorus2what", type=str, default=None)
77
+ argparser.add_argument("--armerge_continuous_segments", action="store_true")
78
+ args = argparser.parse_args()
79
+
80
+ ann_dir = args.ann_dir
81
+ est_dir = args.est_dir
82
+ output_dir = args.output_dir
83
+ if args.armerge_continuous_segments:
84
+ logger.info("Merging continuous segments")
85
+ os.makedirs(output_dir, exist_ok=True)
86
+
87
+ ann_id_lists = [x for x in os.listdir(ann_dir) if x.endswith(".txt")]
88
+ est_id_lists = [x for x in os.listdir(est_dir) if x.endswith(".txt")]
89
+
90
+ common_id_lists = set(ann_id_lists) & set(est_id_lists)
91
+ common_id_lists = list(common_id_lists)
92
+ logger.info(f"Common number of files: {len(common_id_lists)}")
93
+
94
+ resultes = []
95
+ ious = {}
96
+
97
+ for id in tqdm(common_id_lists):
98
+ try:
99
+ logger.info(f"Processing {id}")
100
+ ann_msa = load_msa_info(os.path.join(ann_dir, id))
101
+ est_msa = load_msa_info(os.path.join(est_dir, id))
102
+
103
+ if args.prechorus2what == "verse":
104
+ ann_msa = [
105
+ (t, "verse") if l == "pre-chorus" else (t, l) for t, l in ann_msa
106
+ ]
107
+ est_msa = [
108
+ (t, "verse") if l == "pre-chorus" else (t, l) for t, l in est_msa
109
+ ]
110
+ elif args.prechorus2what == "chorus":
111
+ ann_msa = [
112
+ (t, "chorus") if l == "pre-chorus" else (t, l) for t, l in ann_msa
113
+ ]
114
+ est_msa = [
115
+ (t, "chorus") if l == "pre-chorus" else (t, l) for t, l in est_msa
116
+ ]
117
+ elif args.prechorus2what is not None:
118
+ raise ValueError(f"Unknown prechorus2what: {args.prechorus2what}")
119
+ if args.armerge_continuous_segments:
120
+ ann_msa = merge_continuous_segments(ann_msa)
121
+ est_msa = merge_continuous_segments(est_msa)
122
+
123
+ ann_inter, ann_labels = to_inters_labels(ann_msa)
124
+ est_inter, est_labels = to_inters_labels(est_msa)
125
+
126
+ result = compute_results(
127
+ ann_inter,
128
+ est_inter,
129
+ ann_labels,
130
+ est_labels,
131
+ bins=11,
132
+ est_file="test.txt",
133
+ weight=0.58,
134
+ )
135
+ acc = cal_acc(ann_msa, est_msa, post_digit=3)
136
+
137
+ ious[id] = cal_iou(ann_msa, est_msa)
138
+ result["HitRate_1P"], result["HitRate_1R"], result["HitRate_1F"] = (
139
+ mir_eval.segment.detection(ann_inter, est_inter, window=1, trim=False)
140
+ )
141
+ result.update({"id": Path(id).stem})
142
+ result.update({"acc": acc})
143
+ for v in ious[id]:
144
+ result.update({f"iou-{v['label']}": v["iou"]})
145
+ del result["track_id"]
146
+ del result["ds_name"]
147
+
148
+ resultes.append(result)
149
+ except Exception as e:
150
+ logger.error(f"Error processing {id}: {e}")
151
+ continue
152
+
153
+ df = pd.DataFrame(resultes)
154
+ df.to_csv(f"{output_dir}/eval_infer.csv", index=False)
155
+
156
+ intsec_dur_total = defaultdict(float)
157
+ uni_dur_total = defaultdict(float)
158
+
159
+ for tid, value in ious.items():
160
+ for item in value:
161
+ label = item["label"]
162
+ intsec_dur_total[label] += item.get("intsec_dur", 0)
163
+ uni_dur_total[label] += item.get("uni_dur", 0)
164
+
165
+ total_intsec = sum(intsec_dur_total.values())
166
+ total_uni = sum(uni_dur_total.values())
167
+ overall_iou = total_intsec / total_uni if total_uni > 0 else 0.0
168
+
169
+ class_ious = {}
170
+ for label in intsec_dur_total:
171
+ intsec = intsec_dur_total[label]
172
+ uni = uni_dur_total[label]
173
+ class_ious[label] = intsec / uni if uni > 0 else 0.0
174
+
175
+ summary = pd.DataFrame(
176
+ [
177
+ {
178
+ "num_samples": len(df),
179
+ "HR.5F": df["HitRate_0.5F"].mean(),
180
+ "HR3F": df["HitRate_3F"].mean(),
181
+ "HR1F": df["HitRate_1F"].mean(),
182
+ "PWF": df["PWF"].mean(),
183
+ "Sf": df["Sf"].mean(),
184
+ "acc": df["acc"].mean(),
185
+ "iou": overall_iou,
186
+ **{f"iou_{k}": v for k, v in class_ious.items()},
187
+ }
188
+ ]
189
+ )
190
+ with open(f"{output_dir}/eval_infer_summary.md", "w") as f:
191
+ print(summary.to_markdown(), file=f)
192
+
193
+ summary.to_csv(f"{output_dir}/eval_infer_summary.csv", index=False)
194
+ logger.info(f"Results saved to {output_dir}")
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()
src/SongFormer/infer.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export CUDA_VISIBLE_DEVICES=
3
+ echo "use gpu ${CUDA_VISIBLE_DEVICES}"
4
+
5
+ export PYTHONPATH=../third_party:$PYTHONPATH
6
+
7
+ export OMP_NUM_THREADS=1
8
+ export MPI_NUM_THREADS=1
9
+ export NCCL_P2P_DISABLE=1
10
+ export NCCL_IB_DISABLE=1
11
+
12
+ python infer/infer.py \
13
+ -i XXX.scp \
14
+ -o XXX_dir \
15
+ --model SongFormer \
16
+ --checkpoint SongFormer.safetensors \
17
+ --config_path SongFormer.yaml \
18
+ -gn 1 \
19
+ -tn 1
20
+ # --debug
21
+ # --no_rule_post_processing
src/SongFormer/infer/infer.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import importlib
3
+ import json
4
+ import math
5
+ import multiprocessing as mp
6
+ import os
7
+ import time
8
+ from argparse import Namespace
9
+ from pathlib import Path
10
+
11
+ # monkey patch to fix issues in msaf
12
+ import scipy
13
+ import numpy as np
14
+
15
+ scipy.inf = np.inf
16
+
17
+ import librosa
18
+ import torch
19
+ from ema_pytorch import EMA
20
+ from loguru import logger
21
+ from muq import MuQ
22
+ from musicfm.model.musicfm_25hz import MusicFM25Hz
23
+ from omegaconf import OmegaConf
24
+ from tqdm import tqdm
25
+
26
+ mp.set_start_method("spawn", force=True)
27
+
28
+ MUSICFM_HOME_PATH = os.path.join("ckpts", "MusicFM")
29
+
30
+ BEFORE_DOWNSAMPLING_FRAME_RATES = 25
31
+ AFTER_DOWNSAMPLING_FRAME_RATES = 8.333
32
+
33
+ DATASET_LABEL = "SongForm-HX-8Class"
34
+ DATASET_IDS = [5]
35
+
36
+ TIME_DUR = 420
37
+ INPUT_SAMPLING_RATE = 24000
38
+
39
+ from dataset.label2id import DATASET_ID_ALLOWED_LABEL_IDS, DATASET_LABEL_TO_DATASET_ID
40
+ from postprocessing.functional import postprocess_functional_structure
41
+
42
+
43
+ def get_processed_ids(output_path):
44
+ """Get already processed IDs from output directory"""
45
+ ids = os.listdir(output_path)
46
+ ret = []
47
+ for x in ids:
48
+ if x.endswith(".json"):
49
+ ret.append(x.replace(".json", ""))
50
+ return set(ret)
51
+
52
+
53
+ def get_processing_ids(input_path, processed_ids_set):
54
+ """Get IDs to be processed from input directory"""
55
+ ret = []
56
+ with open(input_path) as f:
57
+ for line in f:
58
+ if line.strip() and Path(line.strip()).stem not in processed_ids_set:
59
+ ret.append(line.strip())
60
+ return ret
61
+
62
+
63
+ def load_checkpoint(checkpoint_path, device=None):
64
+ """Load checkpoint from path"""
65
+ if device is None:
66
+ device = "cpu"
67
+
68
+ if checkpoint_path.endswith(".pt"):
69
+ checkpoint = torch.load(checkpoint_path, map_location=device)
70
+ elif checkpoint_path.endswith(".safetensors"):
71
+ from safetensors.torch import load_file
72
+
73
+ checkpoint = {"model_ema": load_file(checkpoint_path, device=device)}
74
+ else:
75
+ raise ValueError("Unsupported checkpoint format. Use .pt or .safetensors")
76
+ return checkpoint
77
+
78
+
79
+ def rule_post_processing(msa_list):
80
+ if len(msa_list) <= 2:
81
+ return msa_list
82
+
83
+ result = msa_list.copy()
84
+
85
+ while len(result) > 2:
86
+ first_duration = result[1][0] - result[0][0]
87
+ if first_duration < 1.0 and len(result) > 2:
88
+ result[0] = (result[0][0], result[1][1])
89
+ result = [result[0]] + result[2:]
90
+ else:
91
+ break
92
+
93
+ while len(result) > 2:
94
+ last_label_duration = result[-1][0] - result[-2][0]
95
+ if last_label_duration < 1.0:
96
+ result = result[:-2] + [result[-1]]
97
+ else:
98
+ break
99
+
100
+ while len(result) > 2:
101
+ if result[0][1] == result[1][1] and result[1][0] <= 10.0:
102
+ result = [(result[0][0], result[0][1])] + result[2:]
103
+ else:
104
+ break
105
+
106
+ while len(result) > 2:
107
+ last_duration = result[-1][0] - result[-2][0]
108
+ if result[-2][1] == result[-3][1] and last_duration <= 10.0:
109
+ result = result[:-2] + [result[-1]]
110
+ else:
111
+ break
112
+
113
+ return result
114
+
115
+
116
+ def inference(rank, queue_input: mp.Queue, queue_output: mp.Queue, args):
117
+ """Run inference on the input audio"""
118
+ device = f"cuda:{rank}"
119
+
120
+ # MuQ model loading (this will automatically fetch the checkpoint from huggingface)
121
+ muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
122
+ muq = muq.to(device).eval()
123
+
124
+ # MusicFM model loading
125
+ musicfm = MusicFM25Hz(
126
+ is_flash=False,
127
+ stat_path=os.path.join(MUSICFM_HOME_PATH, "msd_stats.json"),
128
+ model_path=os.path.join(MUSICFM_HOME_PATH, "pretrained_msd.pt"),
129
+ )
130
+ musicfm = musicfm.to(device)
131
+ musicfm.eval()
132
+
133
+ # Custom model loading based on the config
134
+ module = importlib.import_module("models." + str(args.model))
135
+ Model = getattr(module, "Model")
136
+ hp = OmegaConf.load(os.path.join("configs", args.config_path))
137
+ model = Model(hp)
138
+
139
+ ckpt = load_checkpoint(checkpoint_path=os.path.join("ckpts", args.checkpoint))
140
+ if ckpt.get("model_ema", None) is not None:
141
+ logger.info("Loading EMA model parameters")
142
+ model_ema = EMA(model, include_online_model=False)
143
+ model_ema.load_state_dict(ckpt["model_ema"])
144
+ model.load_state_dict(model_ema.ema_model.state_dict())
145
+ else:
146
+ logger.info("No EMA model parameters found, using original model")
147
+ model.load_state_dict(ckpt["model"])
148
+
149
+ model.to(device)
150
+ model.eval()
151
+
152
+ num_classes = args.num_classes
153
+ dataset_id2label_mask = {}
154
+
155
+ for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
156
+ dataset_id2label_mask[key] = np.ones(args.num_classes, dtype=bool)
157
+ dataset_id2label_mask[key][allowed_ids] = False
158
+
159
+ with torch.no_grad():
160
+ while True:
161
+ item = queue_input.get()
162
+ if not item:
163
+ queue_output.put(None)
164
+ break
165
+
166
+ try:
167
+ # Loading the audio file
168
+ wav, sr = librosa.load(item, sr=INPUT_SAMPLING_RATE)
169
+ audio = torch.tensor(wav).to(device)
170
+
171
+ win_size = args.win_size
172
+ hop_size = args.hop_size
173
+ total_len = (
174
+ (audio.shape[0] // INPUT_SAMPLING_RATE) // TIME_DUR
175
+ ) * TIME_DUR + TIME_DUR
176
+ total_frames = math.ceil(total_len * AFTER_DOWNSAMPLING_FRAME_RATES)
177
+
178
+ logits = {
179
+ "function_logits": np.zeros([total_frames, num_classes]),
180
+ "boundary_logits": np.zeros([total_frames]),
181
+ }
182
+ logits_num = {
183
+ "function_logits": np.zeros([total_frames, num_classes]),
184
+ "boundary_logits": np.zeros([total_frames]),
185
+ }
186
+
187
+ lens = 0
188
+ i = 0
189
+ while True:
190
+ start_idx = i * INPUT_SAMPLING_RATE
191
+ end_idx = min((i + win_size) * INPUT_SAMPLING_RATE, audio.shape[-1])
192
+ if start_idx >= audio.shape[-1]:
193
+ break
194
+ if end_idx - start_idx <= 1024:
195
+ continue
196
+ audio_seg = audio[start_idx:end_idx]
197
+
198
+ # MuQ embedding
199
+ muq_output = muq(audio_seg.unsqueeze(0), output_hidden_states=True)
200
+ muq_embd_420s = muq_output["hidden_states"][10]
201
+ del muq_output
202
+ torch.cuda.empty_cache()
203
+
204
+ # MusicFM embedding
205
+ _, musicfm_hidden_states = musicfm.get_predictions(
206
+ audio_seg.unsqueeze(0)
207
+ )
208
+ musicfm_embd_420s = musicfm_hidden_states[10]
209
+ del musicfm_hidden_states
210
+ torch.cuda.empty_cache()
211
+
212
+ wraped_muq_embd_30s = []
213
+ wraped_musicfm_embd_30s = []
214
+
215
+ for idx_30s in range(i, i + hop_size, 30):
216
+ start_idx_30s = idx_30s * INPUT_SAMPLING_RATE
217
+ end_idx_30s = min(
218
+ (idx_30s + 30) * INPUT_SAMPLING_RATE,
219
+ audio.shape[-1],
220
+ (i + hop_size) * INPUT_SAMPLING_RATE,
221
+ )
222
+ if start_idx_30s >= audio.shape[-1]:
223
+ break
224
+ if end_idx_30s - start_idx_30s <= 1024:
225
+ continue
226
+ wraped_muq_embd_30s.append(
227
+ muq(
228
+ audio[start_idx_30s:end_idx_30s].unsqueeze(0),
229
+ output_hidden_states=True,
230
+ )["hidden_states"][10]
231
+ )
232
+ torch.cuda.empty_cache()
233
+ wraped_musicfm_embd_30s.append(
234
+ musicfm.get_predictions(
235
+ audio[start_idx_30s:end_idx_30s].unsqueeze(0)
236
+ )[1][10]
237
+ )
238
+ torch.cuda.empty_cache()
239
+
240
+ wraped_muq_embd_30s = torch.concatenate(wraped_muq_embd_30s, dim=1)
241
+ wraped_musicfm_embd_30s = torch.concatenate(
242
+ wraped_musicfm_embd_30s, dim=1
243
+ )
244
+ all_embds = [
245
+ wraped_musicfm_embd_30s,
246
+ wraped_muq_embd_30s,
247
+ musicfm_embd_420s,
248
+ muq_embd_420s,
249
+ ]
250
+
251
+ if len(all_embds) > 1:
252
+ embd_lens = [x.shape[1] for x in all_embds]
253
+ max_embd_len = max(embd_lens)
254
+ min_embd_len = min(embd_lens)
255
+ if abs(max_embd_len - min_embd_len) > 4:
256
+ raise ValueError(
257
+ f"Embedding shapes differ too much: {max_embd_len} vs {min_embd_len}"
258
+ )
259
+
260
+ for idx in range(len(all_embds)):
261
+ all_embds[idx] = all_embds[idx][:, :min_embd_len, :]
262
+
263
+ embd = torch.concatenate(all_embds, axis=-1)
264
+
265
+ dataset_label = DATASET_LABEL
266
+ dataset_ids = torch.Tensor(DATASET_IDS).to(device, dtype=torch.long)
267
+ msa_info, chunk_logits = model.infer(
268
+ input_embeddings=embd,
269
+ dataset_ids=dataset_ids,
270
+ label_id_masks=torch.Tensor(
271
+ dataset_id2label_mask[
272
+ DATASET_LABEL_TO_DATASET_ID[dataset_label]
273
+ ]
274
+ )
275
+ .to(device, dtype=bool)
276
+ .unsqueeze(0)
277
+ .unsqueeze(0),
278
+ with_logits=True,
279
+ )
280
+
281
+ start_frame = int(i * AFTER_DOWNSAMPLING_FRAME_RATES)
282
+ end_frame = start_frame + min(
283
+ math.ceil(hop_size * AFTER_DOWNSAMPLING_FRAME_RATES),
284
+ chunk_logits["boundary_logits"][0].shape[0],
285
+ )
286
+
287
+ logits["function_logits"][start_frame:end_frame, :] += (
288
+ chunk_logits["function_logits"][0].detach().cpu().numpy()
289
+ )
290
+ logits["boundary_logits"][start_frame:end_frame] = (
291
+ chunk_logits["boundary_logits"][0].detach().cpu().numpy()
292
+ )
293
+ logits_num["function_logits"][start_frame:end_frame, :] += 1
294
+ logits_num["boundary_logits"][start_frame:end_frame] += 1
295
+ lens += end_frame - start_frame
296
+
297
+ i += hop_size
298
+ logits["function_logits"] /= logits_num["function_logits"]
299
+ logits["boundary_logits"] /= logits_num["boundary_logits"]
300
+
301
+ logits["function_logits"] = torch.from_numpy(
302
+ logits["function_logits"][:lens]
303
+ ).unsqueeze(0)
304
+ logits["boundary_logits"] = torch.from_numpy(
305
+ logits["boundary_logits"][:lens]
306
+ ).unsqueeze(0)
307
+
308
+ msa_infer_output = postprocess_functional_structure(logits, hp)
309
+
310
+ assert msa_infer_output[-1][-1] == "end"
311
+ if not args.no_rule_post_processing:
312
+ msa_infer_output = rule_post_processing(msa_infer_output)
313
+ msa_json = []
314
+ for idx in range(len(msa_infer_output) - 1):
315
+ msa_json.append(
316
+ {
317
+ "label": msa_infer_output[idx][1],
318
+ "start": msa_infer_output[idx][0],
319
+ "end": msa_infer_output[idx + 1][0],
320
+ }
321
+ )
322
+ json.dump(
323
+ msa_json,
324
+ open(os.path.join(args.output_dir, f"{Path(item).stem}.json"), "w"),
325
+ indent=4,
326
+ ensure_ascii=False,
327
+ )
328
+
329
+ queue_output.put(None)
330
+
331
+ except Exception as e:
332
+ queue_output.put(None)
333
+ logger.error(f"process {rank} error\n{item}\n{e}")
334
+
335
+
336
+ def deal_with_output(output_path, queue_output, length):
337
+ """Handle output data from the queue"""
338
+ pbar = tqdm(range(length), desc="getting inference output")
339
+ for _ in pbar:
340
+ data = queue_output.get()
341
+ if not data:
342
+ continue
343
+
344
+
345
+ def main(args):
346
+ input_path = args.input_path
347
+ output_path = args.output_path
348
+ gpu_num = args.gpu_num
349
+ num_thread_per_gpu = args.num_thread_per_gpu
350
+ debug = args.debug
351
+
352
+ os.makedirs(output_path, exist_ok=True)
353
+
354
+ processed_ids = get_processed_ids(output_path=output_path)
355
+ processing_ids = get_processing_ids(input_path, processed_ids)
356
+
357
+ num_threads = num_thread_per_gpu * gpu_num
358
+
359
+ queue_input: mp.Queue = mp.Queue()
360
+ queue_output: mp.Queue = mp.Queue()
361
+
362
+ init_args = Namespace(
363
+ output_dir=output_path,
364
+ win_size=420,
365
+ hop_size=420,
366
+ num_classes=128,
367
+ model=args.model,
368
+ checkpoint=args.checkpoint,
369
+ config_path=args.config_path,
370
+ no_rule_post_processing=args.no_rule_post_processing,
371
+ )
372
+
373
+ processes = []
374
+
375
+ if debug:
376
+ queue_input.put(processing_ids[0])
377
+ queue_input.put(None)
378
+
379
+ inference(0, queue_input, queue_output, init_args)
380
+
381
+ print("debug exit")
382
+ exit(0)
383
+
384
+ for thread_num in range(num_threads):
385
+ rank = thread_num % gpu_num
386
+ print(f"num_threads: {thread_num} on GPU {rank}")
387
+ time.sleep(0.2)
388
+ p = mp.Process(
389
+ target=inference,
390
+ args=(rank, queue_input, queue_output, init_args),
391
+ daemon=True,
392
+ )
393
+ p.start()
394
+ processes.append(p)
395
+
396
+ for wav_id in tqdm(processing_ids, desc="add data to queue"):
397
+ queue_input.put(wav_id)
398
+
399
+ for _ in range(num_threads):
400
+ queue_input.put(None)
401
+
402
+ deal_with_output(output_path, queue_output, len(processing_ids))
403
+
404
+ for p in processes:
405
+ p.join()
406
+
407
+
408
+ if __name__ == "__main__":
409
+ parser = argparse.ArgumentParser()
410
+
411
+ parser.add_argument(
412
+ "--input_path", "-i", type=str, required=True, help="Input file path"
413
+ )
414
+ parser.add_argument(
415
+ "--output_path", "-o", type=str, required=True, help="Output file path"
416
+ )
417
+ parser.add_argument(
418
+ "--gpu_num", "-gn", type=int, default=1, help="Number of GPUs, default is 1"
419
+ )
420
+ parser.add_argument(
421
+ "--num_thread_per_gpu",
422
+ "-tn",
423
+ type=int,
424
+ default=1,
425
+ help="Number of threads per GPU, default is 1",
426
+ )
427
+ parser.add_argument("--model", type=str, help="Model to use")
428
+ parser.add_argument("--checkpoint", type=str, help="Checkpoint path")
429
+ parser.add_argument("--config_path", type=str, help="Configuration file path")
430
+ parser.add_argument(
431
+ "--no_rule_post_processing",
432
+ action="store_true",
433
+ help="Disable rule-based post-processing",
434
+ )
435
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
436
+
437
+ args = parser.parse_args()
438
+
439
+ main(args=args)
src/SongFormer/models/SongFormer.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ from dataset.custom_types import MsaInfo
6
+ from msaf.eval import compute_results
7
+ from postprocessing.functional import postprocess_functional_structure
8
+ from x_transformers import Encoder
9
+ import bisect
10
+
11
+
12
+ class Head(nn.Module):
13
+ def __init__(self, input_dim, output_dim, hidden_dims=None, activation="silu"):
14
+ super().__init__()
15
+ hidden_dims = hidden_dims or []
16
+ act_layers = {"relu": nn.ReLU, "silu": nn.SiLU, "gelu": nn.GELU}
17
+ act_layer = act_layers.get(activation.lower())
18
+ if not act_layer:
19
+ raise ValueError(f"Unsupported activation: {activation}")
20
+
21
+ dims = [input_dim] + hidden_dims + [output_dim]
22
+ layers = []
23
+ for i in range(len(dims) - 1):
24
+ layers.append(nn.Linear(dims[i], dims[i + 1]))
25
+ if i < len(dims) - 2:
26
+ layers.append(act_layer())
27
+ self.net = nn.Sequential(*layers)
28
+
29
+ def reset_parameters(self, confidence):
30
+ bias_value = -torch.log(torch.tensor((1 - confidence) / confidence))
31
+ self.net[-1].bias.data.fill_(bias_value.item())
32
+
33
+ def forward(self, x):
34
+ batch, T, C = x.shape
35
+ x = x.reshape(-1, C)
36
+ x = self.net(x)
37
+ return x.reshape(batch, T, -1)
38
+
39
+
40
+ class WrapedTransformerEncoder(nn.Module):
41
+ def __init__(
42
+ self, input_dim, transformer_input_dim, num_layers=1, nhead=8, dropout=0.1
43
+ ):
44
+ super().__init__()
45
+ self.input_dim = input_dim
46
+ self.transformer_input_dim = transformer_input_dim
47
+
48
+ if input_dim != transformer_input_dim:
49
+ self.input_proj = nn.Sequential(
50
+ nn.Linear(input_dim, transformer_input_dim),
51
+ nn.LayerNorm(transformer_input_dim),
52
+ nn.GELU(),
53
+ nn.Dropout(dropout * 0.5),
54
+ nn.Linear(transformer_input_dim, transformer_input_dim),
55
+ )
56
+ else:
57
+ self.input_proj = nn.Identity()
58
+
59
+ self.transformer = Encoder(
60
+ dim=transformer_input_dim,
61
+ depth=num_layers,
62
+ heads=nhead,
63
+ layer_dropout=dropout,
64
+ attn_dropout=dropout,
65
+ ff_dropout=dropout,
66
+ attn_flash=True,
67
+ rotary_pos_emb=True,
68
+ )
69
+
70
+ def forward(self, x, src_key_padding_mask=None):
71
+ """
72
+ The input src_key_padding_mask is a B x T boolean mask, where True indicates masked positions.
73
+ However, in x-transformers, False indicates masked positions.
74
+ Therefore, it needs to be converted so that False represents masked positions.
75
+ """
76
+ x = self.input_proj(x)
77
+ mask = (
78
+ ~torch.tensor(src_key_padding_mask, dtype=torch.bool, device=x.device)
79
+ if src_key_padding_mask is not None
80
+ else None
81
+ )
82
+ return self.transformer(x, mask=mask)
83
+
84
+
85
+ def prefix_dict(d, prefix: str):
86
+ if prefix:
87
+ return d
88
+ return {prefix + key: value for key, value in d.items()}
89
+
90
+
91
+ class TimeDownsample(nn.Module):
92
+ def __init__(
93
+ self, dim_in, dim_out=None, kernel_size=5, stride=5, padding=0, dropout=0.1
94
+ ):
95
+ super().__init__()
96
+ self.dim_out = dim_out or dim_in
97
+ assert self.dim_out % 2 == 0
98
+
99
+ self.depthwise_conv = nn.Conv1d(
100
+ in_channels=dim_in,
101
+ out_channels=dim_in,
102
+ kernel_size=kernel_size,
103
+ stride=stride,
104
+ padding=padding,
105
+ groups=dim_in,
106
+ bias=False,
107
+ )
108
+ self.pointwise_conv = nn.Conv1d(
109
+ in_channels=dim_in,
110
+ out_channels=self.dim_out,
111
+ kernel_size=1,
112
+ bias=False,
113
+ )
114
+ self.pool = nn.AvgPool1d(kernel_size, stride, padding=padding)
115
+ self.norm1 = nn.LayerNorm(self.dim_out)
116
+ self.act1 = nn.GELU()
117
+ self.dropout1 = nn.Dropout(dropout)
118
+
119
+ if dim_in != self.dim_out:
120
+ self.residual_conv = nn.Conv1d(
121
+ dim_in, self.dim_out, kernel_size=1, bias=False
122
+ )
123
+ else:
124
+ self.residual_conv = None
125
+
126
+ def forward(self, x):
127
+ residual = x # [B, T, D_in]
128
+ # Convolutional module
129
+ x_c = x.transpose(1, 2) # [B, D_in, T]
130
+ x_c = self.depthwise_conv(x_c) # [B, D_in, T_down]
131
+ x_c = self.pointwise_conv(x_c) # [B, D_out, T_down]
132
+
133
+ # Residual module
134
+ res = self.pool(residual.transpose(1, 2)) # [B, D_in, T]
135
+ if self.residual_conv:
136
+ res = self.residual_conv(res) # [B, D_out, T_down]
137
+ x_c = x_c + res # [B, D_out, T_down]
138
+ x_c = x_c.transpose(1, 2) # [B, T_down, D_out]
139
+ x_c = self.norm1(x_c)
140
+ x_c = self.act1(x_c)
141
+ x_c = self.dropout1(x_c)
142
+ return x_c
143
+
144
+
145
+ class AddFuse(nn.Module):
146
+ def __init__(self):
147
+ super(AddFuse, self).__init__()
148
+
149
+ def forward(self, x, cond):
150
+ return x + cond
151
+
152
+
153
+ class TVLoss1D(nn.Module):
154
+ def __init__(
155
+ self, beta=1.0, lambda_tv=0.4, boundary_threshold=0.01, reduction_weight=0.1
156
+ ):
157
+ """
158
+ Args:
159
+ beta: Exponential parameter for TV loss (recommended 0.5~1.0)
160
+ lambda_tv: Overall weight for TV loss
161
+ boundary_threshold: Label threshold to determine if a region is a "boundary area" (e.g., 0.01)
162
+ reduction_weight: Scaling factor for TV penalty within boundary regions (e.g., 0.1, meaning only 10% penalty)
163
+ """
164
+ super().__init__()
165
+ self.beta = beta
166
+ self.lambda_tv = lambda_tv
167
+ self.boundary_threshold = boundary_threshold
168
+ self.reduction_weight = reduction_weight
169
+
170
+ def forward(self, pred, target=None):
171
+ """
172
+ Args:
173
+ pred: (B, T) or (B, T, 1), float boundary scores output by the model
174
+ target: (B, T) or (B, T, 1), ground truth labels (optional, used for spatial weighting if provided)
175
+
176
+ Returns:
177
+ scalar: weighted TV loss
178
+ """
179
+ if pred.dim() == 3:
180
+ pred = pred.squeeze(-1)
181
+ if target is not None and target.dim() == 3:
182
+ target = target.squeeze(-1)
183
+
184
+ diff = pred[:, 1:] - pred[:, :-1]
185
+ tv_base = torch.pow(torch.abs(diff) + 1e-8, self.beta)
186
+
187
+ if target is None:
188
+ return self.lambda_tv * tv_base.mean()
189
+
190
+ left_in_boundary = target[:, :-1] > self.boundary_threshold
191
+ right_in_boundary = target[:, 1:] > self.boundary_threshold
192
+ near_boundary = left_in_boundary | right_in_boundary
193
+ weight_mask = torch.where(
194
+ near_boundary,
195
+ self.reduction_weight * torch.ones_like(tv_base),
196
+ torch.ones_like(tv_base),
197
+ )
198
+ tv_weighted = (tv_base * weight_mask).mean()
199
+ return self.lambda_tv * tv_weighted
200
+
201
+
202
+ class SoftmaxFocalLoss(nn.Module):
203
+ """
204
+ Softmax Focal Loss for single-label multi-class classification.
205
+ Suitable for mutually exclusive classes.
206
+ """
207
+
208
+ def __init__(self, alpha: float = 0.25, gamma: float = 2.0):
209
+ super().__init__()
210
+ self.alpha = alpha
211
+ self.gamma = gamma
212
+
213
+ def forward(self, pred: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
214
+ """
215
+ Args:
216
+ pred: [B, T, C], raw logits
217
+ targets: [B, T, C] (soft) or [B, T] (hard, dtype=long)
218
+ Returns:
219
+ loss: scalar or [B, T] depending on reduction
220
+ """
221
+ log_probs = F.log_softmax(pred, dim=-1)
222
+ probs = torch.exp(log_probs)
223
+
224
+ if targets.dtype == torch.long:
225
+ targets_onehot = F.one_hot(targets, num_classes=pred.size(-1)).float()
226
+ else:
227
+ targets_onehot = targets
228
+
229
+ p_t = (probs * targets_onehot).sum(dim=-1)
230
+ p_t = p_t.clamp(min=1e-8, max=1.0 - 1e-8)
231
+
232
+ if self.alpha > 0:
233
+ alpha_t = self.alpha * targets_onehot + (1 - self.alpha) * (
234
+ 1 - targets_onehot
235
+ )
236
+ alpha_weight = (alpha_t * targets_onehot).sum(dim=-1)
237
+ else:
238
+ alpha_weight = 1.0
239
+
240
+ focal_weight = (1 - p_t) ** self.gamma
241
+ ce_loss = -log_probs * targets_onehot
242
+ ce_loss = ce_loss.sum(dim=-1)
243
+
244
+ loss = alpha_weight * focal_weight * ce_loss
245
+ return loss
246
+
247
+
248
+ class Model(nn.Module):
249
+ def __init__(self, config):
250
+ super().__init__()
251
+ self.config = config
252
+
253
+ self.input_norm = nn.LayerNorm(config.input_dim)
254
+ self.mixed_win_downsample = nn.Linear(config.input_dim_raw, config.input_dim)
255
+ self.dataset_class_prefix = nn.Embedding(
256
+ num_embeddings=config.num_dataset_classes,
257
+ embedding_dim=config.transformer_encoder_input_dim,
258
+ )
259
+ self.down_sample_conv = TimeDownsample(
260
+ dim_in=config.input_dim,
261
+ dim_out=config.transformer_encoder_input_dim,
262
+ kernel_size=config.down_sample_conv_kernel_size,
263
+ stride=config.down_sample_conv_stride,
264
+ dropout=config.down_sample_conv_dropout,
265
+ padding=config.down_sample_conv_padding,
266
+ )
267
+ self.AddFuse = AddFuse()
268
+ self.transformer = WrapedTransformerEncoder(
269
+ input_dim=config.transformer_encoder_input_dim,
270
+ transformer_input_dim=config.transformer_input_dim,
271
+ num_layers=config.num_transformer_layers,
272
+ nhead=config.transformer_nhead,
273
+ dropout=config.transformer_dropout,
274
+ )
275
+ self.boundary_TVLoss1D = TVLoss1D(
276
+ beta=config.boundary_tv_loss_beta,
277
+ lambda_tv=config.boundary_tv_loss_lambda,
278
+ boundary_threshold=config.boundary_tv_loss_boundary_threshold,
279
+ reduction_weight=config.boundary_tv_loss_reduction_weight,
280
+ )
281
+ self.label_focal_loss = SoftmaxFocalLoss(
282
+ alpha=config.label_focal_loss_alpha, gamma=config.label_focal_loss_gamma
283
+ )
284
+ self.boundary_head = Head(config.transformer_input_dim, 1)
285
+ self.function_head = Head(config.transformer_input_dim, config.num_classes)
286
+
287
+ def cal_metrics(self, gt_info: MsaInfo, msa_info: MsaInfo):
288
+ assert gt_info[-1][1] == "end" and msa_info[-1][1] == "end", (
289
+ "gt_info and msa_info should end with 'end'"
290
+ )
291
+ gt_info_labels = [label for time_, label in gt_info][:-1]
292
+ gt_info_inters = [time_ for time_, label in gt_info]
293
+ gt_info_inters = np.column_stack(
294
+ [np.array(gt_info_inters[:-1]), np.array(gt_info_inters[1:])]
295
+ )
296
+
297
+ msa_info_labels = [label for time_, label in msa_info][:-1]
298
+ msa_info_inters = [time_ for time_, label in msa_info]
299
+ msa_info_inters = np.column_stack(
300
+ [np.array(msa_info_inters[:-1]), np.array(msa_info_inters[1:])]
301
+ )
302
+ result = compute_results(
303
+ ann_inter=gt_info_inters,
304
+ est_inter=msa_info_inters,
305
+ ann_labels=gt_info_labels,
306
+ est_labels=msa_info_labels,
307
+ bins=11,
308
+ est_file="test.txt",
309
+ weight=0.58,
310
+ )
311
+ return result
312
+
313
+ def cal_acc(
314
+ self, ann_info: MsaInfo | str, est_info: MsaInfo | str, post_digit: int = 3
315
+ ):
316
+ ann_info_time = [
317
+ int(round(time_, post_digit) * (10**post_digit))
318
+ for time_, label in ann_info
319
+ ]
320
+ est_info_time = [
321
+ int(round(time_, post_digit) * (10**post_digit))
322
+ for time_, label in est_info
323
+ ]
324
+
325
+ common_start_time = max(ann_info_time[0], est_info_time[0])
326
+ common_end_time = min(ann_info_time[-1], est_info_time[-1])
327
+
328
+ time_points = {common_start_time, common_end_time}
329
+ time_points.update(
330
+ {
331
+ time_
332
+ for time_ in ann_info_time
333
+ if common_start_time <= time_ <= common_end_time
334
+ }
335
+ )
336
+ time_points.update(
337
+ {
338
+ time_
339
+ for time_ in est_info_time
340
+ if common_start_time <= time_ <= common_end_time
341
+ }
342
+ )
343
+
344
+ time_points = sorted(time_points)
345
+ total_duration, total_score = 0, 0
346
+
347
+ for idx in range(len(time_points) - 1):
348
+ duration = time_points[idx + 1] - time_points[idx]
349
+ ann_label = ann_info[
350
+ bisect.bisect_right(ann_info_time, time_points[idx]) - 1
351
+ ][1]
352
+ est_label = est_info[
353
+ bisect.bisect_right(est_info_time, time_points[idx]) - 1
354
+ ][1]
355
+ total_duration += duration
356
+ if ann_label == est_label:
357
+ total_score += duration
358
+ return total_score / total_duration
359
+
360
+ def infer_with_metrics(self, batch, prefix: str = None):
361
+ with torch.no_grad():
362
+ logits = self.forward_func(batch)
363
+
364
+ losses = self.compute_losses(logits, batch, prefix=None)
365
+
366
+ expanded_mask = batch["label_id_masks"].expand(
367
+ -1, logits["function_logits"].size(1), -1
368
+ )
369
+ logits["function_logits"] = logits["function_logits"].masked_fill(
370
+ expanded_mask, -float("inf")
371
+ )
372
+
373
+ msa_info = postprocess_functional_structure(
374
+ logits=logits, config=self.config
375
+ )
376
+ gt_info = batch["msa_infos"][0]
377
+ results = self.cal_metrics(gt_info=gt_info, msa_info=msa_info)
378
+
379
+ ret_results = {
380
+ "loss": losses["loss"].item(),
381
+ "HitRate_3P": results["HitRate_3P"],
382
+ "HitRate_3R": results["HitRate_3R"],
383
+ "HitRate_3F": results["HitRate_3F"],
384
+ "HitRate_0.5P": results["HitRate_0.5P"],
385
+ "HitRate_0.5R": results["HitRate_0.5R"],
386
+ "HitRate_0.5F": results["HitRate_0.5F"],
387
+ "PWF": results["PWF"],
388
+ "PWP": results["PWP"],
389
+ "PWR": results["PWR"],
390
+ "Sf": results["Sf"],
391
+ "So": results["So"],
392
+ "Su": results["Su"],
393
+ "acc": self.cal_acc(ann_info=gt_info, est_info=msa_info),
394
+ }
395
+ if prefix:
396
+ ret_results = prefix_dict(ret_results, prefix)
397
+
398
+ return ret_results
399
+
400
+ def infer(
401
+ self,
402
+ input_embeddings,
403
+ dataset_ids,
404
+ label_id_masks,
405
+ prefix: str = None,
406
+ with_logits=False,
407
+ ):
408
+ with torch.no_grad():
409
+ input_embeddings = self.mixed_win_downsample(input_embeddings)
410
+ input_embeddings = self.input_norm(input_embeddings)
411
+ logits = self.down_sample_conv(input_embeddings)
412
+
413
+ dataset_prefix = self.dataset_class_prefix(dataset_ids)
414
+ dataset_prefix_expand = dataset_prefix.unsqueeze(1).expand(
415
+ logits.size(0), 1, -1
416
+ )
417
+ logits = self.AddFuse(x=logits, cond=dataset_prefix_expand)
418
+ logits = self.transformer(x=logits, src_key_padding_mask=None)
419
+
420
+ function_logits = self.function_head(logits)
421
+ boundary_logits = self.boundary_head(logits).squeeze(-1)
422
+
423
+ logits = {
424
+ "function_logits": function_logits,
425
+ "boundary_logits": boundary_logits,
426
+ }
427
+
428
+ expanded_mask = label_id_masks.expand(
429
+ -1, logits["function_logits"].size(1), -1
430
+ )
431
+ logits["function_logits"] = logits["function_logits"].masked_fill(
432
+ expanded_mask, -float("inf")
433
+ )
434
+
435
+ msa_info = postprocess_functional_structure(
436
+ logits=logits, config=self.config
437
+ )
438
+
439
+ return (msa_info, logits) if with_logits else msa_info
440
+
441
+ def compute_losses(self, outputs, batch, prefix: str = None):
442
+ loss = 0.0
443
+ losses = {}
444
+
445
+ loss_section = F.binary_cross_entropy_with_logits(
446
+ outputs["boundary_logits"],
447
+ batch["widen_true_boundaries"],
448
+ reduction="none",
449
+ )
450
+ loss_section += self.config.boundary_tvloss_weight * self.boundary_TVLoss1D(
451
+ pred=outputs["boundary_logits"],
452
+ target=batch["widen_true_boundaries"],
453
+ )
454
+ loss_function = F.cross_entropy(
455
+ outputs["function_logits"].transpose(1, 2),
456
+ batch["true_functions"].transpose(1, 2),
457
+ reduction="none",
458
+ )
459
+ # input is [B, T, C]
460
+ ttt = self.config.label_focal_loss_weight * self.label_focal_loss(
461
+ pred=outputs["function_logits"], targets=batch["true_functions"]
462
+ )
463
+ loss_function += ttt
464
+
465
+ float_masks = (~batch["masks"]).float()
466
+ boundary_mask = batch.get("boundary_mask", None)
467
+ function_mask = batch.get("function_mask", None)
468
+ if boundary_mask is not None:
469
+ boundary_mask = (~boundary_mask).float()
470
+ else:
471
+ boundary_mask = 1
472
+
473
+ if function_mask is not None:
474
+ function_mask = (~function_mask).float()
475
+ else:
476
+ function_mask = 1
477
+
478
+ loss_section = torch.mean(boundary_mask * float_masks * loss_section)
479
+ loss_function = torch.mean(function_mask * float_masks * loss_function)
480
+
481
+ loss_section *= self.config.loss_weight_section
482
+ loss_function *= self.config.loss_weight_function
483
+
484
+ if self.config.learn_label:
485
+ loss += loss_function
486
+ if self.config.learn_segment:
487
+ loss += loss_section
488
+
489
+ losses.update(
490
+ loss=loss,
491
+ loss_section=loss_section,
492
+ loss_function=loss_function,
493
+ )
494
+ if prefix:
495
+ losses = prefix_dict(losses, prefix)
496
+ return losses
497
+
498
+ def forward_func(self, batch):
499
+ input_embeddings = batch["input_embeddings"]
500
+ input_embeddings = self.mixed_win_downsample(input_embeddings)
501
+ input_embeddings = self.input_norm(input_embeddings)
502
+ logits = self.down_sample_conv(input_embeddings)
503
+
504
+ dataset_prefix = self.dataset_class_prefix(batch["dataset_ids"])
505
+ logits = self.AddFuse(x=logits, cond=dataset_prefix.unsqueeze(1))
506
+ src_key_padding_mask = batch["masks"]
507
+ logits = self.transformer(x=logits, src_key_padding_mask=src_key_padding_mask)
508
+
509
+ function_logits = self.function_head(logits)
510
+ boundary_logits = self.boundary_head(logits).squeeze(-1)
511
+
512
+ logits = {
513
+ "function_logits": function_logits,
514
+ "boundary_logits": boundary_logits,
515
+ }
516
+ return logits
517
+
518
+ def forward(self, batch):
519
+ logits = self.forward_func(batch)
520
+ losses = self.compute_losses(logits, batch, prefix=None)
521
+ return logits, losses["loss"], losses
src/SongFormer/postprocessing/calc_acc.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import bisect
3
+ from dataset.msa_info_utils import (
4
+ load_msa_info,
5
+ )
6
+ from dataset.custom_types import MsaInfo
7
+ import glob
8
+ import pdb
9
+ import pandas as pd
10
+
11
+
12
+ def cal_acc(ann_info: MsaInfo | str, est_info: MsaInfo | str, post_digit: int = 3):
13
+ if type(ann_info) is str:
14
+ assert os.path.exists(ann_info), f"{ann_info} not exists"
15
+ ann_info = load_msa_info(ann_info)
16
+
17
+ if type(ann_info) is str:
18
+ assert os.path.exists(est_info), f"{est_info} not exists"
19
+ est_info = load_msa_info(est_info)
20
+
21
+ ann_info_time = [
22
+ int(round(time_, post_digit) * (10**post_digit)) for time_, label in ann_info
23
+ ]
24
+ est_info_time = [
25
+ int(round(time_, post_digit) * (10**post_digit)) for time_, label in est_info
26
+ ]
27
+
28
+ common_start_time = max(ann_info_time[0], est_info_time[0])
29
+ common_end_time = min(ann_info_time[-1], est_info_time[-1])
30
+
31
+ time_points = set()
32
+ time_points.add(common_start_time)
33
+ time_points.add(common_end_time)
34
+
35
+ for time_ in ann_info_time:
36
+ if time_ >= common_start_time and time_ <= common_end_time:
37
+ time_points.add(time_)
38
+ for time_ in est_info_time:
39
+ if time_ >= common_start_time and time_ <= common_end_time:
40
+ time_points.add(time_)
41
+
42
+ time_points = sorted(list(time_points))
43
+ total_duration = 0
44
+ total_score = 0
45
+
46
+ for idx in range(len(time_points) - 1):
47
+ duration = time_points[idx + 1] - time_points[idx]
48
+ ann_label = ann_info[bisect.bisect_right(ann_info_time, time_points[idx]) - 1][
49
+ 1
50
+ ]
51
+ est_label = est_info[bisect.bisect_right(est_info_time, time_points[idx]) - 1][
52
+ 1
53
+ ]
54
+ total_duration += duration
55
+ if ann_label == est_label:
56
+ total_score += duration
57
+ return total_score / total_duration
58
+
59
+
60
+ if __name__ == "__main__":
61
+ ext_paths = glob.glob("")
62
+ results = []
63
+ for ext_path in ext_paths:
64
+ try:
65
+ ann_path = os.path.join(
66
+ "",
67
+ os.path.basename(ext_path).split(".")[0] + ".txt",
68
+ )
69
+ results.append(
70
+ {
71
+ "data_id": os.path.basename(ext_path).split(".")[0],
72
+ "acc": cal_acc(
73
+ ann_info=ann_path,
74
+ est_info=ext_path,
75
+ ),
76
+ }
77
+ )
78
+ except Exception as e:
79
+ print(e)
80
+ continue
81
+ df = pd.DataFrame(results)
82
+ print(df["acc"].mean())
src/SongFormer/postprocessing/calc_iou.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataset.custom_types import MsaInfo
3
+ from dataset.label2id import LABEL_TO_ID
4
+ from pprint import pprint
5
+
6
+
7
+ def load_msa_info(msa_info_path):
8
+ msa_info: MsaInfo = []
9
+ with open(msa_info_path) as f:
10
+ for line in f:
11
+ line = line.strip()
12
+ if not line:
13
+ continue
14
+ time_, label = line.split()
15
+ time_ = float(time_)
16
+ label = str(label)
17
+ assert label in LABEL_TO_ID or label == "end", f"{label} not in LABEL_TO_ID"
18
+ msa_info.append((time_, label))
19
+ assert msa_info[-1][1] == "end", f"last {msa_info[-1][1]} != end"
20
+ return msa_info
21
+
22
+
23
+ def msa_info_to_segments(msa_info):
24
+ # skip the last "end"
25
+ segments = []
26
+ for i in range(len(msa_info) - 1):
27
+ start = msa_info[i][0]
28
+ end = msa_info[i + 1][0]
29
+ label = msa_info[i][1]
30
+ segments.append((start, end, label))
31
+ return segments
32
+
33
+
34
+ def compute_iou_for_label(segments_a, segments_b, label):
35
+ # segments_a, segments_b: [(start, end, label)]
36
+ # only process the current label
37
+ intervals_a = [(s, e) for s, e, l in segments_a if l == label]
38
+ intervals_b = [(s, e) for s, e, l in segments_b if l == label]
39
+ # sum up all intersections between a and b
40
+ intersection = 0.0
41
+ for sa, ea in intervals_a:
42
+ for sb, eb in intervals_b:
43
+ left = max(sa, sb)
44
+ right = min(ea, eb)
45
+ if left < right:
46
+ intersection += right - left
47
+ # union = total length of both sets - overlapping intersection
48
+ length_a = sum([e - s for s, e in intervals_a])
49
+ length_b = sum([e - s for s, e in intervals_b])
50
+ union = length_a + length_b - intersection
51
+ if union == 0:
52
+ return 0.0
53
+ return intersection / union, intersection, union
54
+
55
+
56
+ def compute_mean_iou(segments_a, segments_b, labels):
57
+ ious = []
58
+ for label in labels:
59
+ iou, intsec_dur, uni_dur = compute_iou_for_label(segments_a, segments_b, label)
60
+ ious.append(
61
+ {"label": label, "iou": iou, "intsec_dur": intsec_dur, "uni_dur": uni_dur}
62
+ )
63
+ return ious
64
+
65
+
66
+ def cal_iou(ann_info, est_info):
67
+ if type(ann_info) is str:
68
+ assert os.path.exists(ann_info), f"{ann_info} not exists"
69
+ ann_info = load_msa_info(ann_info)
70
+
71
+ if type(est_info) is str:
72
+ assert os.path.exists(est_info), f"{est_info} not exists"
73
+ est_info = load_msa_info(est_info)
74
+
75
+ segments_ann = msa_info_to_segments(ann_info)
76
+ segments_est = msa_info_to_segments(est_info)
77
+
78
+ occurred_labels = list(
79
+ set([l for s, e, l in segments_ann]) | set(l for s, e, l in segments_est)
80
+ )
81
+
82
+ mean_iou = compute_mean_iou(segments_ann, segments_est, occurred_labels)
83
+ return mean_iou
84
+
85
+
86
+ if __name__ == "__main__":
87
+ ann_info = ""
88
+ est_info = ""
89
+ pprint(cal_iou(ann_info, est_info))
src/SongFormer/postprocessing/functional.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains code adapted from the following sources:
2
+ # [MIT license] https://github.com/mir-aidj/all-in-one/blob/main/src/allin1/postprocessing/functional.py
3
+
4
+ import numpy as np
5
+ import torch
6
+ from .helpers import (
7
+ local_maxima,
8
+ peak_picking,
9
+ # event_frames_to_time,
10
+ )
11
+ from dataset.label2id import LABEL_TO_ID, ID_TO_LABEL
12
+ from dataset.custom_types import MsaInfo
13
+
14
+
15
+ def event_frames_to_time(frame_rates, boundary: np.array):
16
+ boundary = np.array(boundary)
17
+ boundary_times = boundary / frame_rates
18
+ return boundary_times
19
+
20
+
21
+ def postprocess_functional_structure(
22
+ logits,
23
+ config,
24
+ ):
25
+ # pdb.set_trace()
26
+ boundary_logits = logits["boundary_logits"]
27
+ function_logits = logits["function_logits"]
28
+
29
+ assert boundary_logits.shape[0] == 1 and function_logits.shape[0] == 1, (
30
+ "Only batch size 1 is supported"
31
+ )
32
+ raw_prob_sections = torch.sigmoid(boundary_logits[0])
33
+ raw_prob_functions = torch.softmax(function_logits[0].transpose(0, 1), dim=0)
34
+
35
+ # filter_size=4 * cfg.min_hops_per_beat + 1
36
+ prob_sections, _ = local_maxima(
37
+ raw_prob_sections, filter_size=config.local_maxima_filter_size
38
+ )
39
+ prob_sections = prob_sections.cpu().numpy()
40
+
41
+ prob_functions = raw_prob_functions.cpu().numpy()
42
+
43
+ boundary_candidates = peak_picking(
44
+ boundary_activation=prob_sections,
45
+ window_past=int(12 * config.frame_rates), # 原来是fps
46
+ window_future=int(12 * config.frame_rates),
47
+ )
48
+ boundary = boundary_candidates > 0.0
49
+
50
+ duration = len(prob_sections) / config.frame_rates
51
+ pred_boundary_times = event_frames_to_time(
52
+ frame_rates=config.frame_rates, boundary=np.flatnonzero(boundary)
53
+ )
54
+ if pred_boundary_times[0] != 0:
55
+ pred_boundary_times = np.insert(pred_boundary_times, 0, 0)
56
+ if pred_boundary_times[-1] != duration:
57
+ pred_boundary_times = np.append(pred_boundary_times, duration)
58
+ pred_boundaries = np.stack([pred_boundary_times[:-1], pred_boundary_times[1:]]).T
59
+
60
+ pred_boundary_indices = np.flatnonzero(boundary)
61
+ pred_boundary_indices = pred_boundary_indices[pred_boundary_indices > 0]
62
+ prob_segment_function = np.split(prob_functions, pred_boundary_indices, axis=1)
63
+ pred_labels = [p.mean(axis=1).argmax().item() for p in prob_segment_function]
64
+
65
+ segments: MsaInfo = []
66
+ for (start, end), label in zip(pred_boundaries, pred_labels):
67
+ segment = (float(start), str(ID_TO_LABEL[label]))
68
+ segments.append(segment)
69
+
70
+ segments.append((float(pred_boundary_times[-1]), "end"))
71
+ return segments
src/SongFormer/postprocessing/helpers.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains code adapted from the following sources:
2
+ # [MIT license] https://github.com/mir-aidj/all-in-one/blob/main/src/allin1/postprocessing/helpers.py
3
+
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ import torch
7
+ import librosa
8
+ from typing import Union
9
+ from scipy.signal import argrelextrema
10
+ from scipy.interpolate import interp1d
11
+ from numpy.lib.stride_tricks import sliding_window_view
12
+ from numpy.typing import NDArray
13
+
14
+
15
+ def local_maxima(tensor, filter_size=41):
16
+ assert len(tensor.shape) in (1, 2), "Input tensor should have 1 or 2 dimensions"
17
+ assert filter_size % 2 == 1, "Filter size should be an odd number"
18
+
19
+ original_shape = tensor.shape
20
+ if len(original_shape) == 1:
21
+ tensor = tensor.unsqueeze(0)
22
+
23
+ # Pad the input array with the minimum value
24
+ padding = filter_size // 2
25
+ padded_arr = F.pad(tensor, (padding, padding), mode="constant", value=-torch.inf)
26
+
27
+ # Create a rolling window view of the padded array
28
+ rolling_view = padded_arr.unfold(1, filter_size, 1)
29
+
30
+ # Find the indices of the local maxima
31
+ center = filter_size // 2
32
+ local_maxima_mask = torch.eq(
33
+ rolling_view[:, :, center], torch.max(rolling_view, dim=-1).values
34
+ )
35
+ local_maxima_indices = local_maxima_mask.nonzero()
36
+
37
+ # Initialize a new PyTorch tensor with zeros and the same shape as the input tensor
38
+ output_arr = torch.zeros_like(tensor)
39
+
40
+ # Set the local maxima values in the output tensor
41
+ output_arr[local_maxima_mask] = tensor[local_maxima_mask]
42
+
43
+ output_arr = output_arr.reshape(original_shape)
44
+
45
+ return output_arr, local_maxima_indices
46
+
47
+
48
+ def local_maxima_numpy(arr, order=20):
49
+ is_batch = len(arr.shape) == 2
50
+ if is_batch:
51
+ return np.stack([local_maxima_numpy(x, order) for x in arr])
52
+
53
+ # Define a comparison function for argrelextrema to find local maxima
54
+ compare_func = np.greater
55
+
56
+ # Find the indices of the local maxima
57
+ local_maxima_indices = argrelextrema(arr, compare_func, order=order)
58
+
59
+ # Initialize a new numpy array with zeros and the same shape as the input array
60
+ output_arr = np.zeros_like(arr)
61
+
62
+ # Set the local maxima values in the output array
63
+ output_arr[local_maxima_indices] = arr[local_maxima_indices]
64
+
65
+ return output_arr
66
+
67
+
68
+ def peak_picking(boundary_activation, window_past=12, window_future=6):
69
+ # Find local maxima using a sliding window
70
+ window_size = window_past + window_future
71
+ assert window_size % 2 == 0, "window_past + window_future must be even"
72
+ window_size += 1
73
+
74
+ # Pad boundary_activation
75
+ boundary_activation_padded = np.pad(
76
+ boundary_activation, (window_past, window_future), mode="constant"
77
+ )
78
+ max_filter = sliding_window_view(boundary_activation_padded, window_size)
79
+ local_maxima = (boundary_activation == np.max(max_filter, axis=-1)) & (
80
+ boundary_activation > 0
81
+ )
82
+
83
+ # Compute strength values by subtracting the mean of the past and future windows
84
+ past_window_filter = sliding_window_view(
85
+ boundary_activation_padded[: -(window_future + 1)], window_past
86
+ )
87
+ future_window_filter = sliding_window_view(
88
+ boundary_activation_padded[window_past + 1 :], window_future
89
+ )
90
+ past_mean = np.mean(past_window_filter, axis=-1)
91
+ future_mean = np.mean(future_window_filter, axis=-1)
92
+ strength_values = boundary_activation - ((past_mean + future_mean) / 2)
93
+
94
+ # Get boundary candidates and their corresponding strength values
95
+ boundary_candidates = np.flatnonzero(local_maxima)
96
+ strength_values = strength_values[boundary_candidates]
97
+
98
+ strength_activations = np.zeros_like(boundary_activation)
99
+ strength_activations[boundary_candidates] = strength_values
100
+
101
+ return strength_activations
src/SongFormer/train/accelerate_config/single_gpu.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: 'NO'
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: 'no'
10
+ num_machines: 1
11
+ num_processes: 1
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
src/SongFormer/utils/average_checkpoints.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ from typing import List, Dict, Any
4
+
5
+
6
+ def average_checkpoints(checkpoint_paths: List[str], output_path: str = None):
7
+ """
8
+ Average the model and model_ema weights from multiple checkpoints
9
+
10
+ Parameters:
11
+ checkpoint_paths: List of checkpoint file paths
12
+ output_path: Output path; if None, return the averaged checkpoint dictionary
13
+
14
+ Returns:
15
+ Averaged checkpoint dictionary
16
+ """
17
+ if not checkpoint_paths:
18
+ raise ValueError("At least one checkpoint path is required")
19
+
20
+ # Load the first checkpoint as the base
21
+ print(f"Loading base checkpoint: {checkpoint_paths[0]}")
22
+ avg_checkpoint = torch.load(checkpoint_paths[0], map_location="cpu")
23
+
24
+ if len(checkpoint_paths) == 1:
25
+ if output_path:
26
+ torch.save(avg_checkpoint, output_path)
27
+ return avg_checkpoint
28
+
29
+ # Initialize accumulators
30
+ avg_model_state = copy.deepcopy(avg_checkpoint["model"])
31
+ avg_model_ema_state = None
32
+
33
+ if "model_ema" in avg_checkpoint:
34
+ avg_model_ema_state = copy.deepcopy(avg_checkpoint["model_ema"])
35
+
36
+ # Accumulate the weights from the other checkpoints
37
+ for i, ckpt_path in enumerate(checkpoint_paths[1:], 1):
38
+ print(f"Processing checkpoint {i + 1}/{len(checkpoint_paths)}: {ckpt_path}")
39
+ ckpt = torch.load(ckpt_path, map_location="cpu")
40
+
41
+ # Accumulate model weights
42
+ for key in avg_model_state.keys():
43
+ if key in ckpt["model"]:
44
+ avg_model_state[key] += ckpt["model"][key]
45
+
46
+ # Accumulate model_ema weights (if available)
47
+ if avg_model_ema_state is not None and "model_ema" in ckpt:
48
+ for key in avg_model_ema_state.keys():
49
+ if key in ckpt["model_ema"]:
50
+ avg_model_ema_state[key] += ckpt["model_ema"][key]
51
+
52
+ # Compute the average
53
+ num_checkpoints = len(checkpoint_paths)
54
+ print(f"Averaging over {num_checkpoints} checkpoints...")
55
+
56
+ for key in avg_model_state.keys():
57
+ avg_model_state[key] = avg_model_state[key] / num_checkpoints
58
+
59
+ if avg_model_ema_state is not None:
60
+ for key in avg_model_ema_state.keys():
61
+ avg_model_ema_state[key] = avg_model_ema_state[key] / num_checkpoints
62
+
63
+ # Update the checkpoint dictionary
64
+ avg_checkpoint["model"] = avg_model_state
65
+ if avg_model_ema_state is not None:
66
+ avg_checkpoint["model_ema"] = avg_model_ema_state
67
+
68
+ # Save (if an output path is specified)
69
+ if output_path:
70
+ print(f"Saving averaged checkpoint to: {output_path}")
71
+ torch.save(avg_checkpoint, output_path)
72
+
73
+ return avg_checkpoint
74
+
75
+
76
+ def average_checkpoints_memory_efficient(
77
+ checkpoint_paths: List[str], output_path: str = None
78
+ ):
79
+ """
80
+ Memory efficient version: Load and process checkpoints one by one, suitable for large models
81
+ """
82
+ if not checkpoint_paths:
83
+ raise ValueError("At least one checkpoint path is required")
84
+
85
+ print(f"Loading base checkpoint: {checkpoint_paths[0]}")
86
+ avg_checkpoint = torch.load(checkpoint_paths[0], map_location="cpu")
87
+
88
+ if len(checkpoint_paths) == 1:
89
+ if output_path:
90
+ torch.save(avg_checkpoint, output_path)
91
+ return avg_checkpoint
92
+
93
+ # Convert to float32 for better precision
94
+ for key in avg_checkpoint["model"].keys():
95
+ avg_checkpoint["model"][key] = avg_checkpoint["model"][key].float()
96
+
97
+ if "model_ema" in avg_checkpoint:
98
+ for key in avg_checkpoint["model_ema"].keys():
99
+ avg_checkpoint["model_ema"][key] = avg_checkpoint["model_ema"][key].float()
100
+
101
+ # Load and accumulate checkpoints one by one
102
+ for i, ckpt_path in enumerate(checkpoint_paths[1:], 1):
103
+ print(f"Processing checkpoint {i + 1}/{len(checkpoint_paths)}: {ckpt_path}")
104
+ ckpt = torch.load(ckpt_path, map_location="cpu")
105
+
106
+ # Accumulate model weights
107
+ for key in avg_checkpoint["model"].keys():
108
+ if key in ckpt["model"]:
109
+ avg_checkpoint["model"][key] += ckpt["model"][key].float()
110
+
111
+ # Accumulate model_ema weights
112
+ if "model_ema" in avg_checkpoint and "model_ema" in ckpt:
113
+ for key in avg_checkpoint["model_ema"].keys():
114
+ if key in ckpt["model_ema"]:
115
+ avg_checkpoint["model_ema"][key] += ckpt["model_ema"][key].float()
116
+
117
+ # Free memory
118
+ del ckpt
119
+ torch.cuda.empty_cache()
120
+
121
+ # Compute the average
122
+ num_checkpoints = len(checkpoint_paths)
123
+ print(f"Averaging over {num_checkpoints} checkpoints...")
124
+
125
+ for key in avg_checkpoint["model"].keys():
126
+ avg_checkpoint["model"][key] /= num_checkpoints
127
+
128
+ if "model_ema" in avg_checkpoint:
129
+ for key in avg_checkpoint["model_ema"].keys():
130
+ avg_checkpoint["model_ema"][key] /= num_checkpoints
131
+
132
+ if output_path:
133
+ print(f"Saving averaged checkpoint to: {output_path}")
134
+ torch.save(avg_checkpoint, output_path)
135
+
136
+ return avg_checkpoint
137
+
138
+
139
+ # Example usage
140
+ if __name__ == "__main__":
141
+ # Method 1: Simple usage
142
+ checkpoint_paths = []
143
+
144
+ # Average and save
145
+ average_checkpoints(checkpoint_paths, "")
146
+
147
+ # Method 2: Get the averaged checkpoint and further process it
148
+ # avg_ckpt = average_checkpoints(checkpoint_paths)
149
+ # print("Averaged checkpoint keys:", avg_ckpt.keys())
150
+
151
+ # Method 3: Use memory-efficient version (suitable for large models)
152
+ # average_checkpoints_memory_efficient(checkpoint_paths, 'averaged_checkpoint_efficient.pt')
src/SongFormer/utils/convert_res2msa_txt.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ import fire
5
+
6
+
7
+ def convert_json_to_format(json_data):
8
+ """Convert JSON data to the specified format"""
9
+ result = []
10
+
11
+ # Process the start time and label for each segment
12
+ for segment in json_data:
13
+ start_time = segment["start"]
14
+ label = segment["label"]
15
+ result.append(f"{start_time:.6f} {label}")
16
+
17
+ # Add the last end time
18
+ if json_data:
19
+ last_end_time = json_data[-1]["end"]
20
+ result.append(f"{last_end_time:.6f} end")
21
+
22
+ return "\n".join(result)
23
+
24
+
25
+ def process_json_files(input_folder, output_folder):
26
+ """Process all JSON files in the input folder"""
27
+
28
+ # Create the output folder if it doesn't exist
29
+ Path(output_folder).mkdir(parents=True, exist_ok=True)
30
+
31
+ # Get all JSON files
32
+ json_files = [f for f in os.listdir(input_folder) if f.endswith(".json")]
33
+
34
+ if not json_files:
35
+ print(f"No JSON files found in {input_folder}")
36
+ return
37
+
38
+ print(f"Found {len(json_files)} JSON files")
39
+
40
+ # Process each JSON file
41
+ for json_file in json_files:
42
+ input_path = os.path.join(input_folder, json_file)
43
+
44
+ try:
45
+ # Read the JSON file
46
+ with open(input_path, "r", encoding="utf-8") as f:
47
+ data = json.load(f)
48
+
49
+ # Convert the format
50
+ converted_data = convert_json_to_format(data)
51
+
52
+ # Generate the output filename (replace .json with .txt)
53
+ output_filename = json_file.replace(".json", ".txt")
54
+ output_path = os.path.join(output_folder, output_filename)
55
+
56
+ # Write to the output file
57
+ with open(output_path, "w", encoding="utf-8") as f:
58
+ f.write(converted_data)
59
+
60
+ print(f"✓ Processed: {json_file} -> {output_filename}")
61
+
62
+ except Exception as e:
63
+ print(f"✗ Error processing {json_file}: {str(e)}")
64
+
65
+
66
+ def main(input_folder: str, output_folder: str):
67
+ print(f"Input folder: {input_folder}")
68
+ print(f"Output folder: {output_folder}")
69
+ print("-" * 50)
70
+
71
+ # Process the files
72
+ process_json_files(input_folder, output_folder)
73
+
74
+ print("-" * 50)
75
+ print("Processing complete!")
76
+
77
+
78
+ if __name__ == "__main__":
79
+ fire.Fire(main)
src/SongFormer/utils/fetch_pretrained.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+
6
+ def download(url, path):
7
+ if os.path.exists(path):
8
+ print(f"File already exists, skipping download: {path}")
9
+ return
10
+ os.makedirs(os.path.dirname(path), exist_ok=True)
11
+ response = requests.get(url, stream=True)
12
+ total_size = int(response.headers.get("content-length", 0))
13
+ with (
14
+ open(path, "wb") as f,
15
+ tqdm(
16
+ desc=path,
17
+ total=total_size,
18
+ unit="iB",
19
+ unit_scale=True,
20
+ unit_divisor=1024,
21
+ ) as bar,
22
+ ):
23
+ for data in response.iter_content(chunk_size=1024):
24
+ size = f.write(data)
25
+ bar.update(size)
26
+
27
+
28
+ # 根据 https://github.com/minzwon/musicfm 下载预训练模型
29
+ download(
30
+ "https://huggingface.co/minzwon/MusicFM/resolve/main/msd_stats.json",
31
+ os.path.join("ckpts", "MusicFM", "msd_stats.json"),
32
+ )
33
+ download(
34
+ "https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_msd.pt",
35
+ os.path.join("ckpts", "MusicFM", "pretrained_msd.pt"),
36
+ )
37
+
38
+ # for Mainland China
39
+ # download('https://hf-mirror.com/minzwon/MusicFM/resolve/main/msd_stats.json', os.path.join("ckpts", "MusicFM", "msd_stats.json"))
40
+ # download('https://hf-mirror.com/minzwon/MusicFM/resolve/main/pretrained_msd.pt', os.path.join("ckpts", "MusicFM", "pretrained_msd.pt"))
src/third_party/MuQ/.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
src/third_party/MuQ/.gitignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.egg*/
6
+ *pyc
7
+
8
+ # Distribution / packaging
9
+ .Python
10
+ env/
11
+ build/
12
+ dist/
13
+ *.log
14
+
15
+ # pyenv
16
+ .python-version
17
+
18
+ # dotenv
19
+ .env
20
+
21
+ # virtualenv
22
+ .venv/
23
+ venv/
24
+ ENV/
25
+
26
+ # VSCode settings
27
+ .vscode
28
+
29
+ # IDEA files
30
+ .idea
31
+
32
+ # OSX dir files
33
+ .DS_Store
34
+
35
+ # Sublime Text settings
36
+ *.sublime-workspace
37
+ *.sublime-project
38
+
39
+ # custom
40
+ open/
41
+ src/recipes/pretrain/dataset/music4all/*.json
42
+ src/recipes/contrastive_learning/datasets/mtg-jamendo/*.json
43
+ runs/
44
+ output/
45
+ logs
46
+ outputs/
src/third_party/MuQ/.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "src/recipes/pretrain/fairseq"]
2
+ path = src/recipes/pretrain/fairseq
3
+ url = https://github.com/facebookresearch/fairseq
src/third_party/MuQ/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Tencent.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
src/third_party/MuQ/LICENSE_weights ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+ Section 1 -- Definitions.
71
+
72
+ a. Adapted Material means material subject to Copyright and Similar
73
+ Rights that is derived from or based upon the Licensed Material
74
+ and in which the Licensed Material is translated, altered,
75
+ arranged, transformed, or otherwise modified in a manner requiring
76
+ permission under the Copyright and Similar Rights held by the
77
+ Licensor. For purposes of this Public License, where the Licensed
78
+ Material is a musical work, performance, or sound recording,
79
+ Adapted Material is always produced where the Licensed Material is
80
+ synched in timed relation with a moving image.
81
+
82
+ b. Adapter's License means the license You apply to Your Copyright
83
+ and Similar Rights in Your contributions to Adapted Material in
84
+ accordance with the terms and conditions of this Public License.
85
+
86
+ c. Copyright and Similar Rights means copyright and/or similar rights
87
+ closely related to copyright including, without limitation,
88
+ performance, broadcast, sound recording, and Sui Generis Database
89
+ Rights, without regard to how the rights are labeled or
90
+ categorized. For purposes of this Public License, the rights
91
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
92
+ Rights.
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. NonCommercial means not primarily intended for or directed towards
116
+ commercial advantage or monetary compensation. For purposes of
117
+ this Public License, the exchange of the Licensed Material for
118
+ other material subject to Copyright and Similar Rights by digital
119
+ file-sharing or similar means is NonCommercial provided there is
120
+ no payment of monetary compensation in connection with the
121
+ exchange.
122
+
123
+ j. Share means to provide material to the public by any means or
124
+ process that requires permission under the Licensed Rights, such
125
+ as reproduction, public display, public performance, distribution,
126
+ dissemination, communication, or importation, and to make material
127
+ available to the public including in ways that members of the
128
+ public may access the material from a place and at a time
129
+ individually chosen by them.
130
+
131
+ k. Sui Generis Database Rights means rights other than copyright
132
+ resulting from Directive 96/9/EC of the European Parliament and of
133
+ the Council of 11 March 1996 on the legal protection of databases,
134
+ as amended and/or succeeded, as well as other essentially
135
+ equivalent rights anywhere in the world.
136
+
137
+ l. You means the individual or entity exercising the Licensed Rights
138
+ under this Public License. Your has a corresponding meaning.
139
+
140
+ Section 2 -- Scope.
141
+
142
+ a. License grant.
143
+
144
+ 1. Subject to the terms and conditions of this Public License,
145
+ the Licensor hereby grants You a worldwide, royalty-free,
146
+ non-sublicensable, non-exclusive, irrevocable license to
147
+ exercise the Licensed Rights in the Licensed Material to:
148
+
149
+ a. reproduce and Share the Licensed Material, in whole or
150
+ in part, for NonCommercial purposes only; and
151
+
152
+ b. produce, reproduce, and Share Adapted Material for
153
+ NonCommercial purposes only.
154
+
155
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
156
+ Exceptions and Limitations apply to Your use, this Public
157
+ License does not apply, and You do not need to comply with
158
+ its terms and conditions.
159
+
160
+ 3. Term. The term of this Public License is specified in Section
161
+ 6(a).
162
+
163
+ 4. Media and formats; technical modifications allowed. The
164
+ Licensor authorizes You to exercise the Licensed Rights in
165
+ all media and formats whether now known or hereafter created,
166
+ and to make technical modifications necessary to do so. The
167
+ Licensor waives and/or agrees not to assert any right or
168
+ authority to forbid You from making technical modifications
169
+ necessary to exercise the Licensed Rights, including
170
+ technical modifications necessary to circumvent Effective
171
+ Technological Measures. For purposes of this Public License,
172
+ simply making modifications authorized by this Section 2(a)
173
+ (4) never produces Adapted Material.
174
+
175
+ 5. Downstream recipients.
176
+
177
+ a. Offer from the Licensor -- Licensed Material. Every
178
+ recipient of the Licensed Material automatically
179
+ receives an offer from the Licensor to exercise the
180
+ Licensed Rights under the terms and conditions of this
181
+ Public License.
182
+
183
+ b. No downstream restrictions. You may not offer or impose
184
+ any additional or different terms or conditions on, or
185
+ apply any Effective Technological Measures to, the
186
+ Licensed Material if doing so restricts exercise of the
187
+ Licensed Rights by any recipient of the Licensed
188
+ Material.
189
+
190
+ 6. No endorsement. Nothing in this Public License constitutes or
191
+ may be construed as permission to assert or imply that You
192
+ are, or that Your use of the Licensed Material is, connected
193
+ with, or sponsored, endorsed, or granted official status by,
194
+ the Licensor or others designated to receive attribution as
195
+ provided in Section 3(a)(1)(A)(i).
196
+
197
+ b. Other rights.
198
+
199
+ 1. Moral rights, such as the right of integrity, are not
200
+ licensed under this Public License, nor are publicity,
201
+ privacy, and/or other similar personality rights; however, to
202
+ the extent possible, the Licensor waives and/or agrees not to
203
+ assert any such rights held by the Licensor to the limited
204
+ extent necessary to allow You to exercise the Licensed
205
+ Rights, but not otherwise.
206
+
207
+ 2. Patent and trademark rights are not licensed under this
208
+ Public License.
209
+
210
+ 3. To the extent possible, the Licensor waives any right to
211
+ collect royalties from You for the exercise of the Licensed
212
+ Rights, whether directly or through a collecting society
213
+ under any voluntary or waivable statutory or compulsory
214
+ licensing scheme. In all other cases the Licensor expressly
215
+ reserves any right to collect such royalties, including when
216
+ the Licensed Material is used other than for NonCommercial
217
+ purposes.
218
+
219
+ Section 3 -- License Conditions.
220
+
221
+ Your exercise of the Licensed Rights is expressly made subject to the
222
+ following conditions.
223
+
224
+ a. Attribution.
225
+
226
+ 1. If You Share the Licensed Material (including in modified
227
+ form), You must:
228
+
229
+ a. retain the following if it is supplied by the Licensor
230
+ with the Licensed Material:
231
+
232
+ i. identification of the creator(s) of the Licensed
233
+ Material and any others designated to receive
234
+ attribution, in any reasonable manner requested by
235
+ the Licensor (including by pseudonym if
236
+ designated);
237
+
238
+ ii. a copyright notice;
239
+
240
+ iii. a notice that refers to this Public License;
241
+
242
+ iv. a notice that refers to the disclaimer of
243
+ warranties;
244
+
245
+ v. a URI or hyperlink to the Licensed Material to the
246
+ extent reasonably practicable;
247
+
248
+ b. indicate if You modified the Licensed Material and
249
+ retain an indication of any previous modifications; and
250
+
251
+ c. indicate the Licensed Material is licensed under this
252
+ Public License, and include the text of, or the URI or
253
+ hyperlink to, this Public License.
254
+
255
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
256
+ reasonable manner based on the medium, means, and context in
257
+ which You Share the Licensed Material. For example, it may be
258
+ reasonable to satisfy the conditions by providing a URI or
259
+ hyperlink to a resource that includes the required
260
+ information.
261
+
262
+ 3. If requested by the Licensor, You must remove any of the
263
+ information required by Section 3(a)(1)(A) to the extent
264
+ reasonably practicable.
265
+
266
+ 4. If You Share Adapted Material You produce, the Adapter's
267
+ License You apply must not prevent recipients of the Adapted
268
+ Material from complying with this Public License.
269
+
270
+ Section 4 -- Sui Generis Database Rights.
271
+
272
+ Where the Licensed Rights include Sui Generis Database Rights that
273
+ apply to Your use of the Licensed Material:
274
+
275
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276
+ to extract, reuse, reproduce, and Share all or a substantial
277
+ portion of the contents of the database for NonCommercial purposes
278
+ only;
279
+
280
+ b. if You include all or a substantial portion of the database
281
+ contents in a database in which You have Sui Generis Database
282
+ Rights, then the database in which You have Sui Generis Database
283
+ Rights (but not its individual contents) is Adapted Material; and
284
+
285
+ c. You must comply with the conditions in Section 3(a) if You Share
286
+ all or a substantial portion of the contents of the database.
287
+
288
+ For the avoidance of doubt, this Section 4 supplements and does not
289
+ replace Your obligations under this Public License where the Licensed
290
+ Rights include other Copyright and Similar Rights.
291
+
292
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
+
294
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
+
305
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
+
315
+ c. The disclaimer of warranties and limitation of liability provided
316
+ above shall be interpreted in a manner that, to the extent
317
+ possible, most closely approximates an absolute disclaimer and
318
+ waiver of all liability.
319
+
320
+ Section 6 -- Term and Termination.
321
+
322
+ a. This Public License applies for the term of the Copyright and
323
+ Similar Rights licensed here. However, if You fail to comply with
324
+ this Public License, then Your rights under this Public License
325
+ terminate automatically.
326
+
327
+ b. Where Your right to use the Licensed Material has terminated under
328
+ Section 6(a), it reinstates:
329
+
330
+ 1. automatically as of the date the violation is cured, provided
331
+ it is cured within 30 days of Your discovery of the
332
+ violation; or
333
+
334
+ 2. upon express reinstatement by the Licensor.
335
+
336
+ For the avoidance of doubt, this Section 6(b) does not affect any
337
+ right the Licensor may have to seek remedies for Your violations
338
+ of this Public License.
339
+
340
+ c. For the avoidance of doubt, the Licensor may also offer the
341
+ Licensed Material under separate terms or conditions or stop
342
+ distributing the Licensed Material at any time; however, doing so
343
+ will not terminate this Public License.
344
+
345
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346
+ License.
347
+
348
+ Section 7 -- Other Terms and Conditions.
349
+
350
+ a. The Licensor shall not be bound by any additional or different
351
+ terms or conditions communicated by You unless expressly agreed.
352
+
353
+ b. Any arrangements, understandings, or agreements regarding the
354
+ Licensed Material not stated herein are separate from and
355
+ independent of the terms and conditions of this Public License.
356
+
357
+ Section 8 -- Interpretation.
358
+
359
+ a. For the avoidance of doubt, this Public License does not, and
360
+ shall not be interpreted to, reduce, limit, restrict, or impose
361
+ conditions on any use of the Licensed Material that could lawfully
362
+ be made without permission under this Public License.
363
+
364
+ b. To the extent possible, if any provision of this Public License is
365
+ deemed unenforceable, it shall be automatically reformed to the
366
+ minimum extent necessary to make it enforceable. If the provision
367
+ cannot be reformed, it shall be severed from this Public License
368
+ without affecting the enforceability of the remaining terms and
369
+ conditions.
370
+
371
+ c. No term or condition of this Public License will be waived and no
372
+ failure to comply consented to unless expressly agreed to by the
373
+ Licensor.
374
+
375
+ d. Nothing in this Public License constitutes or may be interpreted
376
+ as a limitation upon, or waiver of, any privileges and immunities
377
+ that apply to the Licensor or You, including from the legal
378
+ processes of any jurisdiction or authority.
379
+
380
+ =======================================================================
381
+
382
+ Creative Commons is not a party to its public
383
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
384
+ its public licenses to material it publishes and in those instances
385
+ will be considered the “Licensor.” The text of the Creative Commons
386
+ public licenses is dedicated to the public domain under the CC0 Public
387
+ Domain Dedication. Except for the limited purpose of indicating that
388
+ material is shared under a Creative Commons public license or as
389
+ otherwise permitted by the Creative Commons policies published at
390
+ creativecommons.org/policies, Creative Commons does not authorize the
391
+ use of the trademark "Creative Commons" or any other trademark or logo
392
+ of Creative Commons without its prior written consent including,
393
+ without limitation, in connection with any unauthorized modifications
394
+ to any of its public licenses or any other arrangements,
395
+ understandings, or agreements concerning use of licensed material. For
396
+ the avoidance of doubt, this paragraph does not form part of the
397
+ public licenses.
398
+
399
+ Creative Commons may be contacted at creativecommons.org.
src/third_party/MuQ/README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # <img src="images/muq-logo.jpeg" alt="" height="24px"> MuQ & MuQ-MuLan
2
+
3
+ <div>
4
+ <a href='#'><img alt="Static Badge" src="https://img.shields.io/badge/Python-3.8%2B-blue?logo=python&logoColor=white"></a>
5
+ <a href='https://arxiv.org/abs/2501.01108'><img alt="Static Badge" src="https://img.shields.io/badge/arXiv-2501.01108-%23b31b1b?logo=arxiv&link=https%3A%2F%2Farxiv.org%2F"></a>
6
+ <a href='https://huggingface.co/OpenMuQ'><img alt="Static Badge" src="https://img.shields.io/badge/huggingface-OpenMuQ-%23FFD21E?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2FOpenMuQ"></a>
7
+ <a href='https://pytorch.org/'><img alt="Static Badge" src="https://img.shields.io/badge/framework-PyTorch-%23EE4C2C?logo=pytorch"></a>
8
+ <a href='https://pypi.org/project/muq'><img alt="Static Badge" src="https://img.shields.io/badge/pip%20install-muq-green?logo=PyPI&logoColor=white&link=https%3A%2F%2Fpypi.org%2Fproject%2Fmuq"></a>
9
+ </div>
10
+
11
+ This is the official repository for the paper *"**MuQ**: Self-Supervised **Mu**sic Representation Learning
12
+ with Mel Residual Vector **Q**uantization"*.
13
+
14
+ In this repo, the following models are released:
15
+
16
+ - **MuQ**: A large music foundation model pre-trained via Self-Supervised Learning (SSL), achieving SOTA in various MIR tasks.
17
+ - **MuQ-MuLan**: A music-text joint embedding model trained via contrastive learning, supporting both English and Chinese texts.
18
+
19
+ ## Overview
20
+
21
+ We develop the **MuQ** for music SSL. MuQ applys our proposed Mel-RVQ as quantitative targets and achieves SOTA performance on many music understanding (or MIR) tasks.
22
+
23
+ We also construct the **MuQ-MuLan**, a CLIP-like model trained by contrastive learning, which jointly represents music and text into embeddings.
24
+
25
+ For more details, please refer to our [paper](https://arxiv.org/abs/2501.01108).
26
+
27
+ <div>
28
+ <img src="images/radar.jpg" width="45%" alt="Evaluation on MARBLE Benchmark">
29
+ <img src="images/tagging.jpg" width="45%" alt="Evaluation on Zero-shot Music Tagging">
30
+ </div>
31
+
32
+ ## Usage
33
+
34
+ To begin with, please use pip to install the official `muq` lib, and ensure that your `python>=3.8`:
35
+ ```bash
36
+ pip3 install muq
37
+ ```
38
+
39
+
40
+ To extract music audio features using **MuQ**, you can refer to the following code:
41
+ ```python
42
+ import torch, librosa
43
+ from muq import MuQ
44
+
45
+ device = 'cuda'
46
+ wav, sr = librosa.load("path/to/music_audio.wav", sr = 24000)
47
+ wavs = torch.tensor(wav).unsqueeze(0).to(device)
48
+
49
+ # This will automatically fetch the checkpoint from huggingface
50
+ muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
51
+ muq = muq.to(device).eval()
52
+
53
+ with torch.no_grad():
54
+ output = muq(wavs, output_hidden_states=True)
55
+
56
+ print('Total number of layers: ', len(output.hidden_states))
57
+ print('Feature shape: ', output.last_hidden_state.shape)
58
+
59
+ ```
60
+
61
+ Using **MuQ-MuLan** to extract the music and text embeddings and calculate the similarity:
62
+ ```python
63
+ import torch, librosa
64
+ from muq import MuQMuLan
65
+
66
+ # This will automatically fetch checkpoints from huggingface
67
+ device = 'cuda'
68
+ mulan = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large")
69
+ mulan = mulan.to(device).eval()
70
+
71
+ # Extract music embeddings
72
+ wav, sr = librosa.load("path/to/music_audio.wav", sr = 24000)
73
+ wavs = torch.tensor(wav).unsqueeze(0).to(device)
74
+ with torch.no_grad():
75
+ audio_embeds = mulan(wavs = wavs)
76
+
77
+ # Extract text embeddings (texts can be in English or Chinese)
78
+ texts = ["classical genres, hopeful mood, piano.", "一首适合海边风景的小提琴曲,节奏欢快"]
79
+ with torch.no_grad():
80
+ text_embeds = mulan(texts = texts)
81
+
82
+ # Calculate dot product similarity
83
+ sim = mulan.calc_similarity(audio_embeds, text_embeds)
84
+ print(sim)
85
+ ```
86
+
87
+ > Note that both MuQ and MuQ-MuLan strictly require **24 kHz** audio as input.
88
+ > We recommend using **fp32** during MuQ inference to avoid potential NaN issues.
89
+
90
+
91
+ ## Performance
92
+
93
+ <img src="images/tab-marble.jpg" width="100%" style="max-width: 800px" alt="Table MARBLE Benchmark">
94
+ <img src="images/tab-mulan.png" width="50%" style="max-width: 400px; margin: 0 25%" alt="Table Mulan Results">
95
+
96
+ ## Model Checkpoints
97
+
98
+ | Model Name | Parameters | Data | HuggingFace🤗 |
99
+ | ----------- | --- | --- | ----------- |
100
+ | MuQ | ~300M | MSD dataset | [OpenMuQ/MuQ-large-msd-iter](https://huggingface.co/OpenMuQ/MuQ-large-msd-iter) |
101
+ | MuQ-MuLan | ~700M | music-text pairs | [OpenMuQ/MuQ-MuLan-large](https://huggingface.co/OpenMuQ/MuQ-MuLan-large) |
102
+
103
+ **Note**: Please note that the open-sourced MuQ was trained on the Million Song Dataset. Due to differences in dataset size, the open-sourced model may not achieve the same level of performance as reported in the paper. The training recipes can be found [here](./src/recipes).
104
+
105
+ ## License
106
+
107
+ The code in this repository is released under the MIT license as found in the [LICENSE](LICENSE) file.
108
+
109
+ The model weights (MuQ-large-msd-iter, MuQ-MuLan-large) in this repository are released under the CC-BY-NC 4.0 license, as detailed in the [LICENSE_weights](LICENSE_weights) file.
110
+
111
+ ## Citation
112
+
113
+ ```
114
+ @article{zhu2025muq,
115
+ title={MuQ: Self-Supervised Music Representation Learning with Mel Residual Vector Quantization},
116
+ author={Haina Zhu and Yizhi Zhou and Hangting Chen and Jianwei Yu and Ziyang Ma and Rongzhi Gu and Yi Luo and Wei Tan and Xie Chen},
117
+ journal={arXiv preprint arXiv:2501.01108},
118
+ year={2025}
119
+ }
120
+ ```
121
+
122
+ ## Acknowledgement
123
+
124
+ We borrow many codes from the following repositories:
125
+ - [lucidrains/musiclm-pytorch](https://github.com/lucidrains/musiclm-pytorch)
126
+ - [minzwon/musicfm](https://github.com/minzwon/musicfm)
127
+
128
+
129
+ Also, we are especially grateful to the awesome [MARBLE-Benchmark](https://github.com/a43992899/MARBLE-Benchmark).
src/third_party/MuQ/images/muq-logo.jpeg ADDED
src/third_party/MuQ/images/radar.jpg ADDED

Git LFS Details

  • SHA256: 5c128d768e4888aa0bfbef2d3caa47e819b81840b153c2e4265fd40d921c3685
  • Pointer size: 130 Bytes
  • Size of remote file: 43.1 kB
src/third_party/MuQ/images/tab-marble.jpg ADDED

Git LFS Details

  • SHA256: d7287c7741b06062fb5cb57b10149c9138cbb56ad3eabef7e3b957ea32db1639
  • Pointer size: 131 Bytes
  • Size of remote file: 264 kB
src/third_party/MuQ/images/tab-mulan.png ADDED

Git LFS Details

  • SHA256: f473ebd635d2f0c5e4f3fdf7b0a11b9b99ad000215035b4dd412e0c9f7fa3304
  • Pointer size: 130 Bytes
  • Size of remote file: 83.6 kB
src/third_party/MuQ/images/tagging.jpg ADDED

Git LFS Details

  • SHA256: 57717afef5c8341b64d685410311bc1335752508c0d127b6573a226688eb61b0
  • Pointer size: 130 Bytes
  • Size of remote file: 44.4 kB
src/third_party/MuQ/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ librosa
3
+ nnAudio
4
+ numpy
5
+ soundfile
6
+ torch
7
+ torchaudio
8
+ tqdm
9
+ transformers
10
+ easydict
11
+ x_clip
src/third_party/MuQ/setup.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='muq', # Name of the package
5
+ version='0.1.0', # Version of the package
6
+ packages=find_packages(where='src'), # Automatically discover packages under the 'src' directory
7
+ package_dir={'': 'src'}, # Specify the root directory for packages as 'src'
8
+ include_package_data=True, # Include additional files, such as static files
9
+ install_requires=[ # List of dependencies
10
+ "einops",
11
+ "librosa",
12
+ "nnAudio",
13
+ "numpy",
14
+ "soundfile",
15
+ "torch",
16
+ "torchaudio",
17
+ "tqdm",
18
+ "transformers",
19
+ "easydict",
20
+ "x_clip",
21
+ ],
22
+ author='Haina Zhu', # Author name
23
+ author_email='juhayna@qq.com', # Author email address
24
+ description='MuQ: A deep learning model for music and text', # Short description of the package
25
+ long_description=open('README.md', encoding='utf-8').read(), # Long description from the README file
26
+ long_description_content_type='text/markdown', # Format of the long description (Markdown)
27
+ url='https://github.com/tencent-ailab/MuQ', # Project URL
28
+ classifiers=[
29
+ 'Programming Language :: Python :: 3', # Python 3 support
30
+ 'License :: OSI Approved :: MIT License', # License type
31
+ 'Operating System :: OS Independent', # Supports all operating systems
32
+ ],
33
+ python_requires='>=3.8', # Supported Python version
34
+ )
src/third_party/MuQ/src/muq/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .muq import MuQ, MuQConfig
2
+ from .muq_mulan import MuQMuLan, MuQMuLanConfig
src/third_party/MuQ/src/muq/muq/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .muq import MuQConfig, MuQ
src/third_party/MuQ/src/muq/muq/models/__init__.py ADDED
File without changes
src/third_party/MuQ/src/muq/muq/models/muq_model.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import torch
4
+ from torch import nn
5
+ from einops import rearrange
6
+ import os
7
+ from easydict import EasyDict
8
+
9
+ from ..modules.random_quantizer import RandomProjectionQuantizer
10
+ from ..modules.features import MelSTFT
11
+ from ..modules.conv import Conv2dSubsampling
12
+
13
+ class MuQModel(nn.Module):
14
+
15
+ def __init__(
16
+ self,
17
+ num_codebooks=1,
18
+ codebook_dim=16,
19
+ codebook_size=4096,
20
+ features=["melspec_2048"],
21
+ hop_length=240,
22
+ n_mels=128,
23
+ conv_dim=512,
24
+ encoder_dim=1024,
25
+ encoder_depth=12,
26
+ mask_hop=0.4,
27
+ mask_prob=0.6,
28
+ is_flash=False,
29
+ stat=dict(),
30
+ w2v2_config=dict(),
31
+ use_rvq_target=False,
32
+ use_vq_target=False,
33
+ use_encodec_target=False,
34
+ rvq_ckpt_path=None,
35
+ recon_loss_ratio=None,
36
+ label_rate=25,
37
+ rvq_n_codebooks=8,
38
+ rvq_multi_layer_num=1,
39
+ ):
40
+ super().__init__()
41
+
42
+ # global variables
43
+ self.hop_length = hop_length
44
+ self.mask_hop = mask_hop
45
+ self.mask_prob = mask_prob
46
+ self.num_codebooks = num_codebooks
47
+ self.codebook_size = codebook_size
48
+ self.features = features
49
+ self.recon_loss_ratio = recon_loss_ratio
50
+ self.n_fold = int(100//label_rate)
51
+ self.label_rate = label_rate
52
+
53
+ # load feature mean / std stats
54
+ self.stat = stat
55
+
56
+ # feature extractor
57
+ self.preprocessor_melspec_2048 = MelSTFT(
58
+ n_fft=2048, hop_length=hop_length, is_db=True
59
+ )
60
+
61
+ # random quantizer
62
+ self.use_rvq_target = use_rvq_target
63
+ self.use_vq_target = use_vq_target
64
+ self.use_encodec_target = use_encodec_target
65
+
66
+ seed = 142
67
+ if self.use_rvq_like_target:
68
+ if use_rvq_target:
69
+ from ..modules.rvq import ResidualVectorQuantize
70
+
71
+ inp_dim = 128*self.n_fold
72
+ self.rvq = ResidualVectorQuantize(
73
+ input_dim = inp_dim,
74
+ n_codebooks = rvq_n_codebooks,
75
+ codebook_size = 1024,
76
+ codebook_dim = 16,
77
+ quantizer_dropout = 0.0,
78
+ use_multi_layer_num = rvq_multi_layer_num,
79
+ )
80
+ elif use_vq_target:
81
+ from ..modules.rvq import VectorQuantize
82
+
83
+ self.rvq = VectorQuantize(
84
+ input_dim = 128*self.n_fold,
85
+ codebook_size = 1024,
86
+ codebook_dim = 8,
87
+ stale_tolerance = 1000,
88
+ mfcc_clustering = False
89
+ )
90
+ elif use_encodec_target:
91
+ from encodec import EncodecModel
92
+ self.rvq = EncodecModel.encodec_model_24khz()
93
+ self.rvq.set_target_bandwidth(6.0)
94
+ for param in self.rvq.parameters():
95
+ param.requires_grad = False
96
+
97
+ if rvq_ckpt_path is not None and os.path.exists(rvq_ckpt_path):
98
+ state_dict = torch.load(rvq_ckpt_path, map_location="cpu")
99
+ self.rvq.load_state_dict(state_dict)
100
+ else:
101
+ pass
102
+ # print(f'Checkpoint for rvq `{rvq_ckpt_path}` not found. Using random initialization.')
103
+ else:
104
+ for feature in self.features:
105
+ for i in range(num_codebooks):
106
+ setattr(
107
+ self,
108
+ f"quantizer_{feature}", # _{i}
109
+ RandomProjectionQuantizer(
110
+ n_mels * self.n_fold, codebook_dim, codebook_size, seed=seed + i
111
+ ),
112
+ )
113
+
114
+ # two residual convolution layers + one projection layer
115
+ strides_factory = {
116
+ 4: [2, 2],
117
+ 2: [2, 1]
118
+ }
119
+ self.conv = Conv2dSubsampling(
120
+ 1, conv_dim, encoder_dim, strides=strides_factory.get(self.n_fold), n_bands=n_mels
121
+ )
122
+
123
+ # Conformer
124
+ if is_flash:
125
+ from modules.flash_conformer import (
126
+ Wav2Vec2ConformerEncoder,
127
+ Wav2Vec2ConformerConfig,
128
+ )
129
+ else:
130
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
131
+ Wav2Vec2ConformerEncoder,
132
+ Wav2Vec2ConformerConfig,
133
+ )
134
+ config = EasyDict(w2v2_config)
135
+ config.num_hidden_layers = encoder_depth
136
+ config.hidden_size = encoder_dim
137
+
138
+ self.conformer = Wav2Vec2ConformerEncoder(config)
139
+
140
+ self.linear = nn.Linear(encoder_dim, codebook_size) # projection layer
141
+
142
+ # reconstruct melspec
143
+ if self.recon_loss_ratio is not None and self.recon_loss_ratio > 0:
144
+ self.recon_proj = nn.Linear(encoder_dim, n_mels * self.n_fold)
145
+ self.recon_loss = nn.MSELoss()
146
+
147
+ # loss function
148
+ self.loss = nn.CrossEntropyLoss()
149
+
150
+ # cls token (used for sequence classification)
151
+ random.seed(seed)
152
+ self.cls_token = nn.Parameter(torch.randn(encoder_dim))
153
+
154
+
155
+ @property
156
+ def use_rvq_like_target(self):
157
+ return self.use_rvq_target or self.use_vq_target or self.use_encodec_target
158
+
159
+ def masking(self, x, attention_mask=None):
160
+ """random masking of 400ms with given probability"""
161
+ mx = x.clone()
162
+ b, t = mx.shape
163
+ len_masking_raw = int(24000 * self.mask_hop)
164
+ len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop)
165
+
166
+ # get random mask indices
167
+ start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
168
+ time_domain_masked_indices = torch.nonzero(
169
+ start_indices.repeat_interleave(len_masking_raw, dim=1)
170
+ )
171
+ token_domain_masked_indices = torch.nonzero(
172
+ start_indices.repeat_interleave(len_masking_token, dim=1)
173
+ )
174
+
175
+ # mask with random values
176
+ masking_noise = (
177
+ torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
178
+ ) # 0 mean 0.1 std
179
+ mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
180
+
181
+ return mx, token_domain_masked_indices
182
+
183
+
184
+ @torch.no_grad()
185
+ def preprocessing(self, x, features):
186
+ """extract classic audio features"""
187
+ # check precision
188
+ if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
189
+ precision = 16
190
+ else:
191
+ precision = 32
192
+
193
+ out = {}
194
+ for key in features:
195
+ layer = getattr(self, "preprocessor_%s" % key)
196
+ layer.to(x.device)
197
+ dtype = x.dtype
198
+ out[key] = layer(x.float())[..., :-1]
199
+ if precision == 16:
200
+ out[key] = out[key].half()
201
+ if out[key].dtype != dtype:
202
+ out[key].to(dtype=dtype)
203
+ return out
204
+
205
+ def encoder(self, x, *, attention_mask=None, is_features_only=False):
206
+ """2-layer conv + w2v-conformer"""
207
+ x = self.conv(x)
208
+ mask_indices = None
209
+ if attention_mask is None:
210
+ out = self.conformer(x, output_hidden_states=True)
211
+ else:
212
+ attention_mask = attention_mask.bool()
213
+ skip_n = int(attention_mask.size(-1) / x.size(1))
214
+ attention_mask = attention_mask[:, ::skip_n]
215
+ attention_mask = attention_mask[:, :x.size(1)]
216
+ out = self.conformer(x, attention_mask=attention_mask, output_hidden_states=True)
217
+ hidden_emb = out["hidden_states"]
218
+ last_emb = out["last_hidden_state"]
219
+ logits = self.linear(last_emb)
220
+ interval = self.codebook_size
221
+ logits = {
222
+ key: logits[:, :, i * interval : (i + 1) * interval]
223
+ for i, key in enumerate(self.features)
224
+ }
225
+ return logits, hidden_emb, mask_indices
226
+
227
+ @torch.no_grad()
228
+ def normalize(self, x):
229
+ """normalize the input audio to have zero mean unit variance"""
230
+ for key in x.keys():
231
+ x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key]
232
+ return x
233
+
234
+ @torch.no_grad()
235
+ def rearrange(self, x):
236
+ """rearrange the batch to flatten every 4 steps"""
237
+ for key in x.keys():
238
+ if key == "chromagram":
239
+ x[key] = rearrange(x[key], "b f t -> b t f")
240
+ else:
241
+ x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=self.n_fold)
242
+ return x
243
+
244
+ def get_rvq_codes(self, inp, raw_wav):
245
+ if self.use_rvq_target:
246
+ quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(inp)
247
+ return codes
248
+ if self.use_vq_target:
249
+ quantized_prompt_embeds, commitment_loss, codebook_loss, codes, _ = self.rvq(inp)
250
+ return codes.unsqueeze(1)
251
+ if self.use_encodec_target:
252
+ encoded_frames = self.rvq.encode(raw_wav.unsqueeze(1)) #list, B,[ 8,T ]
253
+ codes = torch.cat([encoded[0].detach() for encoded in encoded_frames], dim=-1)
254
+ if self.label_rate == 25:
255
+ codes = codes[:, :, ::3]
256
+ return codes
257
+
258
+ @torch.no_grad()
259
+ def tokenize(self, x, raw_wav):
260
+ out = {}
261
+ for key in x.keys():
262
+ if self.use_rvq_like_target:
263
+ self.rvq.eval()
264
+ inp = x[key].permute((0, 2, 1))
265
+ codes = self.get_rvq_codes(inp, raw_wav)
266
+ out[key] = torch.cat([codes[:, idx, ...] for idx in range(int(self.codebook_size//1024))], dim=-1)
267
+ else:
268
+ layer = getattr(self, "quantizer_%s" % key)
269
+ out[key] = layer(x[key])
270
+ return out
271
+
272
+ def get_targets(self, x, label=None):
273
+ if self.use_encodec_target:
274
+ raw_x = x.clone()
275
+ else:
276
+ raw_x = None
277
+ x = self.preprocessing(x, features=self.features)
278
+ x = self.normalize(x)
279
+ x = self.rearrange(x)
280
+ melspec = x['melspec_2048']
281
+ if label is None:
282
+ # Use labels from Mel-RVQ
283
+ target_tokens = self.tokenize(x, raw_x)
284
+ else:
285
+ # Use labels pre-extracted for iteration training
286
+ target_tokens = {'melspec_2048': rearrange(label, "b n s -> b (n s)").long()}
287
+ return target_tokens, melspec
288
+
289
+ def get_predictions(self, x, *, mask=None, attention_mask=None, return_new_mask=False, is_features_only=False):
290
+ # preprocessing
291
+ x = self.preprocessing(x, features=["melspec_2048"])
292
+ x = self.normalize(x)
293
+
294
+ # encoding
295
+ logits, hidden_emb, new_mask = self.encoder(x["melspec_2048"], attention_mask=attention_mask, is_features_only=is_features_only)
296
+
297
+ if return_new_mask:
298
+ return logits, hidden_emb, mask if new_mask is None else new_mask
299
+ else:
300
+ return logits, hidden_emb
301
+
302
+ def get_latent(self, x, layer_ix=12):
303
+ _, hidden_states = self.get_predictions(x)
304
+ emb = hidden_states[layer_ix]
305
+ return emb
306
+
307
+ def compute_nce(self, x, pos, negs):
308
+ neg_is_pos = (pos == negs).all(-1)
309
+ pos = pos.unsqueeze(0)
310
+ targets = torch.cat([pos, negs], dim=0)
311
+
312
+ logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
313
+ logits /= 0.1
314
+ if neg_is_pos.any():
315
+ logits[1:][neg_is_pos] = float("-inf")
316
+ logits = logits.transpose(0, 1)
317
+ return logits
318
+
319
+ def get_loss(self, logits, target_tokens, masked_indices):
320
+ losses = {}
321
+ accuracies = {}
322
+ for key in logits.keys():
323
+ if not self.use_rvq_like_target:
324
+ masked_logits = logits[key][tuple(masked_indices.t())]
325
+ masked_tokens = target_tokens[key][tuple(masked_indices.t())]
326
+ else:
327
+ Batch, SeqLen, N_Codebook_x_CodebookSize = logits[key].shape
328
+ Batch, N_Codebook_x_SeqLen = target_tokens[key].shape
329
+ N_Codebook = int(N_Codebook_x_SeqLen // SeqLen)
330
+ target_tokens[key] = rearrange(target_tokens[key], "b (n s) -> b s n", n=N_Codebook) # Batch, SeqLen=750, N_Codebook=4
331
+ masked_logits = logits[key][tuple(masked_indices.t())]
332
+ masked_tokens = target_tokens[key][tuple(masked_indices.t())]
333
+ masked_logits = rearrange(masked_logits, "b (n c) -> (b n) c", n=N_Codebook)
334
+ masked_tokens = rearrange(masked_tokens, "b n -> (b n)", n=N_Codebook)
335
+
336
+ losses[key] = self.loss(masked_logits, masked_tokens)
337
+ accuracies[key] = (
338
+ torch.sum(masked_logits.argmax(-1) == masked_tokens)
339
+ / masked_tokens.numel()
340
+ )
341
+ return losses, accuracies
342
+
343
+ def get_recon_loss(self, last_hidden_emb, melspec, masked_indices):
344
+ pred_melspec = self.recon_proj(last_hidden_emb[tuple(masked_indices.t())])
345
+ target_melspec = melspec[tuple(masked_indices.t())]
346
+ recon_loss = self.recon_loss(pred_melspec, target_melspec)
347
+ return recon_loss
348
+
349
+ def forward(self, x, attention_mask=None, label=None):
350
+ dtype = x.dtype
351
+ # get target feature tokens
352
+ target_tokens, melspec = self.get_targets(x, label=label)
353
+
354
+ # masking
355
+ x, masked_indices = self.masking(x, attention_mask=attention_mask)
356
+
357
+ # forward
358
+ logits, hidden_emb, masked_indices = self.get_predictions(x, mask=masked_indices, attention_mask=attention_mask, return_new_mask=True)
359
+
360
+ # get loss
361
+ losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
362
+
363
+ if self.recon_loss_ratio:
364
+ losses["recon_loss"] = self.get_recon_loss(hidden_emb[-1], melspec, masked_indices) * self.recon_loss_ratio
365
+
366
+ return logits, hidden_emb, losses, accuracies
src/third_party/MuQ/src/muq/muq/modules/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
src/third_party/MuQ/src/muq/muq/modules/conv.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from einops import rearrange
3
+
4
+
5
+ class Res2dModule(nn.Module):
6
+ def __init__(self, idim, odim, stride=(2, 2)):
7
+ super(Res2dModule, self).__init__()
8
+ self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
9
+ self.bn1 = nn.BatchNorm2d(odim)
10
+ self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
11
+ self.bn2 = nn.BatchNorm2d(odim)
12
+ self.relu = nn.ReLU()
13
+
14
+ # residual
15
+ self.diff = False
16
+ if (idim != odim) or (stride[0] > 1):
17
+ self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
18
+ self.bn3 = nn.BatchNorm2d(odim)
19
+ self.diff = True
20
+
21
+ def forward(self, x):
22
+ out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
23
+ if self.diff:
24
+ x = self.bn3(self.conv3(x))
25
+ out = x + out
26
+ out = self.relu(out)
27
+ return out
28
+
29
+
30
+ class Conv2dSubsampling(nn.Module):
31
+ """Convolutional 2D subsampling (to 1/4 length).
32
+
33
+ Args:
34
+ idim (int): Input dimension.
35
+ hdim (int): Hidden dimension.
36
+ odim (int): Output dimension.
37
+ strides (list): Sizes of strides.
38
+ n_bands (int): Number of frequency bands.
39
+ """
40
+
41
+ def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
42
+ """Construct an Conv2dSubsampling object."""
43
+ super(Conv2dSubsampling, self).__init__()
44
+
45
+ self.conv = nn.Sequential(
46
+ Res2dModule(idim, hdim, (2, strides[0])),
47
+ Res2dModule(hdim, hdim, (2, strides[1])),
48
+ )
49
+ self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
50
+
51
+ def forward(self, x):
52
+ """Subsample x.
53
+
54
+ Args:
55
+ x (torch.Tensor): Input tensor (#batch, idim, time).
56
+
57
+ Returns:
58
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
59
+ where time' = time // 4.
60
+ """
61
+
62
+ if x.dim() == 3:
63
+ x = x.unsqueeze(1) # (b, c, f, t)
64
+ x = self.conv(x)
65
+ x = rearrange(x, "b c f t -> b t (c f)")
66
+ x = self.linear(x)
67
+ return x
68
+
69
+ if __name__ == '__main__':
70
+ import torch
71
+ conv_dim, encoder_dim = 512, 1024
72
+ conv = Conv2dSubsampling(
73
+ 1, conv_dim, encoder_dim, strides=[2, 1], n_bands=128
74
+ )
75
+ inp = torch.randn((1, 128, 3000))
76
+ out = conv(inp)
77
+ print(out.shape)
src/third_party/MuQ/src/muq/muq/modules/features.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ from torch import nn
3
+ import torch
4
+
5
+
6
+ class MelSTFT:
7
+ def __init__(
8
+ self,
9
+ sample_rate=24000,
10
+ n_fft=2048,
11
+ hop_length=240,
12
+ n_mels=128,
13
+ is_db=False,
14
+ ):
15
+ super(MelSTFT, self).__init__()
16
+
17
+ # spectrogram
18
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
19
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
20
+ )
21
+
22
+ # amplitude to decibel
23
+ self.is_db = is_db
24
+ if is_db:
25
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
26
+
27
+ def __call__(self, waveform):
28
+ if self.is_db:
29
+ return self.amplitude_to_db(self.mel_stft(waveform))
30
+ else:
31
+ return self.mel_stft(waveform)
32
+
33
+ def to(self, device):
34
+ self.mel_stft = self.mel_stft.to(device)
35
+ if self.is_db:
36
+ self.amplitude_to_db = self.amplitude_to_db.to(device)
37
+ return self
src/third_party/MuQ/src/muq/muq/modules/flash_conformer.py ADDED
@@ -0,0 +1,2114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Wav2Vec2-Conformer model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+ from torch.nn import functional as F
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutput,
32
+ CausalLMOutput,
33
+ SequenceClassifierOutput,
34
+ TokenClassifierOutput,
35
+ Wav2Vec2BaseModelOutput,
36
+ XVectorOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.utils import (
40
+ ModelOutput,
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+
53
+ _HIDDEN_STATES_START_POSITION = 2
54
+
55
+ # General docstring
56
+ _CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
57
+
58
+ # Base docstring
59
+ _CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
60
+ _EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
61
+
62
+ # CTC docstring
63
+ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
64
+ _CTC_EXPECTED_LOSS = 64.21
65
+
66
+
67
+ WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
68
+ "facebook/wav2vec2-conformer-rel-pos-large",
69
+ # See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
70
+ ]
71
+
72
+
73
+ @dataclass
74
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
75
+ class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
76
+ """
77
+ Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
78
+
79
+ Args:
80
+ loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
81
+ Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
82
+ paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
83
+ projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
84
+ Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
85
+ projected quantized states.
86
+ projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
87
+ Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
88
+ target vectors for contrastive loss.
89
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
90
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
91
+ shape `(batch_size, sequence_length, hidden_size)`.
92
+
93
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
94
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
95
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
96
+ sequence_length)`.
97
+
98
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
99
+ heads.
100
+ contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
101
+ The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
102
+ diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
103
+ The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
104
+ """
105
+
106
+ loss: Optional[torch.FloatTensor] = None
107
+ projected_states: torch.FloatTensor = None
108
+ projected_quantized_states: torch.FloatTensor = None
109
+ codevector_perplexity: torch.FloatTensor = None
110
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
111
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
112
+ contrastive_loss: Optional[torch.FloatTensor] = None
113
+ diversity_loss: Optional[torch.FloatTensor] = None
114
+
115
+
116
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
117
+ def _compute_mask_indices(
118
+ shape: Tuple[int, int],
119
+ mask_prob: float,
120
+ mask_length: int,
121
+ attention_mask: Optional[torch.LongTensor] = None,
122
+ min_masks: int = 0,
123
+ ) -> np.ndarray:
124
+ """
125
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
126
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
127
+ CPU as part of the preprocessing during training.
128
+
129
+ Args:
130
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
131
+ the first element is the batch size and the second element is the length of the axis to span.
132
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
133
+ independently generated mask spans of length `mask_length` is computed by
134
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
135
+ actual percentage will be smaller.
136
+ mask_length: size of the mask
137
+ min_masks: minimum number of masked spans
138
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
139
+ each batch dimension.
140
+ """
141
+ batch_size, sequence_length = shape
142
+
143
+ if mask_length < 1:
144
+ raise ValueError("`mask_length` has to be bigger than 0.")
145
+
146
+ if mask_length > sequence_length:
147
+ raise ValueError(
148
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
149
+ f" and `sequence_length`: {sequence_length}`"
150
+ )
151
+
152
+ # epsilon is used for probabilistic rounding
153
+ epsilon = np.random.rand(1).item()
154
+
155
+ def compute_num_masked_span(input_length):
156
+ """Given input length, compute how many spans should be masked"""
157
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
158
+ num_masked_span = max(num_masked_span, min_masks)
159
+
160
+ # make sure num masked span <= sequence_length
161
+ if num_masked_span * mask_length > sequence_length:
162
+ num_masked_span = sequence_length // mask_length
163
+
164
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
165
+ if input_length - (mask_length - 1) < num_masked_span:
166
+ num_masked_span = max(input_length - (mask_length - 1), 0)
167
+
168
+ return num_masked_span
169
+
170
+ # compute number of masked spans in batch
171
+ input_lengths = (
172
+ attention_mask.sum(-1).detach().tolist()
173
+ if attention_mask is not None
174
+ else [sequence_length for _ in range(batch_size)]
175
+ )
176
+
177
+ # SpecAugment mask to fill
178
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
179
+ spec_aug_mask_idxs = []
180
+
181
+ max_num_masked_span = compute_num_masked_span(sequence_length)
182
+
183
+ if max_num_masked_span == 0:
184
+ return spec_aug_mask
185
+
186
+ for input_length in input_lengths:
187
+ # compute num of masked spans for this input
188
+ num_masked_span = compute_num_masked_span(input_length)
189
+
190
+ # get random indices to mask
191
+ spec_aug_mask_idx = np.random.choice(
192
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
193
+ )
194
+
195
+ # pick first sampled index that will serve as a dummy index to pad vector
196
+ # to ensure same dimension for all batches due to probabilistic rounding
197
+ # Picking first sample just pads those vectors twice.
198
+ if len(spec_aug_mask_idx) == 0:
199
+ # this case can only happen if `input_length` is strictly smaller then
200
+ # `sequence_length` in which case the last token has to be a padding
201
+ # token which we can use as a dummy mask id
202
+ dummy_mask_idx = sequence_length - 1
203
+ else:
204
+ dummy_mask_idx = spec_aug_mask_idx[0]
205
+
206
+ spec_aug_mask_idx = np.concatenate(
207
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
208
+ )
209
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
210
+
211
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
212
+
213
+ # expand masked indices to masked spans
214
+ spec_aug_mask_idxs = np.broadcast_to(
215
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
216
+ )
217
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
218
+
219
+ # add offset to the starting indexes so that indexes now create a span
220
+ offsets = np.arange(mask_length)[None, None, :]
221
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
222
+ batch_size, max_num_masked_span * mask_length
223
+ )
224
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
225
+
226
+ # ensure that we cannot have indices larger than sequence_length
227
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
228
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
229
+
230
+ # scatter indices to mask
231
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
232
+
233
+ return spec_aug_mask
234
+
235
+
236
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
237
+ def _sample_negative_indices(
238
+ features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
239
+ ):
240
+ """
241
+ Sample `num_negatives` vectors from feature vectors.
242
+ """
243
+ batch_size, sequence_length = features_shape
244
+
245
+ # generate indices of the positive vectors themselves, repeat them `num_negatives` times
246
+ sequence_length_range = np.arange(sequence_length)
247
+
248
+ # get `num_negatives` random vector indices from the same utterance
249
+ sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
250
+
251
+ mask_time_indices = (
252
+ mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
253
+ )
254
+
255
+ for batch_idx in range(batch_size):
256
+ high = mask_time_indices[batch_idx].sum() - 1
257
+ mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
258
+
259
+ feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
260
+ sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
261
+ # avoid sampling the same positive vector, but keep the distribution uniform
262
+ sampled_indices[sampled_indices >= feature_indices] += 1
263
+
264
+ # remap to actual indices
265
+ sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
266
+
267
+ # correct for batch size
268
+ sampled_negative_indices[batch_idx] += batch_idx * sequence_length
269
+
270
+ return sampled_negative_indices
271
+
272
+
273
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
274
+ class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
275
+ def __init__(self, config, layer_id=0):
276
+ super().__init__()
277
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
278
+ self.out_conv_dim = config.conv_dim[layer_id]
279
+
280
+ self.conv = nn.Conv1d(
281
+ self.in_conv_dim,
282
+ self.out_conv_dim,
283
+ kernel_size=config.conv_kernel[layer_id],
284
+ stride=config.conv_stride[layer_id],
285
+ bias=config.conv_bias,
286
+ )
287
+ self.activation = ACT2FN[config.feat_extract_activation]
288
+
289
+ def forward(self, hidden_states):
290
+ hidden_states = self.conv(hidden_states)
291
+ hidden_states = self.activation(hidden_states)
292
+ return hidden_states
293
+
294
+
295
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
296
+ class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
297
+ def __init__(self, config, layer_id=0):
298
+ super().__init__()
299
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
300
+ self.out_conv_dim = config.conv_dim[layer_id]
301
+
302
+ self.conv = nn.Conv1d(
303
+ self.in_conv_dim,
304
+ self.out_conv_dim,
305
+ kernel_size=config.conv_kernel[layer_id],
306
+ stride=config.conv_stride[layer_id],
307
+ bias=config.conv_bias,
308
+ )
309
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
310
+ self.activation = ACT2FN[config.feat_extract_activation]
311
+
312
+ def forward(self, hidden_states):
313
+ hidden_states = self.conv(hidden_states)
314
+
315
+ hidden_states = hidden_states.transpose(-2, -1)
316
+ hidden_states = self.layer_norm(hidden_states)
317
+ hidden_states = hidden_states.transpose(-2, -1)
318
+
319
+ hidden_states = self.activation(hidden_states)
320
+ return hidden_states
321
+
322
+
323
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
324
+ class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
325
+ def __init__(self, config, layer_id=0):
326
+ super().__init__()
327
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
328
+ self.out_conv_dim = config.conv_dim[layer_id]
329
+
330
+ self.conv = nn.Conv1d(
331
+ self.in_conv_dim,
332
+ self.out_conv_dim,
333
+ kernel_size=config.conv_kernel[layer_id],
334
+ stride=config.conv_stride[layer_id],
335
+ bias=config.conv_bias,
336
+ )
337
+ self.activation = ACT2FN[config.feat_extract_activation]
338
+
339
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
340
+
341
+ def forward(self, hidden_states):
342
+ hidden_states = self.conv(hidden_states)
343
+ hidden_states = self.layer_norm(hidden_states)
344
+ hidden_states = self.activation(hidden_states)
345
+ return hidden_states
346
+
347
+
348
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
349
+ class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.conv = nn.Conv1d(
353
+ config.hidden_size,
354
+ config.hidden_size,
355
+ kernel_size=config.num_conv_pos_embeddings,
356
+ padding=config.num_conv_pos_embeddings // 2,
357
+ groups=config.num_conv_pos_embedding_groups,
358
+ )
359
+
360
+ if is_deepspeed_zero3_enabled():
361
+ import deepspeed
362
+
363
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
364
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
365
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
366
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
367
+ else:
368
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
369
+
370
+ self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
371
+ self.activation = ACT2FN[config.feat_extract_activation]
372
+
373
+ def forward(self, hidden_states):
374
+ hidden_states = hidden_states.transpose(1, 2)
375
+
376
+ hidden_states = self.conv(hidden_states)
377
+ hidden_states = self.padding(hidden_states)
378
+ hidden_states = self.activation(hidden_states)
379
+
380
+ hidden_states = hidden_states.transpose(1, 2)
381
+ return hidden_states
382
+
383
+
384
+ class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
385
+ """Rotary positional embedding
386
+ Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
387
+ """
388
+
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ dim = config.hidden_size // config.num_attention_heads
392
+ base = config.rotary_embedding_base
393
+
394
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
395
+ self.register_buffer("inv_freq", inv_freq)
396
+ self.cached_sequence_length = None
397
+ self.cached_rotary_positional_embedding = None
398
+
399
+ def forward(self, hidden_states):
400
+ sequence_length = hidden_states.shape[1]
401
+
402
+ if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
403
+ return self.cached_rotary_positional_embedding
404
+
405
+ self.cached_sequence_length = sequence_length
406
+ time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
407
+ freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
408
+ embeddings = torch.cat((freqs, freqs), dim=-1)
409
+
410
+ cos_embeddings = embeddings.cos()[:, None, None, :]
411
+ sin_embeddings = embeddings.sin()[:, None, None, :]
412
+ self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
413
+ return self.cached_rotary_positional_embedding
414
+
415
+
416
+ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
417
+ """Relative positional encoding module."""
418
+
419
+ def __init__(self, config):
420
+ super().__init__()
421
+ self.max_len = config.max_source_positions
422
+ self.d_model = config.hidden_size
423
+ self.pe = None
424
+ self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
425
+
426
+ def extend_pe(self, x):
427
+ # Reset the positional encodings
428
+ if self.pe is not None:
429
+ # self.pe contains both positive and negative parts
430
+ # the length of self.pe is 2 * input_len - 1
431
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
432
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
433
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
434
+ return
435
+ # Suppose `i` is the position of query vector and `j` is the
436
+ # position of key vector. We use positive relative positions when keys
437
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
438
+ pe_positive = torch.zeros(x.size(1), self.d_model)
439
+ pe_negative = torch.zeros(x.size(1), self.d_model)
440
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
441
+ div_term = torch.exp(
442
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
443
+ )
444
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
445
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
446
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
447
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
448
+
449
+ # Reverse the order of positive indices and concat both positive and
450
+ # negative indices. This is used to support the shifting trick
451
+ # as in https://arxiv.org/abs/1901.02860
452
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
453
+ pe_negative = pe_negative[1:].unsqueeze(0)
454
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
455
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
456
+
457
+ def forward(self, hidden_states: torch.Tensor):
458
+ self.extend_pe(hidden_states)
459
+ start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
460
+ end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
461
+ relative_position_embeddings = self.pe[:, start_idx:end_idx]
462
+
463
+ return relative_position_embeddings
464
+
465
+
466
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
467
+ class Wav2Vec2ConformerSamePadLayer(nn.Module):
468
+ def __init__(self, num_conv_pos_embeddings):
469
+ super().__init__()
470
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
471
+
472
+ def forward(self, hidden_states):
473
+ if self.num_pad_remove > 0:
474
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
475
+ return hidden_states
476
+
477
+
478
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
479
+ class Wav2Vec2ConformerFeatureEncoder(nn.Module):
480
+ """Construct the features from raw audio waveform"""
481
+
482
+ def __init__(self, config):
483
+ super().__init__()
484
+
485
+ if config.feat_extract_norm == "group":
486
+ conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
487
+ Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
488
+ for i in range(config.num_feat_extract_layers - 1)
489
+ ]
490
+ elif config.feat_extract_norm == "layer":
491
+ conv_layers = [
492
+ Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
493
+ ]
494
+ else:
495
+ raise ValueError(
496
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
497
+ )
498
+ self.conv_layers = nn.ModuleList(conv_layers)
499
+ self.gradient_checkpointing = False
500
+ self._requires_grad = True
501
+
502
+ def _freeze_parameters(self):
503
+ for param in self.parameters():
504
+ param.requires_grad = False
505
+ self._requires_grad = False
506
+
507
+ def forward(self, input_values):
508
+ hidden_states = input_values[:, None]
509
+
510
+ # make sure hidden_states require grad for gradient_checkpointing
511
+ if self._requires_grad and self.training:
512
+ hidden_states.requires_grad = True
513
+
514
+ for conv_layer in self.conv_layers:
515
+ if self._requires_grad and self.gradient_checkpointing and self.training:
516
+
517
+ def create_custom_forward(module):
518
+ def custom_forward(*inputs):
519
+ return module(*inputs)
520
+
521
+ return custom_forward
522
+
523
+ hidden_states = torch.utils.checkpoint.checkpoint(
524
+ create_custom_forward(conv_layer),
525
+ hidden_states,
526
+ )
527
+ else:
528
+ hidden_states = conv_layer(hidden_states)
529
+
530
+ return hidden_states
531
+
532
+
533
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
534
+ class Wav2Vec2ConformerFeatureProjection(nn.Module):
535
+ def __init__(self, config):
536
+ super().__init__()
537
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
538
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
539
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
540
+
541
+ def forward(self, hidden_states):
542
+ # non-projected hidden states are needed for quantization
543
+ norm_hidden_states = self.layer_norm(hidden_states)
544
+ hidden_states = self.projection(norm_hidden_states)
545
+ hidden_states = self.dropout(hidden_states)
546
+ return hidden_states, norm_hidden_states
547
+
548
+
549
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
550
+ class Wav2Vec2ConformerFeedForward(nn.Module):
551
+ def __init__(self, config):
552
+ super().__init__()
553
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
554
+
555
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
556
+ if isinstance(config.hidden_act, str):
557
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
558
+ else:
559
+ self.intermediate_act_fn = config.hidden_act
560
+
561
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
562
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
563
+
564
+ def forward(self, hidden_states):
565
+ hidden_states = self.intermediate_dense(hidden_states)
566
+ hidden_states = self.intermediate_act_fn(hidden_states)
567
+ hidden_states = self.intermediate_dropout(hidden_states)
568
+
569
+ hidden_states = self.output_dense(hidden_states)
570
+ hidden_states = self.output_dropout(hidden_states)
571
+ return hidden_states
572
+
573
+
574
+ class Wav2Vec2ConformerConvolutionModule(nn.Module):
575
+ """Convolution block used in the conformer block"""
576
+
577
+ def __init__(self, config):
578
+ super().__init__()
579
+ if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
580
+ raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
581
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
582
+ self.pointwise_conv1 = torch.nn.Conv1d(
583
+ config.hidden_size,
584
+ 2 * config.hidden_size,
585
+ kernel_size=1,
586
+ stride=1,
587
+ padding=0,
588
+ bias=False,
589
+ )
590
+ self.glu = torch.nn.GLU(dim=1)
591
+ self.depthwise_conv = torch.nn.Conv1d(
592
+ config.hidden_size,
593
+ config.hidden_size,
594
+ config.conv_depthwise_kernel_size,
595
+ stride=1,
596
+ padding=(config.conv_depthwise_kernel_size - 1) // 2,
597
+ groups=config.hidden_size,
598
+ bias=False,
599
+ )
600
+ self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
601
+ self.activation = ACT2FN[config.hidden_act]
602
+ self.pointwise_conv2 = torch.nn.Conv1d(
603
+ config.hidden_size,
604
+ config.hidden_size,
605
+ kernel_size=1,
606
+ stride=1,
607
+ padding=0,
608
+ bias=False,
609
+ )
610
+ self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
611
+
612
+ def forward(self, hidden_states):
613
+ hidden_states = self.layer_norm(hidden_states)
614
+ # exchange the temporal dimension and the feature dimension
615
+ hidden_states = hidden_states.transpose(1, 2)
616
+
617
+ # GLU mechanism
618
+ # => (batch, 2*channel, dim)
619
+ hidden_states = self.pointwise_conv1(hidden_states)
620
+ # => (batch, channel, dim)
621
+ hidden_states = self.glu(hidden_states)
622
+
623
+ # 1D Depthwise Conv
624
+ hidden_states = self.depthwise_conv(hidden_states)
625
+ hidden_states = self.batch_norm(hidden_states)
626
+ hidden_states = self.activation(hidden_states)
627
+
628
+ hidden_states = self.pointwise_conv2(hidden_states)
629
+ hidden_states = self.dropout(hidden_states)
630
+ hidden_states = hidden_states.transpose(1, 2)
631
+ return hidden_states
632
+
633
+
634
+ class Wav2Vec2ConformerSelfAttention(nn.Module):
635
+ """Construct an Wav2Vec2ConformerSelfAttention object.
636
+ Can be enhanced with rotary or relative position embeddings.
637
+ """
638
+
639
+ def __init__(self, config):
640
+ super().__init__()
641
+
642
+ self.head_size = config.hidden_size // config.num_attention_heads
643
+ self.num_heads = config.num_attention_heads
644
+ self.position_embeddings_type = config.position_embeddings_type
645
+
646
+ self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
647
+ self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
648
+ self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
649
+ self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
650
+
651
+ self.dropout = nn.Dropout(p=config.attention_dropout)
652
+ self.dropout_p = config.attention_dropout
653
+
654
+ self.is_causal = config.is_causal
655
+
656
+ if self.position_embeddings_type == "relative":
657
+ # linear transformation for positional encoding
658
+ self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
659
+ # these two learnable bias are used in matrix c and matrix d
660
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
661
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
662
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
663
+
664
+ def forward(
665
+ self,
666
+ hidden_states: torch.Tensor,
667
+ attention_mask: Optional[torch.Tensor] = None,
668
+ relative_position_embeddings: Optional[torch.Tensor] = None,
669
+ output_attentions: bool = False,
670
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
671
+ # self-attention mechanism
672
+ batch_size, sequence_length, hidden_size = hidden_states.size()
673
+
674
+ # make sure query/key states can be != value states
675
+ query_key_states = hidden_states
676
+ value_states = hidden_states
677
+
678
+ if self.position_embeddings_type == "rotary":
679
+ if relative_position_embeddings is None:
680
+ raise ValueError(
681
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
682
+ )
683
+ query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
684
+
685
+ # project query_key_states and value_states
686
+ query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
687
+ key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
688
+ value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
689
+
690
+ # => (batch, head, time1, d_k)
691
+ query = query.transpose(1, 2)
692
+ key = key.transpose(1, 2)
693
+ value = value.transpose(1, 2)
694
+
695
+ with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
696
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal)
697
+ probs = None
698
+
699
+ # # apply attention_mask if necessary
700
+ # if attention_mask is not None:
701
+ # scores = scores + attention_mask
702
+
703
+ # # => (batch, head, time1, time2)
704
+ # probs = torch.softmax(scores, dim=-1)
705
+ # probs = self.dropout(probs)
706
+
707
+ # # => (batch, head, time1, d_k)
708
+ # hidden_states = torch.matmul(probs, value)
709
+
710
+ # => (batch, time1, hidden_size)
711
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
712
+ hidden_states = self.linear_out(hidden_states)
713
+
714
+ return hidden_states, probs
715
+
716
+ def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
717
+ batch_size, sequence_length, hidden_size = hidden_states.size()
718
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
719
+
720
+ cos = relative_position_embeddings[0, :sequence_length, ...]
721
+ sin = relative_position_embeddings[1, :sequence_length, ...]
722
+
723
+ # rotate hidden_states with rotary embeddings
724
+ hidden_states = hidden_states.transpose(0, 1)
725
+ rotated_states_begin = hidden_states[..., : self.head_size // 2]
726
+ rotated_states_end = hidden_states[..., self.head_size // 2 :]
727
+ rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
728
+ hidden_states = (hidden_states * cos) + (rotated_states * sin)
729
+ hidden_states = hidden_states.transpose(0, 1)
730
+
731
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
732
+
733
+ return hidden_states
734
+
735
+ def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
736
+ # 1. project positional embeddings
737
+ # => (batch, head, 2*time1-1, d_k)
738
+ proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
739
+ proj_relative_position_embeddings = proj_relative_position_embeddings.view(
740
+ relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
741
+ )
742
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
743
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
744
+
745
+ # 2. Add bias to query
746
+ # => (batch, head, time1, d_k)
747
+ query = query.transpose(1, 2)
748
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
749
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
750
+
751
+ # 3. attention score: first compute matrix a and matrix c
752
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
753
+ # => (batch, head, time1, time2)
754
+ scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
755
+
756
+ # 4. then compute matrix b and matrix d
757
+ # => (batch, head, time1, 2*time1-1)
758
+ scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
759
+
760
+ # 5. shift matrix b and matrix d
761
+ zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
762
+ scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
763
+ scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
764
+ scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
765
+ scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
766
+ scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
767
+
768
+ # 6. sum matrices
769
+ # => (batch, head, time1, time2)
770
+ scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
771
+
772
+ return scores
773
+
774
+
775
+ class Wav2Vec2ConformerEncoderLayer(nn.Module):
776
+ """Conformer block based on https://arxiv.org/abs/2005.08100."""
777
+
778
+ def __init__(self, config):
779
+ super().__init__()
780
+ embed_dim = config.hidden_size
781
+ dropout = config.attention_dropout
782
+
783
+ # Feed-forward 1
784
+ self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
785
+ self.ffn1 = Wav2Vec2ConformerFeedForward(config)
786
+
787
+ # Self-Attention
788
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
789
+ self.self_attn_dropout = torch.nn.Dropout(dropout)
790
+ self.self_attn = Wav2Vec2ConformerSelfAttention(config)
791
+
792
+ # Conformer Convolution
793
+ self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
794
+
795
+ # Feed-forward 2
796
+ self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
797
+ self.ffn2 = Wav2Vec2ConformerFeedForward(config)
798
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
799
+
800
+ def forward(
801
+ self,
802
+ hidden_states,
803
+ attention_mask: Optional[torch.Tensor] = None,
804
+ relative_position_embeddings: Optional[torch.Tensor] = None,
805
+ output_attentions: bool = False,
806
+ ):
807
+ hidden_states = hidden_states
808
+
809
+ # 1. Feed-Forward 1 layer
810
+ residual = hidden_states
811
+ hidden_states = self.ffn1_layer_norm(hidden_states)
812
+ hidden_states = self.ffn1(hidden_states)
813
+ hidden_states = hidden_states * 0.5 + residual
814
+ residual = hidden_states
815
+
816
+ # 2. Self-Attention layer
817
+ hidden_states = self.self_attn_layer_norm(hidden_states)
818
+ hidden_states, attn_weigts = self.self_attn(
819
+ hidden_states=hidden_states,
820
+ attention_mask=attention_mask,
821
+ relative_position_embeddings=relative_position_embeddings,
822
+ output_attentions=output_attentions,
823
+ )
824
+ hidden_states = self.self_attn_dropout(hidden_states)
825
+ hidden_states = hidden_states + residual
826
+
827
+ # 3. Convolutional Layer
828
+ residual = hidden_states
829
+ hidden_states = self.conv_module(hidden_states)
830
+ hidden_states = residual + hidden_states
831
+
832
+ # 4. Feed-Forward 2 Layer
833
+ residual = hidden_states
834
+ hidden_states = self.ffn2_layer_norm(hidden_states)
835
+ hidden_states = self.ffn2(hidden_states)
836
+ hidden_states = hidden_states * 0.5 + residual
837
+ hidden_states = self.final_layer_norm(hidden_states)
838
+
839
+ return hidden_states, attn_weigts
840
+
841
+
842
+ class Wav2Vec2ConformerEncoder(nn.Module):
843
+ def __init__(self, config, is_causal=False):
844
+ super().__init__()
845
+ config.is_causal = is_causal
846
+ self.config = config
847
+
848
+ if config.position_embeddings_type == "relative":
849
+ self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
850
+ elif config.position_embeddings_type == "rotary":
851
+ self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
852
+ else:
853
+ self.embed_positions = None
854
+
855
+ self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
856
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
857
+ self.dropout = nn.Dropout(config.hidden_dropout)
858
+ self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
859
+ self.gradient_checkpointing = False
860
+
861
+ def forward(
862
+ self,
863
+ hidden_states,
864
+ attention_mask=None,
865
+ output_attentions=False,
866
+ output_hidden_states=False,
867
+ return_dict=True,
868
+ ):
869
+ all_hidden_states = () if output_hidden_states else None
870
+ all_self_attentions = () if output_attentions else None
871
+
872
+ if attention_mask is not None:
873
+ # make sure padded tokens output 0
874
+ hidden_states[~attention_mask] = 0.0
875
+
876
+ # extend attention_mask
877
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
878
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
879
+ attention_mask = attention_mask.expand(
880
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
881
+ )
882
+
883
+ hidden_states = self.dropout(hidden_states)
884
+
885
+ if self.embed_positions is not None:
886
+ relative_position_embeddings = self.embed_positions(hidden_states)
887
+ else:
888
+ relative_position_embeddings = None
889
+
890
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
891
+
892
+ for i, layer in enumerate(self.layers):
893
+ if output_hidden_states:
894
+ all_hidden_states = all_hidden_states + (hidden_states,)
895
+
896
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
897
+ dropout_probability = np.random.uniform(0, 1)
898
+
899
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
900
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
901
+ # under deepspeed zero3 all gpus must run in sync
902
+ if self.gradient_checkpointing and self.training:
903
+ # create gradient checkpointing function
904
+ def create_custom_forward(module):
905
+ def custom_forward(*inputs):
906
+ return module(*inputs, output_attentions)
907
+
908
+ return custom_forward
909
+
910
+ layer_outputs = torch.utils.checkpoint.checkpoint(
911
+ create_custom_forward(layer),
912
+ hidden_states,
913
+ attention_mask,
914
+ relative_position_embeddings,
915
+ )
916
+ else:
917
+ layer_outputs = layer(
918
+ hidden_states,
919
+ attention_mask=attention_mask,
920
+ relative_position_embeddings=relative_position_embeddings,
921
+ output_attentions=output_attentions,
922
+ )
923
+ hidden_states = layer_outputs[0]
924
+
925
+ if skip_the_layer:
926
+ layer_outputs = (None, None)
927
+
928
+ if output_attentions:
929
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
930
+
931
+ hidden_states = self.layer_norm(hidden_states)
932
+ if output_hidden_states:
933
+ all_hidden_states = all_hidden_states + (hidden_states,)
934
+
935
+ if not return_dict:
936
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
937
+ return BaseModelOutput(
938
+ last_hidden_state=hidden_states,
939
+ hidden_states=all_hidden_states,
940
+ attentions=all_self_attentions,
941
+ )
942
+
943
+
944
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
945
+ class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
946
+ """
947
+ Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
948
+ GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
949
+ """
950
+
951
+ def __init__(self, config):
952
+ super().__init__()
953
+ self.num_groups = config.num_codevector_groups
954
+ self.num_vars = config.num_codevectors_per_group
955
+
956
+ if config.codevector_dim % self.num_groups != 0:
957
+ raise ValueError(
958
+ f"`config.codevector_dim {config.codevector_dim} must be divisible "
959
+ f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
960
+ )
961
+
962
+ # storage for codebook variables (codewords)
963
+ self.codevectors = nn.Parameter(
964
+ torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
965
+ )
966
+ self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
967
+
968
+ # can be decayed for training
969
+ self.temperature = 2
970
+
971
+ @staticmethod
972
+ def _compute_perplexity(probs, mask=None):
973
+ if mask is not None:
974
+ mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
975
+ probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
976
+ marginal_probs = probs.sum(dim=0) / mask.sum()
977
+ else:
978
+ marginal_probs = probs.mean(dim=0)
979
+
980
+ perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
981
+ return perplexity
982
+
983
+ def forward(self, hidden_states, mask_time_indices=None):
984
+ batch_size, sequence_length, hidden_size = hidden_states.shape
985
+
986
+ # project to codevector dim
987
+ hidden_states = self.weight_proj(hidden_states)
988
+ hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
989
+
990
+ if self.training:
991
+ # sample code vector probs via gumbel in differentiateable way
992
+ codevector_probs = nn.functional.gumbel_softmax(
993
+ hidden_states.float(), tau=self.temperature, hard=True
994
+ ).type_as(hidden_states)
995
+
996
+ # compute perplexity
997
+ codevector_soft_dist = torch.softmax(
998
+ hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
999
+ )
1000
+ perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
1001
+ else:
1002
+ # take argmax in non-differentiable way
1003
+ # comptute hard codevector distribution (one hot)
1004
+ codevector_idx = hidden_states.argmax(dim=-1)
1005
+ codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
1006
+ -1, codevector_idx.view(-1, 1), 1.0
1007
+ )
1008
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
1009
+
1010
+ perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
1011
+
1012
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
1013
+ # use probs to retrieve codevectors
1014
+ codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
1015
+ codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
1016
+ codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
1017
+
1018
+ return codevectors, perplexity
1019
+
1020
+
1021
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
1022
+ class Wav2Vec2ConformerAdapter(nn.Module):
1023
+ def __init__(self, config):
1024
+ super().__init__()
1025
+
1026
+ # feature dim might need to be down-projected
1027
+ if config.output_hidden_size != config.hidden_size:
1028
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
1029
+ self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
1030
+ else:
1031
+ self.proj = self.proj_layer_norm = None
1032
+
1033
+ self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
1034
+ self.layerdrop = config.layerdrop
1035
+
1036
+ def forward(self, hidden_states):
1037
+ # down project hidden_states if necessary
1038
+ if self.proj is not None and self.proj_layer_norm is not None:
1039
+ hidden_states = self.proj(hidden_states)
1040
+ hidden_states = self.proj_layer_norm(hidden_states)
1041
+
1042
+ hidden_states = hidden_states.transpose(1, 2)
1043
+
1044
+ for layer in self.layers:
1045
+ layerdrop_prob = np.random.random()
1046
+ if not self.training or (layerdrop_prob > self.layerdrop):
1047
+ hidden_states = layer(hidden_states)
1048
+
1049
+ hidden_states = hidden_states.transpose(1, 2)
1050
+ return hidden_states
1051
+
1052
+
1053
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
1054
+ class Wav2Vec2ConformerAdapterLayer(nn.Module):
1055
+ def __init__(self, config):
1056
+ super().__init__()
1057
+ self.conv = nn.Conv1d(
1058
+ config.output_hidden_size,
1059
+ 2 * config.output_hidden_size,
1060
+ config.adapter_kernel_size,
1061
+ stride=config.adapter_stride,
1062
+ padding=1,
1063
+ )
1064
+
1065
+ def forward(self, hidden_states):
1066
+ hidden_states = self.conv(hidden_states)
1067
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
1068
+
1069
+ return hidden_states
1070
+
1071
+
1072
+ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
1073
+ """
1074
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1075
+ models.
1076
+ """
1077
+
1078
+ config_class = Wav2Vec2ConformerConfig
1079
+ base_model_prefix = "wav2vec2_conformer"
1080
+ main_input_name = "input_values"
1081
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1082
+ supports_gradient_checkpointing = True
1083
+
1084
+ def _init_weights(self, module):
1085
+ """Initialize the weights"""
1086
+ # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
1087
+ if isinstance(module, Wav2Vec2ConformerForPreTraining):
1088
+ module.project_hid.reset_parameters()
1089
+ module.project_q.reset_parameters()
1090
+ module.project_hid._is_hf_initialized = True
1091
+ module.project_q._is_hf_initialized = True
1092
+ # gumbel softmax requires special init
1093
+ elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
1094
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
1095
+ module.weight_proj.bias.data.zero_()
1096
+ nn.init.uniform_(module.codevectors)
1097
+ elif isinstance(module, Wav2Vec2ConformerSelfAttention):
1098
+ if hasattr(module, "pos_bias_u"):
1099
+ nn.init.xavier_uniform_(module.pos_bias_u)
1100
+ if hasattr(module, "pos_bias_v"):
1101
+ nn.init.xavier_uniform_(module.pos_bias_v)
1102
+ elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
1103
+ nn.init.normal_(
1104
+ module.conv.weight,
1105
+ mean=0,
1106
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
1107
+ )
1108
+ nn.init.constant_(module.conv.bias, 0)
1109
+ elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
1110
+ k = math.sqrt(1 / module.projection.in_features)
1111
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
1112
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
1113
+ elif isinstance(module, nn.Linear):
1114
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1115
+
1116
+ if module.bias is not None:
1117
+ module.bias.data.zero_()
1118
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
1119
+ module.bias.data.zero_()
1120
+ module.weight.data.fill_(1.0)
1121
+ elif isinstance(module, nn.Conv1d):
1122
+ nn.init.kaiming_normal_(module.weight)
1123
+
1124
+ if module.bias is not None:
1125
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1126
+ nn.init.uniform_(module.bias, a=-k, b=k)
1127
+
1128
+ def _get_feat_extract_output_lengths(
1129
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
1130
+ ):
1131
+ """
1132
+ Computes the output length of the convolutional layers
1133
+ """
1134
+
1135
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
1136
+
1137
+ def _conv_out_length(input_length, kernel_size, stride):
1138
+ # 1D convolutional layer output length formula taken
1139
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
1140
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
1141
+
1142
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
1143
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
1144
+
1145
+ if add_adapter:
1146
+ for _ in range(self.config.num_adapter_layers):
1147
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
1148
+
1149
+ return input_lengths
1150
+
1151
+ def _get_feature_vector_attention_mask(
1152
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
1153
+ ):
1154
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
1155
+ # on inference mode.
1156
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
1157
+
1158
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
1159
+ output_lengths = output_lengths.to(torch.long)
1160
+
1161
+ batch_size = attention_mask.shape[0]
1162
+
1163
+ attention_mask = torch.zeros(
1164
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
1165
+ )
1166
+ # these two operations makes sure that all values before the output lengths idxs are attended to
1167
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
1168
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
1169
+ return attention_mask
1170
+
1171
+ def _set_gradient_checkpointing(self, module, value=False):
1172
+ if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
1173
+ module.gradient_checkpointing = value
1174
+
1175
+
1176
+ WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
1177
+ Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
1178
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
1179
+ Auli.
1180
+
1181
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1182
+ library implements for all its model (such as downloading or saving etc.).
1183
+
1184
+ This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
1185
+ regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
1186
+
1187
+ Parameters:
1188
+ config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
1189
+ Initializing with a config file does not load the weights associated with the model, only the
1190
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1191
+ """
1192
+
1193
+
1194
+ WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
1195
+ Args:
1196
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
1197
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
1198
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
1199
+ soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
1200
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
1201
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1202
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1203
+ 1]`:
1204
+
1205
+ - 1 for tokens that are **not masked**,
1206
+ - 0 for tokens that are **masked**.
1207
+
1208
+ [What are attention masks?](../glossary#attention-mask)
1209
+
1210
+ <Tip warning={true}>
1211
+
1212
+ `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
1213
+ True`. For all models whose processor has `config.return_attention_mask == False`, such as
1214
+ [wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
1215
+ `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
1216
+ such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
1217
+ that these models also yield slightly different results depending on whether `input_values` is padded or
1218
+ not.
1219
+
1220
+ </Tip>
1221
+
1222
+ output_attentions (`bool`, *optional*):
1223
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1224
+ tensors for more detail.
1225
+ output_hidden_states (`bool`, *optional*):
1226
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1227
+ more detail.
1228
+ return_dict (`bool`, *optional*):
1229
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1230
+ """
1231
+
1232
+
1233
+ @add_start_docstrings(
1234
+ "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
1235
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1236
+ )
1237
+ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
1238
+ def __init__(self, config: Wav2Vec2ConformerConfig):
1239
+ super().__init__(config)
1240
+ self.config = config
1241
+ self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
1242
+ self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
1243
+
1244
+ # model only needs masking vector if mask prob is > 0.0
1245
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
1246
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
1247
+
1248
+ self.encoder = Wav2Vec2ConformerEncoder(config)
1249
+
1250
+ self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
1251
+
1252
+ # Initialize weights and apply final processing
1253
+ self.post_init()
1254
+
1255
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
1256
+ def freeze_feature_encoder(self):
1257
+ """
1258
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1259
+ not be updated during training.
1260
+ """
1261
+ self.feature_extractor._freeze_parameters()
1262
+
1263
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
1264
+ def _mask_hidden_states(
1265
+ self,
1266
+ hidden_states: torch.FloatTensor,
1267
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1268
+ attention_mask: Optional[torch.LongTensor] = None,
1269
+ ):
1270
+ """
1271
+ Masks extracted features along time axis and/or along feature axis according to
1272
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
1273
+ """
1274
+
1275
+ # `config.apply_spec_augment` can set masking to False
1276
+ if not getattr(self.config, "apply_spec_augment", True):
1277
+ return hidden_states
1278
+
1279
+ # generate indices & apply SpecAugment along time axis
1280
+ batch_size, sequence_length, hidden_size = hidden_states.size()
1281
+
1282
+ if mask_time_indices is not None:
1283
+ # apply SpecAugment along time axis with given mask_time_indices
1284
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1285
+ elif self.config.mask_time_prob > 0 and self.training:
1286
+ mask_time_indices = _compute_mask_indices(
1287
+ (batch_size, sequence_length),
1288
+ mask_prob=self.config.mask_time_prob,
1289
+ mask_length=self.config.mask_time_length,
1290
+ attention_mask=attention_mask,
1291
+ min_masks=self.config.mask_time_min_masks,
1292
+ )
1293
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
1294
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1295
+
1296
+ if self.config.mask_feature_prob > 0 and self.training:
1297
+ # generate indices & apply SpecAugment along feature axis
1298
+ mask_feature_indices = _compute_mask_indices(
1299
+ (batch_size, hidden_size),
1300
+ mask_prob=self.config.mask_feature_prob,
1301
+ mask_length=self.config.mask_feature_length,
1302
+ min_masks=self.config.mask_feature_min_masks,
1303
+ )
1304
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
1305
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
1306
+ hidden_states[mask_feature_indices] = 0
1307
+
1308
+ return hidden_states
1309
+
1310
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1311
+ @add_code_sample_docstrings(
1312
+ checkpoint=_CHECKPOINT_FOR_DOC,
1313
+ output_type=Wav2Vec2BaseModelOutput,
1314
+ config_class=_CONFIG_FOR_DOC,
1315
+ modality="audio",
1316
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1317
+ )
1318
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
1319
+ def forward(
1320
+ self,
1321
+ input_values: Optional[torch.Tensor],
1322
+ attention_mask: Optional[torch.Tensor] = None,
1323
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1324
+ output_attentions: Optional[bool] = None,
1325
+ output_hidden_states: Optional[bool] = None,
1326
+ return_dict: Optional[bool] = None,
1327
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
1328
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1329
+ output_hidden_states = (
1330
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1331
+ )
1332
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1333
+
1334
+ extract_features = self.feature_extractor(input_values)
1335
+ extract_features = extract_features.transpose(1, 2)
1336
+
1337
+ if attention_mask is not None:
1338
+ # compute reduced attention_mask corresponding to feature vectors
1339
+ attention_mask = self._get_feature_vector_attention_mask(
1340
+ extract_features.shape[1], attention_mask, add_adapter=False
1341
+ )
1342
+
1343
+ hidden_states, extract_features = self.feature_projection(extract_features)
1344
+ hidden_states = self._mask_hidden_states(
1345
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
1346
+ )
1347
+
1348
+ encoder_outputs = self.encoder(
1349
+ hidden_states,
1350
+ attention_mask=attention_mask,
1351
+ output_attentions=output_attentions,
1352
+ output_hidden_states=output_hidden_states,
1353
+ return_dict=return_dict,
1354
+ )
1355
+
1356
+ hidden_states = encoder_outputs[0]
1357
+
1358
+ if self.adapter is not None:
1359
+ hidden_states = self.adapter(hidden_states)
1360
+
1361
+ if not return_dict:
1362
+ return (hidden_states, extract_features) + encoder_outputs[1:]
1363
+
1364
+ return Wav2Vec2BaseModelOutput(
1365
+ last_hidden_state=hidden_states,
1366
+ extract_features=extract_features,
1367
+ hidden_states=encoder_outputs.hidden_states,
1368
+ attentions=encoder_outputs.attentions,
1369
+ )
1370
+
1371
+
1372
+ @add_start_docstrings(
1373
+ """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
1374
+ )
1375
+ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
1376
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1377
+ def __init__(self, config: Wav2Vec2ConformerConfig):
1378
+ super().__init__(config)
1379
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1380
+ self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
1381
+
1382
+ self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
1383
+
1384
+ self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
1385
+ self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
1386
+
1387
+ # Initialize weights and apply final processing
1388
+ self.post_init()
1389
+
1390
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
1391
+ def set_gumbel_temperature(self, temperature: int):
1392
+ """
1393
+ Set the Gumbel softmax temperature to a given value. Only necessary for training
1394
+ """
1395
+ self.quantizer.temperature = temperature
1396
+
1397
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1398
+ def freeze_feature_encoder(self):
1399
+ """
1400
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1401
+ not be updated during training.
1402
+ """
1403
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1404
+
1405
+ @staticmethod
1406
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
1407
+ def compute_contrastive_logits(
1408
+ target_features: torch.FloatTensor,
1409
+ negative_features: torch.FloatTensor,
1410
+ predicted_features: torch.FloatTensor,
1411
+ temperature: int = 0.1,
1412
+ ):
1413
+ """
1414
+ Compute logits for contrastive loss based using cosine similarity as the distance measure between
1415
+ `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
1416
+ """
1417
+ target_features = torch.cat([target_features, negative_features], dim=0)
1418
+
1419
+ logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
1420
+ target_features
1421
+ )
1422
+
1423
+ # apply temperature
1424
+ logits = logits / temperature
1425
+ return logits
1426
+
1427
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1428
+ @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1429
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
1430
+ def forward(
1431
+ self,
1432
+ input_values: Optional[torch.Tensor],
1433
+ attention_mask: Optional[torch.Tensor] = None,
1434
+ mask_time_indices: Optional[torch.BoolTensor] = None,
1435
+ sampled_negative_indices: Optional[torch.BoolTensor] = None,
1436
+ output_attentions: Optional[bool] = None,
1437
+ output_hidden_states: Optional[bool] = None,
1438
+ return_dict: Optional[bool] = None,
1439
+ ) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
1440
+ r"""
1441
+ mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
1442
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
1443
+ masked extracted features in *config.proj_codevector_dim* space.
1444
+ sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
1445
+ Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
1446
+ Required input for pre-training.
1447
+
1448
+ Returns:
1449
+
1450
+ Example:
1451
+
1452
+ ```python
1453
+ >>> import torch
1454
+ >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
1455
+ >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
1456
+ ... _compute_mask_indices,
1457
+ ... _sample_negative_indices,
1458
+ ... )
1459
+ >>> from datasets import load_dataset
1460
+
1461
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
1462
+ >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
1463
+
1464
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1465
+ >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
1466
+
1467
+ >>> # compute masked indices
1468
+ >>> batch_size, raw_sequence_length = input_values.shape
1469
+ >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
1470
+ >>> mask_time_indices = _compute_mask_indices(
1471
+ ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
1472
+ ... )
1473
+ >>> sampled_negative_indices = _sample_negative_indices(
1474
+ ... features_shape=(batch_size, sequence_length),
1475
+ ... num_negatives=model.config.num_negatives,
1476
+ ... mask_time_indices=mask_time_indices,
1477
+ ... )
1478
+ >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
1479
+ >>> sampled_negative_indices = torch.tensor(
1480
+ ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
1481
+ ... )
1482
+
1483
+ >>> with torch.no_grad():
1484
+ ... outputs = model(input_values, mask_time_indices=mask_time_indices)
1485
+
1486
+ >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
1487
+ >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
1488
+
1489
+ >>> # show that cosine similarity is much higher than random
1490
+ >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
1491
+ tensor(True)
1492
+
1493
+ >>> # for contrastive loss training model should be put into train mode
1494
+ >>> model = model.train()
1495
+ >>> loss = model(
1496
+ ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
1497
+ ... ).loss
1498
+ ```"""
1499
+
1500
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1501
+
1502
+ if mask_time_indices is not None:
1503
+ mask_time_indices = mask_time_indices.to(torch.bool)
1504
+
1505
+ outputs = self.wav2vec2_conformer(
1506
+ input_values,
1507
+ attention_mask=attention_mask,
1508
+ output_attentions=output_attentions,
1509
+ output_hidden_states=output_hidden_states,
1510
+ mask_time_indices=mask_time_indices,
1511
+ return_dict=return_dict,
1512
+ )
1513
+
1514
+ # 1. project all transformed features (including masked) to final vq dim
1515
+ transformer_features = self.project_hid(outputs[0])
1516
+
1517
+ # 2. quantize all (unmasked) extracted features and project to final vq dim
1518
+ extract_features = self.dropout_features(outputs[1])
1519
+
1520
+ if attention_mask is not None:
1521
+ # compute reduced attention_mask correponding to feature vectors
1522
+ attention_mask = self._get_feature_vector_attention_mask(
1523
+ extract_features.shape[1], attention_mask, add_adapter=False
1524
+ )
1525
+
1526
+ quantized_features, codevector_perplexity = self.quantizer(
1527
+ extract_features, mask_time_indices=mask_time_indices
1528
+ )
1529
+ quantized_features = self.project_q(quantized_features)
1530
+
1531
+ loss = contrastive_loss = diversity_loss = None
1532
+ if sampled_negative_indices is not None:
1533
+ batch_size, sequence_length, hidden_size = quantized_features.shape
1534
+
1535
+ # for training, we sample negatives
1536
+ # 3. sample K negatives (distractors) quantized states for contrastive loss
1537
+ # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
1538
+ # sample negative quantized vectors BTC => (BxT)C
1539
+ negative_quantized_features = quantized_features.view(-1, hidden_size)[
1540
+ sampled_negative_indices.long().view(-1)
1541
+ ]
1542
+ negative_quantized_features = negative_quantized_features.view(
1543
+ batch_size, sequence_length, -1, hidden_size
1544
+ ).permute(2, 0, 1, 3)
1545
+
1546
+ # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
1547
+ # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
1548
+ logits = self.compute_contrastive_logits(
1549
+ quantized_features[None, :],
1550
+ negative_quantized_features,
1551
+ transformer_features,
1552
+ self.config.contrastive_logits_temperature,
1553
+ )
1554
+
1555
+ # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
1556
+ # its cosine similarity will be masked
1557
+ neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
1558
+
1559
+ if neg_is_pos.any():
1560
+ logits[1:][neg_is_pos] = float("-inf")
1561
+
1562
+ # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
1563
+ # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
1564
+ logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
1565
+ target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
1566
+
1567
+ contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
1568
+ # 7. compute diversity loss: \mathbf{L}_d
1569
+ num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
1570
+ diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
1571
+
1572
+ # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
1573
+ loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
1574
+
1575
+ if not return_dict:
1576
+ if loss is not None:
1577
+ return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
1578
+ return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
1579
+
1580
+ return Wav2Vec2ConformerForPreTrainingOutput(
1581
+ loss=loss,
1582
+ projected_states=transformer_features,
1583
+ projected_quantized_states=quantized_features,
1584
+ codevector_perplexity=codevector_perplexity,
1585
+ hidden_states=outputs.hidden_states,
1586
+ attentions=outputs.attentions,
1587
+ contrastive_loss=contrastive_loss,
1588
+ diversity_loss=diversity_loss,
1589
+ )
1590
+
1591
+
1592
+ @add_start_docstrings(
1593
+ """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
1594
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1595
+ )
1596
+ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
1597
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1598
+ def __init__(self, config):
1599
+ super().__init__(config)
1600
+
1601
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1602
+ self.dropout = nn.Dropout(config.final_dropout)
1603
+
1604
+ if config.vocab_size is None:
1605
+ raise ValueError(
1606
+ f"You are trying to instantiate {self.__class__} with a configuration that "
1607
+ "does not define the vocabulary size of the language model head. Please "
1608
+ "instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
1609
+ "or define `vocab_size` of your model's configuration."
1610
+ )
1611
+ output_hidden_size = (
1612
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
1613
+ )
1614
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
1615
+
1616
+ # Initialize weights and apply final processing
1617
+ self.post_init()
1618
+
1619
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1620
+ def freeze_feature_encoder(self):
1621
+ """
1622
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1623
+ not be updated during training.
1624
+ """
1625
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1626
+
1627
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1628
+ @add_code_sample_docstrings(
1629
+ checkpoint=_CHECKPOINT_FOR_DOC,
1630
+ output_type=CausalLMOutput,
1631
+ config_class=_CONFIG_FOR_DOC,
1632
+ expected_output=_CTC_EXPECTED_OUTPUT,
1633
+ expected_loss=_CTC_EXPECTED_LOSS,
1634
+ )
1635
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1636
+ def forward(
1637
+ self,
1638
+ input_values: Optional[torch.Tensor],
1639
+ attention_mask: Optional[torch.Tensor] = None,
1640
+ output_attentions: Optional[bool] = None,
1641
+ output_hidden_states: Optional[bool] = None,
1642
+ return_dict: Optional[bool] = None,
1643
+ labels: Optional[torch.Tensor] = None,
1644
+ ) -> Union[Tuple, CausalLMOutput]:
1645
+ r"""
1646
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
1647
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
1648
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
1649
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
1650
+ config.vocab_size - 1]`.
1651
+ """
1652
+
1653
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1654
+
1655
+ outputs = self.wav2vec2_conformer(
1656
+ input_values,
1657
+ attention_mask=attention_mask,
1658
+ output_attentions=output_attentions,
1659
+ output_hidden_states=output_hidden_states,
1660
+ return_dict=return_dict,
1661
+ )
1662
+
1663
+ hidden_states = outputs[0]
1664
+ hidden_states = self.dropout(hidden_states)
1665
+
1666
+ logits = self.lm_head(hidden_states)
1667
+
1668
+ loss = None
1669
+ if labels is not None:
1670
+ if labels.max() >= self.config.vocab_size:
1671
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
1672
+
1673
+ # retrieve loss input_lengths from attention_mask
1674
+ attention_mask = (
1675
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
1676
+ )
1677
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
1678
+
1679
+ # assuming that padded tokens are filled with -100
1680
+ # when not being attended to
1681
+ labels_mask = labels >= 0
1682
+ target_lengths = labels_mask.sum(-1)
1683
+ flattened_targets = labels.masked_select(labels_mask)
1684
+
1685
+ # ctc_loss doesn't support fp16
1686
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
1687
+
1688
+ with torch.backends.cudnn.flags(enabled=False):
1689
+ loss = nn.functional.ctc_loss(
1690
+ log_probs,
1691
+ flattened_targets,
1692
+ input_lengths,
1693
+ target_lengths,
1694
+ blank=self.config.pad_token_id,
1695
+ reduction=self.config.ctc_loss_reduction,
1696
+ zero_infinity=self.config.ctc_zero_infinity,
1697
+ )
1698
+
1699
+ if not return_dict:
1700
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1701
+ return ((loss,) + output) if loss is not None else output
1702
+
1703
+ return CausalLMOutput(
1704
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1705
+ )
1706
+
1707
+
1708
+ @add_start_docstrings(
1709
+ """
1710
+ Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
1711
+ tasks like SUPERB Keyword Spotting.
1712
+ """,
1713
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1714
+ )
1715
+ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
1716
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1717
+ def __init__(self, config):
1718
+ super().__init__(config)
1719
+
1720
+ if hasattr(config, "add_adapter") and config.add_adapter:
1721
+ raise ValueError(
1722
+ "Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
1723
+ )
1724
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1725
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1726
+ if config.use_weighted_layer_sum:
1727
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1728
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
1729
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
1730
+
1731
+ # Initialize weights and apply final processing
1732
+ self.post_init()
1733
+
1734
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1735
+ def freeze_feature_encoder(self):
1736
+ """
1737
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1738
+ not be updated during training.
1739
+ """
1740
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1741
+
1742
+ def freeze_base_model(self):
1743
+ """
1744
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1745
+ be updated during training. Only the classification head will be updated.
1746
+ """
1747
+ for param in self.wav2vec2_conformer.parameters():
1748
+ param.requires_grad = False
1749
+
1750
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1751
+ @add_code_sample_docstrings(
1752
+ checkpoint=_CHECKPOINT_FOR_DOC,
1753
+ output_type=SequenceClassifierOutput,
1754
+ config_class=_CONFIG_FOR_DOC,
1755
+ modality="audio",
1756
+ )
1757
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
1758
+ def forward(
1759
+ self,
1760
+ input_values: Optional[torch.Tensor],
1761
+ attention_mask: Optional[torch.Tensor] = None,
1762
+ output_attentions: Optional[bool] = None,
1763
+ output_hidden_states: Optional[bool] = None,
1764
+ return_dict: Optional[bool] = None,
1765
+ labels: Optional[torch.Tensor] = None,
1766
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1767
+ r"""
1768
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1769
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1770
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1771
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1772
+ """
1773
+
1774
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1775
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1776
+
1777
+ outputs = self.wav2vec2_conformer(
1778
+ input_values,
1779
+ attention_mask=attention_mask,
1780
+ output_attentions=output_attentions,
1781
+ output_hidden_states=output_hidden_states,
1782
+ return_dict=return_dict,
1783
+ )
1784
+
1785
+ if self.config.use_weighted_layer_sum:
1786
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1787
+ hidden_states = torch.stack(hidden_states, dim=1)
1788
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1789
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1790
+ else:
1791
+ hidden_states = outputs[0]
1792
+
1793
+ hidden_states = self.projector(hidden_states)
1794
+ if attention_mask is None:
1795
+ pooled_output = hidden_states.mean(dim=1)
1796
+ else:
1797
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
1798
+ hidden_states[~padding_mask] = 0.0
1799
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
1800
+
1801
+ logits = self.classifier(pooled_output)
1802
+
1803
+ loss = None
1804
+ if labels is not None:
1805
+ loss_fct = CrossEntropyLoss()
1806
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1807
+
1808
+ if not return_dict:
1809
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1810
+ return ((loss,) + output) if loss is not None else output
1811
+
1812
+ return SequenceClassifierOutput(
1813
+ loss=loss,
1814
+ logits=logits,
1815
+ hidden_states=outputs.hidden_states,
1816
+ attentions=outputs.attentions,
1817
+ )
1818
+
1819
+
1820
+ @add_start_docstrings(
1821
+ """
1822
+ Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
1823
+ """,
1824
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1825
+ )
1826
+ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
1827
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
1828
+ def __init__(self, config):
1829
+ super().__init__(config)
1830
+
1831
+ if hasattr(config, "add_adapter") and config.add_adapter:
1832
+ raise ValueError(
1833
+ "Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
1834
+ )
1835
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1836
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1837
+ if config.use_weighted_layer_sum:
1838
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1839
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1840
+ self.num_labels = config.num_labels
1841
+
1842
+ self.init_weights()
1843
+
1844
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1845
+ def freeze_feature_encoder(self):
1846
+ """
1847
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1848
+ not be updated during training.
1849
+ """
1850
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1851
+
1852
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
1853
+ def freeze_base_model(self):
1854
+ """
1855
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1856
+ be updated during training. Only the classification head will be updated.
1857
+ """
1858
+ for param in self.wav2vec2_conformer.parameters():
1859
+ param.requires_grad = False
1860
+
1861
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1862
+ @add_code_sample_docstrings(
1863
+ checkpoint=_CHECKPOINT_FOR_DOC,
1864
+ output_type=TokenClassifierOutput,
1865
+ config_class=_CONFIG_FOR_DOC,
1866
+ modality="audio",
1867
+ )
1868
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
1869
+ def forward(
1870
+ self,
1871
+ input_values: Optional[torch.Tensor],
1872
+ attention_mask: Optional[torch.Tensor] = None,
1873
+ labels: Optional[torch.Tensor] = None,
1874
+ output_attentions: Optional[bool] = None,
1875
+ output_hidden_states: Optional[bool] = None,
1876
+ return_dict: Optional[bool] = None,
1877
+ ) -> Union[Tuple, TokenClassifierOutput]:
1878
+ r"""
1879
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1880
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1881
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1882
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1883
+ """
1884
+
1885
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1886
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1887
+
1888
+ outputs = self.wav2vec2_conformer(
1889
+ input_values,
1890
+ attention_mask=attention_mask,
1891
+ output_attentions=output_attentions,
1892
+ output_hidden_states=output_hidden_states,
1893
+ return_dict=return_dict,
1894
+ )
1895
+
1896
+ if self.config.use_weighted_layer_sum:
1897
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1898
+ hidden_states = torch.stack(hidden_states, dim=1)
1899
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1900
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1901
+ else:
1902
+ hidden_states = outputs[0]
1903
+
1904
+ logits = self.classifier(hidden_states)
1905
+
1906
+ loss = None
1907
+ if labels is not None:
1908
+ loss_fct = CrossEntropyLoss()
1909
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
1910
+
1911
+ if not return_dict:
1912
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1913
+ return output
1914
+
1915
+ return TokenClassifierOutput(
1916
+ loss=loss,
1917
+ logits=logits,
1918
+ hidden_states=outputs.hidden_states,
1919
+ attentions=outputs.attentions,
1920
+ )
1921
+
1922
+
1923
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
1924
+ class AMSoftmaxLoss(nn.Module):
1925
+ def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
1926
+ super(AMSoftmaxLoss, self).__init__()
1927
+ self.scale = scale
1928
+ self.margin = margin
1929
+ self.num_labels = num_labels
1930
+ self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
1931
+ self.loss = nn.CrossEntropyLoss()
1932
+
1933
+ def forward(self, hidden_states, labels):
1934
+ labels = labels.flatten()
1935
+ weight = nn.functional.normalize(self.weight, dim=0)
1936
+ hidden_states = nn.functional.normalize(hidden_states, dim=1)
1937
+ cos_theta = torch.mm(hidden_states, weight)
1938
+ psi = cos_theta - self.margin
1939
+
1940
+ onehot = nn.functional.one_hot(labels, self.num_labels)
1941
+ logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
1942
+ loss = self.loss(logits, labels)
1943
+
1944
+ return loss
1945
+
1946
+
1947
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
1948
+ class TDNNLayer(nn.Module):
1949
+ def __init__(self, config, layer_id=0):
1950
+ super().__init__()
1951
+ self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
1952
+ self.out_conv_dim = config.tdnn_dim[layer_id]
1953
+ self.kernel_size = config.tdnn_kernel[layer_id]
1954
+ self.dilation = config.tdnn_dilation[layer_id]
1955
+
1956
+ self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
1957
+ self.activation = nn.ReLU()
1958
+
1959
+ def forward(self, hidden_states):
1960
+ hidden_states = hidden_states.unsqueeze(1)
1961
+ hidden_states = nn.functional.unfold(
1962
+ hidden_states,
1963
+ (self.kernel_size, self.in_conv_dim),
1964
+ stride=(1, self.in_conv_dim),
1965
+ dilation=(self.dilation, 1),
1966
+ )
1967
+ hidden_states = hidden_states.transpose(1, 2)
1968
+ hidden_states = self.kernel(hidden_states)
1969
+
1970
+ hidden_states = self.activation(hidden_states)
1971
+ return hidden_states
1972
+
1973
+
1974
+ @add_start_docstrings(
1975
+ """
1976
+ Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
1977
+ """,
1978
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1979
+ )
1980
+ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
1981
+ def __init__(self, config):
1982
+ super().__init__(config)
1983
+
1984
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1985
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1986
+ if config.use_weighted_layer_sum:
1987
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1988
+ self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
1989
+
1990
+ tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
1991
+ self.tdnn = nn.ModuleList(tdnn_layers)
1992
+
1993
+ self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
1994
+ self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
1995
+
1996
+ self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
1997
+
1998
+ self.init_weights()
1999
+
2000
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
2001
+ def freeze_feature_encoder(self):
2002
+ """
2003
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
2004
+ not be updated during training.
2005
+ """
2006
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
2007
+
2008
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
2009
+ def freeze_base_model(self):
2010
+ """
2011
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
2012
+ be updated during training. Only the classification head will be updated.
2013
+ """
2014
+ for param in self.wav2vec2_conformer.parameters():
2015
+ param.requires_grad = False
2016
+
2017
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
2018
+ def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
2019
+ """
2020
+ Computes the output length of the TDNN layers
2021
+ """
2022
+
2023
+ def _conv_out_length(input_length, kernel_size, stride):
2024
+ # 1D convolutional layer output length formula taken
2025
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
2026
+ return (input_length - kernel_size) // stride + 1
2027
+
2028
+ for kernel_size in self.config.tdnn_kernel:
2029
+ input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
2030
+
2031
+ return input_lengths
2032
+
2033
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
2034
+ @add_code_sample_docstrings(
2035
+ checkpoint=_CHECKPOINT_FOR_DOC,
2036
+ output_type=XVectorOutput,
2037
+ config_class=_CONFIG_FOR_DOC,
2038
+ modality="audio",
2039
+ )
2040
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
2041
+ def forward(
2042
+ self,
2043
+ input_values: Optional[torch.Tensor],
2044
+ attention_mask: Optional[torch.Tensor] = None,
2045
+ output_attentions: Optional[bool] = None,
2046
+ output_hidden_states: Optional[bool] = None,
2047
+ return_dict: Optional[bool] = None,
2048
+ labels: Optional[torch.Tensor] = None,
2049
+ ) -> Union[Tuple, XVectorOutput]:
2050
+ r"""
2051
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
2052
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
2053
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
2054
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
2055
+ """
2056
+
2057
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2058
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
2059
+
2060
+ outputs = self.wav2vec2_conformer(
2061
+ input_values,
2062
+ attention_mask=attention_mask,
2063
+ output_attentions=output_attentions,
2064
+ output_hidden_states=output_hidden_states,
2065
+ return_dict=return_dict,
2066
+ )
2067
+
2068
+ if self.config.use_weighted_layer_sum:
2069
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
2070
+ hidden_states = torch.stack(hidden_states, dim=1)
2071
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
2072
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
2073
+ else:
2074
+ hidden_states = outputs[0]
2075
+
2076
+ hidden_states = self.projector(hidden_states)
2077
+
2078
+ for tdnn_layer in self.tdnn:
2079
+ hidden_states = tdnn_layer(hidden_states)
2080
+
2081
+ # Statistic Pooling
2082
+ if attention_mask is None:
2083
+ mean_features = hidden_states.mean(dim=1)
2084
+ std_features = hidden_states.std(dim=1)
2085
+ else:
2086
+ feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
2087
+ tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
2088
+ mean_features = []
2089
+ std_features = []
2090
+ for i, length in enumerate(tdnn_output_lengths):
2091
+ mean_features.append(hidden_states[i, :length].mean(dim=0))
2092
+ std_features.append(hidden_states[i, :length].std(dim=0))
2093
+ mean_features = torch.stack(mean_features)
2094
+ std_features = torch.stack(std_features)
2095
+ statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
2096
+
2097
+ output_embeddings = self.feature_extractor(statistic_pooling)
2098
+ logits = self.classifier(output_embeddings)
2099
+
2100
+ loss = None
2101
+ if labels is not None:
2102
+ loss = self.objective(logits, labels)
2103
+
2104
+ if not return_dict:
2105
+ output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
2106
+ return ((loss,) + output) if loss is not None else output
2107
+
2108
+ return XVectorOutput(
2109
+ loss=loss,
2110
+ logits=logits,
2111
+ embeddings=output_embeddings,
2112
+ hidden_states=outputs.hidden_states,
2113
+ attentions=outputs.attentions,
2114
+ )
src/third_party/MuQ/src/muq/muq/modules/random_quantizer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ from einops import rearrange
4
+
5
+
6
+ class RandomProjectionQuantizer(nn.Module):
7
+ """
8
+ Random projection and codebook lookup module
9
+
10
+ Some code is borrowed from:
11
+ https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
12
+ But I did normalization using pre-computed global mean & variance instead of using layer norm.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ input_dim,
18
+ codebook_dim,
19
+ codebook_size,
20
+ seed=142,
21
+ ):
22
+ super().__init__()
23
+
24
+ # random seed
25
+ torch.manual_seed(seed)
26
+
27
+ # randomly initialized projection
28
+ random_projection = torch.empty(input_dim, codebook_dim)
29
+ nn.init.xavier_normal_(random_projection)
30
+ self.register_buffer("random_projection", random_projection)
31
+
32
+ # randomly initialized codebook
33
+ codebook = torch.empty(codebook_size, codebook_dim)
34
+ nn.init.normal_(codebook)
35
+ self.register_buffer("codebook", codebook)
36
+
37
+ def codebook_lookup(self, x):
38
+ # reshape
39
+ b = x.shape[0]
40
+ x = rearrange(x, "b n e -> (b n) e")
41
+
42
+ # L2 normalization
43
+ normalized_x = nn.functional.normalize(x, dim=1, p=2)
44
+ normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
45
+
46
+ # compute distances
47
+ distances = torch.cdist(normalized_codebook, normalized_x)
48
+
49
+ # get nearest
50
+ nearest_indices = torch.argmin(distances, dim=0)
51
+
52
+ # reshape
53
+ xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
54
+
55
+ return xq
56
+
57
+ @torch.no_grad()
58
+ def forward(self, x):
59
+ # always eval
60
+ self.eval()
61
+
62
+ # random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
63
+ x = einsum("b n d, d e -> b n e", x, self.random_projection)
64
+
65
+ # codebook lookup
66
+ xq = self.codebook_lookup(x)
67
+
68
+ return xq
src/third_party/MuQ/src/muq/muq/modules/rvq.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ try:
10
+ from torch.nn.utils import weight_norm
11
+ except:
12
+ try:
13
+ from torch.nn.utils.parametrizations import weight_norm
14
+ except:
15
+ from torch.nn.utils.parametrize import weight_norm
16
+
17
+ def WNConv1d(*args, **kwargs):
18
+ return weight_norm(nn.Conv1d(*args, **kwargs))
19
+
20
+
21
+ class VectorQuantize(nn.Module):
22
+ """
23
+ Implementation of VQ similar to Karpathy's repo:
24
+ https://github.com/karpathy/deep-vector-quantization
25
+ Additionally uses following tricks from Improved VQGAN
26
+ (https://arxiv.org/pdf/2110.04627.pdf):
27
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
28
+ for improved codebook usage
29
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
30
+ improves training stability
31
+ """
32
+
33
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 1000, mfcc_clustering=False, n_layer=1):
34
+ super().__init__()
35
+ self.codebook_size = codebook_size
36
+ self.codebook_dim = codebook_dim
37
+ self.mfcc_clustering = mfcc_clustering
38
+
39
+ ProjClass = nn.Identity if mfcc_clustering else WNConv1d
40
+ if n_layer==1:
41
+ self.in_proj = ProjClass(input_dim, codebook_dim, kernel_size=1)
42
+ self.out_proj = ProjClass(codebook_dim, input_dim, kernel_size=1)
43
+ elif n_layer >= 2:
44
+ ndim_hidden = 128
45
+ self.in_proj = nn.Sequential(
46
+ ProjClass(input_dim, ndim_hidden, kernel_size=1),
47
+ *[nn.Sequential(nn.ReLU(), ProjClass(ndim_hidden, ndim_hidden, kernel_size=1),) for _ in range(n_layer-2)],
48
+ nn.ReLU(),
49
+ ProjClass(ndim_hidden, codebook_dim, kernel_size=1)
50
+ )
51
+ self.out_proj = nn.Sequential(
52
+ ProjClass(codebook_dim, ndim_hidden, kernel_size=1),
53
+ nn.ReLU(),
54
+ *[nn.Sequential(ProjClass(ndim_hidden, ndim_hidden, kernel_size=1), nn.ReLU()) for _ in range(n_layer-2)],
55
+ ProjClass(ndim_hidden, input_dim, kernel_size=1),
56
+ )
57
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
58
+ self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
59
+ self.stale_tolerance = stale_tolerance
60
+
61
+ def forward(self, z):
62
+ """Quantized the input tensor using a fixed codebook and returns
63
+ the corresponding codebook vectors
64
+
65
+ Parameters
66
+ ----------
67
+ z : Tensor[B x D x T]
68
+
69
+ Returns
70
+ -------
71
+ Tensor[B x D x T]
72
+ Quantized continuous representation of input
73
+ Tensor[1]
74
+ Commitment loss to train encoder to predict vectors closer to codebook
75
+ entries
76
+ Tensor[1]
77
+ Codebook loss to update the codebook
78
+ Tensor[B x T]
79
+ Codebook indices (quantized discrete representation of input)
80
+ Tensor[B x D x T]
81
+ Projected latents (continuous representation of input before quantization)
82
+ """
83
+
84
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
85
+
86
+ z_e = self.in_proj(z) # z_e : (B x D x T)
87
+ z_q, indices = self.decode_latents(z_e)
88
+
89
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
90
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
91
+
92
+ z_q = (
93
+ z_e + (z_q - z_e).detach()
94
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
95
+
96
+ z_q = self.out_proj(z_q)
97
+
98
+ return z_q, commitment_loss, codebook_loss, indices, z_e
99
+
100
+ def embed_code(self, embed_id):
101
+ return F.embedding(embed_id, self.codebook.weight)
102
+
103
+ def decode_code(self, embed_id):
104
+ return self.embed_code(embed_id).transpose(1, 2)
105
+
106
+ def decode_latents(self, latents):
107
+ encodings = rearrange(latents, "b d t -> (b t) d")
108
+ codebook = self.codebook.weight # codebook: (N x D)
109
+
110
+ # L2 normalize encodings and codebook (ViT-VQGAN)
111
+ encodings = F.normalize(encodings)
112
+ codebook = F.normalize(codebook)
113
+
114
+ # Compute euclidean distance with codebook
115
+ dist = (
116
+ encodings.pow(2).sum(1, keepdim=True)
117
+ - 2 * encodings @ codebook.t()
118
+ + codebook.pow(2).sum(1, keepdim=True).t()
119
+ )
120
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
121
+ z_q = self.decode_code(indices)
122
+
123
+ if(self.training):
124
+ onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
125
+ stale_codes = (onehots.sum(0).sum(0) == 0).float()
126
+ self.stale_counter = self.stale_counter * stale_codes + stale_codes
127
+
128
+ # random replace codes that haven't been used for a while
129
+ replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
130
+ if replace_code.sum(-1) > 0:
131
+ print("Replace {} codes".format(replace_code.sum(-1)))
132
+ random_input_idx = torch.randperm(encodings.shape[0])
133
+ random_input = encodings[random_input_idx].view(encodings.shape)
134
+ if random_input.shape[0] < self.codebook_size:
135
+ random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
136
+ random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
137
+
138
+ self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
139
+ self.stale_counter = self.stale_counter * (1 - replace_code)
140
+
141
+ return z_q, indices
142
+
143
+
144
+ class ResidualVectorQuantize(nn.Module):
145
+ """
146
+ Introduced in SoundStream: An end2end neural audio codec
147
+ https://arxiv.org/abs/2107.03312
148
+ """
149
+
150
+ def __init__(
151
+ self,
152
+ input_dim: int = 512,
153
+ n_codebooks: int = 9,
154
+ codebook_size: int = 1024,
155
+ codebook_dim: Union[int, list] = 8,
156
+ quantizer_dropout: float = 0.0,
157
+ stale_tolerance: int = 100,
158
+ use_multi_layer_num:int = 1,
159
+ ):
160
+ super().__init__()
161
+ if isinstance(codebook_dim, int):
162
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
163
+
164
+ self.n_codebooks = n_codebooks
165
+ self.codebook_dim = codebook_dim
166
+ self.codebook_size = codebook_size
167
+
168
+ self.quantizers = nn.ModuleList(
169
+ [
170
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance, n_layer=use_multi_layer_num)
171
+ for i in range(n_codebooks)
172
+ ]
173
+ )
174
+ self.quantizer_dropout = quantizer_dropout
175
+
176
+ def forward(self, z, n_quantizers: int = None):
177
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
178
+ the corresponding codebook vectors
179
+ Parameters
180
+ ----------
181
+ z : Tensor[B x D x T]
182
+ n_quantizers : int, optional
183
+ No. of quantizers to use
184
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
185
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
186
+ when in training mode, and a random number of quantizers is used.
187
+ Returns
188
+ -------
189
+ dict
190
+ A dictionary with the following keys:
191
+
192
+ "z" : Tensor[B x D x T]
193
+ Quantized continuous representation of input
194
+ "codes" : Tensor[B x N x T]
195
+ Codebook indices for each codebook
196
+ (quantized discrete representation of input)
197
+ "latents" : Tensor[B x N*D x T]
198
+ Projected latents (continuous representation of input before quantization)
199
+ "vq/commitment_loss" : Tensor[1]
200
+ Commitment loss to train encoder to predict vectors closer to codebook
201
+ entries
202
+ "vq/codebook_loss" : Tensor[1]
203
+ Codebook loss to update the codebook
204
+ """
205
+ z_q = 0
206
+ residual = z
207
+ commitment_loss = 0
208
+ codebook_loss = 0
209
+
210
+ codebook_indices = []
211
+ latents = []
212
+
213
+ if n_quantizers is None:
214
+ n_quantizers = self.n_codebooks
215
+ if self.training:
216
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
217
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
218
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
219
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
220
+ n_quantizers = n_quantizers.to(z.device)
221
+ else:
222
+ n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1
223
+ n_quantizers = n_quantizers.to(z.device)
224
+
225
+ for i, quantizer in enumerate(self.quantizers):
226
+ # if self.training is False and i >= n_quantizers:
227
+ # break
228
+
229
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
230
+ residual
231
+ )
232
+
233
+ # Create mask to apply quantizer dropout
234
+ mask = (
235
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
236
+ )
237
+ z_q = z_q + z_q_i * mask[:, None, None]
238
+ residual = residual - z_q_i
239
+
240
+ # Sum losses
241
+ commitment_loss += (commitment_loss_i * mask).mean()
242
+ codebook_loss += (codebook_loss_i * mask).mean()
243
+
244
+ codebook_indices.append(indices_i)
245
+ latents.append(z_e_i)
246
+
247
+ codes = torch.stack(codebook_indices, dim=1)
248
+ latents = torch.cat(latents, dim=1)
249
+
250
+ encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
251
+
252
+ return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
253
+
254
+ def get_loss(self, x, quantized_prompt_embeds, commitment_loss, codebook_loss):
255
+ final_loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
256
+ return final_loss
257
+
258
+ def from_codes(self, codes: torch.Tensor):
259
+ """Given the quantized codes, reconstruct the continuous representation
260
+ Parameters
261
+ ----------
262
+ codes : Tensor[B x N x T]
263
+ Quantized discrete representation of input
264
+ Returns
265
+ -------
266
+ Tensor[B x D x T]
267
+ Quantized continuous representation of input
268
+ """
269
+ z_q = 0.0
270
+ z_p = []
271
+ n_codebooks = codes.shape[1]
272
+ for i in range(n_codebooks):
273
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
274
+ z_p.append(z_p_i)
275
+
276
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
277
+ z_q = z_q + z_q_i
278
+ return z_q, torch.cat(z_p, dim=1), codes
279
+
280
+ def from_latents(self, latents: torch.Tensor):
281
+ """Given the unquantized latents, reconstruct the
282
+ continuous representation after quantization.
283
+
284
+ Parameters
285
+ ----------
286
+ latents : Tensor[B x N x T]
287
+ Continuous representation of input after projection
288
+
289
+ Returns
290
+ -------
291
+ Tensor[B x D x T]
292
+ Quantized representation of full-projected space
293
+ Tensor[B x D x T]
294
+ Quantized representation of latent space
295
+ """
296
+ z_q = 0
297
+ z_p = []
298
+ codes = []
299
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
300
+
301
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
302
+ 0
303
+ ]
304
+ for i in range(n_codebooks):
305
+ j, k = dims[i], dims[i + 1]
306
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
307
+ z_p.append(z_p_i)
308
+ codes.append(codes_i)
309
+
310
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
311
+ z_q = z_q + z_q_i
312
+
313
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
314
+
src/third_party/MuQ/src/muq/muq/muq.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from .models.muq_model import MuQModel
4
+ from dataclasses import dataclass, field
5
+ from typing import List, Optional
6
+ from transformers.modeling_outputs import BaseModelOutput
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+
9
+ @dataclass
10
+ class MuQConfig:
11
+ label_rate:int = field(default=25)
12
+ num_codebooks:int = field(default=1)
13
+ codebook_dim:int = field(default=16)
14
+ codebook_size:int = field(default=4096)
15
+ features:List[str] = field(default_factory=lambda:["melspec_2048"])
16
+ hop_length:int = field(default=240)
17
+ n_mels:int = field(default=128)
18
+ conv_dim:int = field(default=512)
19
+ encoder_dim:int = field(default=1024)
20
+ encoder_depth:int = field(default=12)
21
+ mask_hop:float = field(default=0.4)
22
+ mask_prob:float = field(default=0.6)
23
+ is_flash:bool = field(default=False)
24
+ stat:Optional[dict] = field(default_factory=dict)
25
+ w2v2_config:Optional[dict] = field(default_factory=dict)
26
+ use_rvq_target:bool = field(default=False)
27
+ use_vq_target:bool = field(default=False)
28
+ use_encodec_target:bool = field(default=False)
29
+ rvq_ckpt_path: Optional[str] = field(default=None)
30
+ recon_loss_ratio: Optional[float] = field(default=None)
31
+ resume_checkpoint: Optional[str] = None
32
+ rvq_n_codebooks:int = field(default=8)
33
+ rvq_multi_layer_num:int = field(default=1)
34
+
35
+ class MuQ(nn.Module, PyTorchModelHubMixin):
36
+ def __init__(self, config: MuQConfig):
37
+ super().__init__()
38
+ if isinstance(config, dict):
39
+ config = MuQConfig(**config)
40
+ self.config = config
41
+ self.model = MuQModel(
42
+ num_codebooks=config.num_codebooks,
43
+ codebook_dim=config.codebook_dim,
44
+ codebook_size=config.codebook_size,
45
+ features=config.features,
46
+ hop_length=config.hop_length,
47
+ n_mels=config.n_mels,
48
+ conv_dim=config.conv_dim,
49
+ encoder_dim=config.encoder_dim,
50
+ encoder_depth=config.encoder_depth,
51
+ mask_hop=config.mask_hop,
52
+ mask_prob=config.mask_prob,
53
+ is_flash=config.is_flash,
54
+ stat=config.stat,
55
+ w2v2_config=config.w2v2_config,
56
+ use_rvq_target=config.use_rvq_target,
57
+ use_vq_target=config.use_vq_target,
58
+ use_encodec_target=config.use_encodec_target,
59
+ rvq_ckpt_path=config.rvq_ckpt_path,
60
+ recon_loss_ratio=config.recon_loss_ratio,
61
+ label_rate=config.label_rate,
62
+ rvq_n_codebooks=config.rvq_n_codebooks,
63
+ rvq_multi_layer_num=config.rvq_multi_layer_num,
64
+ )
65
+
66
+ def forward(self, x, attention_mask:Optional[torch.Tensor]=None, output_hidden_states:bool=True) ->BaseModelOutput:
67
+ """
68
+ Forward pass through the MuQ model and extract features.
69
+
70
+ Args:
71
+ x (torch.Tensor): Input waveform tensor of shape (batch_size, time).
72
+ attention_mask (torch.Tensor, optional): Mask to avoid performing attention on padding token indices.
73
+ Default is None.
74
+ output_hidden_states (bool, optional): Whether to return all hidden states or only the last one.
75
+ Default is False.
76
+
77
+ Returns:
78
+ BaseModelOutput: An object containing the last hidden state and optionally all hidden states.
79
+ - last_hidden_state (torch.Tensor): The last hidden state of the model, i.e. extracted MuQ features, of shape (batch_size, sequence_length, hidden_size).
80
+ - hidden_states (tuple(torch.Tensor), optional): A tuple containing all hidden states produced by the model,
81
+ each of shape (batch_size, sequence_length, hidden_size). Only returned if output_hidden_states is True.
82
+ """
83
+ _, hidden_states = self.model.get_predictions(x, attention_mask=attention_mask, is_features_only=True)
84
+ last_hidden_state = hidden_states[-1]
85
+ if not output_hidden_states:
86
+ return BaseModelOutput(last_hidden_state=last_hidden_state)
87
+ return BaseModelOutput(
88
+ last_hidden_state=last_hidden_state,
89
+ hidden_states=hidden_states
90
+ )
src/third_party/MuQ/src/muq/muq_mulan/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .muq_mulan import MuQMuLan, MuQMuLanConfig, MuLanConfig, ModalModelConfig, TextTransformerConfig, AudioTransformerConfig
src/third_party/MuQ/src/muq/muq_mulan/models/__init__.py ADDED
File without changes