Spaces:
Running
on
Zero
Running
on
Zero
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +186 -0
- app.py +636 -0
- requirements.txt +86 -0
- src/SongFormer/ckpts/md5sum.txt +4 -0
- src/SongFormer/configs/SongFormer.yaml +186 -0
- src/SongFormer/dataset/DatasetAdaper.py +33 -0
- src/SongFormer/dataset/GeminiOnlyLabelAdapter.py +332 -0
- src/SongFormer/dataset/HookTheoryAdapter.py +448 -0
- src/SongFormer/dataset/custom_types.py +14 -0
- src/SongFormer/dataset/label2id.py +163 -0
- src/SongFormer/dataset/msa_info_utils.py +47 -0
- src/SongFormer/eval.sh +22 -0
- src/SongFormer/evaluation/eval_infer_results.py +198 -0
- src/SongFormer/infer.sh +21 -0
- src/SongFormer/infer/infer.py +439 -0
- src/SongFormer/models/SongFormer.py +521 -0
- src/SongFormer/postprocessing/calc_acc.py +82 -0
- src/SongFormer/postprocessing/calc_iou.py +89 -0
- src/SongFormer/postprocessing/functional.py +71 -0
- src/SongFormer/postprocessing/helpers.py +101 -0
- src/SongFormer/train/accelerate_config/single_gpu.yaml +17 -0
- src/SongFormer/utils/average_checkpoints.py +152 -0
- src/SongFormer/utils/convert_res2msa_txt.py +79 -0
- src/SongFormer/utils/fetch_pretrained.py +40 -0
- src/third_party/MuQ/.gitattributes +2 -0
- src/third_party/MuQ/.gitignore +46 -0
- src/third_party/MuQ/.gitmodules +3 -0
- src/third_party/MuQ/LICENSE +21 -0
- src/third_party/MuQ/LICENSE_weights +399 -0
- src/third_party/MuQ/README.md +129 -0
- src/third_party/MuQ/images/muq-logo.jpeg +0 -0
- src/third_party/MuQ/images/radar.jpg +3 -0
- src/third_party/MuQ/images/tab-marble.jpg +3 -0
- src/third_party/MuQ/images/tab-mulan.png +3 -0
- src/third_party/MuQ/images/tagging.jpg +3 -0
- src/third_party/MuQ/requirements.txt +11 -0
- src/third_party/MuQ/setup.py +34 -0
- src/third_party/MuQ/src/muq/__init__.py +2 -0
- src/third_party/MuQ/src/muq/muq/__init__.py +1 -0
- src/third_party/MuQ/src/muq/muq/models/__init__.py +0 -0
- src/third_party/MuQ/src/muq/muq/models/muq_model.py +366 -0
- src/third_party/MuQ/src/muq/muq/modules/__init__.py +2 -0
- src/third_party/MuQ/src/muq/muq/modules/conv.py +77 -0
- src/third_party/MuQ/src/muq/muq/modules/features.py +37 -0
- src/third_party/MuQ/src/muq/muq/modules/flash_conformer.py +2114 -0
- src/third_party/MuQ/src/muq/muq/modules/random_quantizer.py +68 -0
- src/third_party/MuQ/src/muq/muq/modules/rvq.py +314 -0
- src/third_party/MuQ/src/muq/muq/muq.py +90 -0
- src/third_party/MuQ/src/muq/muq_mulan/__init__.py +1 -0
- src/third_party/MuQ/src/muq/muq_mulan/models/__init__.py +0 -0
README.md
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SongFormer
|
| 3 |
+
emoji: 🎵
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: gradio
|
| 7 |
+
python_version: "3.10"
|
| 8 |
+
app_file: app.py
|
| 9 |
+
tags:
|
| 10 |
+
- music-structure-annotation
|
| 11 |
+
- transformer
|
| 12 |
+
short_description: State-of-the-art music analysis with multi-scale datasets
|
| 13 |
+
fullWidth: true
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
<p align="center">
|
| 17 |
+
<img src="figs/logo.png" width="50%" />
|
| 18 |
+
</p>
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# SONGFORMER: SCALING MUSIC STRUCTURE ANALYSIS WITH HETEROGENEOUS SUPERVISION
|
| 22 |
+
|
| 23 |
+

|
| 24 |
+

|
| 25 |
+
[]()
|
| 26 |
+
[](https://github.com/ASLP-lab/SongFormer)
|
| 27 |
+
[](https://huggingface.co/spaces/ASLP-lab/SongFormer)
|
| 28 |
+
[](https://huggingface.co/ASLP-lab/SongFormer)
|
| 29 |
+
[](https://huggingface.co/datasets/ASLP-lab/SongFormDB)
|
| 30 |
+
[](https://huggingface.co/datasets/ASLP-lab/SongFormBench)
|
| 31 |
+
[](https://discord.gg/rwcqh7Em)
|
| 32 |
+
[](http://www.npu-aslp.org/)
|
| 33 |
+
|
| 34 |
+
Chunbo Hao<sup>*</sup>, Ruibin Yuan<sup>*</sup>, Jixun Yao, Qixin Deng, Xinyi Bai, Wei Xue, Lei Xie<sup>†</sup>
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
----
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
SongFormer is a music structure analysis framework that leverages multi-resolution self-supervised representations and heterogeneous supervision, accompanied by the large-scale multilingual dataset SongFormDB and the high-quality benchmark SongFormBench to foster fair and reproducible research.
|
| 41 |
+
|
| 42 |
+

|
| 43 |
+
|
| 44 |
+
## News and Updates
|
| 45 |
+
|
| 46 |
+
## 📋 To-Do List
|
| 47 |
+
|
| 48 |
+
- [x] Complete and push inference code to GitHub
|
| 49 |
+
- [x] Upload model checkpoint(s) to Hugging Face Hub
|
| 50 |
+
- [ ] Upload the paper to arXiv
|
| 51 |
+
- [x] Fix readme
|
| 52 |
+
- [ ] Deploy an out-of-the-box inference version on Hugging Face (via Inference API or Spaces)
|
| 53 |
+
- [ ] Publish the package to PyPI for easy installation via `pip`
|
| 54 |
+
- [ ] Open-source evaluation code
|
| 55 |
+
- [ ] Open-source training code
|
| 56 |
+
|
| 57 |
+
## Installation
|
| 58 |
+
|
| 59 |
+
### Setting up Python Environment
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
git clone https://github.com/ASLP-lab/SongFormer.git
|
| 63 |
+
|
| 64 |
+
# Get MuQ and MusicFM source code
|
| 65 |
+
git submodule update --init --recursive
|
| 66 |
+
|
| 67 |
+
conda create -n songformer python=3.10 -y
|
| 68 |
+
conda activate songformer
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
For users in mainland China, you may need to set up pip mirror source:
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
Install dependencies:
|
| 78 |
+
|
| 79 |
+
```bash
|
| 80 |
+
pip install -r requirements.txt
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
We tested this on Ubuntu 22.04.1 LTS and it works normally. If you cannot install, you may need to remove version constraints in `requirements.txt`
|
| 84 |
+
|
| 85 |
+
### Download Pre-trained Models
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
cd src/SongFormer
|
| 89 |
+
# For users in mainland China, you can modify according to the py file instructions to use hf-mirror.com for downloading
|
| 90 |
+
python utils/fetch_pretrained.py
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
After downloading, you can verify the md5sum values in `src/SongFormer/ckpts/MusicFM/md5sum.txt` match the downloaded files:
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
md5sum ckpts/MusicFM/msd_stats.json
|
| 97 |
+
md5sum ckpts/MusicFM/pretrained_msd.pt
|
| 98 |
+
md5sum ckpts/SongFormer.safetensors
|
| 99 |
+
# md5sum ckpts/SongFormer.pt
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
## Inference
|
| 103 |
+
|
| 104 |
+
## Inference
|
| 105 |
+
|
| 106 |
+
### 1. One-Click Inference with HuggingFace Space (coming soon)
|
| 107 |
+
|
| 108 |
+
Available at: [https://huggingface.co/spaces/ASLP-lab/SongFormer](https://huggingface.co/spaces/ASLP-lab/SongFormer)
|
| 109 |
+
|
| 110 |
+
### 2. Gradio App
|
| 111 |
+
|
| 112 |
+
First, cd to the project root directory and activate the environment:
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
conda activate songformer
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
You can modify the server port and listening address in the last line of `app.py` according to your preference.
|
| 119 |
+
|
| 120 |
+
> If you're using an HTTP proxy, please ensure you include:
|
| 121 |
+
>
|
| 122 |
+
> ```bash
|
| 123 |
+
> export no_proxy="localhost, 127.0.0.1, ::1"
|
| 124 |
+
> export NO_PROXY="localhost, 127.0.0.1, ::1"
|
| 125 |
+
> ```
|
| 126 |
+
>
|
| 127 |
+
> Otherwise, Gradio may incorrectly assume the service hasn't started, causing startup to exit directly.
|
| 128 |
+
|
| 129 |
+
When first running `app.py`, it will connect to Hugging Face to download MuQ-related weights. We recommend creating an empty folder in an appropriate location and using `export HF_HOME=XXX` to point to this folder, so cache will be stored there for easy cleanup and transfer.
|
| 130 |
+
|
| 131 |
+
And for users in mainland China, you may need `export HF_ENDPOINT=https://hf-mirror.com`. For details, refer to https://hf-mirror.com/
|
| 132 |
+
|
| 133 |
+
```bash
|
| 134 |
+
python app.py
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
### 3. Python Code
|
| 138 |
+
|
| 139 |
+
You can refer to the file `src/SongFormer/infer/infer.py`. The corresponding execution script is located at `src/SongFormer/infer.sh`. This is a ready-to-use, single-machine, multi-process annotation script.
|
| 140 |
+
|
| 141 |
+
Below are some configurable parameters from the `src/SongFormer/infer.sh` script. You can set `CUDA_VISIBLE_DEVICES` to specify which GPUs to use:
|
| 142 |
+
|
| 143 |
+
```bash
|
| 144 |
+
-i # Input SCP folder path, each line containing the absolute path to one audio file
|
| 145 |
+
-o # Output directory for annotation results
|
| 146 |
+
--model # Annotation model; the default is 'SongFormer', change if using a fine-tuned model
|
| 147 |
+
--checkpoint # Path to the model checkpoint file
|
| 148 |
+
--config_pat # Path to the configuration file
|
| 149 |
+
-gn # Total number of GPUs to use — should match the number specified in CUDA_VISIBLE_DEVICES
|
| 150 |
+
-tn # Number of processes to run per GPU
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
You can control which GPUs are used by setting the `CUDA_VISIBLE_DEVICES` environment variable.
|
| 154 |
+
|
| 155 |
+
### 4. CLI Inference
|
| 156 |
+
|
| 157 |
+
Coming soon
|
| 158 |
+
|
| 159 |
+
### 4. Pitfall
|
| 160 |
+
|
| 161 |
+
- You may need to modify line 121 in `src/third_party/musicfm/model/musicfm_25hz.py` to:
|
| 162 |
+
`S = torch.load(model_path, weights_only=False)["state_dict"]`
|
| 163 |
+
|
| 164 |
+
## Training
|
| 165 |
+
|
| 166 |
+
## Citation
|
| 167 |
+
|
| 168 |
+
If our work and codebase is useful for you, please cite as:
|
| 169 |
+
|
| 170 |
+
````
|
| 171 |
+
comming soon
|
| 172 |
+
````
|
| 173 |
+
## License
|
| 174 |
+
|
| 175 |
+
Our code is released under CC-BY-4.0 License.
|
| 176 |
+
|
| 177 |
+
## Contact Us
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
<p align="center">
|
| 181 |
+
<a href="http://www.nwpu-aslp.org/">
|
| 182 |
+
<img src="figs/aslp.png" width="400"/>
|
| 183 |
+
</a>
|
| 184 |
+
</p>
|
| 185 |
+
|
| 186 |
+
|
app.py
ADDED
|
@@ -0,0 +1,636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import os
|
| 2 |
+
# import sys
|
| 3 |
+
|
| 4 |
+
# os.chdir(os.path.join("src", "SongFormer"))
|
| 5 |
+
# sys.path.append(os.path.join("..", "third_party"))
|
| 6 |
+
# sys.path.append(".")
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
# 获取当前文件的绝对路径和脚本名称
|
| 11 |
+
current_file = os.path.abspath(__file__)
|
| 12 |
+
current_dir = os.path.dirname(current_file)
|
| 13 |
+
script_name = os.path.basename(__file__)
|
| 14 |
+
print(f"[INFO] 正在运行脚本:{script_name}")
|
| 15 |
+
print(f"[INFO] 当前文件所在目录为:{current_dir}")
|
| 16 |
+
# 设置工作目录为 `src/SongFormer`(如果该路径存在)
|
| 17 |
+
songformer_path = os.path.join(current_dir, "src", "SongFormer")
|
| 18 |
+
if os.path.exists(songformer_path):
|
| 19 |
+
os.chdir(songformer_path)
|
| 20 |
+
print(f"[INFO] 工作目录已修改为:{songformer_path}")
|
| 21 |
+
else:
|
| 22 |
+
print(f"[WARNING] 目标工作目录不存在:{songformer_path}")
|
| 23 |
+
# 获取当前工作目录,即运行 os.chdir 后的路径
|
| 24 |
+
working_dir = os.getcwd()
|
| 25 |
+
print(f"[INFO] 当前工作目录为:{working_dir}")
|
| 26 |
+
# 添加第三方库路径到 sys.path(third_party)
|
| 27 |
+
third_party_path = os.path.join(current_dir, "third_party")
|
| 28 |
+
if os.path.exists(third_party_path):
|
| 29 |
+
sys.path.insert(0, third_party_path)
|
| 30 |
+
print(f"[INFO] 已添加第三方库路径到 sys.path:{third_party_path}")
|
| 31 |
+
else:
|
| 32 |
+
print(f"[WARNING] third_party 路径不存在:{third_party_path}")
|
| 33 |
+
# 添加当前工作目录到 sys.path(通常是 src/SongFormer)
|
| 34 |
+
sys.path.insert(0, working_dir)
|
| 35 |
+
print(f"[INFO] 已添加当前工作目录到 sys.path:{working_dir}")
|
| 36 |
+
# 尝试添加多个可能用于 musicfm 导入的路径
|
| 37 |
+
musicfm_paths = [
|
| 38 |
+
os.path.join(current_dir, "src"),
|
| 39 |
+
os.path.join(current_dir, "third_party"),
|
| 40 |
+
os.path.join(current_dir, "src", "SongFormer"),
|
| 41 |
+
]
|
| 42 |
+
for path in musicfm_paths:
|
| 43 |
+
if os.path.exists(path):
|
| 44 |
+
sys.path.insert(0, path)
|
| 45 |
+
print(f"[INFO] 已添加路径到 sys.path:{path}")
|
| 46 |
+
else:
|
| 47 |
+
print(f"[DEBUG] 路径不存在,跳过添加:{path}")
|
| 48 |
+
# 可选:打印 sys.path 的当前状态
|
| 49 |
+
print("\n[DEBUG] 当前 sys.path 设置如下:")
|
| 50 |
+
for idx, p in enumerate(sys.path):
|
| 51 |
+
print(f" {idx}: {p}")
|
| 52 |
+
|
| 53 |
+
# monkey patch to fix issues in msaf
|
| 54 |
+
import scipy
|
| 55 |
+
import numpy as np
|
| 56 |
+
|
| 57 |
+
scipy.inf = np.inf
|
| 58 |
+
|
| 59 |
+
import gradio as gr
|
| 60 |
+
import torch
|
| 61 |
+
import librosa
|
| 62 |
+
import json
|
| 63 |
+
import math
|
| 64 |
+
import importlib
|
| 65 |
+
import matplotlib.pyplot as plt
|
| 66 |
+
import matplotlib.ticker as ticker
|
| 67 |
+
from pathlib import Path
|
| 68 |
+
from argparse import Namespace
|
| 69 |
+
from omegaconf import OmegaConf
|
| 70 |
+
from ema_pytorch import EMA
|
| 71 |
+
from muq import MuQ
|
| 72 |
+
from musicfm.model.musicfm_25hz import MusicFM25Hz
|
| 73 |
+
from postprocessing.functional import postprocess_functional_structure
|
| 74 |
+
from dataset.label2id import DATASET_ID_ALLOWED_LABEL_IDS, DATASET_LABEL_TO_DATASET_ID
|
| 75 |
+
from utils.fetch_pretrained import download_all
|
| 76 |
+
|
| 77 |
+
# Constants
|
| 78 |
+
MUSICFM_HOME_PATH = os.path.join("ckpts", "MusicFM")
|
| 79 |
+
BEFORE_DOWNSAMPLING_FRAME_RATES = 25
|
| 80 |
+
AFTER_DOWNSAMPLING_FRAME_RATES = 8.333
|
| 81 |
+
DATASET_LABEL = "SongForm-HX-8Class"
|
| 82 |
+
DATASET_IDS = [5]
|
| 83 |
+
TIME_DUR = 420
|
| 84 |
+
INPUT_SAMPLING_RATE = 24000
|
| 85 |
+
|
| 86 |
+
# Global model variables
|
| 87 |
+
muq_model = None
|
| 88 |
+
musicfm_model = None
|
| 89 |
+
msa_model = None
|
| 90 |
+
device = None
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def load_checkpoint(checkpoint_path, device=None):
|
| 94 |
+
"""Load checkpoint from path"""
|
| 95 |
+
if device is None:
|
| 96 |
+
device = "cpu"
|
| 97 |
+
|
| 98 |
+
if checkpoint_path.endswith(".pt"):
|
| 99 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 100 |
+
elif checkpoint_path.endswith(".safetensors"):
|
| 101 |
+
from safetensors.torch import load_file
|
| 102 |
+
|
| 103 |
+
checkpoint = {"model_ema": load_file(checkpoint_path, device=device)}
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError("Unsupported checkpoint format. Use .pt or .safetensors")
|
| 106 |
+
return checkpoint
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def initialize_models(model_name: str, checkpoint: str, config_path: str):
|
| 110 |
+
"""Initialize all models"""
|
| 111 |
+
global muq_model, musicfm_model, msa_model, device
|
| 112 |
+
|
| 113 |
+
# Set device
|
| 114 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 115 |
+
|
| 116 |
+
# Load MuQ
|
| 117 |
+
muq_model = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
|
| 118 |
+
muq_model = muq_model.to(device).eval()
|
| 119 |
+
|
| 120 |
+
# Load MusicFM
|
| 121 |
+
musicfm_model = MusicFM25Hz(
|
| 122 |
+
is_flash=False,
|
| 123 |
+
stat_path=os.path.join(MUSICFM_HOME_PATH, "msd_stats.json"),
|
| 124 |
+
model_path=os.path.join(MUSICFM_HOME_PATH, "pretrained_msd.pt"),
|
| 125 |
+
)
|
| 126 |
+
musicfm_model = musicfm_model.to(device).eval()
|
| 127 |
+
|
| 128 |
+
# Load MSA model
|
| 129 |
+
module = importlib.import_module("models." + str(model_name))
|
| 130 |
+
Model = getattr(module, "Model")
|
| 131 |
+
hp = OmegaConf.load(os.path.join("configs", config_path))
|
| 132 |
+
msa_model = Model(hp)
|
| 133 |
+
|
| 134 |
+
ckpt = load_checkpoint(checkpoint_path=os.path.join("ckpts", checkpoint))
|
| 135 |
+
if ckpt.get("model_ema", None) is not None:
|
| 136 |
+
model_ema = EMA(msa_model, include_online_model=False)
|
| 137 |
+
model_ema.load_state_dict(ckpt["model_ema"])
|
| 138 |
+
msa_model.load_state_dict(model_ema.ema_model.state_dict())
|
| 139 |
+
else:
|
| 140 |
+
msa_model.load_state_dict(ckpt["model"])
|
| 141 |
+
|
| 142 |
+
msa_model.to(device).eval()
|
| 143 |
+
|
| 144 |
+
return hp
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def process_audio(audio_path, win_size=420, hop_size=420, num_classes=128):
|
| 148 |
+
"""Process audio file and return structure analysis results"""
|
| 149 |
+
global muq_model, musicfm_model, msa_model, device
|
| 150 |
+
|
| 151 |
+
if muq_model is None:
|
| 152 |
+
hp = initialize_models()
|
| 153 |
+
else:
|
| 154 |
+
hp = OmegaConf.load(os.path.join("configs", "SongFormer.yaml"))
|
| 155 |
+
|
| 156 |
+
# Load audio
|
| 157 |
+
wav, sr = librosa.load(audio_path, sr=INPUT_SAMPLING_RATE)
|
| 158 |
+
audio = torch.tensor(wav).to(device)
|
| 159 |
+
|
| 160 |
+
# Prepare output
|
| 161 |
+
total_len = (
|
| 162 |
+
(audio.shape[0] // INPUT_SAMPLING_RATE) // TIME_DUR * TIME_DUR
|
| 163 |
+
) + TIME_DUR
|
| 164 |
+
total_frames = math.ceil(total_len * AFTER_DOWNSAMPLING_FRAME_RATES)
|
| 165 |
+
|
| 166 |
+
logits = {
|
| 167 |
+
"function_logits": np.zeros([total_frames, num_classes]),
|
| 168 |
+
"boundary_logits": np.zeros([total_frames]),
|
| 169 |
+
}
|
| 170 |
+
logits_num = {
|
| 171 |
+
"function_logits": np.zeros([total_frames, num_classes]),
|
| 172 |
+
"boundary_logits": np.zeros([total_frames]),
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
# Prepare label masks
|
| 176 |
+
dataset_id2label_mask = {}
|
| 177 |
+
for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
|
| 178 |
+
dataset_id2label_mask[key] = np.ones(num_classes, dtype=bool)
|
| 179 |
+
dataset_id2label_mask[key][allowed_ids] = False
|
| 180 |
+
|
| 181 |
+
lens = 0
|
| 182 |
+
i = 0
|
| 183 |
+
|
| 184 |
+
with torch.no_grad():
|
| 185 |
+
while True:
|
| 186 |
+
start_idx = i * INPUT_SAMPLING_RATE
|
| 187 |
+
end_idx = min((i + win_size) * INPUT_SAMPLING_RATE, audio.shape[-1])
|
| 188 |
+
if start_idx >= audio.shape[-1]:
|
| 189 |
+
break
|
| 190 |
+
if end_idx - start_idx <= 1024:
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
audio_seg = audio[start_idx:end_idx]
|
| 194 |
+
|
| 195 |
+
# Get embeddings
|
| 196 |
+
muq_output = muq_model(audio_seg.unsqueeze(0), output_hidden_states=True)
|
| 197 |
+
muq_embd_420s = muq_output["hidden_states"][10]
|
| 198 |
+
del muq_output
|
| 199 |
+
torch.cuda.empty_cache()
|
| 200 |
+
|
| 201 |
+
_, musicfm_hidden_states = musicfm_model.get_predictions(
|
| 202 |
+
audio_seg.unsqueeze(0)
|
| 203 |
+
)
|
| 204 |
+
musicfm_embd_420s = musicfm_hidden_states[10]
|
| 205 |
+
del musicfm_hidden_states
|
| 206 |
+
torch.cuda.empty_cache()
|
| 207 |
+
|
| 208 |
+
# Process 30-second segments
|
| 209 |
+
wraped_muq_embd_30s = []
|
| 210 |
+
wraped_musicfm_embd_30s = []
|
| 211 |
+
|
| 212 |
+
for idx_30s in range(i, i + hop_size, 30):
|
| 213 |
+
start_idx_30s = idx_30s * INPUT_SAMPLING_RATE
|
| 214 |
+
end_idx_30s = min(
|
| 215 |
+
(idx_30s + 30) * INPUT_SAMPLING_RATE,
|
| 216 |
+
audio.shape[-1],
|
| 217 |
+
(i + hop_size) * INPUT_SAMPLING_RATE,
|
| 218 |
+
)
|
| 219 |
+
if start_idx_30s >= audio.shape[-1]:
|
| 220 |
+
break
|
| 221 |
+
if end_idx_30s - start_idx_30s <= 1024:
|
| 222 |
+
continue
|
| 223 |
+
|
| 224 |
+
wraped_muq_embd_30s.append(
|
| 225 |
+
muq_model(
|
| 226 |
+
audio[start_idx_30s:end_idx_30s].unsqueeze(0),
|
| 227 |
+
output_hidden_states=True,
|
| 228 |
+
)["hidden_states"][10]
|
| 229 |
+
)
|
| 230 |
+
torch.cuda.empty_cache()
|
| 231 |
+
|
| 232 |
+
wraped_musicfm_embd_30s.append(
|
| 233 |
+
musicfm_model.get_predictions(
|
| 234 |
+
audio[start_idx_30s:end_idx_30s].unsqueeze(0)
|
| 235 |
+
)[1][10]
|
| 236 |
+
)
|
| 237 |
+
torch.cuda.empty_cache()
|
| 238 |
+
|
| 239 |
+
if wraped_muq_embd_30s:
|
| 240 |
+
wraped_muq_embd_30s = torch.concatenate(wraped_muq_embd_30s, dim=1)
|
| 241 |
+
wraped_musicfm_embd_30s = torch.concatenate(
|
| 242 |
+
wraped_musicfm_embd_30s, dim=1
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
all_embds = [
|
| 246 |
+
wraped_musicfm_embd_30s,
|
| 247 |
+
wraped_muq_embd_30s,
|
| 248 |
+
musicfm_embd_420s,
|
| 249 |
+
muq_embd_420s,
|
| 250 |
+
]
|
| 251 |
+
|
| 252 |
+
# Align embedding lengths
|
| 253 |
+
if len(all_embds) > 1:
|
| 254 |
+
embd_lens = [x.shape[1] for x in all_embds]
|
| 255 |
+
min_embd_len = min(embd_lens)
|
| 256 |
+
for idx in range(len(all_embds)):
|
| 257 |
+
all_embds[idx] = all_embds[idx][:, :min_embd_len, :]
|
| 258 |
+
|
| 259 |
+
embd = torch.concatenate(all_embds, axis=-1)
|
| 260 |
+
|
| 261 |
+
# Inference
|
| 262 |
+
dataset_ids = torch.Tensor(DATASET_IDS).to(device, dtype=torch.long)
|
| 263 |
+
msa_info, chunk_logits = msa_model.infer(
|
| 264 |
+
input_embeddings=embd,
|
| 265 |
+
dataset_ids=dataset_ids,
|
| 266 |
+
label_id_masks=torch.Tensor(
|
| 267 |
+
dataset_id2label_mask[
|
| 268 |
+
DATASET_LABEL_TO_DATASET_ID[DATASET_LABEL]
|
| 269 |
+
]
|
| 270 |
+
)
|
| 271 |
+
.to(device, dtype=bool)
|
| 272 |
+
.unsqueeze(0)
|
| 273 |
+
.unsqueeze(0),
|
| 274 |
+
with_logits=True,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Accumulate logits
|
| 278 |
+
start_frame = int(i * AFTER_DOWNSAMPLING_FRAME_RATES)
|
| 279 |
+
end_frame = start_frame + min(
|
| 280 |
+
math.ceil(hop_size * AFTER_DOWNSAMPLING_FRAME_RATES),
|
| 281 |
+
chunk_logits["boundary_logits"][0].shape[0],
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
logits["function_logits"][start_frame:end_frame, :] += (
|
| 285 |
+
chunk_logits["function_logits"][0].detach().cpu().numpy()
|
| 286 |
+
)
|
| 287 |
+
logits["boundary_logits"][start_frame:end_frame] = (
|
| 288 |
+
chunk_logits["boundary_logits"][0].detach().cpu().numpy()
|
| 289 |
+
)
|
| 290 |
+
logits_num["function_logits"][start_frame:end_frame, :] += 1
|
| 291 |
+
logits_num["boundary_logits"][start_frame:end_frame] += 1
|
| 292 |
+
lens += end_frame - start_frame
|
| 293 |
+
|
| 294 |
+
i += hop_size
|
| 295 |
+
|
| 296 |
+
# Average logits
|
| 297 |
+
logits["function_logits"] /= np.maximum(logits_num["function_logits"], 1)
|
| 298 |
+
logits["boundary_logits"] /= np.maximum(logits_num["boundary_logits"], 1)
|
| 299 |
+
|
| 300 |
+
logits["function_logits"] = torch.from_numpy(
|
| 301 |
+
logits["function_logits"][:lens]
|
| 302 |
+
).unsqueeze(0)
|
| 303 |
+
logits["boundary_logits"] = torch.from_numpy(
|
| 304 |
+
logits["boundary_logits"][:lens]
|
| 305 |
+
).unsqueeze(0)
|
| 306 |
+
|
| 307 |
+
# Post-process
|
| 308 |
+
msa_infer_output = postprocess_functional_structure(logits, hp)
|
| 309 |
+
|
| 310 |
+
return logits, msa_infer_output
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def format_as_segments(msa_output):
|
| 314 |
+
"""Format as list of segments"""
|
| 315 |
+
segments = []
|
| 316 |
+
for idx in range(len(msa_output) - 1):
|
| 317 |
+
segments.append(
|
| 318 |
+
{
|
| 319 |
+
"start": str(round(msa_output[idx][0], 2)),
|
| 320 |
+
"end": str(round(msa_output[idx + 1][0], 2)),
|
| 321 |
+
"label": msa_output[idx][1],
|
| 322 |
+
}
|
| 323 |
+
)
|
| 324 |
+
return segments
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def format_as_msa(msa_output):
|
| 328 |
+
"""Format as MSA format"""
|
| 329 |
+
lines = []
|
| 330 |
+
for time, label in msa_output:
|
| 331 |
+
lines.append(f"{time:.2f} {label}")
|
| 332 |
+
return "\n".join(lines)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def format_as_json(segments):
|
| 336 |
+
"""Format as JSON"""
|
| 337 |
+
return json.dumps(segments, indent=2, ensure_ascii=False)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def create_visualization(
|
| 341 |
+
logits, msa_output, label_num=8, frame_rates=AFTER_DOWNSAMPLING_FRAME_RATES
|
| 342 |
+
):
|
| 343 |
+
"""Create visualization plot"""
|
| 344 |
+
# Assume ID_TO_LABEL mapping exists
|
| 345 |
+
try:
|
| 346 |
+
from dataset.label2id import ID_TO_LABEL
|
| 347 |
+
except:
|
| 348 |
+
ID_TO_LABEL = {i: f"Class_{i}" for i in range(128)}
|
| 349 |
+
|
| 350 |
+
function_vals = logits["function_logits"].squeeze().cpu().numpy()
|
| 351 |
+
boundary_vals = logits["boundary_logits"].squeeze().cpu().numpy()
|
| 352 |
+
|
| 353 |
+
top_classes = np.argsort(function_vals.mean(axis=0))[-label_num:]
|
| 354 |
+
T = function_vals.shape[0]
|
| 355 |
+
time_axis = np.arange(T) / frame_rates
|
| 356 |
+
|
| 357 |
+
fig, ax = plt.subplots(2, 1, figsize=(15, 8), sharex=True)
|
| 358 |
+
|
| 359 |
+
# Plot function logits
|
| 360 |
+
for cls in top_classes:
|
| 361 |
+
ax[1].plot(
|
| 362 |
+
time_axis,
|
| 363 |
+
function_vals[:, cls],
|
| 364 |
+
label=f"{ID_TO_LABEL.get(cls, f'Class_{cls}')}",
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
ax[1].set_title("Top 8 Function Logits by Mean Activation")
|
| 368 |
+
ax[1].set_xlabel("Time (seconds)")
|
| 369 |
+
ax[1].set_ylabel("Logit")
|
| 370 |
+
ax[1].xaxis.set_major_locator(ticker.MultipleLocator(20))
|
| 371 |
+
ax[1].xaxis.set_minor_locator(ticker.MultipleLocator(5))
|
| 372 |
+
ax[1].xaxis.set_major_formatter(ticker.FormatStrFormatter("%.1f"))
|
| 373 |
+
ax[1].legend()
|
| 374 |
+
ax[1].grid(True)
|
| 375 |
+
|
| 376 |
+
# Plot boundary logits
|
| 377 |
+
ax[0].plot(time_axis, boundary_vals, label="Boundary Logit", color="orange")
|
| 378 |
+
ax[0].set_title("Boundary Logits")
|
| 379 |
+
ax[0].set_ylabel("Logit")
|
| 380 |
+
ax[0].legend()
|
| 381 |
+
ax[0].grid(True)
|
| 382 |
+
|
| 383 |
+
# Add vertical lines for markers
|
| 384 |
+
for t_sec, label in msa_output:
|
| 385 |
+
for a in ax:
|
| 386 |
+
a.axvline(x=t_sec, color="red", linestyle="--", linewidth=0.8, alpha=0.7)
|
| 387 |
+
if label != "end":
|
| 388 |
+
ax[1].text(
|
| 389 |
+
t_sec + 0.3,
|
| 390 |
+
ax[1].get_ylim()[1] * 0.85,
|
| 391 |
+
label,
|
| 392 |
+
rotation=90,
|
| 393 |
+
fontsize=8,
|
| 394 |
+
color="red",
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
plt.suptitle("Music Structure Analysis - Logits Overview", fontsize=16)
|
| 398 |
+
plt.tight_layout()
|
| 399 |
+
|
| 400 |
+
return fig
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def rule_post_processing(msa_list):
|
| 404 |
+
if len(msa_list) <= 2:
|
| 405 |
+
return msa_list
|
| 406 |
+
|
| 407 |
+
result = msa_list.copy()
|
| 408 |
+
|
| 409 |
+
while len(result) > 2:
|
| 410 |
+
first_duration = result[1][0] - result[0][0]
|
| 411 |
+
if first_duration < 1.0 and len(result) > 2:
|
| 412 |
+
result[0] = (result[0][0], result[1][1])
|
| 413 |
+
result = [result[0]] + result[2:]
|
| 414 |
+
else:
|
| 415 |
+
break
|
| 416 |
+
|
| 417 |
+
while len(result) > 2:
|
| 418 |
+
last_label_duration = result[-1][0] - result[-2][0]
|
| 419 |
+
if last_label_duration < 1.0:
|
| 420 |
+
result = result[:-2] + [result[-1]]
|
| 421 |
+
else:
|
| 422 |
+
break
|
| 423 |
+
|
| 424 |
+
while len(result) > 2:
|
| 425 |
+
if result[0][1] == result[1][1] and result[1][0] <= 10.0:
|
| 426 |
+
result = [(result[0][0], result[0][1])] + result[2:]
|
| 427 |
+
else:
|
| 428 |
+
break
|
| 429 |
+
|
| 430 |
+
while len(result) > 2:
|
| 431 |
+
last_duration = result[-1][0] - result[-2][0]
|
| 432 |
+
if result[-2][1] == result[-3][1] and last_duration <= 10.0:
|
| 433 |
+
result = result[:-2] + [result[-1]]
|
| 434 |
+
else:
|
| 435 |
+
break
|
| 436 |
+
|
| 437 |
+
return result
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def process_and_analyze(audio_file):
|
| 441 |
+
"""Main processing function"""
|
| 442 |
+
|
| 443 |
+
def format_time(t: float) -> str:
|
| 444 |
+
minutes = int(t // 60)
|
| 445 |
+
seconds = t % 60
|
| 446 |
+
return f"{minutes:02d}:{seconds:06.3f}" # 这个格式是正确的
|
| 447 |
+
|
| 448 |
+
if audio_file is None:
|
| 449 |
+
return None, "", "", None
|
| 450 |
+
|
| 451 |
+
try:
|
| 452 |
+
# Process audio
|
| 453 |
+
logits, msa_output = process_audio(audio_file)
|
| 454 |
+
# Apply rule-based post-processing, if not needed, use in cli infer
|
| 455 |
+
msa_output = rule_post_processing(msa_output)
|
| 456 |
+
# Format outputs
|
| 457 |
+
segments = format_as_segments(msa_output)
|
| 458 |
+
msa_format = format_as_msa(msa_output)
|
| 459 |
+
json_format = format_as_json(segments)
|
| 460 |
+
|
| 461 |
+
# Create table data
|
| 462 |
+
table_data = [
|
| 463 |
+
[
|
| 464 |
+
f"{float(seg['start']):.2f} ({format_time(float(seg['start']))})",
|
| 465 |
+
f"{float(seg['end']):.2f} ({format_time(float(seg['end']))})",
|
| 466 |
+
seg["label"],
|
| 467 |
+
]
|
| 468 |
+
for seg in segments
|
| 469 |
+
]
|
| 470 |
+
|
| 471 |
+
# Create visualization
|
| 472 |
+
fig = create_visualization(logits, msa_output)
|
| 473 |
+
|
| 474 |
+
return table_data, json_format, msa_format, fig
|
| 475 |
+
|
| 476 |
+
except Exception as e:
|
| 477 |
+
import traceback
|
| 478 |
+
|
| 479 |
+
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
|
| 480 |
+
print(error_msg) # 在命令行输出完整错误
|
| 481 |
+
return None, "", error_msg, None
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
# Create Gradio interface
|
| 485 |
+
with gr.Blocks(
|
| 486 |
+
title="Music Structure Analysis",
|
| 487 |
+
css="""
|
| 488 |
+
.logo-container {
|
| 489 |
+
text-align: center;
|
| 490 |
+
margin-bottom: 20px;
|
| 491 |
+
}
|
| 492 |
+
.links-container {
|
| 493 |
+
display: flex;
|
| 494 |
+
justify-content: center;
|
| 495 |
+
column-gap: 10px;
|
| 496 |
+
margin-bottom: 10px;
|
| 497 |
+
}
|
| 498 |
+
.model-title {
|
| 499 |
+
text-align: center;
|
| 500 |
+
font-size: 24px;
|
| 501 |
+
font-weight: bold;
|
| 502 |
+
margin-bottom: 30px;
|
| 503 |
+
}
|
| 504 |
+
""",
|
| 505 |
+
) as demo:
|
| 506 |
+
# Top Logo
|
| 507 |
+
gr.HTML("""
|
| 508 |
+
<div style="display: flex; justify-content: center; align-items: center;">
|
| 509 |
+
<img src="https://raw.githubusercontent.com/ASLP-lab/SongFormer/refs/heads/main/figs/logo.png" style="max-width: 300px; height: auto;" />
|
| 510 |
+
</div>
|
| 511 |
+
""")
|
| 512 |
+
|
| 513 |
+
# Model title
|
| 514 |
+
gr.HTML("""
|
| 515 |
+
<div class="model-title">
|
| 516 |
+
SongFormer: Scaling Music Structure Analysis with Heterogeneous Supervision
|
| 517 |
+
</div>
|
| 518 |
+
""")
|
| 519 |
+
|
| 520 |
+
# Links
|
| 521 |
+
gr.HTML("""
|
| 522 |
+
<div class="links-container">
|
| 523 |
+
<img src="https://img.shields.io/badge/Python-3.10-brightgreen" alt="Python">
|
| 524 |
+
<img src="https://img.shields.io/badge/License-CC%20BY%204.0-lightblue" alt="License">
|
| 525 |
+
<a href="https://arxiv.org/abs/">
|
| 526 |
+
<img src="https://img.shields.io/badge/arXiv-com.svg?logo=arXiv" alt="arXiv">
|
| 527 |
+
</a>
|
| 528 |
+
<a href="https://github.com/ASLP-lab/SongFormer">
|
| 529 |
+
<img src="https://img.shields.io/badge/GitHub-SongFormer-black" alt="GitHub">
|
| 530 |
+
</a>
|
| 531 |
+
<a href="https://huggingface.co/spaces/ASLP-lab/SongFormer">
|
| 532 |
+
<img src="https://img.shields.io/badge/HuggingFace-space-yellow" alt="HuggingFace Space">
|
| 533 |
+
</a>
|
| 534 |
+
<a href="https://huggingface.co/ASLP-lab/SongFormer">
|
| 535 |
+
<img src="https://img.shields.io/badge/HuggingFace-model-blue" alt="HuggingFace Model">
|
| 536 |
+
</a>
|
| 537 |
+
<a href="https://huggingface.co/datasets/ASLP-lab/SongFormDB">
|
| 538 |
+
<img src="https://img.shields.io/badge/HF%20Dataset-SongFormDB-green" alt="Dataset SongFormDB">
|
| 539 |
+
</a>
|
| 540 |
+
<a href="https://huggingface.co/datasets/ASLP-lab/SongFormBench">
|
| 541 |
+
<img src="https://img.shields.io/badge/HF%20Dataset-SongFormBench-orange" alt="Dataset SongFormBench">
|
| 542 |
+
</a>
|
| 543 |
+
<a href="https://discord.gg/rwcqh7Em">
|
| 544 |
+
<img src="https://img.shields.io/badge/Discord-join%20us-purple?logo=discord&logoColor=white" alt="Discord">
|
| 545 |
+
</a>
|
| 546 |
+
<a href="http://www.npu-aslp.org/">
|
| 547 |
+
<img src="https://img.shields.io/badge/🏫-ASLP-grey?labelColor=lightgrey" alt="ASLP">
|
| 548 |
+
</a>
|
| 549 |
+
</div>
|
| 550 |
+
""")
|
| 551 |
+
|
| 552 |
+
# Main input area
|
| 553 |
+
with gr.Row():
|
| 554 |
+
with gr.Column(scale=3):
|
| 555 |
+
audio_input = gr.Audio(
|
| 556 |
+
label="Upload Audio File", type="filepath", elem_id="audio-input"
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
with gr.Column(scale=1):
|
| 560 |
+
gr.Markdown("### 📌 Examples")
|
| 561 |
+
gr.Examples(
|
| 562 |
+
examples=[
|
| 563 |
+
# Add your example audio file paths
|
| 564 |
+
# ["example1.mp3"],
|
| 565 |
+
# ["example2.mp3"],
|
| 566 |
+
],
|
| 567 |
+
inputs=[audio_input],
|
| 568 |
+
label="Click to load example",
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
# Analyze button
|
| 572 |
+
with gr.Row():
|
| 573 |
+
analyze_btn = gr.Button(
|
| 574 |
+
"🚀 Analyze Music Structure", variant="primary", scale=1
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
# Results display area
|
| 578 |
+
with gr.Row():
|
| 579 |
+
with gr.Column(scale=13):
|
| 580 |
+
segments_table = gr.Dataframe(
|
| 581 |
+
headers=["Start / s (m:s.ms)", "End / s (m:s.ms)", "Label"],
|
| 582 |
+
label="Detected Music Segments",
|
| 583 |
+
interactive=False,
|
| 584 |
+
elem_id="result-table",
|
| 585 |
+
)
|
| 586 |
+
with gr.Column(scale=8):
|
| 587 |
+
with gr.Row():
|
| 588 |
+
with gr.Accordion("📄 JSON Output", open=False):
|
| 589 |
+
json_output = gr.Textbox(
|
| 590 |
+
label="JSON Format",
|
| 591 |
+
lines=15,
|
| 592 |
+
max_lines=20,
|
| 593 |
+
interactive=False,
|
| 594 |
+
show_copy_button=True,
|
| 595 |
+
)
|
| 596 |
+
with gr.Row():
|
| 597 |
+
with gr.Accordion("📋 MSA Text Output", open=False):
|
| 598 |
+
msa_output = gr.Textbox(
|
| 599 |
+
label="MSA Format",
|
| 600 |
+
lines=15,
|
| 601 |
+
max_lines=20,
|
| 602 |
+
interactive=False,
|
| 603 |
+
show_copy_button=True,
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
# Visualization plot
|
| 607 |
+
with gr.Row():
|
| 608 |
+
plot_output = gr.Plot(label="Activation Curves Visualization")
|
| 609 |
+
|
| 610 |
+
gr.HTML("""
|
| 611 |
+
<div style="display: flex; justify-content: center; align-items: center;">
|
| 612 |
+
<img src="https://raw.githubusercontent.com/ASLP-lab/SongFormer/refs/heads/main/figs/aslp.png" style="max-width: 300px; height: auto;" />
|
| 613 |
+
</div>
|
| 614 |
+
""")
|
| 615 |
+
|
| 616 |
+
# Set event handlers
|
| 617 |
+
analyze_btn.click(
|
| 618 |
+
fn=process_and_analyze,
|
| 619 |
+
inputs=[audio_input],
|
| 620 |
+
outputs=[segments_table, json_output, msa_output, plot_output],
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
if __name__ == "__main__":
|
| 624 |
+
# Download pretrained models if not exist
|
| 625 |
+
download_all(use_mirror=False)
|
| 626 |
+
# Initialize models
|
| 627 |
+
print("Initializing models...")
|
| 628 |
+
initialize_models(
|
| 629 |
+
model_name="SongFormer",
|
| 630 |
+
checkpoint="SongFormer.safetensors",
|
| 631 |
+
config_path="SongFormer.yaml",
|
| 632 |
+
)
|
| 633 |
+
print("Models loaded successfully!")
|
| 634 |
+
|
| 635 |
+
# Launch interface
|
| 636 |
+
demo.launch(server_name="127.0.0.1", server_port=7891, debug=True)
|
requirements.txt
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core Deep Learning Framework
|
| 2 |
+
torch==2.4.0
|
| 3 |
+
torchaudio==2.4.0
|
| 4 |
+
lightning==2.5.1.post0
|
| 5 |
+
|
| 6 |
+
# ML/DL Libraries
|
| 7 |
+
transformers==4.51.1
|
| 8 |
+
accelerate==1.5.2
|
| 9 |
+
datasets==3.6.0
|
| 10 |
+
tokenizers==0.21.1
|
| 11 |
+
huggingface-hub==0.30.1
|
| 12 |
+
safetensors==0.5.3
|
| 13 |
+
|
| 14 |
+
# Scientific Computing
|
| 15 |
+
numpy==1.25.0
|
| 16 |
+
scipy==1.15.2
|
| 17 |
+
scikit-learn==1.6.1
|
| 18 |
+
pandas==2.2.3
|
| 19 |
+
|
| 20 |
+
# Audio Processing
|
| 21 |
+
librosa==0.11.0
|
| 22 |
+
audioread==3.0.1
|
| 23 |
+
soundfile==0.13.1
|
| 24 |
+
pesq==0.0.4
|
| 25 |
+
auraloss==0.4.0
|
| 26 |
+
nnAudio==0.3.3
|
| 27 |
+
julius==0.2.7
|
| 28 |
+
soxr==0.5.0.post1
|
| 29 |
+
mir_eval==0.8.2
|
| 30 |
+
jams==0.3.4
|
| 31 |
+
msaf==0.1.80
|
| 32 |
+
|
| 33 |
+
# Visualization & Monitoring
|
| 34 |
+
matplotlib==3.10.1
|
| 35 |
+
seaborn==0.13.2
|
| 36 |
+
tensorboard==2.19.0
|
| 37 |
+
wandb==0.19.8
|
| 38 |
+
gpustat==1.1.1
|
| 39 |
+
|
| 40 |
+
# Configuration & CLI
|
| 41 |
+
hydra-core==1.3.2
|
| 42 |
+
omegaconf==2.3.0
|
| 43 |
+
fire==0.7.1
|
| 44 |
+
click==8.1.8
|
| 45 |
+
|
| 46 |
+
# Deep Learning Utils
|
| 47 |
+
einops==0.8.1
|
| 48 |
+
einx==0.3.0
|
| 49 |
+
x-transformers==2.4.14
|
| 50 |
+
x-clip==0.14.4
|
| 51 |
+
ema-pytorch==0.7.7
|
| 52 |
+
schedulefree==1.4.1
|
| 53 |
+
torchmetrics==1.7.1
|
| 54 |
+
|
| 55 |
+
# Data Processing
|
| 56 |
+
h5py==3.13.0
|
| 57 |
+
pyarrow==19.0.1
|
| 58 |
+
pillow==11.1.0
|
| 59 |
+
|
| 60 |
+
# Text Processing
|
| 61 |
+
ftfy==6.3.1
|
| 62 |
+
regex==2024.11.6
|
| 63 |
+
pypinyin==0.54.0
|
| 64 |
+
textgrid==1.6.1
|
| 65 |
+
pylrc==0.1.2
|
| 66 |
+
|
| 67 |
+
# Model Management
|
| 68 |
+
modelscope==1.27.1
|
| 69 |
+
|
| 70 |
+
# Utilities
|
| 71 |
+
tqdm==4.67.1
|
| 72 |
+
loguru==0.7.3
|
| 73 |
+
joblib==1.4.2
|
| 74 |
+
easydict==1.13
|
| 75 |
+
addict==2.4.0
|
| 76 |
+
beartype==0.21.0
|
| 77 |
+
|
| 78 |
+
# Others
|
| 79 |
+
triton==3.0.0
|
| 80 |
+
muq==0.1.0
|
| 81 |
+
vmo==0.30.5
|
| 82 |
+
|
| 83 |
+
# others
|
| 84 |
+
gradio
|
| 85 |
+
einops
|
| 86 |
+
beartype
|
src/SongFormer/ckpts/md5sum.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
df930aceac8209818556c4a656a0714c MusicFM/pretrained_msd.pt
|
| 2 |
+
75ab2e47b093e07378f7f703bdb82c14 MusicFM/msd_stats.json
|
| 3 |
+
5a24800e12ab357744f8b47e523ba3e6 SongFormer.safetensors
|
| 4 |
+
2c66c0bb91364e318e90dbc2d9a79ee2 _SongFormer.pt
|
src/SongFormer/configs/SongFormer.yaml
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================
|
| 2 |
+
# Model Configuration
|
| 3 |
+
# ============================
|
| 4 |
+
|
| 5 |
+
input_dim_raw: 4096 # Downsampled Fused SSL Representation Dimension
|
| 6 |
+
input_dim: 2048 # Input Dimension after Linear Layer
|
| 7 |
+
|
| 8 |
+
# Downsampling Module
|
| 9 |
+
down_sample_conv_kernel_size: 3
|
| 10 |
+
down_sample_conv_stride: 3
|
| 11 |
+
down_sample_conv_dropout: 0.1
|
| 12 |
+
down_sample_conv_padding: 0
|
| 13 |
+
|
| 14 |
+
# Transformer Module
|
| 15 |
+
transformer_encoder_input_dim: 1024
|
| 16 |
+
transformer_input_dim: 512
|
| 17 |
+
num_transformer_layers: 4
|
| 18 |
+
transformer_nhead: 8
|
| 19 |
+
transformer_dropout: 0.1
|
| 20 |
+
|
| 21 |
+
# task-specific heads
|
| 22 |
+
boundary_head_hidden_dims: [128, 64, 8]
|
| 23 |
+
function_head_hidden_dims: []
|
| 24 |
+
|
| 25 |
+
num_classes: 128
|
| 26 |
+
num_dataset_classes: 64
|
| 27 |
+
|
| 28 |
+
# scheduler
|
| 29 |
+
warmup_steps: 300
|
| 30 |
+
total_steps: 12010
|
| 31 |
+
warmup_max_lr: 0.0001
|
| 32 |
+
|
| 33 |
+
# frame rates after downsampling
|
| 34 |
+
output_logits_frame_rates: 8.333
|
| 35 |
+
# it means output_logits_frame_rates = input_embd_frame_rates // downsample_rates, because the padding is 0.
|
| 36 |
+
downsample_rates: 3
|
| 37 |
+
# frame rates after downsampling, used by model and post process
|
| 38 |
+
frame_rates: 8.333
|
| 39 |
+
|
| 40 |
+
# ema config
|
| 41 |
+
ema_kwargs:
|
| 42 |
+
{update_after_step: 200}
|
| 43 |
+
|
| 44 |
+
# ============================
|
| 45 |
+
# Loss Functions configuration
|
| 46 |
+
# ============================
|
| 47 |
+
|
| 48 |
+
# Focal loss
|
| 49 |
+
label_focal_loss_weight: 0.2
|
| 50 |
+
|
| 51 |
+
label_focal_loss_alpha: 0.25
|
| 52 |
+
label_focal_loss_gamma: 2.0
|
| 53 |
+
|
| 54 |
+
# Boundary TV loss
|
| 55 |
+
boundary_tvloss_weight: 0.05
|
| 56 |
+
|
| 57 |
+
boundary_tv_loss_beta: 0.6
|
| 58 |
+
boundary_tv_loss_lambda: 0.4
|
| 59 |
+
boundary_tv_loss_boundary_threshold: 0.01
|
| 60 |
+
boundary_tv_loss_reduction_weight: 0.1
|
| 61 |
+
|
| 62 |
+
loss_weight_section: 0.2
|
| 63 |
+
loss_weight_function: 0.8
|
| 64 |
+
|
| 65 |
+
# ============================
|
| 66 |
+
# Training config
|
| 67 |
+
# ============================
|
| 68 |
+
|
| 69 |
+
# Number of neighbors used to augment boundaries in the dataset.
|
| 70 |
+
# Example: 1/25*3 * 10s = 1.2s (both sides total 4.2s)
|
| 71 |
+
num_neighbors: 10
|
| 72 |
+
learn_label: true
|
| 73 |
+
learn_segment: true
|
| 74 |
+
accumulation_steps: 2
|
| 75 |
+
slice_dur: 420
|
| 76 |
+
early_stopping_step: 3
|
| 77 |
+
local_maxima_filter_size: 3
|
| 78 |
+
|
| 79 |
+
# ============================
|
| 80 |
+
# Dataset config
|
| 81 |
+
# ============================
|
| 82 |
+
|
| 83 |
+
train_dataset:
|
| 84 |
+
_target_: dataset.SongFormerDataset.Dataset
|
| 85 |
+
dataset_abstracts:
|
| 86 |
+
[
|
| 87 |
+
{
|
| 88 |
+
"internal_tmp_id": "SongForm-HX-8Class",
|
| 89 |
+
"dataset_type": "SongForm-HX-8Class",
|
| 90 |
+
"input_embedding_dir": "your_data_dir/30s_420s/harmonix/musicfm_hop420/layer_10 your_data_dir/30s_420s/harmonix/muq_hop420/layer_10 your_data_dir/420s/harmonix/musicfm_hop420/layer_10 your_data_dir/420s/harmonix/muq_hop420/layer_10",
|
| 91 |
+
"label_path": "your_data_dir/labels/harmonixset_8class_rule_revision.jsonl",
|
| 92 |
+
"split_ids_path": "your_data_dir/separated_ids/harmonixset_separated_ids_with_val_set/train.txt",
|
| 93 |
+
"multiplier": 4,
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"internal_tmp_id": "SongForm-Private",
|
| 97 |
+
"dataset_type": "SongForm-Private",
|
| 98 |
+
"input_embedding_dir": "your_data_dir/30s_420s/Internal_data/musicfm_hop420/layer_10 your_data_dir/30s_420s/Internal_data/muq_hop420/layer_10 your_data_dir/420s/Internal_data/musicfm_hop420/layer_10 your_data_dir/420s/Internal_data/muq_hop420/layer_10",
|
| 99 |
+
"label_path": "your_data_dir/labels/0006_single_layer_transformer_musicfm_muq_along_time_00_5k_v1.jsonl",
|
| 100 |
+
"split_ids_path": "your_data_dir/separated_ids/internal_data_sofa_clean/train.txt",
|
| 101 |
+
"multiplier": 1,
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
adapter: HookTheoryAdapter,
|
| 105 |
+
internal_tmp_id: "SongForm-Hook",
|
| 106 |
+
structure_jsonl_paths: [
|
| 107 |
+
"your_data_dir/HookTheoryStructure.train.jsonl"
|
| 108 |
+
],
|
| 109 |
+
dataset_type: "SongForm-Hook",
|
| 110 |
+
input_embedding_dir: "your_data_dir/30s_420s/HookTheory/musicfm_hop420/layer_10 your_data_dir/30s_420s/HookTheory/muq_hop420/layer_10 your_data_dir/420s/HookTheory/musicfm_hop420/layer_10 your_data_dir/420s/HookTheory/muq_hop420/layer_10",
|
| 111 |
+
split_ids_path: "your_data_dir/separated_ids/hooktheory_separated_ids/train.txt",
|
| 112 |
+
multiplier: 1,
|
| 113 |
+
},
|
| 114 |
+
]
|
| 115 |
+
hparams:
|
| 116 |
+
output_logits_frame_rates: ${output_logits_frame_rates}
|
| 117 |
+
downsample_rates: ${downsample_rates}
|
| 118 |
+
num_neighbors: ${num_neighbors}
|
| 119 |
+
input_dim: ${input_dim_raw}
|
| 120 |
+
slice_dur: ${slice_dur}
|
| 121 |
+
num_classes: ${num_classes}
|
| 122 |
+
frame_rates: ${frame_rates}
|
| 123 |
+
|
| 124 |
+
eval_dataset:
|
| 125 |
+
_target_: dataset.SongFormerDataset.Dataset
|
| 126 |
+
dataset_abstracts:
|
| 127 |
+
[
|
| 128 |
+
{
|
| 129 |
+
"internal_tmp_id": "SongForm-HX-8Classs_val",
|
| 130 |
+
"dataset_type": "SongForm-HX-8Class",
|
| 131 |
+
"input_embedding_dir": "your_data_dir/30s_420s/harmonix/musicfm_hop420/layer_10 your_data_dir/30s_420s/harmonix/muq_hop420/layer_10 your_data_dir/420s/harmonix/musicfm_hop420/layer_10 your_data_dir/420s/harmonix/muq_hop420/layer_10",
|
| 132 |
+
"label_path": "your_data_dir/processed_data/labels/harmonixset_8class_rule_revision.jsonl",
|
| 133 |
+
"split_ids_path": "your_data_dir/separated_ids/harmonixset_separated_ids_with_val_set/val.txt",
|
| 134 |
+
"multiplier": 1,
|
| 135 |
+
},
|
| 136 |
+
]
|
| 137 |
+
hparams:
|
| 138 |
+
output_logits_frame_rates: ${output_logits_frame_rates}
|
| 139 |
+
downsample_rates: ${downsample_rates}
|
| 140 |
+
num_neighbors: ${num_neighbors}
|
| 141 |
+
input_dim: ${input_dim_raw}
|
| 142 |
+
slice_dur: ${slice_dur}
|
| 143 |
+
num_classes: ${num_classes}
|
| 144 |
+
frame_rates: ${frame_rates}
|
| 145 |
+
|
| 146 |
+
# ============================
|
| 147 |
+
# DataLoader configuration
|
| 148 |
+
# ============================
|
| 149 |
+
|
| 150 |
+
train_dataloader:
|
| 151 |
+
num_workers: 4
|
| 152 |
+
batch_size: 4
|
| 153 |
+
pin_memory: True
|
| 154 |
+
prefetch_factor: 4
|
| 155 |
+
drop_last: True
|
| 156 |
+
persistent_workers: True
|
| 157 |
+
shuffle: true
|
| 158 |
+
|
| 159 |
+
eval_dataloader:
|
| 160 |
+
num_workers: 0
|
| 161 |
+
batch_size: 1
|
| 162 |
+
shuffle: false
|
| 163 |
+
|
| 164 |
+
# ============================
|
| 165 |
+
# Optimizer configuration
|
| 166 |
+
# ============================
|
| 167 |
+
|
| 168 |
+
optimizer:
|
| 169 |
+
lr: ${warmup_max_lr}
|
| 170 |
+
betas: [0.8, 0.999]
|
| 171 |
+
eps: 1e-08
|
| 172 |
+
weight_decay: 3e-7
|
| 173 |
+
|
| 174 |
+
# ============================
|
| 175 |
+
# Training Run configuration
|
| 176 |
+
# ============================
|
| 177 |
+
|
| 178 |
+
args:
|
| 179 |
+
run_name: SongFormer
|
| 180 |
+
model_name: SongFormer
|
| 181 |
+
save_interval: 800
|
| 182 |
+
eval_interval: 800
|
| 183 |
+
checkpoint_dir: output/SongFormer
|
| 184 |
+
max_epochs: 1000
|
| 185 |
+
max_steps: 12010
|
| 186 |
+
tags: null
|
src/SongFormer/dataset/DatasetAdaper.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class DatasetAdapter(ABC):
|
| 5 |
+
"""
|
| 6 |
+
Abstract base class for dataset adapters.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
@abstractmethod
|
| 10 |
+
def __init__(self, *args, **kwargs):
|
| 11 |
+
"""
|
| 12 |
+
Initialize the dataset adapter with necessary parameters.
|
| 13 |
+
"""
|
| 14 |
+
raise NotImplementedError("Subclasses must implement the __init__ method.")
|
| 15 |
+
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def get_ids(self):
|
| 18 |
+
"""
|
| 19 |
+
Get the IDs of the dataset.
|
| 20 |
+
This method should be implemented by subclasses.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
A list or set of IDs representing the dataset. In format: ID + start_time
|
| 24 |
+
must cosider the split of dataset, e.g. train, val, test.
|
| 25 |
+
"""
|
| 26 |
+
raise NotImplementedError("Subclasses must implement this method.")
|
| 27 |
+
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def get_item_json(self, *args, **kwargs):
|
| 30 |
+
"""
|
| 31 |
+
Get the item JSON representation from the dataset.
|
| 32 |
+
"""
|
| 33 |
+
raise NotImplementedError("Subclasses must implement this method.")
|
src/SongFormer/dataset/GeminiOnlyLabelAdapter.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1. It was found that the annotations generated by Gemini are discontinuous between segments
|
| 2 |
+
# (possibly differing by more than 1.7 seconds, accounting for approximately 1/4 to 1/3 of the cases).
|
| 3 |
+
# 2. Gemini's labels can compete with our SOTA model, but Gemini's boundary metrics are very poor.
|
| 4 |
+
# With a tolerance of 3 seconds, they are similar to the metrics of our best model.
|
| 5 |
+
import pdb
|
| 6 |
+
import random
|
| 7 |
+
import os
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import json
|
| 11 |
+
from venv import logger
|
| 12 |
+
import numpy as np
|
| 13 |
+
import math
|
| 14 |
+
from .label2id import (
|
| 15 |
+
DATASET_ID_ALLOWED_LABEL_IDS,
|
| 16 |
+
DATASET_LABEL_TO_DATASET_ID,
|
| 17 |
+
ID_TO_LABEL,
|
| 18 |
+
LABEL_TO_ID,
|
| 19 |
+
)
|
| 20 |
+
from argparse import Namespace
|
| 21 |
+
from scipy.ndimage import gaussian_filter1d
|
| 22 |
+
from .DatasetAdaper import DatasetAdapter
|
| 23 |
+
from omegaconf import ListConfig
|
| 24 |
+
import copy
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Adapter for datasets labeled only by Gemini
|
| 28 |
+
class GeminiOnlyLabelAdapter(DatasetAdapter):
|
| 29 |
+
def __init__(self, **kwargs):
|
| 30 |
+
(
|
| 31 |
+
label_paths,
|
| 32 |
+
hparams,
|
| 33 |
+
internal_tmp_id,
|
| 34 |
+
dataset_type,
|
| 35 |
+
input_embedding_dir,
|
| 36 |
+
split_ids_path,
|
| 37 |
+
) = (
|
| 38 |
+
kwargs["label_paths"],
|
| 39 |
+
kwargs["hparams"],
|
| 40 |
+
kwargs["internal_tmp_id"],
|
| 41 |
+
kwargs["dataset_type"],
|
| 42 |
+
kwargs["input_embedding_dir"],
|
| 43 |
+
kwargs["split_ids_path"],
|
| 44 |
+
)
|
| 45 |
+
self.frame_rates = hparams.frame_rates
|
| 46 |
+
self.hparams = hparams
|
| 47 |
+
self.label_to_id = LABEL_TO_ID
|
| 48 |
+
self.dataset_id_to_dataset_id = DATASET_LABEL_TO_DATASET_ID
|
| 49 |
+
self.id_to_label = ID_TO_LABEL
|
| 50 |
+
self.internal_tmp_id = internal_tmp_id
|
| 51 |
+
self.dataset_type = dataset_type
|
| 52 |
+
self.EPS = 1e-6
|
| 53 |
+
self.dataset_id2label_mask = {}
|
| 54 |
+
for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
|
| 55 |
+
self.dataset_id2label_mask[key] = np.ones(
|
| 56 |
+
self.hparams.num_classes, dtype=bool
|
| 57 |
+
)
|
| 58 |
+
self.dataset_id2label_mask[key][allowed_ids] = False
|
| 59 |
+
|
| 60 |
+
self.id2segments = {}
|
| 61 |
+
data = self.load_jsonl(label_paths)
|
| 62 |
+
|
| 63 |
+
self.input_embedding_dir = input_embedding_dir
|
| 64 |
+
all_input_embedding_dirs = input_embedding_dir.split()
|
| 65 |
+
|
| 66 |
+
valid_data_ids = self.get_ids_from_dir(all_input_embedding_dirs[0])
|
| 67 |
+
|
| 68 |
+
for x in all_input_embedding_dirs:
|
| 69 |
+
valid_data_ids = valid_data_ids.intersection(self.get_ids_from_dir(x))
|
| 70 |
+
split_ids = []
|
| 71 |
+
with open(split_ids_path) as f:
|
| 72 |
+
for line in f:
|
| 73 |
+
if not line.strip():
|
| 74 |
+
continue
|
| 75 |
+
split_ids.append(line.strip())
|
| 76 |
+
split_ids = set(split_ids)
|
| 77 |
+
|
| 78 |
+
valid_data_ids = [
|
| 79 |
+
x for x in valid_data_ids if "_".join(x.split("_")[:-1]) in split_ids
|
| 80 |
+
]
|
| 81 |
+
valid_data_ids = [
|
| 82 |
+
(internal_tmp_id, dataset_type, x, "HookTheoryAdapter")
|
| 83 |
+
for x in valid_data_ids
|
| 84 |
+
]
|
| 85 |
+
self.valid_data_ids = valid_data_ids
|
| 86 |
+
rng = random.Random(42)
|
| 87 |
+
rng.shuffle(self.valid_data_ids)
|
| 88 |
+
for item in data:
|
| 89 |
+
self.id2segments[item["data_id"]] = item["msa_info"]
|
| 90 |
+
|
| 91 |
+
def get_ids_from_dir(self, dir_path: str):
|
| 92 |
+
ids = os.listdir(dir_path)
|
| 93 |
+
ids = [Path(x).stem for x in ids if x.endswith(".npy")]
|
| 94 |
+
return set(ids)
|
| 95 |
+
|
| 96 |
+
def time2frame(self, this_time):
|
| 97 |
+
return int(this_time * self.frame_rates)
|
| 98 |
+
|
| 99 |
+
def load_jsonl(self, paths):
|
| 100 |
+
data = []
|
| 101 |
+
for path in paths:
|
| 102 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 103 |
+
for line in f:
|
| 104 |
+
line = line.strip()
|
| 105 |
+
if not line:
|
| 106 |
+
continue
|
| 107 |
+
obj = json.loads(line)
|
| 108 |
+
data.append(obj)
|
| 109 |
+
return data
|
| 110 |
+
|
| 111 |
+
def get_ids(self):
|
| 112 |
+
return list(self.valid_data_ids)
|
| 113 |
+
|
| 114 |
+
def widen_temporal_events(self, events, num_neighbors):
|
| 115 |
+
def theoretical_gaussian_max(sigma):
|
| 116 |
+
return 1 / (np.sqrt(2 * np.pi) * sigma)
|
| 117 |
+
|
| 118 |
+
widen_events = events
|
| 119 |
+
sigma = num_neighbors / 3.0
|
| 120 |
+
smoothed = gaussian_filter1d(widen_events.astype(float), sigma=sigma)
|
| 121 |
+
smoothed /= theoretical_gaussian_max(sigma)
|
| 122 |
+
smoothed = np.clip(smoothed, 0, 1)
|
| 123 |
+
|
| 124 |
+
return smoothed
|
| 125 |
+
|
| 126 |
+
def get_item_json(self, utt, start_time, end_time):
|
| 127 |
+
embd_list = []
|
| 128 |
+
embd_dirs = self.input_embedding_dir.split()
|
| 129 |
+
for embd_dir in embd_dirs:
|
| 130 |
+
if not Path(embd_dir).exists():
|
| 131 |
+
raise FileNotFoundError(
|
| 132 |
+
f"Embedding directory {embd_dir} does not exist"
|
| 133 |
+
)
|
| 134 |
+
tmp = np.load(Path(embd_dir) / f"{utt}.npy").squeeze(axis=0)
|
| 135 |
+
embd_list.append(tmp)
|
| 136 |
+
|
| 137 |
+
# Check that max and min lengths of all representations differ by at most 2
|
| 138 |
+
if len(embd_list) > 1:
|
| 139 |
+
embd_shapes = [x.shape for x in embd_list]
|
| 140 |
+
max_shape = max(embd_shapes, key=lambda x: x[0])
|
| 141 |
+
min_shape = min(embd_shapes, key=lambda x: x[0])
|
| 142 |
+
if abs(max_shape[0] - min_shape[0]) > 2:
|
| 143 |
+
raise ValueError(
|
| 144 |
+
f"Embedding shapes differ too much: {max_shape} vs {min_shape}"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
for idx in range(len(embd_list)):
|
| 148 |
+
embd_list[idx] = embd_list[idx][: min_shape[0], :]
|
| 149 |
+
|
| 150 |
+
input_embedding = np.concatenate(embd_list, axis=-1)
|
| 151 |
+
|
| 152 |
+
return_json = self._get_item_json_without_embedding(
|
| 153 |
+
"_".join(utt.split("_")[:-1]), start_time, end_time
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if return_json is None:
|
| 157 |
+
logger.warning(
|
| 158 |
+
f"Skip {utt} because no valid segments found in {start_time} to {end_time}."
|
| 159 |
+
)
|
| 160 |
+
return None
|
| 161 |
+
else:
|
| 162 |
+
return_json["input_embedding"] = input_embedding
|
| 163 |
+
return return_json
|
| 164 |
+
|
| 165 |
+
def get_local_times_labels(self, utt):
|
| 166 |
+
assert utt in self.id2segments, f"utt {utt} not found in id2segments"
|
| 167 |
+
time_datas = [x[0] for x in self.id2segments[utt]]
|
| 168 |
+
time_datas = list(map(float, time_datas))
|
| 169 |
+
label_datas = [
|
| 170 |
+
-1 if x[1] == "end" else self.label_to_id[x[1]]
|
| 171 |
+
for x in self.id2segments[utt]
|
| 172 |
+
]
|
| 173 |
+
return np.array(time_datas), label_datas
|
| 174 |
+
|
| 175 |
+
def _get_item_json_without_embedding(self, utt, start_time, end_time):
|
| 176 |
+
SLICE_DUR = int(math.ceil(end_time - start_time))
|
| 177 |
+
|
| 178 |
+
local_times, local_labels = self.get_local_times_labels(utt)
|
| 179 |
+
|
| 180 |
+
local_times, local_labels = (
|
| 181 |
+
copy.deepcopy(local_times),
|
| 182 |
+
copy.deepcopy(local_labels),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
assert np.all(local_times[:-1] < local_times[1:]), (
|
| 186 |
+
f"time must be sorted, but {utt} is {local_times}"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
local_times = local_times - start_time
|
| 190 |
+
|
| 191 |
+
time_L = max(0.0, float(local_times.min()))
|
| 192 |
+
time_R = min(float(SLICE_DUR), float(local_times.max()))
|
| 193 |
+
# Note whether boundary labels are reachable
|
| 194 |
+
keep_boundarys = (time_L + self.EPS < local_times) & (
|
| 195 |
+
local_times < time_R - self.EPS
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# If no valid boundaries, return None
|
| 199 |
+
if keep_boundarys.sum() <= 0:
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
mask = np.ones([int(SLICE_DUR * self.frame_rates)], dtype=bool)
|
| 203 |
+
mask[self.time2frame(time_L) : self.time2frame(time_R)] = False
|
| 204 |
+
|
| 205 |
+
true_boundary = np.zeros([int(SLICE_DUR * self.frame_rates)], dtype=float)
|
| 206 |
+
for idx in np.flatnonzero(keep_boundarys):
|
| 207 |
+
true_boundary[self.time2frame(local_times[idx])] = 1
|
| 208 |
+
|
| 209 |
+
true_function = np.zeros(
|
| 210 |
+
[int(SLICE_DUR * self.frame_rates), self.hparams.num_classes],
|
| 211 |
+
dtype=float,
|
| 212 |
+
)
|
| 213 |
+
true_function_list = []
|
| 214 |
+
msa_info = []
|
| 215 |
+
last_pos = self.time2frame(time_L)
|
| 216 |
+
for idx in np.flatnonzero(keep_boundarys):
|
| 217 |
+
|
| 218 |
+
true_function[
|
| 219 |
+
last_pos : self.time2frame(local_times[idx]),
|
| 220 |
+
int(local_labels[idx - 1]),
|
| 221 |
+
] = 1
|
| 222 |
+
true_function_list.append(
|
| 223 |
+
[int(x) for x in local_labels[idx - 1]]
|
| 224 |
+
if isinstance(local_labels[idx - 1], list)
|
| 225 |
+
else int(local_labels[idx - 1])
|
| 226 |
+
)
|
| 227 |
+
msa_info.append(
|
| 228 |
+
(
|
| 229 |
+
float(max(local_times[idx - 1], time_L)),
|
| 230 |
+
[str(self.id_to_label[int(x)]) for x in local_labels[idx - 1]]
|
| 231 |
+
if isinstance(local_labels[idx - 1], list)
|
| 232 |
+
else str(self.id_to_label[int(local_labels[idx - 1])]),
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
last_pos = self.time2frame(local_times[idx])
|
| 236 |
+
|
| 237 |
+
# Check last label correctness
|
| 238 |
+
true_function[
|
| 239 |
+
last_pos : self.time2frame(time_R),
|
| 240 |
+
local_labels[int(np.flatnonzero(keep_boundarys)[-1])],
|
| 241 |
+
] = 1
|
| 242 |
+
true_function_list.append(
|
| 243 |
+
[int(x) for x in local_labels[int(np.flatnonzero(keep_boundarys)[-1])]]
|
| 244 |
+
if isinstance(local_labels[int(np.flatnonzero(keep_boundarys)[-1])], list)
|
| 245 |
+
else int(local_labels[int(np.flatnonzero(keep_boundarys)[-1])])
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
msa_info.append(
|
| 249 |
+
(
|
| 250 |
+
float(local_times[int(np.flatnonzero(keep_boundarys)[-1])]),
|
| 251 |
+
[
|
| 252 |
+
str(self.id_to_label[int(x)])
|
| 253 |
+
for x in local_labels[int(np.flatnonzero(keep_boundarys)[-1])]
|
| 254 |
+
]
|
| 255 |
+
if isinstance(
|
| 256 |
+
local_labels[int(np.flatnonzero(keep_boundarys)[-1])], list
|
| 257 |
+
)
|
| 258 |
+
else str(
|
| 259 |
+
self.id_to_label[
|
| 260 |
+
int(local_labels[int(np.flatnonzero(keep_boundarys)[-1])])
|
| 261 |
+
]
|
| 262 |
+
),
|
| 263 |
+
)
|
| 264 |
+
)
|
| 265 |
+
# Append final label at end; decide if it's necessary
|
| 266 |
+
msa_info.append((float(time_R), "end"))
|
| 267 |
+
|
| 268 |
+
# Add boundary_mask & function_mask
|
| 269 |
+
frame_len = int(SLICE_DUR * self.frame_rates)
|
| 270 |
+
# During loss computation, boundaries are fully masked
|
| 271 |
+
boundary_mask = np.ones([frame_len], dtype=bool)
|
| 272 |
+
function_mask = np.zeros([frame_len], dtype=bool)
|
| 273 |
+
|
| 274 |
+
# Set masks according to msa_info
|
| 275 |
+
for i in range(len(msa_info) - 1):
|
| 276 |
+
seg_start, seg_label = msa_info[i]
|
| 277 |
+
seg_end, _ = msa_info[i + 1]
|
| 278 |
+
start_frame = self.time2frame(seg_start)
|
| 279 |
+
end_frame = self.time2frame(seg_end)
|
| 280 |
+
|
| 281 |
+
# Handle case where label may be string or list
|
| 282 |
+
is_no_label = (
|
| 283 |
+
seg_label == "NO_LABEL"
|
| 284 |
+
if isinstance(seg_label, str)
|
| 285 |
+
else "NO_LABEL" in seg_label
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if is_no_label:
|
| 289 |
+
# function_mask set True
|
| 290 |
+
function_mask[start_frame:end_frame] = True
|
| 291 |
+
|
| 292 |
+
# ------~~------------
|
| 293 |
+
# During loss computation, boundaries are fully masked
|
| 294 |
+
boundary_mask = np.ones([frame_len], dtype=bool)
|
| 295 |
+
function_mask = np.zeros([frame_len], dtype=bool)
|
| 296 |
+
|
| 297 |
+
# Set masks according to msa_info
|
| 298 |
+
for i in range(len(msa_info) - 1):
|
| 299 |
+
seg_start, seg_label = msa_info[i]
|
| 300 |
+
seg_end, _ = msa_info[i + 1]
|
| 301 |
+
start_frame = self.time2frame(seg_start)
|
| 302 |
+
end_frame = self.time2frame(seg_end)
|
| 303 |
+
|
| 304 |
+
# Handle case where label may be string or list
|
| 305 |
+
is_no_label = (
|
| 306 |
+
seg_label == "NO_LABEL"
|
| 307 |
+
if isinstance(seg_label, str)
|
| 308 |
+
else "NO_LABEL" in seg_label
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if is_no_label:
|
| 312 |
+
# function_mask set True
|
| 313 |
+
function_mask[start_frame:end_frame] = True
|
| 314 |
+
|
| 315 |
+
# return all things except for input_embedding
|
| 316 |
+
return {
|
| 317 |
+
"data_id": self.internal_tmp_id + "_" + f"{utt}_{start_time}",
|
| 318 |
+
"mask": mask,
|
| 319 |
+
"true_boundary": true_boundary,
|
| 320 |
+
"widen_true_boundary": self.widen_temporal_events(
|
| 321 |
+
true_boundary, num_neighbors=self.hparams.num_neighbors
|
| 322 |
+
),
|
| 323 |
+
"true_function": true_function,
|
| 324 |
+
"true_function_list": true_function_list,
|
| 325 |
+
"msa_info": msa_info,
|
| 326 |
+
"dataset_id": self.dataset_id_to_dataset_id[self.dataset_type],
|
| 327 |
+
"label_id_mask": self.dataset_id2label_mask[
|
| 328 |
+
self.dataset_id_to_dataset_id[self.dataset_type]
|
| 329 |
+
],
|
| 330 |
+
"boundary_mask": boundary_mask, # Only effective during loss calculation
|
| 331 |
+
"function_mask": function_mask, # Only effective during loss calculation
|
| 332 |
+
}
|
src/SongFormer/dataset/HookTheoryAdapter.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import os
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
import math
|
| 8 |
+
from .label2id import (
|
| 9 |
+
DATASET_ID_ALLOWED_LABEL_IDS,
|
| 10 |
+
DATASET_LABEL_TO_DATASET_ID,
|
| 11 |
+
ID_TO_LABEL,
|
| 12 |
+
LABEL_TO_ID,
|
| 13 |
+
)
|
| 14 |
+
from argparse import Namespace
|
| 15 |
+
from scipy.ndimage import gaussian_filter1d
|
| 16 |
+
from .DatasetAdaper import DatasetAdapter
|
| 17 |
+
from omegaconf import ListConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class HookTheoryAdapter(DatasetAdapter):
|
| 21 |
+
def __init__(self, **kwargs):
|
| 22 |
+
(
|
| 23 |
+
structure_jsonl_paths,
|
| 24 |
+
hparams,
|
| 25 |
+
internal_tmp_id,
|
| 26 |
+
dataset_type,
|
| 27 |
+
input_embedding_dir,
|
| 28 |
+
split_ids_path,
|
| 29 |
+
) = (
|
| 30 |
+
kwargs["structure_jsonl_paths"],
|
| 31 |
+
kwargs["hparams"],
|
| 32 |
+
kwargs["internal_tmp_id"],
|
| 33 |
+
kwargs["dataset_type"],
|
| 34 |
+
kwargs.get("input_embedding_dir", None),
|
| 35 |
+
kwargs.get("split_ids_path", None),
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# basic attrs
|
| 39 |
+
self.frame_rates = hparams.frame_rates
|
| 40 |
+
self.hparams = hparams
|
| 41 |
+
self.label_to_id = LABEL_TO_ID
|
| 42 |
+
self.dataset_id_to_dataset_id = DATASET_LABEL_TO_DATASET_ID
|
| 43 |
+
self.id_to_label = ID_TO_LABEL
|
| 44 |
+
self.internal_tmp_id = internal_tmp_id
|
| 45 |
+
self.dataset_type = dataset_type
|
| 46 |
+
self.EPS = 1e-6
|
| 47 |
+
|
| 48 |
+
# build dataset-specific label mask
|
| 49 |
+
self.dataset_id2label_mask = {}
|
| 50 |
+
for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
|
| 51 |
+
self.dataset_id2label_mask[key] = np.ones(
|
| 52 |
+
self.hparams.num_classes, dtype=bool
|
| 53 |
+
)
|
| 54 |
+
self.dataset_id2label_mask[key][allowed_ids] = False
|
| 55 |
+
|
| 56 |
+
assert isinstance(structure_jsonl_paths, (ListConfig, tuple, list))
|
| 57 |
+
|
| 58 |
+
# load segments per audio id
|
| 59 |
+
self.id2segments = defaultdict(list)
|
| 60 |
+
data = self.load_jsonl(structure_jsonl_paths)
|
| 61 |
+
|
| 62 |
+
# input embedding dirs (space-separated)
|
| 63 |
+
self.input_embedding_dir = input_embedding_dir
|
| 64 |
+
all_input_embedding_dirs = input_embedding_dir.split()
|
| 65 |
+
|
| 66 |
+
# get valid ids that exist in all embedding dirs
|
| 67 |
+
valid_data_ids = self.get_ids_from_dir(all_input_embedding_dirs[0])
|
| 68 |
+
for x in all_input_embedding_dirs:
|
| 69 |
+
valid_data_ids = valid_data_ids.intersection(self.get_ids_from_dir(x))
|
| 70 |
+
|
| 71 |
+
# read split ids
|
| 72 |
+
split_ids = []
|
| 73 |
+
with open(split_ids_path) as f:
|
| 74 |
+
for line in f:
|
| 75 |
+
if not line.strip():
|
| 76 |
+
continue
|
| 77 |
+
split_ids.append(line.strip())
|
| 78 |
+
split_ids = set(split_ids)
|
| 79 |
+
|
| 80 |
+
# filter valid ids by split
|
| 81 |
+
valid_data_ids = [
|
| 82 |
+
x for x in valid_data_ids if "_".join(x.split("_")[:-1]) in split_ids
|
| 83 |
+
]
|
| 84 |
+
valid_data_ids = [
|
| 85 |
+
(internal_tmp_id, dataset_type, x, "HookTheoryAdapter")
|
| 86 |
+
for x in valid_data_ids
|
| 87 |
+
]
|
| 88 |
+
self.valid_data_ids = valid_data_ids
|
| 89 |
+
|
| 90 |
+
rng = random.Random(42)
|
| 91 |
+
rng.shuffle(self.valid_data_ids)
|
| 92 |
+
|
| 93 |
+
for item in data:
|
| 94 |
+
self.id2segments[Path(item["ori_audio_path"]).stem].append(item)
|
| 95 |
+
# logger.info(f"load {len(self.id2segments)} songs from {structure_jsonl_paths}")
|
| 96 |
+
|
| 97 |
+
def get_ids_from_dir(self, dir_path: str):
|
| 98 |
+
ids = os.listdir(dir_path)
|
| 99 |
+
ids = [Path(x).stem for x in ids if x.endswith(".npy")]
|
| 100 |
+
return set(ids)
|
| 101 |
+
|
| 102 |
+
def time2frame(self, this_time):
|
| 103 |
+
# convert time (s) to frame index
|
| 104 |
+
return int(this_time * self.frame_rates)
|
| 105 |
+
|
| 106 |
+
def load_jsonl(self, paths):
|
| 107 |
+
# load list of jsonl files
|
| 108 |
+
data = []
|
| 109 |
+
for path in paths:
|
| 110 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 111 |
+
for line in f:
|
| 112 |
+
line = line.strip()
|
| 113 |
+
if not line:
|
| 114 |
+
continue
|
| 115 |
+
obj = json.loads(line)
|
| 116 |
+
data.append(obj)
|
| 117 |
+
return data
|
| 118 |
+
|
| 119 |
+
def split_and_label(self, query_start, query_end, segments):
|
| 120 |
+
"""
|
| 121 |
+
segments: List of dicts, each with keys: "segment_start", "segment_end", 'label'
|
| 122 |
+
"""
|
| 123 |
+
# Step 1: collect all boundary points (only within query interval)
|
| 124 |
+
points = set([query_start, query_end])
|
| 125 |
+
for seg in segments:
|
| 126 |
+
if query_start <= seg["segment_start"] <= query_end:
|
| 127 |
+
points.add(seg["segment_start"])
|
| 128 |
+
if query_start <= seg["segment_end"] <= query_end:
|
| 129 |
+
points.add(seg["segment_end"])
|
| 130 |
+
sorted_points = sorted(points)
|
| 131 |
+
|
| 132 |
+
result = []
|
| 133 |
+
# Step 2: for each small interval, check which segments cover it
|
| 134 |
+
for i in range(len(sorted_points) - 1):
|
| 135 |
+
part_start = sorted_points[i]
|
| 136 |
+
part_end = sorted_points[i + 1]
|
| 137 |
+
labels = []
|
| 138 |
+
for seg in segments:
|
| 139 |
+
if (
|
| 140 |
+
seg["segment_start"] <= part_start
|
| 141 |
+
and seg["segment_end"] >= part_end
|
| 142 |
+
):
|
| 143 |
+
labels.extend(seg["label"])
|
| 144 |
+
if not labels:
|
| 145 |
+
labels = ["NO_LABEL"]
|
| 146 |
+
result.append(
|
| 147 |
+
{"segment_start": part_start, "segment_end": part_end, "labels": labels}
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# deduplicate labels per interval
|
| 151 |
+
for idx in range(len(result)):
|
| 152 |
+
result[idx]["labels"] = list(set(result[idx]["labels"]))
|
| 153 |
+
return result
|
| 154 |
+
|
| 155 |
+
def merge_small_intervals(self, parts, min_duration=2.0):
|
| 156 |
+
"""
|
| 157 |
+
parts: list of dicts with "segment_start", "segment_end", 'labels'
|
| 158 |
+
Merge intervals shorter than min_duration into neighbor intervals.
|
| 159 |
+
"""
|
| 160 |
+
new_parts = []
|
| 161 |
+
i = 0
|
| 162 |
+
while i < len(parts):
|
| 163 |
+
part = parts[i]
|
| 164 |
+
duration = part["segment_end"] - part["segment_start"]
|
| 165 |
+
if duration < min_duration:
|
| 166 |
+
# decide where to merge
|
| 167 |
+
if len(new_parts) > 0 and (i + 1) < len(parts):
|
| 168 |
+
# randomly choose previous or next
|
| 169 |
+
if random.choice([True, False]):
|
| 170 |
+
prev = new_parts[-1]
|
| 171 |
+
prev["segment_end"] = part["segment_end"]
|
| 172 |
+
else:
|
| 173 |
+
next_part = parts[i + 1]
|
| 174 |
+
next_part["segment_start"] = part["segment_start"]
|
| 175 |
+
# skip adding this part
|
| 176 |
+
elif len(new_parts) > 0:
|
| 177 |
+
# only previous exists - merge into previous
|
| 178 |
+
prev = new_parts[-1]
|
| 179 |
+
prev["segment_end"] = part["segment_end"]
|
| 180 |
+
elif (i + 1) < len(parts):
|
| 181 |
+
# only next exists - merge into next
|
| 182 |
+
next_part = parts[i + 1]
|
| 183 |
+
next_part["segment_start"] = part["segment_start"]
|
| 184 |
+
# else: nothing to merge, drop
|
| 185 |
+
i += 1
|
| 186 |
+
else:
|
| 187 |
+
new_parts.append(part)
|
| 188 |
+
i += 1
|
| 189 |
+
return new_parts
|
| 190 |
+
|
| 191 |
+
def rounding_time(self, segments, num_decimals=3):
|
| 192 |
+
# round segment boundaries to given decimals
|
| 193 |
+
for idx in range(len(segments)):
|
| 194 |
+
segments[idx]["segment_start"] = round(
|
| 195 |
+
segments[idx]["segment_start"], num_decimals
|
| 196 |
+
)
|
| 197 |
+
segments[idx]["segment_end"] = round(
|
| 198 |
+
segments[idx]["segment_end"], num_decimals
|
| 199 |
+
)
|
| 200 |
+
return segments
|
| 201 |
+
|
| 202 |
+
def get_ids(self):
|
| 203 |
+
return list(self.valid_data_ids)
|
| 204 |
+
|
| 205 |
+
def convert_label(self, label: str):
|
| 206 |
+
# map various labels to canonical labels
|
| 207 |
+
mapping = {
|
| 208 |
+
"chorus": "chorus",
|
| 209 |
+
"intro": "intro",
|
| 210 |
+
"bridge": "bridge",
|
| 211 |
+
"verse": "verse",
|
| 212 |
+
"pre-chorus": "pre-chorus",
|
| 213 |
+
"solo": "inst",
|
| 214 |
+
"instrumental": "inst",
|
| 215 |
+
"outro": "outro",
|
| 216 |
+
"NO_LABEL": "NO_LABEL",
|
| 217 |
+
}
|
| 218 |
+
assert label in mapping, f"Unknown label: {label}"
|
| 219 |
+
return mapping[label]
|
| 220 |
+
|
| 221 |
+
def parts_to_label_and_times(self, parts, use_random_tag=True):
|
| 222 |
+
"""
|
| 223 |
+
parts: list of dicts with 'segment_start', 'segment_end', 'labels'
|
| 224 |
+
|
| 225 |
+
if use_random_tag: label will be random from valid labels
|
| 226 |
+
else: label will be all valid labels (labels list)
|
| 227 |
+
|
| 228 |
+
return:
|
| 229 |
+
local_times: np.array of right boundary time points (excluding query_end)
|
| 230 |
+
local_labels: list of label indices corresponding to self.label_to_id
|
| 231 |
+
"""
|
| 232 |
+
local_times = []
|
| 233 |
+
local_labels = []
|
| 234 |
+
|
| 235 |
+
for part in parts:
|
| 236 |
+
local_times.append(part["segment_start"])
|
| 237 |
+
label = random.choice(part["labels"]) if use_random_tag else part["labels"]
|
| 238 |
+
local_labels.append(self.label_to_id[self.convert_label(label)])
|
| 239 |
+
return np.array(local_times), local_labels
|
| 240 |
+
|
| 241 |
+
def get_parts(self, utt, query_start, query_end):
|
| 242 |
+
key = "_".join(utt.split("_")[:-1])
|
| 243 |
+
assert key in self.id2segments
|
| 244 |
+
segments = self.id2segments[key]
|
| 245 |
+
segments = self.rounding_time(segments)
|
| 246 |
+
parts = self.split_and_label(query_start, query_end, segments)
|
| 247 |
+
|
| 248 |
+
# Apply merging twice to remove very short intervals
|
| 249 |
+
new_parts = self.merge_small_intervals(parts, min_duration=2.0)
|
| 250 |
+
new_parts = self.merge_small_intervals(new_parts, min_duration=2.0)
|
| 251 |
+
|
| 252 |
+
return new_parts
|
| 253 |
+
|
| 254 |
+
def widen_temporal_events(self, events, num_neighbors):
|
| 255 |
+
# smooth binary events with a normalized gaussian
|
| 256 |
+
def theoretical_gaussian_max(sigma):
|
| 257 |
+
return 1 / (np.sqrt(2 * np.pi) * sigma)
|
| 258 |
+
|
| 259 |
+
widen_events = events
|
| 260 |
+
sigma = num_neighbors / 3.0
|
| 261 |
+
smoothed = gaussian_filter1d(widen_events.astype(float), sigma=sigma)
|
| 262 |
+
smoothed /= theoretical_gaussian_max(sigma)
|
| 263 |
+
smoothed = np.clip(smoothed, 0, 1)
|
| 264 |
+
|
| 265 |
+
return smoothed
|
| 266 |
+
|
| 267 |
+
def get_item_json(self, utt, start_time, end_time):
|
| 268 |
+
# load embeddings from all embedding dirs
|
| 269 |
+
embd_list = []
|
| 270 |
+
embd_dirs = self.input_embedding_dir.split()
|
| 271 |
+
for embd_dir in embd_dirs:
|
| 272 |
+
if not Path(embd_dir).exists():
|
| 273 |
+
raise FileNotFoundError(
|
| 274 |
+
f"Embedding directory {embd_dir} does not exist"
|
| 275 |
+
)
|
| 276 |
+
tmp = np.load(Path(embd_dir) / f"{utt}.npy").squeeze(axis=0)
|
| 277 |
+
embd_list.append(tmp)
|
| 278 |
+
|
| 279 |
+
# Check that max/min length difference across embeddings <= 2
|
| 280 |
+
if len(embd_list) > 1:
|
| 281 |
+
embd_shapes = [x.shape for x in embd_list]
|
| 282 |
+
max_shape = max(embd_shapes, key=lambda x: x[0])
|
| 283 |
+
min_shape = min(embd_shapes, key=lambda x: x[0])
|
| 284 |
+
if abs(max_shape[0] - min_shape[0]) > 2:
|
| 285 |
+
raise ValueError(
|
| 286 |
+
f"Embedding shapes differ too much: {max_shape} vs {min_shape}"
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
for idx in range(len(embd_list)):
|
| 290 |
+
embd_list[idx] = embd_list[idx][: min_shape[0], :]
|
| 291 |
+
|
| 292 |
+
input_embedding = np.concatenate(embd_list, axis=-1)
|
| 293 |
+
|
| 294 |
+
return_json = self.get_item_json_without_embedding(utt, start_time, end_time)
|
| 295 |
+
if return_json is None:
|
| 296 |
+
return None
|
| 297 |
+
else:
|
| 298 |
+
return_json["input_embedding"] = input_embedding
|
| 299 |
+
return return_json
|
| 300 |
+
|
| 301 |
+
def get_item_json_without_embedding(self, utt, start_time, end_time):
|
| 302 |
+
SLICE_DUR = int(math.ceil(end_time - start_time))
|
| 303 |
+
|
| 304 |
+
local_times, local_labels = self.parts_to_label_and_times(
|
| 305 |
+
self.get_parts(utt, start_time, end_time)
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
assert np.all(local_times[:-1] < local_times[1:]), (
|
| 309 |
+
f"time must be sorted, but {utt} is {local_times}"
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# normalize local times relative to slice start
|
| 313 |
+
local_times = local_times - start_time
|
| 314 |
+
time_L = 0.0
|
| 315 |
+
# here time_R is full slice duration because NO_LABEL may appear
|
| 316 |
+
time_R = float(SLICE_DUR)
|
| 317 |
+
|
| 318 |
+
# determine which boundaries are within (time_L, time_R)
|
| 319 |
+
keep_boundarys = (time_L + self.EPS < local_times) & (
|
| 320 |
+
local_times < time_R - self.EPS
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# if no valid boundary, return None
|
| 324 |
+
if keep_boundarys.sum() <= 0:
|
| 325 |
+
return None
|
| 326 |
+
|
| 327 |
+
mask = np.ones([int(SLICE_DUR * self.frame_rates)], dtype=bool)
|
| 328 |
+
mask[self.time2frame(time_L) : self.time2frame(time_R)] = False
|
| 329 |
+
|
| 330 |
+
true_boundary = np.zeros([int(SLICE_DUR * self.frame_rates)], dtype=float)
|
| 331 |
+
for idx in np.flatnonzero(keep_boundarys):
|
| 332 |
+
true_boundary[self.time2frame(local_times[idx])] = 1
|
| 333 |
+
|
| 334 |
+
true_function = np.zeros(
|
| 335 |
+
[int(SLICE_DUR * self.frame_rates), self.hparams.num_classes],
|
| 336 |
+
dtype=float,
|
| 337 |
+
)
|
| 338 |
+
true_function_list = []
|
| 339 |
+
msa_info = []
|
| 340 |
+
last_pos = self.time2frame(time_L)
|
| 341 |
+
for idx in np.flatnonzero(keep_boundarys):
|
| 342 |
+
# local_labels[idx] might be int or list(int)
|
| 343 |
+
true_function[
|
| 344 |
+
last_pos : self.time2frame(local_times[idx]),
|
| 345 |
+
local_labels[idx - 1],
|
| 346 |
+
] = 1
|
| 347 |
+
true_function_list.append(
|
| 348 |
+
[int(x) for x in local_labels[idx - 1]]
|
| 349 |
+
if isinstance(local_labels[idx - 1], list)
|
| 350 |
+
else int(local_labels[idx - 1])
|
| 351 |
+
)
|
| 352 |
+
msa_info.append(
|
| 353 |
+
(
|
| 354 |
+
float(max(local_times[idx - 1], time_L)),
|
| 355 |
+
[str(self.id_to_label[int(x)]) for x in local_labels[idx - 1]]
|
| 356 |
+
if isinstance(local_labels[idx - 1], list)
|
| 357 |
+
else str(self.id_to_label[int(local_labels[idx - 1])]),
|
| 358 |
+
)
|
| 359 |
+
)
|
| 360 |
+
last_pos = self.time2frame(local_times[idx])
|
| 361 |
+
|
| 362 |
+
# check last label correctness
|
| 363 |
+
true_function[
|
| 364 |
+
last_pos : self.time2frame(time_R),
|
| 365 |
+
local_labels[int(np.flatnonzero(keep_boundarys)[-1])],
|
| 366 |
+
] = 1
|
| 367 |
+
true_function_list.append(
|
| 368 |
+
[int(x) for x in local_labels[int(np.flatnonzero(keep_boundarys)[-1])]]
|
| 369 |
+
if isinstance(local_labels[int(np.flatnonzero(keep_boundarys)[-1])], list)
|
| 370 |
+
else int(local_labels[int(np.flatnonzero(keep_boundarys)[-1])])
|
| 371 |
+
)
|
| 372 |
+
msa_info.append(
|
| 373 |
+
(
|
| 374 |
+
float(local_times[int(np.flatnonzero(keep_boundarys)[-1])]),
|
| 375 |
+
[
|
| 376 |
+
str(self.id_to_label[int(x)])
|
| 377 |
+
for x in local_labels[int(np.flatnonzero(keep_boundarys)[-1])]
|
| 378 |
+
]
|
| 379 |
+
if isinstance(
|
| 380 |
+
local_labels[int(np.flatnonzero(keep_boundarys)[-1])], list
|
| 381 |
+
)
|
| 382 |
+
else str(
|
| 383 |
+
self.id_to_label[
|
| 384 |
+
int(local_labels[int(np.flatnonzero(keep_boundarys)[-1])])
|
| 385 |
+
]
|
| 386 |
+
),
|
| 387 |
+
)
|
| 388 |
+
)
|
| 389 |
+
# append final "end" marker
|
| 390 |
+
msa_info.append((float(time_R), "end"))
|
| 391 |
+
|
| 392 |
+
# -------------------------
|
| 393 |
+
# boundary_mask & function_mask
|
| 394 |
+
# -------------------------
|
| 395 |
+
frame_len = int(SLICE_DUR * self.frame_rates)
|
| 396 |
+
boundary_mask = np.zeros([frame_len], dtype=bool)
|
| 397 |
+
function_mask = np.zeros([frame_len], dtype=bool)
|
| 398 |
+
|
| 399 |
+
# set masks according to msa_info
|
| 400 |
+
for i in range(len(msa_info) - 1):
|
| 401 |
+
seg_start, seg_label = msa_info[i]
|
| 402 |
+
seg_end, _ = msa_info[i + 1]
|
| 403 |
+
start_frame = self.time2frame(seg_start)
|
| 404 |
+
end_frame = self.time2frame(seg_end)
|
| 405 |
+
|
| 406 |
+
# handle label being string or list
|
| 407 |
+
is_no_label = (
|
| 408 |
+
seg_label == "NO_LABEL"
|
| 409 |
+
if isinstance(seg_label, str)
|
| 410 |
+
else "NO_LABEL" in seg_label
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
if is_no_label:
|
| 414 |
+
# set function_mask True for NO_LABEL regions
|
| 415 |
+
function_mask[start_frame:end_frame] = True
|
| 416 |
+
|
| 417 |
+
# set boundary_mask True for regions >4s away from ends
|
| 418 |
+
left_offset = self.time2frame(seg_start + 4)
|
| 419 |
+
right_offset = self.time2frame(seg_end - 4)
|
| 420 |
+
if i == 0:
|
| 421 |
+
if right_offset > 0:
|
| 422 |
+
boundary_mask[0 : min(right_offset, frame_len)] = True
|
| 423 |
+
elif i == len(msa_info) - 2:
|
| 424 |
+
if left_offset < frame_len:
|
| 425 |
+
boundary_mask[left_offset:frame_len] = True
|
| 426 |
+
elif right_offset > left_offset:
|
| 427 |
+
boundary_mask[left_offset:right_offset] = True
|
| 428 |
+
|
| 429 |
+
# -------------------------
|
| 430 |
+
# return all things except input_embedding
|
| 431 |
+
# -------------------------
|
| 432 |
+
return {
|
| 433 |
+
"data_id": self.internal_tmp_id + "_" + f"{utt}_{start_time}",
|
| 434 |
+
"mask": mask,
|
| 435 |
+
"true_boundary": true_boundary,
|
| 436 |
+
"widen_true_boundary": self.widen_temporal_events(
|
| 437 |
+
true_boundary, num_neighbors=self.hparams.num_neighbors
|
| 438 |
+
),
|
| 439 |
+
"true_function": true_function,
|
| 440 |
+
"true_function_list": true_function_list,
|
| 441 |
+
"msa_info": msa_info,
|
| 442 |
+
"dataset_id": self.dataset_id_to_dataset_id[self.dataset_type],
|
| 443 |
+
"label_id_mask": self.dataset_id2label_mask[
|
| 444 |
+
self.dataset_id_to_dataset_id[self.dataset_type]
|
| 445 |
+
],
|
| 446 |
+
"boundary_mask": boundary_mask, # only effective during loss computation
|
| 447 |
+
"function_mask": function_mask, # only effective during loss computation
|
| 448 |
+
}
|
src/SongFormer/dataset/custom_types.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MsaInfo
|
| 3 |
+
A list of (timestamp, label) tuples used to represent music structure
|
| 4 |
+
analysis (MSA). The first element of the tuple is a float timestamp
|
| 5 |
+
(in seconds) and the second is a string label
|
| 6 |
+
|
| 7 |
+
Example
|
| 8 |
+
-------
|
| 9 |
+
>>> msa: MsaInfo = [(0.0, "intro"), (12.5, "verse"), (34.0, "chorus")]
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from typing import List, Tuple
|
| 13 |
+
|
| 14 |
+
MsaInfo = List[Tuple[float, str]]
|
src/SongFormer/dataset/label2id.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LABEL_TO_ID = {
|
| 2 |
+
"intro": 0,
|
| 3 |
+
"verse": 1,
|
| 4 |
+
"chorus": 2,
|
| 5 |
+
"bridge": 3,
|
| 6 |
+
"inst": 4,
|
| 7 |
+
"outro": 5,
|
| 8 |
+
"silence": 6,
|
| 9 |
+
"intchorus": 7,
|
| 10 |
+
"prechorus": 8,
|
| 11 |
+
"gtrbreak": 9,
|
| 12 |
+
"solo": 10,
|
| 13 |
+
"quietchorus": 11,
|
| 14 |
+
"bre": 12,
|
| 15 |
+
"break": 13,
|
| 16 |
+
"introverse": 14,
|
| 17 |
+
"mainriff": 15,
|
| 18 |
+
"chorushalf": 16,
|
| 19 |
+
"instintro": 17,
|
| 20 |
+
"gtr": 18,
|
| 21 |
+
"vocaloutro": 19,
|
| 22 |
+
"verse_slow": 20,
|
| 23 |
+
"fadein": 21,
|
| 24 |
+
"saxobeat": 22,
|
| 25 |
+
"transition": 23,
|
| 26 |
+
"verse1a": 24,
|
| 27 |
+
"build": 25,
|
| 28 |
+
"pre-chorus": 26,
|
| 29 |
+
"outroa": 27,
|
| 30 |
+
"bigoutro": 28,
|
| 31 |
+
"fast": 29,
|
| 32 |
+
"instrumentalverse": 30,
|
| 33 |
+
"section": 31,
|
| 34 |
+
"choruspart": 32,
|
| 35 |
+
"instbridge": 33,
|
| 36 |
+
"guitar": 34,
|
| 37 |
+
"instrumental": 35,
|
| 38 |
+
"breakdown": 36,
|
| 39 |
+
"rhythmlessintro": 37,
|
| 40 |
+
"intropt": 38,
|
| 41 |
+
"interlude": 39,
|
| 42 |
+
"postchorus": 40,
|
| 43 |
+
"postverse": 41,
|
| 44 |
+
"opening": 42,
|
| 45 |
+
"altchorus": 43,
|
| 46 |
+
"stutter": 44,
|
| 47 |
+
"oddriff": 45,
|
| 48 |
+
"synth": 46,
|
| 49 |
+
"preverse": 47,
|
| 50 |
+
"quiet": 48,
|
| 51 |
+
"raps": 49,
|
| 52 |
+
"verseinst": 50,
|
| 53 |
+
"instchorus": 51,
|
| 54 |
+
"chorus_instrumental": 52,
|
| 55 |
+
"slowverse": 53,
|
| 56 |
+
"slow": 54,
|
| 57 |
+
"worstthingever": 55,
|
| 58 |
+
"transition2a": 56,
|
| 59 |
+
"miniverse": 57,
|
| 60 |
+
"refrain": 58,
|
| 61 |
+
"introchorus": 59,
|
| 62 |
+
"drumroll": 60,
|
| 63 |
+
"guitarsolo": 61,
|
| 64 |
+
"versepart": 62,
|
| 65 |
+
"chorusinst": 63,
|
| 66 |
+
"ending": 64,
|
| 67 |
+
"no-vocal-intro": 65,
|
| 68 |
+
"no-vocal-interlude": 66,
|
| 69 |
+
"no-vocal-outro": 67,
|
| 70 |
+
"NO_LABEL": 68, # Only referring to cases without labels, this portion of labels will be ignored during the loss calculation process.
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
ID_TO_LABEL = {v: k for k, v in LABEL_TO_ID.items()}
|
| 74 |
+
|
| 75 |
+
# Reserve 64 embedding positions for dataset identifiers in the model.
|
| 76 |
+
DATASET_LABEL_TO_DATASET_ID = {
|
| 77 |
+
"SongForm-HX-7Class": 0, # Categories after rule mapping for HarmonixSet
|
| 78 |
+
"SongForm-HX-Widen": 1, # Original HarmonixSet
|
| 79 |
+
"SongForm-Private-Raw": 2,
|
| 80 |
+
"SongForm-Private": 3,
|
| 81 |
+
"SongForm-HX-Gemini-Relabeled": 4, # Rule-mapped HarmonixSet corrected by Gemini
|
| 82 |
+
"SongForm-HX-8Class": 5, # Rule-mapped (pre-chorus retained)
|
| 83 |
+
"SongForm-Hook": 6,
|
| 84 |
+
"SongForm-Gem": 7,
|
| 85 |
+
"SongForm-Gem-Only-Label": 8, # Use only segments with labels in SongForm-Gem
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
DATASET_ID_TO_DATASET_LABEL = {v: k for k, v in DATASET_LABEL_TO_DATASET_ID.items()}
|
| 89 |
+
|
| 90 |
+
DATASET_ID_ALLOWED_LABEL_IDS = {
|
| 91 |
+
0: [0, 1, 2, 3, 4, 5, 6],
|
| 92 |
+
1: [
|
| 93 |
+
0,
|
| 94 |
+
1,
|
| 95 |
+
2,
|
| 96 |
+
3,
|
| 97 |
+
4,
|
| 98 |
+
5,
|
| 99 |
+
6,
|
| 100 |
+
7,
|
| 101 |
+
8,
|
| 102 |
+
9,
|
| 103 |
+
10,
|
| 104 |
+
11,
|
| 105 |
+
12,
|
| 106 |
+
13,
|
| 107 |
+
14,
|
| 108 |
+
15,
|
| 109 |
+
16,
|
| 110 |
+
17,
|
| 111 |
+
18,
|
| 112 |
+
19,
|
| 113 |
+
20,
|
| 114 |
+
21,
|
| 115 |
+
22,
|
| 116 |
+
23,
|
| 117 |
+
24,
|
| 118 |
+
25,
|
| 119 |
+
27,
|
| 120 |
+
28,
|
| 121 |
+
29,
|
| 122 |
+
30,
|
| 123 |
+
31,
|
| 124 |
+
32,
|
| 125 |
+
33,
|
| 126 |
+
34,
|
| 127 |
+
35,
|
| 128 |
+
36,
|
| 129 |
+
37,
|
| 130 |
+
38,
|
| 131 |
+
40,
|
| 132 |
+
41,
|
| 133 |
+
42,
|
| 134 |
+
43,
|
| 135 |
+
44,
|
| 136 |
+
45,
|
| 137 |
+
46,
|
| 138 |
+
47,
|
| 139 |
+
48,
|
| 140 |
+
49,
|
| 141 |
+
50,
|
| 142 |
+
51,
|
| 143 |
+
52,
|
| 144 |
+
53,
|
| 145 |
+
54,
|
| 146 |
+
55,
|
| 147 |
+
56,
|
| 148 |
+
57,
|
| 149 |
+
58,
|
| 150 |
+
59,
|
| 151 |
+
60,
|
| 152 |
+
61,
|
| 153 |
+
62,
|
| 154 |
+
63,
|
| 155 |
+
],
|
| 156 |
+
2: [0, 1, 2, 3, 26, 39, 64, 65, 66, 67],
|
| 157 |
+
3: [0, 1, 2, 3, 4, 5, 6, 26, 39, 64, 65, 66, 67],
|
| 158 |
+
4: [0, 1, 2, 3, 4, 5, 6, 26],
|
| 159 |
+
5: [0, 1, 2, 3, 4, 5, 6, 26],
|
| 160 |
+
6: [0, 1, 2, 3, 4, 5, 6, 26],
|
| 161 |
+
7: [0, 1, 2, 3, 4, 5, 6, 26],
|
| 162 |
+
8: [0, 1, 2, 3, 4, 5, 6, 26],
|
| 163 |
+
}
|
src/SongFormer/dataset/msa_info_utils.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataset.custom_types import MsaInfo
|
| 2 |
+
from dataset.label2id import LABEL_TO_ID
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def load_msa_info(msa_info_path):
|
| 6 |
+
msa_info: MsaInfo = []
|
| 7 |
+
with open(msa_info_path) as f:
|
| 8 |
+
for line in f:
|
| 9 |
+
line = line.strip()
|
| 10 |
+
if not line:
|
| 11 |
+
continue
|
| 12 |
+
time_, label = line.split()
|
| 13 |
+
time_ = float(time_)
|
| 14 |
+
label = str(label)
|
| 15 |
+
assert label in LABEL_TO_ID or label == "end", f"{label} not in LABEL_TO_ID"
|
| 16 |
+
msa_info.append((time_, label))
|
| 17 |
+
assert msa_info[-1][1] == "end", f"last {msa_info[-1][1]} != end"
|
| 18 |
+
return msa_info
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_msa_infos(msa_str):
|
| 22 |
+
msa_info: MsaInfo = []
|
| 23 |
+
for line in msa_str:
|
| 24 |
+
line = line.strip()
|
| 25 |
+
if not line:
|
| 26 |
+
continue
|
| 27 |
+
time_, label = line.split()
|
| 28 |
+
time_ = float(time_)
|
| 29 |
+
label = str(label)
|
| 30 |
+
assert label in LABEL_TO_ID or label == "end", f"{label} not in LABEL_TO_ID"
|
| 31 |
+
msa_info.append((time_, label))
|
| 32 |
+
assert msa_info[-1][1] == "end", f"last {msa_info[-1][1]} != end"
|
| 33 |
+
return msa_info
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def dump_msa_info(msa_info_path, msa_info: MsaInfo):
|
| 37 |
+
with open(msa_info_path, "w") as f:
|
| 38 |
+
for time_, label in msa_info:
|
| 39 |
+
f.write(f"{time_} {label}\n")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def dump_msa_infos(msa_info: MsaInfo):
|
| 43 |
+
mas_strs = []
|
| 44 |
+
for time_, label in msa_info:
|
| 45 |
+
mas_strs.append(f"{round(time_, 2)} {label}")
|
| 46 |
+
|
| 47 |
+
return "\n".join(mas_strs)
|
src/SongFormer/eval.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export CUDA_VISIBLE_DEVICES=-1
|
| 2 |
+
export PYTHONPATH=${PWD}:$PYTHONPATH
|
| 3 |
+
|
| 4 |
+
export HYDRA_FULL_ERROR=1
|
| 5 |
+
export OMP_NUM_THREADS=1
|
| 6 |
+
export MPI_NUM_THREADS=1
|
| 7 |
+
export NCCL_P2P_DISABLE=1
|
| 8 |
+
export NCCL_IB_DISABLE=1
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
EST_DIR=
|
| 12 |
+
ANN_DIR=
|
| 13 |
+
OUTPUT_DIR=
|
| 14 |
+
echo "$EST_DIR --> $OUTPUT_DIR"
|
| 15 |
+
mkdir -p "$OUTPUT_DIR"
|
| 16 |
+
|
| 17 |
+
python evaluation/eval_infer_results.py \
|
| 18 |
+
--ann_dir $ANN_DIR \
|
| 19 |
+
--est_dir $EST_DIR \
|
| 20 |
+
--output_dir $OUTPUT_DIR \
|
| 21 |
+
--prechorus2what verse
|
| 22 |
+
# --armerge_continuous_segments
|
src/SongFormer/evaluation/eval_infer_results.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import mir_eval
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from dataset.custom_types import MsaInfo
|
| 9 |
+
from dataset.label2id import LABEL_TO_ID
|
| 10 |
+
from dataset.msa_info_utils import load_msa_info
|
| 11 |
+
from msaf.eval import compute_results
|
| 12 |
+
from postprocessing.calc_acc import cal_acc
|
| 13 |
+
from postprocessing.calc_iou import cal_iou
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from loguru import logger
|
| 16 |
+
|
| 17 |
+
LEGAL_LABELS = {
|
| 18 |
+
"end",
|
| 19 |
+
"intro",
|
| 20 |
+
"verse",
|
| 21 |
+
"chorus",
|
| 22 |
+
"bridge",
|
| 23 |
+
"inst",
|
| 24 |
+
"outro",
|
| 25 |
+
"silence",
|
| 26 |
+
"pre-chorus",
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def to_inters_labels(msa_info: MsaInfo):
|
| 31 |
+
label_ids = np.array([LABEL_TO_ID[x[1]] for x in msa_info[:-1]])
|
| 32 |
+
times = [x[0] for x in msa_info]
|
| 33 |
+
start_times = np.column_stack([np.array(times[:-1]), np.array(times[1:])])
|
| 34 |
+
return start_times, label_ids
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def merge_continuous_segments(segments):
|
| 38 |
+
"""
|
| 39 |
+
Merge continuous segments with the same label.
|
| 40 |
+
|
| 41 |
+
Parameters:
|
| 42 |
+
segments: List of tuples [(start_time, label), ...], where the last element is (end_time, 'end')
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Merged segment list in the same format [(start_time, label), ...], with the last element being (end_time, 'end')
|
| 46 |
+
"""
|
| 47 |
+
if not segments or len(segments) < 2:
|
| 48 |
+
return segments
|
| 49 |
+
|
| 50 |
+
merged = []
|
| 51 |
+
current_start = segments[0][0]
|
| 52 |
+
current_label = segments[0][1]
|
| 53 |
+
|
| 54 |
+
for i in range(1, len(segments)):
|
| 55 |
+
time, label = segments[i]
|
| 56 |
+
|
| 57 |
+
if label == "end":
|
| 58 |
+
if current_label != "end":
|
| 59 |
+
merged.append((current_start, current_label))
|
| 60 |
+
merged.append((time, "end"))
|
| 61 |
+
break
|
| 62 |
+
|
| 63 |
+
if label != current_label:
|
| 64 |
+
merged.append((current_start, current_label))
|
| 65 |
+
current_start = time
|
| 66 |
+
current_label = label
|
| 67 |
+
|
| 68 |
+
return merged
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def main():
|
| 72 |
+
argparser = argparse.ArgumentParser()
|
| 73 |
+
argparser.add_argument("--ann_dir", type=str, required=True)
|
| 74 |
+
argparser.add_argument("--est_dir", type=str, required=True)
|
| 75 |
+
argparser.add_argument("--output_dir", type=str, default="./eval_infer_results")
|
| 76 |
+
argparser.add_argument("--prechorus2what", type=str, default=None)
|
| 77 |
+
argparser.add_argument("--armerge_continuous_segments", action="store_true")
|
| 78 |
+
args = argparser.parse_args()
|
| 79 |
+
|
| 80 |
+
ann_dir = args.ann_dir
|
| 81 |
+
est_dir = args.est_dir
|
| 82 |
+
output_dir = args.output_dir
|
| 83 |
+
if args.armerge_continuous_segments:
|
| 84 |
+
logger.info("Merging continuous segments")
|
| 85 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
ann_id_lists = [x for x in os.listdir(ann_dir) if x.endswith(".txt")]
|
| 88 |
+
est_id_lists = [x for x in os.listdir(est_dir) if x.endswith(".txt")]
|
| 89 |
+
|
| 90 |
+
common_id_lists = set(ann_id_lists) & set(est_id_lists)
|
| 91 |
+
common_id_lists = list(common_id_lists)
|
| 92 |
+
logger.info(f"Common number of files: {len(common_id_lists)}")
|
| 93 |
+
|
| 94 |
+
resultes = []
|
| 95 |
+
ious = {}
|
| 96 |
+
|
| 97 |
+
for id in tqdm(common_id_lists):
|
| 98 |
+
try:
|
| 99 |
+
logger.info(f"Processing {id}")
|
| 100 |
+
ann_msa = load_msa_info(os.path.join(ann_dir, id))
|
| 101 |
+
est_msa = load_msa_info(os.path.join(est_dir, id))
|
| 102 |
+
|
| 103 |
+
if args.prechorus2what == "verse":
|
| 104 |
+
ann_msa = [
|
| 105 |
+
(t, "verse") if l == "pre-chorus" else (t, l) for t, l in ann_msa
|
| 106 |
+
]
|
| 107 |
+
est_msa = [
|
| 108 |
+
(t, "verse") if l == "pre-chorus" else (t, l) for t, l in est_msa
|
| 109 |
+
]
|
| 110 |
+
elif args.prechorus2what == "chorus":
|
| 111 |
+
ann_msa = [
|
| 112 |
+
(t, "chorus") if l == "pre-chorus" else (t, l) for t, l in ann_msa
|
| 113 |
+
]
|
| 114 |
+
est_msa = [
|
| 115 |
+
(t, "chorus") if l == "pre-chorus" else (t, l) for t, l in est_msa
|
| 116 |
+
]
|
| 117 |
+
elif args.prechorus2what is not None:
|
| 118 |
+
raise ValueError(f"Unknown prechorus2what: {args.prechorus2what}")
|
| 119 |
+
if args.armerge_continuous_segments:
|
| 120 |
+
ann_msa = merge_continuous_segments(ann_msa)
|
| 121 |
+
est_msa = merge_continuous_segments(est_msa)
|
| 122 |
+
|
| 123 |
+
ann_inter, ann_labels = to_inters_labels(ann_msa)
|
| 124 |
+
est_inter, est_labels = to_inters_labels(est_msa)
|
| 125 |
+
|
| 126 |
+
result = compute_results(
|
| 127 |
+
ann_inter,
|
| 128 |
+
est_inter,
|
| 129 |
+
ann_labels,
|
| 130 |
+
est_labels,
|
| 131 |
+
bins=11,
|
| 132 |
+
est_file="test.txt",
|
| 133 |
+
weight=0.58,
|
| 134 |
+
)
|
| 135 |
+
acc = cal_acc(ann_msa, est_msa, post_digit=3)
|
| 136 |
+
|
| 137 |
+
ious[id] = cal_iou(ann_msa, est_msa)
|
| 138 |
+
result["HitRate_1P"], result["HitRate_1R"], result["HitRate_1F"] = (
|
| 139 |
+
mir_eval.segment.detection(ann_inter, est_inter, window=1, trim=False)
|
| 140 |
+
)
|
| 141 |
+
result.update({"id": Path(id).stem})
|
| 142 |
+
result.update({"acc": acc})
|
| 143 |
+
for v in ious[id]:
|
| 144 |
+
result.update({f"iou-{v['label']}": v["iou"]})
|
| 145 |
+
del result["track_id"]
|
| 146 |
+
del result["ds_name"]
|
| 147 |
+
|
| 148 |
+
resultes.append(result)
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.error(f"Error processing {id}: {e}")
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
df = pd.DataFrame(resultes)
|
| 154 |
+
df.to_csv(f"{output_dir}/eval_infer.csv", index=False)
|
| 155 |
+
|
| 156 |
+
intsec_dur_total = defaultdict(float)
|
| 157 |
+
uni_dur_total = defaultdict(float)
|
| 158 |
+
|
| 159 |
+
for tid, value in ious.items():
|
| 160 |
+
for item in value:
|
| 161 |
+
label = item["label"]
|
| 162 |
+
intsec_dur_total[label] += item.get("intsec_dur", 0)
|
| 163 |
+
uni_dur_total[label] += item.get("uni_dur", 0)
|
| 164 |
+
|
| 165 |
+
total_intsec = sum(intsec_dur_total.values())
|
| 166 |
+
total_uni = sum(uni_dur_total.values())
|
| 167 |
+
overall_iou = total_intsec / total_uni if total_uni > 0 else 0.0
|
| 168 |
+
|
| 169 |
+
class_ious = {}
|
| 170 |
+
for label in intsec_dur_total:
|
| 171 |
+
intsec = intsec_dur_total[label]
|
| 172 |
+
uni = uni_dur_total[label]
|
| 173 |
+
class_ious[label] = intsec / uni if uni > 0 else 0.0
|
| 174 |
+
|
| 175 |
+
summary = pd.DataFrame(
|
| 176 |
+
[
|
| 177 |
+
{
|
| 178 |
+
"num_samples": len(df),
|
| 179 |
+
"HR.5F": df["HitRate_0.5F"].mean(),
|
| 180 |
+
"HR3F": df["HitRate_3F"].mean(),
|
| 181 |
+
"HR1F": df["HitRate_1F"].mean(),
|
| 182 |
+
"PWF": df["PWF"].mean(),
|
| 183 |
+
"Sf": df["Sf"].mean(),
|
| 184 |
+
"acc": df["acc"].mean(),
|
| 185 |
+
"iou": overall_iou,
|
| 186 |
+
**{f"iou_{k}": v for k, v in class_ious.items()},
|
| 187 |
+
}
|
| 188 |
+
]
|
| 189 |
+
)
|
| 190 |
+
with open(f"{output_dir}/eval_infer_summary.md", "w") as f:
|
| 191 |
+
print(summary.to_markdown(), file=f)
|
| 192 |
+
|
| 193 |
+
summary.to_csv(f"{output_dir}/eval_infer_summary.csv", index=False)
|
| 194 |
+
logger.info(f"Results saved to {output_dir}")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
main()
|
src/SongFormer/infer.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
export CUDA_VISIBLE_DEVICES=
|
| 3 |
+
echo "use gpu ${CUDA_VISIBLE_DEVICES}"
|
| 4 |
+
|
| 5 |
+
export PYTHONPATH=../third_party:$PYTHONPATH
|
| 6 |
+
|
| 7 |
+
export OMP_NUM_THREADS=1
|
| 8 |
+
export MPI_NUM_THREADS=1
|
| 9 |
+
export NCCL_P2P_DISABLE=1
|
| 10 |
+
export NCCL_IB_DISABLE=1
|
| 11 |
+
|
| 12 |
+
python infer/infer.py \
|
| 13 |
+
-i XXX.scp \
|
| 14 |
+
-o XXX_dir \
|
| 15 |
+
--model SongFormer \
|
| 16 |
+
--checkpoint SongFormer.safetensors \
|
| 17 |
+
--config_path SongFormer.yaml \
|
| 18 |
+
-gn 1 \
|
| 19 |
+
-tn 1
|
| 20 |
+
# --debug
|
| 21 |
+
# --no_rule_post_processing
|
src/SongFormer/infer/infer.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import importlib
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import multiprocessing as mp
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from argparse import Namespace
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
# monkey patch to fix issues in msaf
|
| 12 |
+
import scipy
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
scipy.inf = np.inf
|
| 16 |
+
|
| 17 |
+
import librosa
|
| 18 |
+
import torch
|
| 19 |
+
from ema_pytorch import EMA
|
| 20 |
+
from loguru import logger
|
| 21 |
+
from muq import MuQ
|
| 22 |
+
from musicfm.model.musicfm_25hz import MusicFM25Hz
|
| 23 |
+
from omegaconf import OmegaConf
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
mp.set_start_method("spawn", force=True)
|
| 27 |
+
|
| 28 |
+
MUSICFM_HOME_PATH = os.path.join("ckpts", "MusicFM")
|
| 29 |
+
|
| 30 |
+
BEFORE_DOWNSAMPLING_FRAME_RATES = 25
|
| 31 |
+
AFTER_DOWNSAMPLING_FRAME_RATES = 8.333
|
| 32 |
+
|
| 33 |
+
DATASET_LABEL = "SongForm-HX-8Class"
|
| 34 |
+
DATASET_IDS = [5]
|
| 35 |
+
|
| 36 |
+
TIME_DUR = 420
|
| 37 |
+
INPUT_SAMPLING_RATE = 24000
|
| 38 |
+
|
| 39 |
+
from dataset.label2id import DATASET_ID_ALLOWED_LABEL_IDS, DATASET_LABEL_TO_DATASET_ID
|
| 40 |
+
from postprocessing.functional import postprocess_functional_structure
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_processed_ids(output_path):
|
| 44 |
+
"""Get already processed IDs from output directory"""
|
| 45 |
+
ids = os.listdir(output_path)
|
| 46 |
+
ret = []
|
| 47 |
+
for x in ids:
|
| 48 |
+
if x.endswith(".json"):
|
| 49 |
+
ret.append(x.replace(".json", ""))
|
| 50 |
+
return set(ret)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_processing_ids(input_path, processed_ids_set):
|
| 54 |
+
"""Get IDs to be processed from input directory"""
|
| 55 |
+
ret = []
|
| 56 |
+
with open(input_path) as f:
|
| 57 |
+
for line in f:
|
| 58 |
+
if line.strip() and Path(line.strip()).stem not in processed_ids_set:
|
| 59 |
+
ret.append(line.strip())
|
| 60 |
+
return ret
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_checkpoint(checkpoint_path, device=None):
|
| 64 |
+
"""Load checkpoint from path"""
|
| 65 |
+
if device is None:
|
| 66 |
+
device = "cpu"
|
| 67 |
+
|
| 68 |
+
if checkpoint_path.endswith(".pt"):
|
| 69 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 70 |
+
elif checkpoint_path.endswith(".safetensors"):
|
| 71 |
+
from safetensors.torch import load_file
|
| 72 |
+
|
| 73 |
+
checkpoint = {"model_ema": load_file(checkpoint_path, device=device)}
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError("Unsupported checkpoint format. Use .pt or .safetensors")
|
| 76 |
+
return checkpoint
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def rule_post_processing(msa_list):
|
| 80 |
+
if len(msa_list) <= 2:
|
| 81 |
+
return msa_list
|
| 82 |
+
|
| 83 |
+
result = msa_list.copy()
|
| 84 |
+
|
| 85 |
+
while len(result) > 2:
|
| 86 |
+
first_duration = result[1][0] - result[0][0]
|
| 87 |
+
if first_duration < 1.0 and len(result) > 2:
|
| 88 |
+
result[0] = (result[0][0], result[1][1])
|
| 89 |
+
result = [result[0]] + result[2:]
|
| 90 |
+
else:
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
while len(result) > 2:
|
| 94 |
+
last_label_duration = result[-1][0] - result[-2][0]
|
| 95 |
+
if last_label_duration < 1.0:
|
| 96 |
+
result = result[:-2] + [result[-1]]
|
| 97 |
+
else:
|
| 98 |
+
break
|
| 99 |
+
|
| 100 |
+
while len(result) > 2:
|
| 101 |
+
if result[0][1] == result[1][1] and result[1][0] <= 10.0:
|
| 102 |
+
result = [(result[0][0], result[0][1])] + result[2:]
|
| 103 |
+
else:
|
| 104 |
+
break
|
| 105 |
+
|
| 106 |
+
while len(result) > 2:
|
| 107 |
+
last_duration = result[-1][0] - result[-2][0]
|
| 108 |
+
if result[-2][1] == result[-3][1] and last_duration <= 10.0:
|
| 109 |
+
result = result[:-2] + [result[-1]]
|
| 110 |
+
else:
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
return result
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def inference(rank, queue_input: mp.Queue, queue_output: mp.Queue, args):
|
| 117 |
+
"""Run inference on the input audio"""
|
| 118 |
+
device = f"cuda:{rank}"
|
| 119 |
+
|
| 120 |
+
# MuQ model loading (this will automatically fetch the checkpoint from huggingface)
|
| 121 |
+
muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
|
| 122 |
+
muq = muq.to(device).eval()
|
| 123 |
+
|
| 124 |
+
# MusicFM model loading
|
| 125 |
+
musicfm = MusicFM25Hz(
|
| 126 |
+
is_flash=False,
|
| 127 |
+
stat_path=os.path.join(MUSICFM_HOME_PATH, "msd_stats.json"),
|
| 128 |
+
model_path=os.path.join(MUSICFM_HOME_PATH, "pretrained_msd.pt"),
|
| 129 |
+
)
|
| 130 |
+
musicfm = musicfm.to(device)
|
| 131 |
+
musicfm.eval()
|
| 132 |
+
|
| 133 |
+
# Custom model loading based on the config
|
| 134 |
+
module = importlib.import_module("models." + str(args.model))
|
| 135 |
+
Model = getattr(module, "Model")
|
| 136 |
+
hp = OmegaConf.load(os.path.join("configs", args.config_path))
|
| 137 |
+
model = Model(hp)
|
| 138 |
+
|
| 139 |
+
ckpt = load_checkpoint(checkpoint_path=os.path.join("ckpts", args.checkpoint))
|
| 140 |
+
if ckpt.get("model_ema", None) is not None:
|
| 141 |
+
logger.info("Loading EMA model parameters")
|
| 142 |
+
model_ema = EMA(model, include_online_model=False)
|
| 143 |
+
model_ema.load_state_dict(ckpt["model_ema"])
|
| 144 |
+
model.load_state_dict(model_ema.ema_model.state_dict())
|
| 145 |
+
else:
|
| 146 |
+
logger.info("No EMA model parameters found, using original model")
|
| 147 |
+
model.load_state_dict(ckpt["model"])
|
| 148 |
+
|
| 149 |
+
model.to(device)
|
| 150 |
+
model.eval()
|
| 151 |
+
|
| 152 |
+
num_classes = args.num_classes
|
| 153 |
+
dataset_id2label_mask = {}
|
| 154 |
+
|
| 155 |
+
for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
|
| 156 |
+
dataset_id2label_mask[key] = np.ones(args.num_classes, dtype=bool)
|
| 157 |
+
dataset_id2label_mask[key][allowed_ids] = False
|
| 158 |
+
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
while True:
|
| 161 |
+
item = queue_input.get()
|
| 162 |
+
if not item:
|
| 163 |
+
queue_output.put(None)
|
| 164 |
+
break
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
# Loading the audio file
|
| 168 |
+
wav, sr = librosa.load(item, sr=INPUT_SAMPLING_RATE)
|
| 169 |
+
audio = torch.tensor(wav).to(device)
|
| 170 |
+
|
| 171 |
+
win_size = args.win_size
|
| 172 |
+
hop_size = args.hop_size
|
| 173 |
+
total_len = (
|
| 174 |
+
(audio.shape[0] // INPUT_SAMPLING_RATE) // TIME_DUR
|
| 175 |
+
) * TIME_DUR + TIME_DUR
|
| 176 |
+
total_frames = math.ceil(total_len * AFTER_DOWNSAMPLING_FRAME_RATES)
|
| 177 |
+
|
| 178 |
+
logits = {
|
| 179 |
+
"function_logits": np.zeros([total_frames, num_classes]),
|
| 180 |
+
"boundary_logits": np.zeros([total_frames]),
|
| 181 |
+
}
|
| 182 |
+
logits_num = {
|
| 183 |
+
"function_logits": np.zeros([total_frames, num_classes]),
|
| 184 |
+
"boundary_logits": np.zeros([total_frames]),
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
lens = 0
|
| 188 |
+
i = 0
|
| 189 |
+
while True:
|
| 190 |
+
start_idx = i * INPUT_SAMPLING_RATE
|
| 191 |
+
end_idx = min((i + win_size) * INPUT_SAMPLING_RATE, audio.shape[-1])
|
| 192 |
+
if start_idx >= audio.shape[-1]:
|
| 193 |
+
break
|
| 194 |
+
if end_idx - start_idx <= 1024:
|
| 195 |
+
continue
|
| 196 |
+
audio_seg = audio[start_idx:end_idx]
|
| 197 |
+
|
| 198 |
+
# MuQ embedding
|
| 199 |
+
muq_output = muq(audio_seg.unsqueeze(0), output_hidden_states=True)
|
| 200 |
+
muq_embd_420s = muq_output["hidden_states"][10]
|
| 201 |
+
del muq_output
|
| 202 |
+
torch.cuda.empty_cache()
|
| 203 |
+
|
| 204 |
+
# MusicFM embedding
|
| 205 |
+
_, musicfm_hidden_states = musicfm.get_predictions(
|
| 206 |
+
audio_seg.unsqueeze(0)
|
| 207 |
+
)
|
| 208 |
+
musicfm_embd_420s = musicfm_hidden_states[10]
|
| 209 |
+
del musicfm_hidden_states
|
| 210 |
+
torch.cuda.empty_cache()
|
| 211 |
+
|
| 212 |
+
wraped_muq_embd_30s = []
|
| 213 |
+
wraped_musicfm_embd_30s = []
|
| 214 |
+
|
| 215 |
+
for idx_30s in range(i, i + hop_size, 30):
|
| 216 |
+
start_idx_30s = idx_30s * INPUT_SAMPLING_RATE
|
| 217 |
+
end_idx_30s = min(
|
| 218 |
+
(idx_30s + 30) * INPUT_SAMPLING_RATE,
|
| 219 |
+
audio.shape[-1],
|
| 220 |
+
(i + hop_size) * INPUT_SAMPLING_RATE,
|
| 221 |
+
)
|
| 222 |
+
if start_idx_30s >= audio.shape[-1]:
|
| 223 |
+
break
|
| 224 |
+
if end_idx_30s - start_idx_30s <= 1024:
|
| 225 |
+
continue
|
| 226 |
+
wraped_muq_embd_30s.append(
|
| 227 |
+
muq(
|
| 228 |
+
audio[start_idx_30s:end_idx_30s].unsqueeze(0),
|
| 229 |
+
output_hidden_states=True,
|
| 230 |
+
)["hidden_states"][10]
|
| 231 |
+
)
|
| 232 |
+
torch.cuda.empty_cache()
|
| 233 |
+
wraped_musicfm_embd_30s.append(
|
| 234 |
+
musicfm.get_predictions(
|
| 235 |
+
audio[start_idx_30s:end_idx_30s].unsqueeze(0)
|
| 236 |
+
)[1][10]
|
| 237 |
+
)
|
| 238 |
+
torch.cuda.empty_cache()
|
| 239 |
+
|
| 240 |
+
wraped_muq_embd_30s = torch.concatenate(wraped_muq_embd_30s, dim=1)
|
| 241 |
+
wraped_musicfm_embd_30s = torch.concatenate(
|
| 242 |
+
wraped_musicfm_embd_30s, dim=1
|
| 243 |
+
)
|
| 244 |
+
all_embds = [
|
| 245 |
+
wraped_musicfm_embd_30s,
|
| 246 |
+
wraped_muq_embd_30s,
|
| 247 |
+
musicfm_embd_420s,
|
| 248 |
+
muq_embd_420s,
|
| 249 |
+
]
|
| 250 |
+
|
| 251 |
+
if len(all_embds) > 1:
|
| 252 |
+
embd_lens = [x.shape[1] for x in all_embds]
|
| 253 |
+
max_embd_len = max(embd_lens)
|
| 254 |
+
min_embd_len = min(embd_lens)
|
| 255 |
+
if abs(max_embd_len - min_embd_len) > 4:
|
| 256 |
+
raise ValueError(
|
| 257 |
+
f"Embedding shapes differ too much: {max_embd_len} vs {min_embd_len}"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
for idx in range(len(all_embds)):
|
| 261 |
+
all_embds[idx] = all_embds[idx][:, :min_embd_len, :]
|
| 262 |
+
|
| 263 |
+
embd = torch.concatenate(all_embds, axis=-1)
|
| 264 |
+
|
| 265 |
+
dataset_label = DATASET_LABEL
|
| 266 |
+
dataset_ids = torch.Tensor(DATASET_IDS).to(device, dtype=torch.long)
|
| 267 |
+
msa_info, chunk_logits = model.infer(
|
| 268 |
+
input_embeddings=embd,
|
| 269 |
+
dataset_ids=dataset_ids,
|
| 270 |
+
label_id_masks=torch.Tensor(
|
| 271 |
+
dataset_id2label_mask[
|
| 272 |
+
DATASET_LABEL_TO_DATASET_ID[dataset_label]
|
| 273 |
+
]
|
| 274 |
+
)
|
| 275 |
+
.to(device, dtype=bool)
|
| 276 |
+
.unsqueeze(0)
|
| 277 |
+
.unsqueeze(0),
|
| 278 |
+
with_logits=True,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
start_frame = int(i * AFTER_DOWNSAMPLING_FRAME_RATES)
|
| 282 |
+
end_frame = start_frame + min(
|
| 283 |
+
math.ceil(hop_size * AFTER_DOWNSAMPLING_FRAME_RATES),
|
| 284 |
+
chunk_logits["boundary_logits"][0].shape[0],
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
logits["function_logits"][start_frame:end_frame, :] += (
|
| 288 |
+
chunk_logits["function_logits"][0].detach().cpu().numpy()
|
| 289 |
+
)
|
| 290 |
+
logits["boundary_logits"][start_frame:end_frame] = (
|
| 291 |
+
chunk_logits["boundary_logits"][0].detach().cpu().numpy()
|
| 292 |
+
)
|
| 293 |
+
logits_num["function_logits"][start_frame:end_frame, :] += 1
|
| 294 |
+
logits_num["boundary_logits"][start_frame:end_frame] += 1
|
| 295 |
+
lens += end_frame - start_frame
|
| 296 |
+
|
| 297 |
+
i += hop_size
|
| 298 |
+
logits["function_logits"] /= logits_num["function_logits"]
|
| 299 |
+
logits["boundary_logits"] /= logits_num["boundary_logits"]
|
| 300 |
+
|
| 301 |
+
logits["function_logits"] = torch.from_numpy(
|
| 302 |
+
logits["function_logits"][:lens]
|
| 303 |
+
).unsqueeze(0)
|
| 304 |
+
logits["boundary_logits"] = torch.from_numpy(
|
| 305 |
+
logits["boundary_logits"][:lens]
|
| 306 |
+
).unsqueeze(0)
|
| 307 |
+
|
| 308 |
+
msa_infer_output = postprocess_functional_structure(logits, hp)
|
| 309 |
+
|
| 310 |
+
assert msa_infer_output[-1][-1] == "end"
|
| 311 |
+
if not args.no_rule_post_processing:
|
| 312 |
+
msa_infer_output = rule_post_processing(msa_infer_output)
|
| 313 |
+
msa_json = []
|
| 314 |
+
for idx in range(len(msa_infer_output) - 1):
|
| 315 |
+
msa_json.append(
|
| 316 |
+
{
|
| 317 |
+
"label": msa_infer_output[idx][1],
|
| 318 |
+
"start": msa_infer_output[idx][0],
|
| 319 |
+
"end": msa_infer_output[idx + 1][0],
|
| 320 |
+
}
|
| 321 |
+
)
|
| 322 |
+
json.dump(
|
| 323 |
+
msa_json,
|
| 324 |
+
open(os.path.join(args.output_dir, f"{Path(item).stem}.json"), "w"),
|
| 325 |
+
indent=4,
|
| 326 |
+
ensure_ascii=False,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
queue_output.put(None)
|
| 330 |
+
|
| 331 |
+
except Exception as e:
|
| 332 |
+
queue_output.put(None)
|
| 333 |
+
logger.error(f"process {rank} error\n{item}\n{e}")
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def deal_with_output(output_path, queue_output, length):
|
| 337 |
+
"""Handle output data from the queue"""
|
| 338 |
+
pbar = tqdm(range(length), desc="getting inference output")
|
| 339 |
+
for _ in pbar:
|
| 340 |
+
data = queue_output.get()
|
| 341 |
+
if not data:
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def main(args):
|
| 346 |
+
input_path = args.input_path
|
| 347 |
+
output_path = args.output_path
|
| 348 |
+
gpu_num = args.gpu_num
|
| 349 |
+
num_thread_per_gpu = args.num_thread_per_gpu
|
| 350 |
+
debug = args.debug
|
| 351 |
+
|
| 352 |
+
os.makedirs(output_path, exist_ok=True)
|
| 353 |
+
|
| 354 |
+
processed_ids = get_processed_ids(output_path=output_path)
|
| 355 |
+
processing_ids = get_processing_ids(input_path, processed_ids)
|
| 356 |
+
|
| 357 |
+
num_threads = num_thread_per_gpu * gpu_num
|
| 358 |
+
|
| 359 |
+
queue_input: mp.Queue = mp.Queue()
|
| 360 |
+
queue_output: mp.Queue = mp.Queue()
|
| 361 |
+
|
| 362 |
+
init_args = Namespace(
|
| 363 |
+
output_dir=output_path,
|
| 364 |
+
win_size=420,
|
| 365 |
+
hop_size=420,
|
| 366 |
+
num_classes=128,
|
| 367 |
+
model=args.model,
|
| 368 |
+
checkpoint=args.checkpoint,
|
| 369 |
+
config_path=args.config_path,
|
| 370 |
+
no_rule_post_processing=args.no_rule_post_processing,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
processes = []
|
| 374 |
+
|
| 375 |
+
if debug:
|
| 376 |
+
queue_input.put(processing_ids[0])
|
| 377 |
+
queue_input.put(None)
|
| 378 |
+
|
| 379 |
+
inference(0, queue_input, queue_output, init_args)
|
| 380 |
+
|
| 381 |
+
print("debug exit")
|
| 382 |
+
exit(0)
|
| 383 |
+
|
| 384 |
+
for thread_num in range(num_threads):
|
| 385 |
+
rank = thread_num % gpu_num
|
| 386 |
+
print(f"num_threads: {thread_num} on GPU {rank}")
|
| 387 |
+
time.sleep(0.2)
|
| 388 |
+
p = mp.Process(
|
| 389 |
+
target=inference,
|
| 390 |
+
args=(rank, queue_input, queue_output, init_args),
|
| 391 |
+
daemon=True,
|
| 392 |
+
)
|
| 393 |
+
p.start()
|
| 394 |
+
processes.append(p)
|
| 395 |
+
|
| 396 |
+
for wav_id in tqdm(processing_ids, desc="add data to queue"):
|
| 397 |
+
queue_input.put(wav_id)
|
| 398 |
+
|
| 399 |
+
for _ in range(num_threads):
|
| 400 |
+
queue_input.put(None)
|
| 401 |
+
|
| 402 |
+
deal_with_output(output_path, queue_output, len(processing_ids))
|
| 403 |
+
|
| 404 |
+
for p in processes:
|
| 405 |
+
p.join()
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
if __name__ == "__main__":
|
| 409 |
+
parser = argparse.ArgumentParser()
|
| 410 |
+
|
| 411 |
+
parser.add_argument(
|
| 412 |
+
"--input_path", "-i", type=str, required=True, help="Input file path"
|
| 413 |
+
)
|
| 414 |
+
parser.add_argument(
|
| 415 |
+
"--output_path", "-o", type=str, required=True, help="Output file path"
|
| 416 |
+
)
|
| 417 |
+
parser.add_argument(
|
| 418 |
+
"--gpu_num", "-gn", type=int, default=1, help="Number of GPUs, default is 1"
|
| 419 |
+
)
|
| 420 |
+
parser.add_argument(
|
| 421 |
+
"--num_thread_per_gpu",
|
| 422 |
+
"-tn",
|
| 423 |
+
type=int,
|
| 424 |
+
default=1,
|
| 425 |
+
help="Number of threads per GPU, default is 1",
|
| 426 |
+
)
|
| 427 |
+
parser.add_argument("--model", type=str, help="Model to use")
|
| 428 |
+
parser.add_argument("--checkpoint", type=str, help="Checkpoint path")
|
| 429 |
+
parser.add_argument("--config_path", type=str, help="Configuration file path")
|
| 430 |
+
parser.add_argument(
|
| 431 |
+
"--no_rule_post_processing",
|
| 432 |
+
action="store_true",
|
| 433 |
+
help="Disable rule-based post-processing",
|
| 434 |
+
)
|
| 435 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 436 |
+
|
| 437 |
+
args = parser.parse_args()
|
| 438 |
+
|
| 439 |
+
main(args=args)
|
src/SongFormer/models/SongFormer.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from dataset.custom_types import MsaInfo
|
| 6 |
+
from msaf.eval import compute_results
|
| 7 |
+
from postprocessing.functional import postprocess_functional_structure
|
| 8 |
+
from x_transformers import Encoder
|
| 9 |
+
import bisect
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Head(nn.Module):
|
| 13 |
+
def __init__(self, input_dim, output_dim, hidden_dims=None, activation="silu"):
|
| 14 |
+
super().__init__()
|
| 15 |
+
hidden_dims = hidden_dims or []
|
| 16 |
+
act_layers = {"relu": nn.ReLU, "silu": nn.SiLU, "gelu": nn.GELU}
|
| 17 |
+
act_layer = act_layers.get(activation.lower())
|
| 18 |
+
if not act_layer:
|
| 19 |
+
raise ValueError(f"Unsupported activation: {activation}")
|
| 20 |
+
|
| 21 |
+
dims = [input_dim] + hidden_dims + [output_dim]
|
| 22 |
+
layers = []
|
| 23 |
+
for i in range(len(dims) - 1):
|
| 24 |
+
layers.append(nn.Linear(dims[i], dims[i + 1]))
|
| 25 |
+
if i < len(dims) - 2:
|
| 26 |
+
layers.append(act_layer())
|
| 27 |
+
self.net = nn.Sequential(*layers)
|
| 28 |
+
|
| 29 |
+
def reset_parameters(self, confidence):
|
| 30 |
+
bias_value = -torch.log(torch.tensor((1 - confidence) / confidence))
|
| 31 |
+
self.net[-1].bias.data.fill_(bias_value.item())
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
batch, T, C = x.shape
|
| 35 |
+
x = x.reshape(-1, C)
|
| 36 |
+
x = self.net(x)
|
| 37 |
+
return x.reshape(batch, T, -1)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class WrapedTransformerEncoder(nn.Module):
|
| 41 |
+
def __init__(
|
| 42 |
+
self, input_dim, transformer_input_dim, num_layers=1, nhead=8, dropout=0.1
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.input_dim = input_dim
|
| 46 |
+
self.transformer_input_dim = transformer_input_dim
|
| 47 |
+
|
| 48 |
+
if input_dim != transformer_input_dim:
|
| 49 |
+
self.input_proj = nn.Sequential(
|
| 50 |
+
nn.Linear(input_dim, transformer_input_dim),
|
| 51 |
+
nn.LayerNorm(transformer_input_dim),
|
| 52 |
+
nn.GELU(),
|
| 53 |
+
nn.Dropout(dropout * 0.5),
|
| 54 |
+
nn.Linear(transformer_input_dim, transformer_input_dim),
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
self.input_proj = nn.Identity()
|
| 58 |
+
|
| 59 |
+
self.transformer = Encoder(
|
| 60 |
+
dim=transformer_input_dim,
|
| 61 |
+
depth=num_layers,
|
| 62 |
+
heads=nhead,
|
| 63 |
+
layer_dropout=dropout,
|
| 64 |
+
attn_dropout=dropout,
|
| 65 |
+
ff_dropout=dropout,
|
| 66 |
+
attn_flash=True,
|
| 67 |
+
rotary_pos_emb=True,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def forward(self, x, src_key_padding_mask=None):
|
| 71 |
+
"""
|
| 72 |
+
The input src_key_padding_mask is a B x T boolean mask, where True indicates masked positions.
|
| 73 |
+
However, in x-transformers, False indicates masked positions.
|
| 74 |
+
Therefore, it needs to be converted so that False represents masked positions.
|
| 75 |
+
"""
|
| 76 |
+
x = self.input_proj(x)
|
| 77 |
+
mask = (
|
| 78 |
+
~torch.tensor(src_key_padding_mask, dtype=torch.bool, device=x.device)
|
| 79 |
+
if src_key_padding_mask is not None
|
| 80 |
+
else None
|
| 81 |
+
)
|
| 82 |
+
return self.transformer(x, mask=mask)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def prefix_dict(d, prefix: str):
|
| 86 |
+
if prefix:
|
| 87 |
+
return d
|
| 88 |
+
return {prefix + key: value for key, value in d.items()}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class TimeDownsample(nn.Module):
|
| 92 |
+
def __init__(
|
| 93 |
+
self, dim_in, dim_out=None, kernel_size=5, stride=5, padding=0, dropout=0.1
|
| 94 |
+
):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.dim_out = dim_out or dim_in
|
| 97 |
+
assert self.dim_out % 2 == 0
|
| 98 |
+
|
| 99 |
+
self.depthwise_conv = nn.Conv1d(
|
| 100 |
+
in_channels=dim_in,
|
| 101 |
+
out_channels=dim_in,
|
| 102 |
+
kernel_size=kernel_size,
|
| 103 |
+
stride=stride,
|
| 104 |
+
padding=padding,
|
| 105 |
+
groups=dim_in,
|
| 106 |
+
bias=False,
|
| 107 |
+
)
|
| 108 |
+
self.pointwise_conv = nn.Conv1d(
|
| 109 |
+
in_channels=dim_in,
|
| 110 |
+
out_channels=self.dim_out,
|
| 111 |
+
kernel_size=1,
|
| 112 |
+
bias=False,
|
| 113 |
+
)
|
| 114 |
+
self.pool = nn.AvgPool1d(kernel_size, stride, padding=padding)
|
| 115 |
+
self.norm1 = nn.LayerNorm(self.dim_out)
|
| 116 |
+
self.act1 = nn.GELU()
|
| 117 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 118 |
+
|
| 119 |
+
if dim_in != self.dim_out:
|
| 120 |
+
self.residual_conv = nn.Conv1d(
|
| 121 |
+
dim_in, self.dim_out, kernel_size=1, bias=False
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
self.residual_conv = None
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
residual = x # [B, T, D_in]
|
| 128 |
+
# Convolutional module
|
| 129 |
+
x_c = x.transpose(1, 2) # [B, D_in, T]
|
| 130 |
+
x_c = self.depthwise_conv(x_c) # [B, D_in, T_down]
|
| 131 |
+
x_c = self.pointwise_conv(x_c) # [B, D_out, T_down]
|
| 132 |
+
|
| 133 |
+
# Residual module
|
| 134 |
+
res = self.pool(residual.transpose(1, 2)) # [B, D_in, T]
|
| 135 |
+
if self.residual_conv:
|
| 136 |
+
res = self.residual_conv(res) # [B, D_out, T_down]
|
| 137 |
+
x_c = x_c + res # [B, D_out, T_down]
|
| 138 |
+
x_c = x_c.transpose(1, 2) # [B, T_down, D_out]
|
| 139 |
+
x_c = self.norm1(x_c)
|
| 140 |
+
x_c = self.act1(x_c)
|
| 141 |
+
x_c = self.dropout1(x_c)
|
| 142 |
+
return x_c
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class AddFuse(nn.Module):
|
| 146 |
+
def __init__(self):
|
| 147 |
+
super(AddFuse, self).__init__()
|
| 148 |
+
|
| 149 |
+
def forward(self, x, cond):
|
| 150 |
+
return x + cond
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class TVLoss1D(nn.Module):
|
| 154 |
+
def __init__(
|
| 155 |
+
self, beta=1.0, lambda_tv=0.4, boundary_threshold=0.01, reduction_weight=0.1
|
| 156 |
+
):
|
| 157 |
+
"""
|
| 158 |
+
Args:
|
| 159 |
+
beta: Exponential parameter for TV loss (recommended 0.5~1.0)
|
| 160 |
+
lambda_tv: Overall weight for TV loss
|
| 161 |
+
boundary_threshold: Label threshold to determine if a region is a "boundary area" (e.g., 0.01)
|
| 162 |
+
reduction_weight: Scaling factor for TV penalty within boundary regions (e.g., 0.1, meaning only 10% penalty)
|
| 163 |
+
"""
|
| 164 |
+
super().__init__()
|
| 165 |
+
self.beta = beta
|
| 166 |
+
self.lambda_tv = lambda_tv
|
| 167 |
+
self.boundary_threshold = boundary_threshold
|
| 168 |
+
self.reduction_weight = reduction_weight
|
| 169 |
+
|
| 170 |
+
def forward(self, pred, target=None):
|
| 171 |
+
"""
|
| 172 |
+
Args:
|
| 173 |
+
pred: (B, T) or (B, T, 1), float boundary scores output by the model
|
| 174 |
+
target: (B, T) or (B, T, 1), ground truth labels (optional, used for spatial weighting if provided)
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
scalar: weighted TV loss
|
| 178 |
+
"""
|
| 179 |
+
if pred.dim() == 3:
|
| 180 |
+
pred = pred.squeeze(-1)
|
| 181 |
+
if target is not None and target.dim() == 3:
|
| 182 |
+
target = target.squeeze(-1)
|
| 183 |
+
|
| 184 |
+
diff = pred[:, 1:] - pred[:, :-1]
|
| 185 |
+
tv_base = torch.pow(torch.abs(diff) + 1e-8, self.beta)
|
| 186 |
+
|
| 187 |
+
if target is None:
|
| 188 |
+
return self.lambda_tv * tv_base.mean()
|
| 189 |
+
|
| 190 |
+
left_in_boundary = target[:, :-1] > self.boundary_threshold
|
| 191 |
+
right_in_boundary = target[:, 1:] > self.boundary_threshold
|
| 192 |
+
near_boundary = left_in_boundary | right_in_boundary
|
| 193 |
+
weight_mask = torch.where(
|
| 194 |
+
near_boundary,
|
| 195 |
+
self.reduction_weight * torch.ones_like(tv_base),
|
| 196 |
+
torch.ones_like(tv_base),
|
| 197 |
+
)
|
| 198 |
+
tv_weighted = (tv_base * weight_mask).mean()
|
| 199 |
+
return self.lambda_tv * tv_weighted
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class SoftmaxFocalLoss(nn.Module):
|
| 203 |
+
"""
|
| 204 |
+
Softmax Focal Loss for single-label multi-class classification.
|
| 205 |
+
Suitable for mutually exclusive classes.
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
def __init__(self, alpha: float = 0.25, gamma: float = 2.0):
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.alpha = alpha
|
| 211 |
+
self.gamma = gamma
|
| 212 |
+
|
| 213 |
+
def forward(self, pred: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 214 |
+
"""
|
| 215 |
+
Args:
|
| 216 |
+
pred: [B, T, C], raw logits
|
| 217 |
+
targets: [B, T, C] (soft) or [B, T] (hard, dtype=long)
|
| 218 |
+
Returns:
|
| 219 |
+
loss: scalar or [B, T] depending on reduction
|
| 220 |
+
"""
|
| 221 |
+
log_probs = F.log_softmax(pred, dim=-1)
|
| 222 |
+
probs = torch.exp(log_probs)
|
| 223 |
+
|
| 224 |
+
if targets.dtype == torch.long:
|
| 225 |
+
targets_onehot = F.one_hot(targets, num_classes=pred.size(-1)).float()
|
| 226 |
+
else:
|
| 227 |
+
targets_onehot = targets
|
| 228 |
+
|
| 229 |
+
p_t = (probs * targets_onehot).sum(dim=-1)
|
| 230 |
+
p_t = p_t.clamp(min=1e-8, max=1.0 - 1e-8)
|
| 231 |
+
|
| 232 |
+
if self.alpha > 0:
|
| 233 |
+
alpha_t = self.alpha * targets_onehot + (1 - self.alpha) * (
|
| 234 |
+
1 - targets_onehot
|
| 235 |
+
)
|
| 236 |
+
alpha_weight = (alpha_t * targets_onehot).sum(dim=-1)
|
| 237 |
+
else:
|
| 238 |
+
alpha_weight = 1.0
|
| 239 |
+
|
| 240 |
+
focal_weight = (1 - p_t) ** self.gamma
|
| 241 |
+
ce_loss = -log_probs * targets_onehot
|
| 242 |
+
ce_loss = ce_loss.sum(dim=-1)
|
| 243 |
+
|
| 244 |
+
loss = alpha_weight * focal_weight * ce_loss
|
| 245 |
+
return loss
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class Model(nn.Module):
|
| 249 |
+
def __init__(self, config):
|
| 250 |
+
super().__init__()
|
| 251 |
+
self.config = config
|
| 252 |
+
|
| 253 |
+
self.input_norm = nn.LayerNorm(config.input_dim)
|
| 254 |
+
self.mixed_win_downsample = nn.Linear(config.input_dim_raw, config.input_dim)
|
| 255 |
+
self.dataset_class_prefix = nn.Embedding(
|
| 256 |
+
num_embeddings=config.num_dataset_classes,
|
| 257 |
+
embedding_dim=config.transformer_encoder_input_dim,
|
| 258 |
+
)
|
| 259 |
+
self.down_sample_conv = TimeDownsample(
|
| 260 |
+
dim_in=config.input_dim,
|
| 261 |
+
dim_out=config.transformer_encoder_input_dim,
|
| 262 |
+
kernel_size=config.down_sample_conv_kernel_size,
|
| 263 |
+
stride=config.down_sample_conv_stride,
|
| 264 |
+
dropout=config.down_sample_conv_dropout,
|
| 265 |
+
padding=config.down_sample_conv_padding,
|
| 266 |
+
)
|
| 267 |
+
self.AddFuse = AddFuse()
|
| 268 |
+
self.transformer = WrapedTransformerEncoder(
|
| 269 |
+
input_dim=config.transformer_encoder_input_dim,
|
| 270 |
+
transformer_input_dim=config.transformer_input_dim,
|
| 271 |
+
num_layers=config.num_transformer_layers,
|
| 272 |
+
nhead=config.transformer_nhead,
|
| 273 |
+
dropout=config.transformer_dropout,
|
| 274 |
+
)
|
| 275 |
+
self.boundary_TVLoss1D = TVLoss1D(
|
| 276 |
+
beta=config.boundary_tv_loss_beta,
|
| 277 |
+
lambda_tv=config.boundary_tv_loss_lambda,
|
| 278 |
+
boundary_threshold=config.boundary_tv_loss_boundary_threshold,
|
| 279 |
+
reduction_weight=config.boundary_tv_loss_reduction_weight,
|
| 280 |
+
)
|
| 281 |
+
self.label_focal_loss = SoftmaxFocalLoss(
|
| 282 |
+
alpha=config.label_focal_loss_alpha, gamma=config.label_focal_loss_gamma
|
| 283 |
+
)
|
| 284 |
+
self.boundary_head = Head(config.transformer_input_dim, 1)
|
| 285 |
+
self.function_head = Head(config.transformer_input_dim, config.num_classes)
|
| 286 |
+
|
| 287 |
+
def cal_metrics(self, gt_info: MsaInfo, msa_info: MsaInfo):
|
| 288 |
+
assert gt_info[-1][1] == "end" and msa_info[-1][1] == "end", (
|
| 289 |
+
"gt_info and msa_info should end with 'end'"
|
| 290 |
+
)
|
| 291 |
+
gt_info_labels = [label for time_, label in gt_info][:-1]
|
| 292 |
+
gt_info_inters = [time_ for time_, label in gt_info]
|
| 293 |
+
gt_info_inters = np.column_stack(
|
| 294 |
+
[np.array(gt_info_inters[:-1]), np.array(gt_info_inters[1:])]
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
msa_info_labels = [label for time_, label in msa_info][:-1]
|
| 298 |
+
msa_info_inters = [time_ for time_, label in msa_info]
|
| 299 |
+
msa_info_inters = np.column_stack(
|
| 300 |
+
[np.array(msa_info_inters[:-1]), np.array(msa_info_inters[1:])]
|
| 301 |
+
)
|
| 302 |
+
result = compute_results(
|
| 303 |
+
ann_inter=gt_info_inters,
|
| 304 |
+
est_inter=msa_info_inters,
|
| 305 |
+
ann_labels=gt_info_labels,
|
| 306 |
+
est_labels=msa_info_labels,
|
| 307 |
+
bins=11,
|
| 308 |
+
est_file="test.txt",
|
| 309 |
+
weight=0.58,
|
| 310 |
+
)
|
| 311 |
+
return result
|
| 312 |
+
|
| 313 |
+
def cal_acc(
|
| 314 |
+
self, ann_info: MsaInfo | str, est_info: MsaInfo | str, post_digit: int = 3
|
| 315 |
+
):
|
| 316 |
+
ann_info_time = [
|
| 317 |
+
int(round(time_, post_digit) * (10**post_digit))
|
| 318 |
+
for time_, label in ann_info
|
| 319 |
+
]
|
| 320 |
+
est_info_time = [
|
| 321 |
+
int(round(time_, post_digit) * (10**post_digit))
|
| 322 |
+
for time_, label in est_info
|
| 323 |
+
]
|
| 324 |
+
|
| 325 |
+
common_start_time = max(ann_info_time[0], est_info_time[0])
|
| 326 |
+
common_end_time = min(ann_info_time[-1], est_info_time[-1])
|
| 327 |
+
|
| 328 |
+
time_points = {common_start_time, common_end_time}
|
| 329 |
+
time_points.update(
|
| 330 |
+
{
|
| 331 |
+
time_
|
| 332 |
+
for time_ in ann_info_time
|
| 333 |
+
if common_start_time <= time_ <= common_end_time
|
| 334 |
+
}
|
| 335 |
+
)
|
| 336 |
+
time_points.update(
|
| 337 |
+
{
|
| 338 |
+
time_
|
| 339 |
+
for time_ in est_info_time
|
| 340 |
+
if common_start_time <= time_ <= common_end_time
|
| 341 |
+
}
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
time_points = sorted(time_points)
|
| 345 |
+
total_duration, total_score = 0, 0
|
| 346 |
+
|
| 347 |
+
for idx in range(len(time_points) - 1):
|
| 348 |
+
duration = time_points[idx + 1] - time_points[idx]
|
| 349 |
+
ann_label = ann_info[
|
| 350 |
+
bisect.bisect_right(ann_info_time, time_points[idx]) - 1
|
| 351 |
+
][1]
|
| 352 |
+
est_label = est_info[
|
| 353 |
+
bisect.bisect_right(est_info_time, time_points[idx]) - 1
|
| 354 |
+
][1]
|
| 355 |
+
total_duration += duration
|
| 356 |
+
if ann_label == est_label:
|
| 357 |
+
total_score += duration
|
| 358 |
+
return total_score / total_duration
|
| 359 |
+
|
| 360 |
+
def infer_with_metrics(self, batch, prefix: str = None):
|
| 361 |
+
with torch.no_grad():
|
| 362 |
+
logits = self.forward_func(batch)
|
| 363 |
+
|
| 364 |
+
losses = self.compute_losses(logits, batch, prefix=None)
|
| 365 |
+
|
| 366 |
+
expanded_mask = batch["label_id_masks"].expand(
|
| 367 |
+
-1, logits["function_logits"].size(1), -1
|
| 368 |
+
)
|
| 369 |
+
logits["function_logits"] = logits["function_logits"].masked_fill(
|
| 370 |
+
expanded_mask, -float("inf")
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
msa_info = postprocess_functional_structure(
|
| 374 |
+
logits=logits, config=self.config
|
| 375 |
+
)
|
| 376 |
+
gt_info = batch["msa_infos"][0]
|
| 377 |
+
results = self.cal_metrics(gt_info=gt_info, msa_info=msa_info)
|
| 378 |
+
|
| 379 |
+
ret_results = {
|
| 380 |
+
"loss": losses["loss"].item(),
|
| 381 |
+
"HitRate_3P": results["HitRate_3P"],
|
| 382 |
+
"HitRate_3R": results["HitRate_3R"],
|
| 383 |
+
"HitRate_3F": results["HitRate_3F"],
|
| 384 |
+
"HitRate_0.5P": results["HitRate_0.5P"],
|
| 385 |
+
"HitRate_0.5R": results["HitRate_0.5R"],
|
| 386 |
+
"HitRate_0.5F": results["HitRate_0.5F"],
|
| 387 |
+
"PWF": results["PWF"],
|
| 388 |
+
"PWP": results["PWP"],
|
| 389 |
+
"PWR": results["PWR"],
|
| 390 |
+
"Sf": results["Sf"],
|
| 391 |
+
"So": results["So"],
|
| 392 |
+
"Su": results["Su"],
|
| 393 |
+
"acc": self.cal_acc(ann_info=gt_info, est_info=msa_info),
|
| 394 |
+
}
|
| 395 |
+
if prefix:
|
| 396 |
+
ret_results = prefix_dict(ret_results, prefix)
|
| 397 |
+
|
| 398 |
+
return ret_results
|
| 399 |
+
|
| 400 |
+
def infer(
|
| 401 |
+
self,
|
| 402 |
+
input_embeddings,
|
| 403 |
+
dataset_ids,
|
| 404 |
+
label_id_masks,
|
| 405 |
+
prefix: str = None,
|
| 406 |
+
with_logits=False,
|
| 407 |
+
):
|
| 408 |
+
with torch.no_grad():
|
| 409 |
+
input_embeddings = self.mixed_win_downsample(input_embeddings)
|
| 410 |
+
input_embeddings = self.input_norm(input_embeddings)
|
| 411 |
+
logits = self.down_sample_conv(input_embeddings)
|
| 412 |
+
|
| 413 |
+
dataset_prefix = self.dataset_class_prefix(dataset_ids)
|
| 414 |
+
dataset_prefix_expand = dataset_prefix.unsqueeze(1).expand(
|
| 415 |
+
logits.size(0), 1, -1
|
| 416 |
+
)
|
| 417 |
+
logits = self.AddFuse(x=logits, cond=dataset_prefix_expand)
|
| 418 |
+
logits = self.transformer(x=logits, src_key_padding_mask=None)
|
| 419 |
+
|
| 420 |
+
function_logits = self.function_head(logits)
|
| 421 |
+
boundary_logits = self.boundary_head(logits).squeeze(-1)
|
| 422 |
+
|
| 423 |
+
logits = {
|
| 424 |
+
"function_logits": function_logits,
|
| 425 |
+
"boundary_logits": boundary_logits,
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
expanded_mask = label_id_masks.expand(
|
| 429 |
+
-1, logits["function_logits"].size(1), -1
|
| 430 |
+
)
|
| 431 |
+
logits["function_logits"] = logits["function_logits"].masked_fill(
|
| 432 |
+
expanded_mask, -float("inf")
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
msa_info = postprocess_functional_structure(
|
| 436 |
+
logits=logits, config=self.config
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
return (msa_info, logits) if with_logits else msa_info
|
| 440 |
+
|
| 441 |
+
def compute_losses(self, outputs, batch, prefix: str = None):
|
| 442 |
+
loss = 0.0
|
| 443 |
+
losses = {}
|
| 444 |
+
|
| 445 |
+
loss_section = F.binary_cross_entropy_with_logits(
|
| 446 |
+
outputs["boundary_logits"],
|
| 447 |
+
batch["widen_true_boundaries"],
|
| 448 |
+
reduction="none",
|
| 449 |
+
)
|
| 450 |
+
loss_section += self.config.boundary_tvloss_weight * self.boundary_TVLoss1D(
|
| 451 |
+
pred=outputs["boundary_logits"],
|
| 452 |
+
target=batch["widen_true_boundaries"],
|
| 453 |
+
)
|
| 454 |
+
loss_function = F.cross_entropy(
|
| 455 |
+
outputs["function_logits"].transpose(1, 2),
|
| 456 |
+
batch["true_functions"].transpose(1, 2),
|
| 457 |
+
reduction="none",
|
| 458 |
+
)
|
| 459 |
+
# input is [B, T, C]
|
| 460 |
+
ttt = self.config.label_focal_loss_weight * self.label_focal_loss(
|
| 461 |
+
pred=outputs["function_logits"], targets=batch["true_functions"]
|
| 462 |
+
)
|
| 463 |
+
loss_function += ttt
|
| 464 |
+
|
| 465 |
+
float_masks = (~batch["masks"]).float()
|
| 466 |
+
boundary_mask = batch.get("boundary_mask", None)
|
| 467 |
+
function_mask = batch.get("function_mask", None)
|
| 468 |
+
if boundary_mask is not None:
|
| 469 |
+
boundary_mask = (~boundary_mask).float()
|
| 470 |
+
else:
|
| 471 |
+
boundary_mask = 1
|
| 472 |
+
|
| 473 |
+
if function_mask is not None:
|
| 474 |
+
function_mask = (~function_mask).float()
|
| 475 |
+
else:
|
| 476 |
+
function_mask = 1
|
| 477 |
+
|
| 478 |
+
loss_section = torch.mean(boundary_mask * float_masks * loss_section)
|
| 479 |
+
loss_function = torch.mean(function_mask * float_masks * loss_function)
|
| 480 |
+
|
| 481 |
+
loss_section *= self.config.loss_weight_section
|
| 482 |
+
loss_function *= self.config.loss_weight_function
|
| 483 |
+
|
| 484 |
+
if self.config.learn_label:
|
| 485 |
+
loss += loss_function
|
| 486 |
+
if self.config.learn_segment:
|
| 487 |
+
loss += loss_section
|
| 488 |
+
|
| 489 |
+
losses.update(
|
| 490 |
+
loss=loss,
|
| 491 |
+
loss_section=loss_section,
|
| 492 |
+
loss_function=loss_function,
|
| 493 |
+
)
|
| 494 |
+
if prefix:
|
| 495 |
+
losses = prefix_dict(losses, prefix)
|
| 496 |
+
return losses
|
| 497 |
+
|
| 498 |
+
def forward_func(self, batch):
|
| 499 |
+
input_embeddings = batch["input_embeddings"]
|
| 500 |
+
input_embeddings = self.mixed_win_downsample(input_embeddings)
|
| 501 |
+
input_embeddings = self.input_norm(input_embeddings)
|
| 502 |
+
logits = self.down_sample_conv(input_embeddings)
|
| 503 |
+
|
| 504 |
+
dataset_prefix = self.dataset_class_prefix(batch["dataset_ids"])
|
| 505 |
+
logits = self.AddFuse(x=logits, cond=dataset_prefix.unsqueeze(1))
|
| 506 |
+
src_key_padding_mask = batch["masks"]
|
| 507 |
+
logits = self.transformer(x=logits, src_key_padding_mask=src_key_padding_mask)
|
| 508 |
+
|
| 509 |
+
function_logits = self.function_head(logits)
|
| 510 |
+
boundary_logits = self.boundary_head(logits).squeeze(-1)
|
| 511 |
+
|
| 512 |
+
logits = {
|
| 513 |
+
"function_logits": function_logits,
|
| 514 |
+
"boundary_logits": boundary_logits,
|
| 515 |
+
}
|
| 516 |
+
return logits
|
| 517 |
+
|
| 518 |
+
def forward(self, batch):
|
| 519 |
+
logits = self.forward_func(batch)
|
| 520 |
+
losses = self.compute_losses(logits, batch, prefix=None)
|
| 521 |
+
return logits, losses["loss"], losses
|
src/SongFormer/postprocessing/calc_acc.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import bisect
|
| 3 |
+
from dataset.msa_info_utils import (
|
| 4 |
+
load_msa_info,
|
| 5 |
+
)
|
| 6 |
+
from dataset.custom_types import MsaInfo
|
| 7 |
+
import glob
|
| 8 |
+
import pdb
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def cal_acc(ann_info: MsaInfo | str, est_info: MsaInfo | str, post_digit: int = 3):
|
| 13 |
+
if type(ann_info) is str:
|
| 14 |
+
assert os.path.exists(ann_info), f"{ann_info} not exists"
|
| 15 |
+
ann_info = load_msa_info(ann_info)
|
| 16 |
+
|
| 17 |
+
if type(ann_info) is str:
|
| 18 |
+
assert os.path.exists(est_info), f"{est_info} not exists"
|
| 19 |
+
est_info = load_msa_info(est_info)
|
| 20 |
+
|
| 21 |
+
ann_info_time = [
|
| 22 |
+
int(round(time_, post_digit) * (10**post_digit)) for time_, label in ann_info
|
| 23 |
+
]
|
| 24 |
+
est_info_time = [
|
| 25 |
+
int(round(time_, post_digit) * (10**post_digit)) for time_, label in est_info
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
common_start_time = max(ann_info_time[0], est_info_time[0])
|
| 29 |
+
common_end_time = min(ann_info_time[-1], est_info_time[-1])
|
| 30 |
+
|
| 31 |
+
time_points = set()
|
| 32 |
+
time_points.add(common_start_time)
|
| 33 |
+
time_points.add(common_end_time)
|
| 34 |
+
|
| 35 |
+
for time_ in ann_info_time:
|
| 36 |
+
if time_ >= common_start_time and time_ <= common_end_time:
|
| 37 |
+
time_points.add(time_)
|
| 38 |
+
for time_ in est_info_time:
|
| 39 |
+
if time_ >= common_start_time and time_ <= common_end_time:
|
| 40 |
+
time_points.add(time_)
|
| 41 |
+
|
| 42 |
+
time_points = sorted(list(time_points))
|
| 43 |
+
total_duration = 0
|
| 44 |
+
total_score = 0
|
| 45 |
+
|
| 46 |
+
for idx in range(len(time_points) - 1):
|
| 47 |
+
duration = time_points[idx + 1] - time_points[idx]
|
| 48 |
+
ann_label = ann_info[bisect.bisect_right(ann_info_time, time_points[idx]) - 1][
|
| 49 |
+
1
|
| 50 |
+
]
|
| 51 |
+
est_label = est_info[bisect.bisect_right(est_info_time, time_points[idx]) - 1][
|
| 52 |
+
1
|
| 53 |
+
]
|
| 54 |
+
total_duration += duration
|
| 55 |
+
if ann_label == est_label:
|
| 56 |
+
total_score += duration
|
| 57 |
+
return total_score / total_duration
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
ext_paths = glob.glob("")
|
| 62 |
+
results = []
|
| 63 |
+
for ext_path in ext_paths:
|
| 64 |
+
try:
|
| 65 |
+
ann_path = os.path.join(
|
| 66 |
+
"",
|
| 67 |
+
os.path.basename(ext_path).split(".")[0] + ".txt",
|
| 68 |
+
)
|
| 69 |
+
results.append(
|
| 70 |
+
{
|
| 71 |
+
"data_id": os.path.basename(ext_path).split(".")[0],
|
| 72 |
+
"acc": cal_acc(
|
| 73 |
+
ann_info=ann_path,
|
| 74 |
+
est_info=ext_path,
|
| 75 |
+
),
|
| 76 |
+
}
|
| 77 |
+
)
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(e)
|
| 80 |
+
continue
|
| 81 |
+
df = pd.DataFrame(results)
|
| 82 |
+
print(df["acc"].mean())
|
src/SongFormer/postprocessing/calc_iou.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataset.custom_types import MsaInfo
|
| 3 |
+
from dataset.label2id import LABEL_TO_ID
|
| 4 |
+
from pprint import pprint
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def load_msa_info(msa_info_path):
|
| 8 |
+
msa_info: MsaInfo = []
|
| 9 |
+
with open(msa_info_path) as f:
|
| 10 |
+
for line in f:
|
| 11 |
+
line = line.strip()
|
| 12 |
+
if not line:
|
| 13 |
+
continue
|
| 14 |
+
time_, label = line.split()
|
| 15 |
+
time_ = float(time_)
|
| 16 |
+
label = str(label)
|
| 17 |
+
assert label in LABEL_TO_ID or label == "end", f"{label} not in LABEL_TO_ID"
|
| 18 |
+
msa_info.append((time_, label))
|
| 19 |
+
assert msa_info[-1][1] == "end", f"last {msa_info[-1][1]} != end"
|
| 20 |
+
return msa_info
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def msa_info_to_segments(msa_info):
|
| 24 |
+
# skip the last "end"
|
| 25 |
+
segments = []
|
| 26 |
+
for i in range(len(msa_info) - 1):
|
| 27 |
+
start = msa_info[i][0]
|
| 28 |
+
end = msa_info[i + 1][0]
|
| 29 |
+
label = msa_info[i][1]
|
| 30 |
+
segments.append((start, end, label))
|
| 31 |
+
return segments
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def compute_iou_for_label(segments_a, segments_b, label):
|
| 35 |
+
# segments_a, segments_b: [(start, end, label)]
|
| 36 |
+
# only process the current label
|
| 37 |
+
intervals_a = [(s, e) for s, e, l in segments_a if l == label]
|
| 38 |
+
intervals_b = [(s, e) for s, e, l in segments_b if l == label]
|
| 39 |
+
# sum up all intersections between a and b
|
| 40 |
+
intersection = 0.0
|
| 41 |
+
for sa, ea in intervals_a:
|
| 42 |
+
for sb, eb in intervals_b:
|
| 43 |
+
left = max(sa, sb)
|
| 44 |
+
right = min(ea, eb)
|
| 45 |
+
if left < right:
|
| 46 |
+
intersection += right - left
|
| 47 |
+
# union = total length of both sets - overlapping intersection
|
| 48 |
+
length_a = sum([e - s for s, e in intervals_a])
|
| 49 |
+
length_b = sum([e - s for s, e in intervals_b])
|
| 50 |
+
union = length_a + length_b - intersection
|
| 51 |
+
if union == 0:
|
| 52 |
+
return 0.0
|
| 53 |
+
return intersection / union, intersection, union
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def compute_mean_iou(segments_a, segments_b, labels):
|
| 57 |
+
ious = []
|
| 58 |
+
for label in labels:
|
| 59 |
+
iou, intsec_dur, uni_dur = compute_iou_for_label(segments_a, segments_b, label)
|
| 60 |
+
ious.append(
|
| 61 |
+
{"label": label, "iou": iou, "intsec_dur": intsec_dur, "uni_dur": uni_dur}
|
| 62 |
+
)
|
| 63 |
+
return ious
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def cal_iou(ann_info, est_info):
|
| 67 |
+
if type(ann_info) is str:
|
| 68 |
+
assert os.path.exists(ann_info), f"{ann_info} not exists"
|
| 69 |
+
ann_info = load_msa_info(ann_info)
|
| 70 |
+
|
| 71 |
+
if type(est_info) is str:
|
| 72 |
+
assert os.path.exists(est_info), f"{est_info} not exists"
|
| 73 |
+
est_info = load_msa_info(est_info)
|
| 74 |
+
|
| 75 |
+
segments_ann = msa_info_to_segments(ann_info)
|
| 76 |
+
segments_est = msa_info_to_segments(est_info)
|
| 77 |
+
|
| 78 |
+
occurred_labels = list(
|
| 79 |
+
set([l for s, e, l in segments_ann]) | set(l for s, e, l in segments_est)
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
mean_iou = compute_mean_iou(segments_ann, segments_est, occurred_labels)
|
| 83 |
+
return mean_iou
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
ann_info = ""
|
| 88 |
+
est_info = ""
|
| 89 |
+
pprint(cal_iou(ann_info, est_info))
|
src/SongFormer/postprocessing/functional.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file contains code adapted from the following sources:
|
| 2 |
+
# [MIT license] https://github.com/mir-aidj/all-in-one/blob/main/src/allin1/postprocessing/functional.py
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from .helpers import (
|
| 7 |
+
local_maxima,
|
| 8 |
+
peak_picking,
|
| 9 |
+
# event_frames_to_time,
|
| 10 |
+
)
|
| 11 |
+
from dataset.label2id import LABEL_TO_ID, ID_TO_LABEL
|
| 12 |
+
from dataset.custom_types import MsaInfo
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def event_frames_to_time(frame_rates, boundary: np.array):
|
| 16 |
+
boundary = np.array(boundary)
|
| 17 |
+
boundary_times = boundary / frame_rates
|
| 18 |
+
return boundary_times
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def postprocess_functional_structure(
|
| 22 |
+
logits,
|
| 23 |
+
config,
|
| 24 |
+
):
|
| 25 |
+
# pdb.set_trace()
|
| 26 |
+
boundary_logits = logits["boundary_logits"]
|
| 27 |
+
function_logits = logits["function_logits"]
|
| 28 |
+
|
| 29 |
+
assert boundary_logits.shape[0] == 1 and function_logits.shape[0] == 1, (
|
| 30 |
+
"Only batch size 1 is supported"
|
| 31 |
+
)
|
| 32 |
+
raw_prob_sections = torch.sigmoid(boundary_logits[0])
|
| 33 |
+
raw_prob_functions = torch.softmax(function_logits[0].transpose(0, 1), dim=0)
|
| 34 |
+
|
| 35 |
+
# filter_size=4 * cfg.min_hops_per_beat + 1
|
| 36 |
+
prob_sections, _ = local_maxima(
|
| 37 |
+
raw_prob_sections, filter_size=config.local_maxima_filter_size
|
| 38 |
+
)
|
| 39 |
+
prob_sections = prob_sections.cpu().numpy()
|
| 40 |
+
|
| 41 |
+
prob_functions = raw_prob_functions.cpu().numpy()
|
| 42 |
+
|
| 43 |
+
boundary_candidates = peak_picking(
|
| 44 |
+
boundary_activation=prob_sections,
|
| 45 |
+
window_past=int(12 * config.frame_rates), # 原来是fps
|
| 46 |
+
window_future=int(12 * config.frame_rates),
|
| 47 |
+
)
|
| 48 |
+
boundary = boundary_candidates > 0.0
|
| 49 |
+
|
| 50 |
+
duration = len(prob_sections) / config.frame_rates
|
| 51 |
+
pred_boundary_times = event_frames_to_time(
|
| 52 |
+
frame_rates=config.frame_rates, boundary=np.flatnonzero(boundary)
|
| 53 |
+
)
|
| 54 |
+
if pred_boundary_times[0] != 0:
|
| 55 |
+
pred_boundary_times = np.insert(pred_boundary_times, 0, 0)
|
| 56 |
+
if pred_boundary_times[-1] != duration:
|
| 57 |
+
pred_boundary_times = np.append(pred_boundary_times, duration)
|
| 58 |
+
pred_boundaries = np.stack([pred_boundary_times[:-1], pred_boundary_times[1:]]).T
|
| 59 |
+
|
| 60 |
+
pred_boundary_indices = np.flatnonzero(boundary)
|
| 61 |
+
pred_boundary_indices = pred_boundary_indices[pred_boundary_indices > 0]
|
| 62 |
+
prob_segment_function = np.split(prob_functions, pred_boundary_indices, axis=1)
|
| 63 |
+
pred_labels = [p.mean(axis=1).argmax().item() for p in prob_segment_function]
|
| 64 |
+
|
| 65 |
+
segments: MsaInfo = []
|
| 66 |
+
for (start, end), label in zip(pred_boundaries, pred_labels):
|
| 67 |
+
segment = (float(start), str(ID_TO_LABEL[label]))
|
| 68 |
+
segments.append(segment)
|
| 69 |
+
|
| 70 |
+
segments.append((float(pred_boundary_times[-1]), "end"))
|
| 71 |
+
return segments
|
src/SongFormer/postprocessing/helpers.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file contains code adapted from the following sources:
|
| 2 |
+
# [MIT license] https://github.com/mir-aidj/all-in-one/blob/main/src/allin1/postprocessing/helpers.py
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch
|
| 7 |
+
import librosa
|
| 8 |
+
from typing import Union
|
| 9 |
+
from scipy.signal import argrelextrema
|
| 10 |
+
from scipy.interpolate import interp1d
|
| 11 |
+
from numpy.lib.stride_tricks import sliding_window_view
|
| 12 |
+
from numpy.typing import NDArray
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def local_maxima(tensor, filter_size=41):
|
| 16 |
+
assert len(tensor.shape) in (1, 2), "Input tensor should have 1 or 2 dimensions"
|
| 17 |
+
assert filter_size % 2 == 1, "Filter size should be an odd number"
|
| 18 |
+
|
| 19 |
+
original_shape = tensor.shape
|
| 20 |
+
if len(original_shape) == 1:
|
| 21 |
+
tensor = tensor.unsqueeze(0)
|
| 22 |
+
|
| 23 |
+
# Pad the input array with the minimum value
|
| 24 |
+
padding = filter_size // 2
|
| 25 |
+
padded_arr = F.pad(tensor, (padding, padding), mode="constant", value=-torch.inf)
|
| 26 |
+
|
| 27 |
+
# Create a rolling window view of the padded array
|
| 28 |
+
rolling_view = padded_arr.unfold(1, filter_size, 1)
|
| 29 |
+
|
| 30 |
+
# Find the indices of the local maxima
|
| 31 |
+
center = filter_size // 2
|
| 32 |
+
local_maxima_mask = torch.eq(
|
| 33 |
+
rolling_view[:, :, center], torch.max(rolling_view, dim=-1).values
|
| 34 |
+
)
|
| 35 |
+
local_maxima_indices = local_maxima_mask.nonzero()
|
| 36 |
+
|
| 37 |
+
# Initialize a new PyTorch tensor with zeros and the same shape as the input tensor
|
| 38 |
+
output_arr = torch.zeros_like(tensor)
|
| 39 |
+
|
| 40 |
+
# Set the local maxima values in the output tensor
|
| 41 |
+
output_arr[local_maxima_mask] = tensor[local_maxima_mask]
|
| 42 |
+
|
| 43 |
+
output_arr = output_arr.reshape(original_shape)
|
| 44 |
+
|
| 45 |
+
return output_arr, local_maxima_indices
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def local_maxima_numpy(arr, order=20):
|
| 49 |
+
is_batch = len(arr.shape) == 2
|
| 50 |
+
if is_batch:
|
| 51 |
+
return np.stack([local_maxima_numpy(x, order) for x in arr])
|
| 52 |
+
|
| 53 |
+
# Define a comparison function for argrelextrema to find local maxima
|
| 54 |
+
compare_func = np.greater
|
| 55 |
+
|
| 56 |
+
# Find the indices of the local maxima
|
| 57 |
+
local_maxima_indices = argrelextrema(arr, compare_func, order=order)
|
| 58 |
+
|
| 59 |
+
# Initialize a new numpy array with zeros and the same shape as the input array
|
| 60 |
+
output_arr = np.zeros_like(arr)
|
| 61 |
+
|
| 62 |
+
# Set the local maxima values in the output array
|
| 63 |
+
output_arr[local_maxima_indices] = arr[local_maxima_indices]
|
| 64 |
+
|
| 65 |
+
return output_arr
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def peak_picking(boundary_activation, window_past=12, window_future=6):
|
| 69 |
+
# Find local maxima using a sliding window
|
| 70 |
+
window_size = window_past + window_future
|
| 71 |
+
assert window_size % 2 == 0, "window_past + window_future must be even"
|
| 72 |
+
window_size += 1
|
| 73 |
+
|
| 74 |
+
# Pad boundary_activation
|
| 75 |
+
boundary_activation_padded = np.pad(
|
| 76 |
+
boundary_activation, (window_past, window_future), mode="constant"
|
| 77 |
+
)
|
| 78 |
+
max_filter = sliding_window_view(boundary_activation_padded, window_size)
|
| 79 |
+
local_maxima = (boundary_activation == np.max(max_filter, axis=-1)) & (
|
| 80 |
+
boundary_activation > 0
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Compute strength values by subtracting the mean of the past and future windows
|
| 84 |
+
past_window_filter = sliding_window_view(
|
| 85 |
+
boundary_activation_padded[: -(window_future + 1)], window_past
|
| 86 |
+
)
|
| 87 |
+
future_window_filter = sliding_window_view(
|
| 88 |
+
boundary_activation_padded[window_past + 1 :], window_future
|
| 89 |
+
)
|
| 90 |
+
past_mean = np.mean(past_window_filter, axis=-1)
|
| 91 |
+
future_mean = np.mean(future_window_filter, axis=-1)
|
| 92 |
+
strength_values = boundary_activation - ((past_mean + future_mean) / 2)
|
| 93 |
+
|
| 94 |
+
# Get boundary candidates and their corresponding strength values
|
| 95 |
+
boundary_candidates = np.flatnonzero(local_maxima)
|
| 96 |
+
strength_values = strength_values[boundary_candidates]
|
| 97 |
+
|
| 98 |
+
strength_activations = np.zeros_like(boundary_activation)
|
| 99 |
+
strength_activations[boundary_candidates] = strength_values
|
| 100 |
+
|
| 101 |
+
return strength_activations
|
src/SongFormer/train/accelerate_config/single_gpu.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: 'NO'
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
gpu_ids: all
|
| 7 |
+
machine_rank: 0
|
| 8 |
+
main_training_function: main
|
| 9 |
+
mixed_precision: 'no'
|
| 10 |
+
num_machines: 1
|
| 11 |
+
num_processes: 1
|
| 12 |
+
rdzv_backend: static
|
| 13 |
+
same_network: true
|
| 14 |
+
tpu_env: []
|
| 15 |
+
tpu_use_cluster: false
|
| 16 |
+
tpu_use_sudo: false
|
| 17 |
+
use_cpu: false
|
src/SongFormer/utils/average_checkpoints.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import copy
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def average_checkpoints(checkpoint_paths: List[str], output_path: str = None):
|
| 7 |
+
"""
|
| 8 |
+
Average the model and model_ema weights from multiple checkpoints
|
| 9 |
+
|
| 10 |
+
Parameters:
|
| 11 |
+
checkpoint_paths: List of checkpoint file paths
|
| 12 |
+
output_path: Output path; if None, return the averaged checkpoint dictionary
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
Averaged checkpoint dictionary
|
| 16 |
+
"""
|
| 17 |
+
if not checkpoint_paths:
|
| 18 |
+
raise ValueError("At least one checkpoint path is required")
|
| 19 |
+
|
| 20 |
+
# Load the first checkpoint as the base
|
| 21 |
+
print(f"Loading base checkpoint: {checkpoint_paths[0]}")
|
| 22 |
+
avg_checkpoint = torch.load(checkpoint_paths[0], map_location="cpu")
|
| 23 |
+
|
| 24 |
+
if len(checkpoint_paths) == 1:
|
| 25 |
+
if output_path:
|
| 26 |
+
torch.save(avg_checkpoint, output_path)
|
| 27 |
+
return avg_checkpoint
|
| 28 |
+
|
| 29 |
+
# Initialize accumulators
|
| 30 |
+
avg_model_state = copy.deepcopy(avg_checkpoint["model"])
|
| 31 |
+
avg_model_ema_state = None
|
| 32 |
+
|
| 33 |
+
if "model_ema" in avg_checkpoint:
|
| 34 |
+
avg_model_ema_state = copy.deepcopy(avg_checkpoint["model_ema"])
|
| 35 |
+
|
| 36 |
+
# Accumulate the weights from the other checkpoints
|
| 37 |
+
for i, ckpt_path in enumerate(checkpoint_paths[1:], 1):
|
| 38 |
+
print(f"Processing checkpoint {i + 1}/{len(checkpoint_paths)}: {ckpt_path}")
|
| 39 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 40 |
+
|
| 41 |
+
# Accumulate model weights
|
| 42 |
+
for key in avg_model_state.keys():
|
| 43 |
+
if key in ckpt["model"]:
|
| 44 |
+
avg_model_state[key] += ckpt["model"][key]
|
| 45 |
+
|
| 46 |
+
# Accumulate model_ema weights (if available)
|
| 47 |
+
if avg_model_ema_state is not None and "model_ema" in ckpt:
|
| 48 |
+
for key in avg_model_ema_state.keys():
|
| 49 |
+
if key in ckpt["model_ema"]:
|
| 50 |
+
avg_model_ema_state[key] += ckpt["model_ema"][key]
|
| 51 |
+
|
| 52 |
+
# Compute the average
|
| 53 |
+
num_checkpoints = len(checkpoint_paths)
|
| 54 |
+
print(f"Averaging over {num_checkpoints} checkpoints...")
|
| 55 |
+
|
| 56 |
+
for key in avg_model_state.keys():
|
| 57 |
+
avg_model_state[key] = avg_model_state[key] / num_checkpoints
|
| 58 |
+
|
| 59 |
+
if avg_model_ema_state is not None:
|
| 60 |
+
for key in avg_model_ema_state.keys():
|
| 61 |
+
avg_model_ema_state[key] = avg_model_ema_state[key] / num_checkpoints
|
| 62 |
+
|
| 63 |
+
# Update the checkpoint dictionary
|
| 64 |
+
avg_checkpoint["model"] = avg_model_state
|
| 65 |
+
if avg_model_ema_state is not None:
|
| 66 |
+
avg_checkpoint["model_ema"] = avg_model_ema_state
|
| 67 |
+
|
| 68 |
+
# Save (if an output path is specified)
|
| 69 |
+
if output_path:
|
| 70 |
+
print(f"Saving averaged checkpoint to: {output_path}")
|
| 71 |
+
torch.save(avg_checkpoint, output_path)
|
| 72 |
+
|
| 73 |
+
return avg_checkpoint
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def average_checkpoints_memory_efficient(
|
| 77 |
+
checkpoint_paths: List[str], output_path: str = None
|
| 78 |
+
):
|
| 79 |
+
"""
|
| 80 |
+
Memory efficient version: Load and process checkpoints one by one, suitable for large models
|
| 81 |
+
"""
|
| 82 |
+
if not checkpoint_paths:
|
| 83 |
+
raise ValueError("At least one checkpoint path is required")
|
| 84 |
+
|
| 85 |
+
print(f"Loading base checkpoint: {checkpoint_paths[0]}")
|
| 86 |
+
avg_checkpoint = torch.load(checkpoint_paths[0], map_location="cpu")
|
| 87 |
+
|
| 88 |
+
if len(checkpoint_paths) == 1:
|
| 89 |
+
if output_path:
|
| 90 |
+
torch.save(avg_checkpoint, output_path)
|
| 91 |
+
return avg_checkpoint
|
| 92 |
+
|
| 93 |
+
# Convert to float32 for better precision
|
| 94 |
+
for key in avg_checkpoint["model"].keys():
|
| 95 |
+
avg_checkpoint["model"][key] = avg_checkpoint["model"][key].float()
|
| 96 |
+
|
| 97 |
+
if "model_ema" in avg_checkpoint:
|
| 98 |
+
for key in avg_checkpoint["model_ema"].keys():
|
| 99 |
+
avg_checkpoint["model_ema"][key] = avg_checkpoint["model_ema"][key].float()
|
| 100 |
+
|
| 101 |
+
# Load and accumulate checkpoints one by one
|
| 102 |
+
for i, ckpt_path in enumerate(checkpoint_paths[1:], 1):
|
| 103 |
+
print(f"Processing checkpoint {i + 1}/{len(checkpoint_paths)}: {ckpt_path}")
|
| 104 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 105 |
+
|
| 106 |
+
# Accumulate model weights
|
| 107 |
+
for key in avg_checkpoint["model"].keys():
|
| 108 |
+
if key in ckpt["model"]:
|
| 109 |
+
avg_checkpoint["model"][key] += ckpt["model"][key].float()
|
| 110 |
+
|
| 111 |
+
# Accumulate model_ema weights
|
| 112 |
+
if "model_ema" in avg_checkpoint and "model_ema" in ckpt:
|
| 113 |
+
for key in avg_checkpoint["model_ema"].keys():
|
| 114 |
+
if key in ckpt["model_ema"]:
|
| 115 |
+
avg_checkpoint["model_ema"][key] += ckpt["model_ema"][key].float()
|
| 116 |
+
|
| 117 |
+
# Free memory
|
| 118 |
+
del ckpt
|
| 119 |
+
torch.cuda.empty_cache()
|
| 120 |
+
|
| 121 |
+
# Compute the average
|
| 122 |
+
num_checkpoints = len(checkpoint_paths)
|
| 123 |
+
print(f"Averaging over {num_checkpoints} checkpoints...")
|
| 124 |
+
|
| 125 |
+
for key in avg_checkpoint["model"].keys():
|
| 126 |
+
avg_checkpoint["model"][key] /= num_checkpoints
|
| 127 |
+
|
| 128 |
+
if "model_ema" in avg_checkpoint:
|
| 129 |
+
for key in avg_checkpoint["model_ema"].keys():
|
| 130 |
+
avg_checkpoint["model_ema"][key] /= num_checkpoints
|
| 131 |
+
|
| 132 |
+
if output_path:
|
| 133 |
+
print(f"Saving averaged checkpoint to: {output_path}")
|
| 134 |
+
torch.save(avg_checkpoint, output_path)
|
| 135 |
+
|
| 136 |
+
return avg_checkpoint
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# Example usage
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
# Method 1: Simple usage
|
| 142 |
+
checkpoint_paths = []
|
| 143 |
+
|
| 144 |
+
# Average and save
|
| 145 |
+
average_checkpoints(checkpoint_paths, "")
|
| 146 |
+
|
| 147 |
+
# Method 2: Get the averaged checkpoint and further process it
|
| 148 |
+
# avg_ckpt = average_checkpoints(checkpoint_paths)
|
| 149 |
+
# print("Averaged checkpoint keys:", avg_ckpt.keys())
|
| 150 |
+
|
| 151 |
+
# Method 3: Use memory-efficient version (suitable for large models)
|
| 152 |
+
# average_checkpoints_memory_efficient(checkpoint_paths, 'averaged_checkpoint_efficient.pt')
|
src/SongFormer/utils/convert_res2msa_txt.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import fire
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def convert_json_to_format(json_data):
|
| 8 |
+
"""Convert JSON data to the specified format"""
|
| 9 |
+
result = []
|
| 10 |
+
|
| 11 |
+
# Process the start time and label for each segment
|
| 12 |
+
for segment in json_data:
|
| 13 |
+
start_time = segment["start"]
|
| 14 |
+
label = segment["label"]
|
| 15 |
+
result.append(f"{start_time:.6f} {label}")
|
| 16 |
+
|
| 17 |
+
# Add the last end time
|
| 18 |
+
if json_data:
|
| 19 |
+
last_end_time = json_data[-1]["end"]
|
| 20 |
+
result.append(f"{last_end_time:.6f} end")
|
| 21 |
+
|
| 22 |
+
return "\n".join(result)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def process_json_files(input_folder, output_folder):
|
| 26 |
+
"""Process all JSON files in the input folder"""
|
| 27 |
+
|
| 28 |
+
# Create the output folder if it doesn't exist
|
| 29 |
+
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
# Get all JSON files
|
| 32 |
+
json_files = [f for f in os.listdir(input_folder) if f.endswith(".json")]
|
| 33 |
+
|
| 34 |
+
if not json_files:
|
| 35 |
+
print(f"No JSON files found in {input_folder}")
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
print(f"Found {len(json_files)} JSON files")
|
| 39 |
+
|
| 40 |
+
# Process each JSON file
|
| 41 |
+
for json_file in json_files:
|
| 42 |
+
input_path = os.path.join(input_folder, json_file)
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
# Read the JSON file
|
| 46 |
+
with open(input_path, "r", encoding="utf-8") as f:
|
| 47 |
+
data = json.load(f)
|
| 48 |
+
|
| 49 |
+
# Convert the format
|
| 50 |
+
converted_data = convert_json_to_format(data)
|
| 51 |
+
|
| 52 |
+
# Generate the output filename (replace .json with .txt)
|
| 53 |
+
output_filename = json_file.replace(".json", ".txt")
|
| 54 |
+
output_path = os.path.join(output_folder, output_filename)
|
| 55 |
+
|
| 56 |
+
# Write to the output file
|
| 57 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 58 |
+
f.write(converted_data)
|
| 59 |
+
|
| 60 |
+
print(f"✓ Processed: {json_file} -> {output_filename}")
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"✗ Error processing {json_file}: {str(e)}")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def main(input_folder: str, output_folder: str):
|
| 67 |
+
print(f"Input folder: {input_folder}")
|
| 68 |
+
print(f"Output folder: {output_folder}")
|
| 69 |
+
print("-" * 50)
|
| 70 |
+
|
| 71 |
+
# Process the files
|
| 72 |
+
process_json_files(input_folder, output_folder)
|
| 73 |
+
|
| 74 |
+
print("-" * 50)
|
| 75 |
+
print("Processing complete!")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
fire.Fire(main)
|
src/SongFormer/utils/fetch_pretrained.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def download(url, path):
|
| 7 |
+
if os.path.exists(path):
|
| 8 |
+
print(f"File already exists, skipping download: {path}")
|
| 9 |
+
return
|
| 10 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 11 |
+
response = requests.get(url, stream=True)
|
| 12 |
+
total_size = int(response.headers.get("content-length", 0))
|
| 13 |
+
with (
|
| 14 |
+
open(path, "wb") as f,
|
| 15 |
+
tqdm(
|
| 16 |
+
desc=path,
|
| 17 |
+
total=total_size,
|
| 18 |
+
unit="iB",
|
| 19 |
+
unit_scale=True,
|
| 20 |
+
unit_divisor=1024,
|
| 21 |
+
) as bar,
|
| 22 |
+
):
|
| 23 |
+
for data in response.iter_content(chunk_size=1024):
|
| 24 |
+
size = f.write(data)
|
| 25 |
+
bar.update(size)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# 根据 https://github.com/minzwon/musicfm 下载预训练模型
|
| 29 |
+
download(
|
| 30 |
+
"https://huggingface.co/minzwon/MusicFM/resolve/main/msd_stats.json",
|
| 31 |
+
os.path.join("ckpts", "MusicFM", "msd_stats.json"),
|
| 32 |
+
)
|
| 33 |
+
download(
|
| 34 |
+
"https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_msd.pt",
|
| 35 |
+
os.path.join("ckpts", "MusicFM", "pretrained_msd.pt"),
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# for Mainland China
|
| 39 |
+
# download('https://hf-mirror.com/minzwon/MusicFM/resolve/main/msd_stats.json', os.path.join("ckpts", "MusicFM", "msd_stats.json"))
|
| 40 |
+
# download('https://hf-mirror.com/minzwon/MusicFM/resolve/main/pretrained_msd.pt', os.path.join("ckpts", "MusicFM", "pretrained_msd.pt"))
|
src/third_party/MuQ/.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
src/third_party/MuQ/.gitignore
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.egg*/
|
| 6 |
+
*pyc
|
| 7 |
+
|
| 8 |
+
# Distribution / packaging
|
| 9 |
+
.Python
|
| 10 |
+
env/
|
| 11 |
+
build/
|
| 12 |
+
dist/
|
| 13 |
+
*.log
|
| 14 |
+
|
| 15 |
+
# pyenv
|
| 16 |
+
.python-version
|
| 17 |
+
|
| 18 |
+
# dotenv
|
| 19 |
+
.env
|
| 20 |
+
|
| 21 |
+
# virtualenv
|
| 22 |
+
.venv/
|
| 23 |
+
venv/
|
| 24 |
+
ENV/
|
| 25 |
+
|
| 26 |
+
# VSCode settings
|
| 27 |
+
.vscode
|
| 28 |
+
|
| 29 |
+
# IDEA files
|
| 30 |
+
.idea
|
| 31 |
+
|
| 32 |
+
# OSX dir files
|
| 33 |
+
.DS_Store
|
| 34 |
+
|
| 35 |
+
# Sublime Text settings
|
| 36 |
+
*.sublime-workspace
|
| 37 |
+
*.sublime-project
|
| 38 |
+
|
| 39 |
+
# custom
|
| 40 |
+
open/
|
| 41 |
+
src/recipes/pretrain/dataset/music4all/*.json
|
| 42 |
+
src/recipes/contrastive_learning/datasets/mtg-jamendo/*.json
|
| 43 |
+
runs/
|
| 44 |
+
output/
|
| 45 |
+
logs
|
| 46 |
+
outputs/
|
src/third_party/MuQ/.gitmodules
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "src/recipes/pretrain/fairseq"]
|
| 2 |
+
path = src/recipes/pretrain/fairseq
|
| 3 |
+
url = https://github.com/facebookresearch/fairseq
|
src/third_party/MuQ/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Tencent.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
src/third_party/MuQ/LICENSE_weights
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Attribution-NonCommercial 4.0 International
|
| 2 |
+
|
| 3 |
+
=======================================================================
|
| 4 |
+
|
| 5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
| 6 |
+
does not provide legal services or legal advice. Distribution of
|
| 7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
| 8 |
+
other relationship. Creative Commons makes its licenses and related
|
| 9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
| 10 |
+
warranties regarding its licenses, any material licensed under their
|
| 11 |
+
terms and conditions, or any related information. Creative Commons
|
| 12 |
+
disclaims all liability for damages resulting from their use to the
|
| 13 |
+
fullest extent possible.
|
| 14 |
+
|
| 15 |
+
Using Creative Commons Public Licenses
|
| 16 |
+
|
| 17 |
+
Creative Commons public licenses provide a standard set of terms and
|
| 18 |
+
conditions that creators and other rights holders may use to share
|
| 19 |
+
original works of authorship and other material subject to copyright
|
| 20 |
+
and certain other rights specified in the public license below. The
|
| 21 |
+
following considerations are for informational purposes only, are not
|
| 22 |
+
exhaustive, and do not form part of our licenses.
|
| 23 |
+
|
| 24 |
+
Considerations for licensors: Our public licenses are
|
| 25 |
+
intended for use by those authorized to give the public
|
| 26 |
+
permission to use material in ways otherwise restricted by
|
| 27 |
+
copyright and certain other rights. Our licenses are
|
| 28 |
+
irrevocable. Licensors should read and understand the terms
|
| 29 |
+
and conditions of the license they choose before applying it.
|
| 30 |
+
Licensors should also secure all rights necessary before
|
| 31 |
+
applying our licenses so that the public can reuse the
|
| 32 |
+
material as expected. Licensors should clearly mark any
|
| 33 |
+
material not subject to the license. This includes other CC-
|
| 34 |
+
licensed material, or material used under an exception or
|
| 35 |
+
limitation to copyright. More considerations for licensors:
|
| 36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
| 37 |
+
|
| 38 |
+
Considerations for the public: By using one of our public
|
| 39 |
+
licenses, a licensor grants the public permission to use the
|
| 40 |
+
licensed material under specified terms and conditions. If
|
| 41 |
+
the licensor's permission is not necessary for any reason--for
|
| 42 |
+
example, because of any applicable exception or limitation to
|
| 43 |
+
copyright--then that use is not regulated by the license. Our
|
| 44 |
+
licenses grant only permissions under copyright and certain
|
| 45 |
+
other rights that a licensor has authority to grant. Use of
|
| 46 |
+
the licensed material may still be restricted for other
|
| 47 |
+
reasons, including because others have copyright or other
|
| 48 |
+
rights in the material. A licensor may make special requests,
|
| 49 |
+
such as asking that all changes be marked or described.
|
| 50 |
+
Although not required by our licenses, you are encouraged to
|
| 51 |
+
respect those requests where reasonable. More_considerations
|
| 52 |
+
for the public:
|
| 53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
| 54 |
+
|
| 55 |
+
=======================================================================
|
| 56 |
+
|
| 57 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
| 58 |
+
License
|
| 59 |
+
|
| 60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
| 61 |
+
to be bound by the terms and conditions of this Creative Commons
|
| 62 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
| 63 |
+
License"). To the extent this Public License may be interpreted as a
|
| 64 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
| 65 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
| 66 |
+
such rights in consideration of benefits the Licensor receives from
|
| 67 |
+
making the Licensed Material available under these terms and
|
| 68 |
+
conditions.
|
| 69 |
+
|
| 70 |
+
Section 1 -- Definitions.
|
| 71 |
+
|
| 72 |
+
a. Adapted Material means material subject to Copyright and Similar
|
| 73 |
+
Rights that is derived from or based upon the Licensed Material
|
| 74 |
+
and in which the Licensed Material is translated, altered,
|
| 75 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
| 76 |
+
permission under the Copyright and Similar Rights held by the
|
| 77 |
+
Licensor. For purposes of this Public License, where the Licensed
|
| 78 |
+
Material is a musical work, performance, or sound recording,
|
| 79 |
+
Adapted Material is always produced where the Licensed Material is
|
| 80 |
+
synched in timed relation with a moving image.
|
| 81 |
+
|
| 82 |
+
b. Adapter's License means the license You apply to Your Copyright
|
| 83 |
+
and Similar Rights in Your contributions to Adapted Material in
|
| 84 |
+
accordance with the terms and conditions of this Public License.
|
| 85 |
+
|
| 86 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
| 87 |
+
closely related to copyright including, without limitation,
|
| 88 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
| 89 |
+
Rights, without regard to how the rights are labeled or
|
| 90 |
+
categorized. For purposes of this Public License, the rights
|
| 91 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
| 92 |
+
Rights.
|
| 93 |
+
d. Effective Technological Measures means those measures that, in the
|
| 94 |
+
absence of proper authority, may not be circumvented under laws
|
| 95 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
| 96 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
| 97 |
+
agreements.
|
| 98 |
+
|
| 99 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
| 100 |
+
any other exception or limitation to Copyright and Similar Rights
|
| 101 |
+
that applies to Your use of the Licensed Material.
|
| 102 |
+
|
| 103 |
+
f. Licensed Material means the artistic or literary work, database,
|
| 104 |
+
or other material to which the Licensor applied this Public
|
| 105 |
+
License.
|
| 106 |
+
|
| 107 |
+
g. Licensed Rights means the rights granted to You subject to the
|
| 108 |
+
terms and conditions of this Public License, which are limited to
|
| 109 |
+
all Copyright and Similar Rights that apply to Your use of the
|
| 110 |
+
Licensed Material and that the Licensor has authority to license.
|
| 111 |
+
|
| 112 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
| 113 |
+
under this Public License.
|
| 114 |
+
|
| 115 |
+
i. NonCommercial means not primarily intended for or directed towards
|
| 116 |
+
commercial advantage or monetary compensation. For purposes of
|
| 117 |
+
this Public License, the exchange of the Licensed Material for
|
| 118 |
+
other material subject to Copyright and Similar Rights by digital
|
| 119 |
+
file-sharing or similar means is NonCommercial provided there is
|
| 120 |
+
no payment of monetary compensation in connection with the
|
| 121 |
+
exchange.
|
| 122 |
+
|
| 123 |
+
j. Share means to provide material to the public by any means or
|
| 124 |
+
process that requires permission under the Licensed Rights, such
|
| 125 |
+
as reproduction, public display, public performance, distribution,
|
| 126 |
+
dissemination, communication, or importation, and to make material
|
| 127 |
+
available to the public including in ways that members of the
|
| 128 |
+
public may access the material from a place and at a time
|
| 129 |
+
individually chosen by them.
|
| 130 |
+
|
| 131 |
+
k. Sui Generis Database Rights means rights other than copyright
|
| 132 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
| 133 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
| 134 |
+
as amended and/or succeeded, as well as other essentially
|
| 135 |
+
equivalent rights anywhere in the world.
|
| 136 |
+
|
| 137 |
+
l. You means the individual or entity exercising the Licensed Rights
|
| 138 |
+
under this Public License. Your has a corresponding meaning.
|
| 139 |
+
|
| 140 |
+
Section 2 -- Scope.
|
| 141 |
+
|
| 142 |
+
a. License grant.
|
| 143 |
+
|
| 144 |
+
1. Subject to the terms and conditions of this Public License,
|
| 145 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
| 146 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
| 147 |
+
exercise the Licensed Rights in the Licensed Material to:
|
| 148 |
+
|
| 149 |
+
a. reproduce and Share the Licensed Material, in whole or
|
| 150 |
+
in part, for NonCommercial purposes only; and
|
| 151 |
+
|
| 152 |
+
b. produce, reproduce, and Share Adapted Material for
|
| 153 |
+
NonCommercial purposes only.
|
| 154 |
+
|
| 155 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
| 156 |
+
Exceptions and Limitations apply to Your use, this Public
|
| 157 |
+
License does not apply, and You do not need to comply with
|
| 158 |
+
its terms and conditions.
|
| 159 |
+
|
| 160 |
+
3. Term. The term of this Public License is specified in Section
|
| 161 |
+
6(a).
|
| 162 |
+
|
| 163 |
+
4. Media and formats; technical modifications allowed. The
|
| 164 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
| 165 |
+
all media and formats whether now known or hereafter created,
|
| 166 |
+
and to make technical modifications necessary to do so. The
|
| 167 |
+
Licensor waives and/or agrees not to assert any right or
|
| 168 |
+
authority to forbid You from making technical modifications
|
| 169 |
+
necessary to exercise the Licensed Rights, including
|
| 170 |
+
technical modifications necessary to circumvent Effective
|
| 171 |
+
Technological Measures. For purposes of this Public License,
|
| 172 |
+
simply making modifications authorized by this Section 2(a)
|
| 173 |
+
(4) never produces Adapted Material.
|
| 174 |
+
|
| 175 |
+
5. Downstream recipients.
|
| 176 |
+
|
| 177 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
| 178 |
+
recipient of the Licensed Material automatically
|
| 179 |
+
receives an offer from the Licensor to exercise the
|
| 180 |
+
Licensed Rights under the terms and conditions of this
|
| 181 |
+
Public License.
|
| 182 |
+
|
| 183 |
+
b. No downstream restrictions. You may not offer or impose
|
| 184 |
+
any additional or different terms or conditions on, or
|
| 185 |
+
apply any Effective Technological Measures to, the
|
| 186 |
+
Licensed Material if doing so restricts exercise of the
|
| 187 |
+
Licensed Rights by any recipient of the Licensed
|
| 188 |
+
Material.
|
| 189 |
+
|
| 190 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
| 191 |
+
may be construed as permission to assert or imply that You
|
| 192 |
+
are, or that Your use of the Licensed Material is, connected
|
| 193 |
+
with, or sponsored, endorsed, or granted official status by,
|
| 194 |
+
the Licensor or others designated to receive attribution as
|
| 195 |
+
provided in Section 3(a)(1)(A)(i).
|
| 196 |
+
|
| 197 |
+
b. Other rights.
|
| 198 |
+
|
| 199 |
+
1. Moral rights, such as the right of integrity, are not
|
| 200 |
+
licensed under this Public License, nor are publicity,
|
| 201 |
+
privacy, and/or other similar personality rights; however, to
|
| 202 |
+
the extent possible, the Licensor waives and/or agrees not to
|
| 203 |
+
assert any such rights held by the Licensor to the limited
|
| 204 |
+
extent necessary to allow You to exercise the Licensed
|
| 205 |
+
Rights, but not otherwise.
|
| 206 |
+
|
| 207 |
+
2. Patent and trademark rights are not licensed under this
|
| 208 |
+
Public License.
|
| 209 |
+
|
| 210 |
+
3. To the extent possible, the Licensor waives any right to
|
| 211 |
+
collect royalties from You for the exercise of the Licensed
|
| 212 |
+
Rights, whether directly or through a collecting society
|
| 213 |
+
under any voluntary or waivable statutory or compulsory
|
| 214 |
+
licensing scheme. In all other cases the Licensor expressly
|
| 215 |
+
reserves any right to collect such royalties, including when
|
| 216 |
+
the Licensed Material is used other than for NonCommercial
|
| 217 |
+
purposes.
|
| 218 |
+
|
| 219 |
+
Section 3 -- License Conditions.
|
| 220 |
+
|
| 221 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
| 222 |
+
following conditions.
|
| 223 |
+
|
| 224 |
+
a. Attribution.
|
| 225 |
+
|
| 226 |
+
1. If You Share the Licensed Material (including in modified
|
| 227 |
+
form), You must:
|
| 228 |
+
|
| 229 |
+
a. retain the following if it is supplied by the Licensor
|
| 230 |
+
with the Licensed Material:
|
| 231 |
+
|
| 232 |
+
i. identification of the creator(s) of the Licensed
|
| 233 |
+
Material and any others designated to receive
|
| 234 |
+
attribution, in any reasonable manner requested by
|
| 235 |
+
the Licensor (including by pseudonym if
|
| 236 |
+
designated);
|
| 237 |
+
|
| 238 |
+
ii. a copyright notice;
|
| 239 |
+
|
| 240 |
+
iii. a notice that refers to this Public License;
|
| 241 |
+
|
| 242 |
+
iv. a notice that refers to the disclaimer of
|
| 243 |
+
warranties;
|
| 244 |
+
|
| 245 |
+
v. a URI or hyperlink to the Licensed Material to the
|
| 246 |
+
extent reasonably practicable;
|
| 247 |
+
|
| 248 |
+
b. indicate if You modified the Licensed Material and
|
| 249 |
+
retain an indication of any previous modifications; and
|
| 250 |
+
|
| 251 |
+
c. indicate the Licensed Material is licensed under this
|
| 252 |
+
Public License, and include the text of, or the URI or
|
| 253 |
+
hyperlink to, this Public License.
|
| 254 |
+
|
| 255 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
| 256 |
+
reasonable manner based on the medium, means, and context in
|
| 257 |
+
which You Share the Licensed Material. For example, it may be
|
| 258 |
+
reasonable to satisfy the conditions by providing a URI or
|
| 259 |
+
hyperlink to a resource that includes the required
|
| 260 |
+
information.
|
| 261 |
+
|
| 262 |
+
3. If requested by the Licensor, You must remove any of the
|
| 263 |
+
information required by Section 3(a)(1)(A) to the extent
|
| 264 |
+
reasonably practicable.
|
| 265 |
+
|
| 266 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
| 267 |
+
License You apply must not prevent recipients of the Adapted
|
| 268 |
+
Material from complying with this Public License.
|
| 269 |
+
|
| 270 |
+
Section 4 -- Sui Generis Database Rights.
|
| 271 |
+
|
| 272 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
| 273 |
+
apply to Your use of the Licensed Material:
|
| 274 |
+
|
| 275 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
| 276 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
| 277 |
+
portion of the contents of the database for NonCommercial purposes
|
| 278 |
+
only;
|
| 279 |
+
|
| 280 |
+
b. if You include all or a substantial portion of the database
|
| 281 |
+
contents in a database in which You have Sui Generis Database
|
| 282 |
+
Rights, then the database in which You have Sui Generis Database
|
| 283 |
+
Rights (but not its individual contents) is Adapted Material; and
|
| 284 |
+
|
| 285 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
| 286 |
+
all or a substantial portion of the contents of the database.
|
| 287 |
+
|
| 288 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
| 289 |
+
replace Your obligations under this Public License where the Licensed
|
| 290 |
+
Rights include other Copyright and Similar Rights.
|
| 291 |
+
|
| 292 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
| 293 |
+
|
| 294 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
| 295 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
| 296 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
| 297 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
| 298 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
| 299 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
| 300 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
| 301 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
| 302 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
| 303 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
| 304 |
+
|
| 305 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
| 306 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
| 307 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
| 308 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
| 309 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
| 310 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
| 311 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
| 312 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
| 313 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
| 314 |
+
|
| 315 |
+
c. The disclaimer of warranties and limitation of liability provided
|
| 316 |
+
above shall be interpreted in a manner that, to the extent
|
| 317 |
+
possible, most closely approximates an absolute disclaimer and
|
| 318 |
+
waiver of all liability.
|
| 319 |
+
|
| 320 |
+
Section 6 -- Term and Termination.
|
| 321 |
+
|
| 322 |
+
a. This Public License applies for the term of the Copyright and
|
| 323 |
+
Similar Rights licensed here. However, if You fail to comply with
|
| 324 |
+
this Public License, then Your rights under this Public License
|
| 325 |
+
terminate automatically.
|
| 326 |
+
|
| 327 |
+
b. Where Your right to use the Licensed Material has terminated under
|
| 328 |
+
Section 6(a), it reinstates:
|
| 329 |
+
|
| 330 |
+
1. automatically as of the date the violation is cured, provided
|
| 331 |
+
it is cured within 30 days of Your discovery of the
|
| 332 |
+
violation; or
|
| 333 |
+
|
| 334 |
+
2. upon express reinstatement by the Licensor.
|
| 335 |
+
|
| 336 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
| 337 |
+
right the Licensor may have to seek remedies for Your violations
|
| 338 |
+
of this Public License.
|
| 339 |
+
|
| 340 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
| 341 |
+
Licensed Material under separate terms or conditions or stop
|
| 342 |
+
distributing the Licensed Material at any time; however, doing so
|
| 343 |
+
will not terminate this Public License.
|
| 344 |
+
|
| 345 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
| 346 |
+
License.
|
| 347 |
+
|
| 348 |
+
Section 7 -- Other Terms and Conditions.
|
| 349 |
+
|
| 350 |
+
a. The Licensor shall not be bound by any additional or different
|
| 351 |
+
terms or conditions communicated by You unless expressly agreed.
|
| 352 |
+
|
| 353 |
+
b. Any arrangements, understandings, or agreements regarding the
|
| 354 |
+
Licensed Material not stated herein are separate from and
|
| 355 |
+
independent of the terms and conditions of this Public License.
|
| 356 |
+
|
| 357 |
+
Section 8 -- Interpretation.
|
| 358 |
+
|
| 359 |
+
a. For the avoidance of doubt, this Public License does not, and
|
| 360 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
| 361 |
+
conditions on any use of the Licensed Material that could lawfully
|
| 362 |
+
be made without permission under this Public License.
|
| 363 |
+
|
| 364 |
+
b. To the extent possible, if any provision of this Public License is
|
| 365 |
+
deemed unenforceable, it shall be automatically reformed to the
|
| 366 |
+
minimum extent necessary to make it enforceable. If the provision
|
| 367 |
+
cannot be reformed, it shall be severed from this Public License
|
| 368 |
+
without affecting the enforceability of the remaining terms and
|
| 369 |
+
conditions.
|
| 370 |
+
|
| 371 |
+
c. No term or condition of this Public License will be waived and no
|
| 372 |
+
failure to comply consented to unless expressly agreed to by the
|
| 373 |
+
Licensor.
|
| 374 |
+
|
| 375 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
| 376 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
| 377 |
+
that apply to the Licensor or You, including from the legal
|
| 378 |
+
processes of any jurisdiction or authority.
|
| 379 |
+
|
| 380 |
+
=======================================================================
|
| 381 |
+
|
| 382 |
+
Creative Commons is not a party to its public
|
| 383 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
| 384 |
+
its public licenses to material it publishes and in those instances
|
| 385 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
| 386 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
| 387 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
| 388 |
+
material is shared under a Creative Commons public license or as
|
| 389 |
+
otherwise permitted by the Creative Commons policies published at
|
| 390 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
| 391 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
| 392 |
+
of Creative Commons without its prior written consent including,
|
| 393 |
+
without limitation, in connection with any unauthorized modifications
|
| 394 |
+
to any of its public licenses or any other arrangements,
|
| 395 |
+
understandings, or agreements concerning use of licensed material. For
|
| 396 |
+
the avoidance of doubt, this paragraph does not form part of the
|
| 397 |
+
public licenses.
|
| 398 |
+
|
| 399 |
+
Creative Commons may be contacted at creativecommons.org.
|
src/third_party/MuQ/README.md
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# <img src="images/muq-logo.jpeg" alt="" height="24px"> MuQ & MuQ-MuLan
|
| 2 |
+
|
| 3 |
+
<div>
|
| 4 |
+
<a href='#'><img alt="Static Badge" src="https://img.shields.io/badge/Python-3.8%2B-blue?logo=python&logoColor=white"></a>
|
| 5 |
+
<a href='https://arxiv.org/abs/2501.01108'><img alt="Static Badge" src="https://img.shields.io/badge/arXiv-2501.01108-%23b31b1b?logo=arxiv&link=https%3A%2F%2Farxiv.org%2F"></a>
|
| 6 |
+
<a href='https://huggingface.co/OpenMuQ'><img alt="Static Badge" src="https://img.shields.io/badge/huggingface-OpenMuQ-%23FFD21E?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2FOpenMuQ"></a>
|
| 7 |
+
<a href='https://pytorch.org/'><img alt="Static Badge" src="https://img.shields.io/badge/framework-PyTorch-%23EE4C2C?logo=pytorch"></a>
|
| 8 |
+
<a href='https://pypi.org/project/muq'><img alt="Static Badge" src="https://img.shields.io/badge/pip%20install-muq-green?logo=PyPI&logoColor=white&link=https%3A%2F%2Fpypi.org%2Fproject%2Fmuq"></a>
|
| 9 |
+
</div>
|
| 10 |
+
|
| 11 |
+
This is the official repository for the paper *"**MuQ**: Self-Supervised **Mu**sic Representation Learning
|
| 12 |
+
with Mel Residual Vector **Q**uantization"*.
|
| 13 |
+
|
| 14 |
+
In this repo, the following models are released:
|
| 15 |
+
|
| 16 |
+
- **MuQ**: A large music foundation model pre-trained via Self-Supervised Learning (SSL), achieving SOTA in various MIR tasks.
|
| 17 |
+
- **MuQ-MuLan**: A music-text joint embedding model trained via contrastive learning, supporting both English and Chinese texts.
|
| 18 |
+
|
| 19 |
+
## Overview
|
| 20 |
+
|
| 21 |
+
We develop the **MuQ** for music SSL. MuQ applys our proposed Mel-RVQ as quantitative targets and achieves SOTA performance on many music understanding (or MIR) tasks.
|
| 22 |
+
|
| 23 |
+
We also construct the **MuQ-MuLan**, a CLIP-like model trained by contrastive learning, which jointly represents music and text into embeddings.
|
| 24 |
+
|
| 25 |
+
For more details, please refer to our [paper](https://arxiv.org/abs/2501.01108).
|
| 26 |
+
|
| 27 |
+
<div>
|
| 28 |
+
<img src="images/radar.jpg" width="45%" alt="Evaluation on MARBLE Benchmark">
|
| 29 |
+
<img src="images/tagging.jpg" width="45%" alt="Evaluation on Zero-shot Music Tagging">
|
| 30 |
+
</div>
|
| 31 |
+
|
| 32 |
+
## Usage
|
| 33 |
+
|
| 34 |
+
To begin with, please use pip to install the official `muq` lib, and ensure that your `python>=3.8`:
|
| 35 |
+
```bash
|
| 36 |
+
pip3 install muq
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
To extract music audio features using **MuQ**, you can refer to the following code:
|
| 41 |
+
```python
|
| 42 |
+
import torch, librosa
|
| 43 |
+
from muq import MuQ
|
| 44 |
+
|
| 45 |
+
device = 'cuda'
|
| 46 |
+
wav, sr = librosa.load("path/to/music_audio.wav", sr = 24000)
|
| 47 |
+
wavs = torch.tensor(wav).unsqueeze(0).to(device)
|
| 48 |
+
|
| 49 |
+
# This will automatically fetch the checkpoint from huggingface
|
| 50 |
+
muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
|
| 51 |
+
muq = muq.to(device).eval()
|
| 52 |
+
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
output = muq(wavs, output_hidden_states=True)
|
| 55 |
+
|
| 56 |
+
print('Total number of layers: ', len(output.hidden_states))
|
| 57 |
+
print('Feature shape: ', output.last_hidden_state.shape)
|
| 58 |
+
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
Using **MuQ-MuLan** to extract the music and text embeddings and calculate the similarity:
|
| 62 |
+
```python
|
| 63 |
+
import torch, librosa
|
| 64 |
+
from muq import MuQMuLan
|
| 65 |
+
|
| 66 |
+
# This will automatically fetch checkpoints from huggingface
|
| 67 |
+
device = 'cuda'
|
| 68 |
+
mulan = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large")
|
| 69 |
+
mulan = mulan.to(device).eval()
|
| 70 |
+
|
| 71 |
+
# Extract music embeddings
|
| 72 |
+
wav, sr = librosa.load("path/to/music_audio.wav", sr = 24000)
|
| 73 |
+
wavs = torch.tensor(wav).unsqueeze(0).to(device)
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
audio_embeds = mulan(wavs = wavs)
|
| 76 |
+
|
| 77 |
+
# Extract text embeddings (texts can be in English or Chinese)
|
| 78 |
+
texts = ["classical genres, hopeful mood, piano.", "一首适合海边风景的小提琴曲,节奏欢快"]
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
text_embeds = mulan(texts = texts)
|
| 81 |
+
|
| 82 |
+
# Calculate dot product similarity
|
| 83 |
+
sim = mulan.calc_similarity(audio_embeds, text_embeds)
|
| 84 |
+
print(sim)
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
> Note that both MuQ and MuQ-MuLan strictly require **24 kHz** audio as input.
|
| 88 |
+
> We recommend using **fp32** during MuQ inference to avoid potential NaN issues.
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
## Performance
|
| 92 |
+
|
| 93 |
+
<img src="images/tab-marble.jpg" width="100%" style="max-width: 800px" alt="Table MARBLE Benchmark">
|
| 94 |
+
<img src="images/tab-mulan.png" width="50%" style="max-width: 400px; margin: 0 25%" alt="Table Mulan Results">
|
| 95 |
+
|
| 96 |
+
## Model Checkpoints
|
| 97 |
+
|
| 98 |
+
| Model Name | Parameters | Data | HuggingFace🤗 |
|
| 99 |
+
| ----------- | --- | --- | ----------- |
|
| 100 |
+
| MuQ | ~300M | MSD dataset | [OpenMuQ/MuQ-large-msd-iter](https://huggingface.co/OpenMuQ/MuQ-large-msd-iter) |
|
| 101 |
+
| MuQ-MuLan | ~700M | music-text pairs | [OpenMuQ/MuQ-MuLan-large](https://huggingface.co/OpenMuQ/MuQ-MuLan-large) |
|
| 102 |
+
|
| 103 |
+
**Note**: Please note that the open-sourced MuQ was trained on the Million Song Dataset. Due to differences in dataset size, the open-sourced model may not achieve the same level of performance as reported in the paper. The training recipes can be found [here](./src/recipes).
|
| 104 |
+
|
| 105 |
+
## License
|
| 106 |
+
|
| 107 |
+
The code in this repository is released under the MIT license as found in the [LICENSE](LICENSE) file.
|
| 108 |
+
|
| 109 |
+
The model weights (MuQ-large-msd-iter, MuQ-MuLan-large) in this repository are released under the CC-BY-NC 4.0 license, as detailed in the [LICENSE_weights](LICENSE_weights) file.
|
| 110 |
+
|
| 111 |
+
## Citation
|
| 112 |
+
|
| 113 |
+
```
|
| 114 |
+
@article{zhu2025muq,
|
| 115 |
+
title={MuQ: Self-Supervised Music Representation Learning with Mel Residual Vector Quantization},
|
| 116 |
+
author={Haina Zhu and Yizhi Zhou and Hangting Chen and Jianwei Yu and Ziyang Ma and Rongzhi Gu and Yi Luo and Wei Tan and Xie Chen},
|
| 117 |
+
journal={arXiv preprint arXiv:2501.01108},
|
| 118 |
+
year={2025}
|
| 119 |
+
}
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Acknowledgement
|
| 123 |
+
|
| 124 |
+
We borrow many codes from the following repositories:
|
| 125 |
+
- [lucidrains/musiclm-pytorch](https://github.com/lucidrains/musiclm-pytorch)
|
| 126 |
+
- [minzwon/musicfm](https://github.com/minzwon/musicfm)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
Also, we are especially grateful to the awesome [MARBLE-Benchmark](https://github.com/a43992899/MARBLE-Benchmark).
|
src/third_party/MuQ/images/muq-logo.jpeg
ADDED
|
src/third_party/MuQ/images/radar.jpg
ADDED
|
Git LFS Details
|
src/third_party/MuQ/images/tab-marble.jpg
ADDED
|
Git LFS Details
|
src/third_party/MuQ/images/tab-mulan.png
ADDED
|
Git LFS Details
|
src/third_party/MuQ/images/tagging.jpg
ADDED
|
Git LFS Details
|
src/third_party/MuQ/requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
einops
|
| 2 |
+
librosa
|
| 3 |
+
nnAudio
|
| 4 |
+
numpy
|
| 5 |
+
soundfile
|
| 6 |
+
torch
|
| 7 |
+
torchaudio
|
| 8 |
+
tqdm
|
| 9 |
+
transformers
|
| 10 |
+
easydict
|
| 11 |
+
x_clip
|
src/third_party/MuQ/setup.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name='muq', # Name of the package
|
| 5 |
+
version='0.1.0', # Version of the package
|
| 6 |
+
packages=find_packages(where='src'), # Automatically discover packages under the 'src' directory
|
| 7 |
+
package_dir={'': 'src'}, # Specify the root directory for packages as 'src'
|
| 8 |
+
include_package_data=True, # Include additional files, such as static files
|
| 9 |
+
install_requires=[ # List of dependencies
|
| 10 |
+
"einops",
|
| 11 |
+
"librosa",
|
| 12 |
+
"nnAudio",
|
| 13 |
+
"numpy",
|
| 14 |
+
"soundfile",
|
| 15 |
+
"torch",
|
| 16 |
+
"torchaudio",
|
| 17 |
+
"tqdm",
|
| 18 |
+
"transformers",
|
| 19 |
+
"easydict",
|
| 20 |
+
"x_clip",
|
| 21 |
+
],
|
| 22 |
+
author='Haina Zhu', # Author name
|
| 23 |
+
author_email='juhayna@qq.com', # Author email address
|
| 24 |
+
description='MuQ: A deep learning model for music and text', # Short description of the package
|
| 25 |
+
long_description=open('README.md', encoding='utf-8').read(), # Long description from the README file
|
| 26 |
+
long_description_content_type='text/markdown', # Format of the long description (Markdown)
|
| 27 |
+
url='https://github.com/tencent-ailab/MuQ', # Project URL
|
| 28 |
+
classifiers=[
|
| 29 |
+
'Programming Language :: Python :: 3', # Python 3 support
|
| 30 |
+
'License :: OSI Approved :: MIT License', # License type
|
| 31 |
+
'Operating System :: OS Independent', # Supports all operating systems
|
| 32 |
+
],
|
| 33 |
+
python_requires='>=3.8', # Supported Python version
|
| 34 |
+
)
|
src/third_party/MuQ/src/muq/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .muq import MuQ, MuQConfig
|
| 2 |
+
from .muq_mulan import MuQMuLan, MuQMuLanConfig
|
src/third_party/MuQ/src/muq/muq/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .muq import MuQConfig, MuQ
|
src/third_party/MuQ/src/muq/muq/models/__init__.py
ADDED
|
File without changes
|
src/third_party/MuQ/src/muq/muq/models/muq_model.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
import os
|
| 7 |
+
from easydict import EasyDict
|
| 8 |
+
|
| 9 |
+
from ..modules.random_quantizer import RandomProjectionQuantizer
|
| 10 |
+
from ..modules.features import MelSTFT
|
| 11 |
+
from ..modules.conv import Conv2dSubsampling
|
| 12 |
+
|
| 13 |
+
class MuQModel(nn.Module):
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
num_codebooks=1,
|
| 18 |
+
codebook_dim=16,
|
| 19 |
+
codebook_size=4096,
|
| 20 |
+
features=["melspec_2048"],
|
| 21 |
+
hop_length=240,
|
| 22 |
+
n_mels=128,
|
| 23 |
+
conv_dim=512,
|
| 24 |
+
encoder_dim=1024,
|
| 25 |
+
encoder_depth=12,
|
| 26 |
+
mask_hop=0.4,
|
| 27 |
+
mask_prob=0.6,
|
| 28 |
+
is_flash=False,
|
| 29 |
+
stat=dict(),
|
| 30 |
+
w2v2_config=dict(),
|
| 31 |
+
use_rvq_target=False,
|
| 32 |
+
use_vq_target=False,
|
| 33 |
+
use_encodec_target=False,
|
| 34 |
+
rvq_ckpt_path=None,
|
| 35 |
+
recon_loss_ratio=None,
|
| 36 |
+
label_rate=25,
|
| 37 |
+
rvq_n_codebooks=8,
|
| 38 |
+
rvq_multi_layer_num=1,
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
# global variables
|
| 43 |
+
self.hop_length = hop_length
|
| 44 |
+
self.mask_hop = mask_hop
|
| 45 |
+
self.mask_prob = mask_prob
|
| 46 |
+
self.num_codebooks = num_codebooks
|
| 47 |
+
self.codebook_size = codebook_size
|
| 48 |
+
self.features = features
|
| 49 |
+
self.recon_loss_ratio = recon_loss_ratio
|
| 50 |
+
self.n_fold = int(100//label_rate)
|
| 51 |
+
self.label_rate = label_rate
|
| 52 |
+
|
| 53 |
+
# load feature mean / std stats
|
| 54 |
+
self.stat = stat
|
| 55 |
+
|
| 56 |
+
# feature extractor
|
| 57 |
+
self.preprocessor_melspec_2048 = MelSTFT(
|
| 58 |
+
n_fft=2048, hop_length=hop_length, is_db=True
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# random quantizer
|
| 62 |
+
self.use_rvq_target = use_rvq_target
|
| 63 |
+
self.use_vq_target = use_vq_target
|
| 64 |
+
self.use_encodec_target = use_encodec_target
|
| 65 |
+
|
| 66 |
+
seed = 142
|
| 67 |
+
if self.use_rvq_like_target:
|
| 68 |
+
if use_rvq_target:
|
| 69 |
+
from ..modules.rvq import ResidualVectorQuantize
|
| 70 |
+
|
| 71 |
+
inp_dim = 128*self.n_fold
|
| 72 |
+
self.rvq = ResidualVectorQuantize(
|
| 73 |
+
input_dim = inp_dim,
|
| 74 |
+
n_codebooks = rvq_n_codebooks,
|
| 75 |
+
codebook_size = 1024,
|
| 76 |
+
codebook_dim = 16,
|
| 77 |
+
quantizer_dropout = 0.0,
|
| 78 |
+
use_multi_layer_num = rvq_multi_layer_num,
|
| 79 |
+
)
|
| 80 |
+
elif use_vq_target:
|
| 81 |
+
from ..modules.rvq import VectorQuantize
|
| 82 |
+
|
| 83 |
+
self.rvq = VectorQuantize(
|
| 84 |
+
input_dim = 128*self.n_fold,
|
| 85 |
+
codebook_size = 1024,
|
| 86 |
+
codebook_dim = 8,
|
| 87 |
+
stale_tolerance = 1000,
|
| 88 |
+
mfcc_clustering = False
|
| 89 |
+
)
|
| 90 |
+
elif use_encodec_target:
|
| 91 |
+
from encodec import EncodecModel
|
| 92 |
+
self.rvq = EncodecModel.encodec_model_24khz()
|
| 93 |
+
self.rvq.set_target_bandwidth(6.0)
|
| 94 |
+
for param in self.rvq.parameters():
|
| 95 |
+
param.requires_grad = False
|
| 96 |
+
|
| 97 |
+
if rvq_ckpt_path is not None and os.path.exists(rvq_ckpt_path):
|
| 98 |
+
state_dict = torch.load(rvq_ckpt_path, map_location="cpu")
|
| 99 |
+
self.rvq.load_state_dict(state_dict)
|
| 100 |
+
else:
|
| 101 |
+
pass
|
| 102 |
+
# print(f'Checkpoint for rvq `{rvq_ckpt_path}` not found. Using random initialization.')
|
| 103 |
+
else:
|
| 104 |
+
for feature in self.features:
|
| 105 |
+
for i in range(num_codebooks):
|
| 106 |
+
setattr(
|
| 107 |
+
self,
|
| 108 |
+
f"quantizer_{feature}", # _{i}
|
| 109 |
+
RandomProjectionQuantizer(
|
| 110 |
+
n_mels * self.n_fold, codebook_dim, codebook_size, seed=seed + i
|
| 111 |
+
),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# two residual convolution layers + one projection layer
|
| 115 |
+
strides_factory = {
|
| 116 |
+
4: [2, 2],
|
| 117 |
+
2: [2, 1]
|
| 118 |
+
}
|
| 119 |
+
self.conv = Conv2dSubsampling(
|
| 120 |
+
1, conv_dim, encoder_dim, strides=strides_factory.get(self.n_fold), n_bands=n_mels
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Conformer
|
| 124 |
+
if is_flash:
|
| 125 |
+
from modules.flash_conformer import (
|
| 126 |
+
Wav2Vec2ConformerEncoder,
|
| 127 |
+
Wav2Vec2ConformerConfig,
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
|
| 131 |
+
Wav2Vec2ConformerEncoder,
|
| 132 |
+
Wav2Vec2ConformerConfig,
|
| 133 |
+
)
|
| 134 |
+
config = EasyDict(w2v2_config)
|
| 135 |
+
config.num_hidden_layers = encoder_depth
|
| 136 |
+
config.hidden_size = encoder_dim
|
| 137 |
+
|
| 138 |
+
self.conformer = Wav2Vec2ConformerEncoder(config)
|
| 139 |
+
|
| 140 |
+
self.linear = nn.Linear(encoder_dim, codebook_size) # projection layer
|
| 141 |
+
|
| 142 |
+
# reconstruct melspec
|
| 143 |
+
if self.recon_loss_ratio is not None and self.recon_loss_ratio > 0:
|
| 144 |
+
self.recon_proj = nn.Linear(encoder_dim, n_mels * self.n_fold)
|
| 145 |
+
self.recon_loss = nn.MSELoss()
|
| 146 |
+
|
| 147 |
+
# loss function
|
| 148 |
+
self.loss = nn.CrossEntropyLoss()
|
| 149 |
+
|
| 150 |
+
# cls token (used for sequence classification)
|
| 151 |
+
random.seed(seed)
|
| 152 |
+
self.cls_token = nn.Parameter(torch.randn(encoder_dim))
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def use_rvq_like_target(self):
|
| 157 |
+
return self.use_rvq_target or self.use_vq_target or self.use_encodec_target
|
| 158 |
+
|
| 159 |
+
def masking(self, x, attention_mask=None):
|
| 160 |
+
"""random masking of 400ms with given probability"""
|
| 161 |
+
mx = x.clone()
|
| 162 |
+
b, t = mx.shape
|
| 163 |
+
len_masking_raw = int(24000 * self.mask_hop)
|
| 164 |
+
len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop)
|
| 165 |
+
|
| 166 |
+
# get random mask indices
|
| 167 |
+
start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
|
| 168 |
+
time_domain_masked_indices = torch.nonzero(
|
| 169 |
+
start_indices.repeat_interleave(len_masking_raw, dim=1)
|
| 170 |
+
)
|
| 171 |
+
token_domain_masked_indices = torch.nonzero(
|
| 172 |
+
start_indices.repeat_interleave(len_masking_token, dim=1)
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# mask with random values
|
| 176 |
+
masking_noise = (
|
| 177 |
+
torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
|
| 178 |
+
) # 0 mean 0.1 std
|
| 179 |
+
mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
|
| 180 |
+
|
| 181 |
+
return mx, token_domain_masked_indices
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@torch.no_grad()
|
| 185 |
+
def preprocessing(self, x, features):
|
| 186 |
+
"""extract classic audio features"""
|
| 187 |
+
# check precision
|
| 188 |
+
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
| 189 |
+
precision = 16
|
| 190 |
+
else:
|
| 191 |
+
precision = 32
|
| 192 |
+
|
| 193 |
+
out = {}
|
| 194 |
+
for key in features:
|
| 195 |
+
layer = getattr(self, "preprocessor_%s" % key)
|
| 196 |
+
layer.to(x.device)
|
| 197 |
+
dtype = x.dtype
|
| 198 |
+
out[key] = layer(x.float())[..., :-1]
|
| 199 |
+
if precision == 16:
|
| 200 |
+
out[key] = out[key].half()
|
| 201 |
+
if out[key].dtype != dtype:
|
| 202 |
+
out[key].to(dtype=dtype)
|
| 203 |
+
return out
|
| 204 |
+
|
| 205 |
+
def encoder(self, x, *, attention_mask=None, is_features_only=False):
|
| 206 |
+
"""2-layer conv + w2v-conformer"""
|
| 207 |
+
x = self.conv(x)
|
| 208 |
+
mask_indices = None
|
| 209 |
+
if attention_mask is None:
|
| 210 |
+
out = self.conformer(x, output_hidden_states=True)
|
| 211 |
+
else:
|
| 212 |
+
attention_mask = attention_mask.bool()
|
| 213 |
+
skip_n = int(attention_mask.size(-1) / x.size(1))
|
| 214 |
+
attention_mask = attention_mask[:, ::skip_n]
|
| 215 |
+
attention_mask = attention_mask[:, :x.size(1)]
|
| 216 |
+
out = self.conformer(x, attention_mask=attention_mask, output_hidden_states=True)
|
| 217 |
+
hidden_emb = out["hidden_states"]
|
| 218 |
+
last_emb = out["last_hidden_state"]
|
| 219 |
+
logits = self.linear(last_emb)
|
| 220 |
+
interval = self.codebook_size
|
| 221 |
+
logits = {
|
| 222 |
+
key: logits[:, :, i * interval : (i + 1) * interval]
|
| 223 |
+
for i, key in enumerate(self.features)
|
| 224 |
+
}
|
| 225 |
+
return logits, hidden_emb, mask_indices
|
| 226 |
+
|
| 227 |
+
@torch.no_grad()
|
| 228 |
+
def normalize(self, x):
|
| 229 |
+
"""normalize the input audio to have zero mean unit variance"""
|
| 230 |
+
for key in x.keys():
|
| 231 |
+
x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key]
|
| 232 |
+
return x
|
| 233 |
+
|
| 234 |
+
@torch.no_grad()
|
| 235 |
+
def rearrange(self, x):
|
| 236 |
+
"""rearrange the batch to flatten every 4 steps"""
|
| 237 |
+
for key in x.keys():
|
| 238 |
+
if key == "chromagram":
|
| 239 |
+
x[key] = rearrange(x[key], "b f t -> b t f")
|
| 240 |
+
else:
|
| 241 |
+
x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=self.n_fold)
|
| 242 |
+
return x
|
| 243 |
+
|
| 244 |
+
def get_rvq_codes(self, inp, raw_wav):
|
| 245 |
+
if self.use_rvq_target:
|
| 246 |
+
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(inp)
|
| 247 |
+
return codes
|
| 248 |
+
if self.use_vq_target:
|
| 249 |
+
quantized_prompt_embeds, commitment_loss, codebook_loss, codes, _ = self.rvq(inp)
|
| 250 |
+
return codes.unsqueeze(1)
|
| 251 |
+
if self.use_encodec_target:
|
| 252 |
+
encoded_frames = self.rvq.encode(raw_wav.unsqueeze(1)) #list, B,[ 8,T ]
|
| 253 |
+
codes = torch.cat([encoded[0].detach() for encoded in encoded_frames], dim=-1)
|
| 254 |
+
if self.label_rate == 25:
|
| 255 |
+
codes = codes[:, :, ::3]
|
| 256 |
+
return codes
|
| 257 |
+
|
| 258 |
+
@torch.no_grad()
|
| 259 |
+
def tokenize(self, x, raw_wav):
|
| 260 |
+
out = {}
|
| 261 |
+
for key in x.keys():
|
| 262 |
+
if self.use_rvq_like_target:
|
| 263 |
+
self.rvq.eval()
|
| 264 |
+
inp = x[key].permute((0, 2, 1))
|
| 265 |
+
codes = self.get_rvq_codes(inp, raw_wav)
|
| 266 |
+
out[key] = torch.cat([codes[:, idx, ...] for idx in range(int(self.codebook_size//1024))], dim=-1)
|
| 267 |
+
else:
|
| 268 |
+
layer = getattr(self, "quantizer_%s" % key)
|
| 269 |
+
out[key] = layer(x[key])
|
| 270 |
+
return out
|
| 271 |
+
|
| 272 |
+
def get_targets(self, x, label=None):
|
| 273 |
+
if self.use_encodec_target:
|
| 274 |
+
raw_x = x.clone()
|
| 275 |
+
else:
|
| 276 |
+
raw_x = None
|
| 277 |
+
x = self.preprocessing(x, features=self.features)
|
| 278 |
+
x = self.normalize(x)
|
| 279 |
+
x = self.rearrange(x)
|
| 280 |
+
melspec = x['melspec_2048']
|
| 281 |
+
if label is None:
|
| 282 |
+
# Use labels from Mel-RVQ
|
| 283 |
+
target_tokens = self.tokenize(x, raw_x)
|
| 284 |
+
else:
|
| 285 |
+
# Use labels pre-extracted for iteration training
|
| 286 |
+
target_tokens = {'melspec_2048': rearrange(label, "b n s -> b (n s)").long()}
|
| 287 |
+
return target_tokens, melspec
|
| 288 |
+
|
| 289 |
+
def get_predictions(self, x, *, mask=None, attention_mask=None, return_new_mask=False, is_features_only=False):
|
| 290 |
+
# preprocessing
|
| 291 |
+
x = self.preprocessing(x, features=["melspec_2048"])
|
| 292 |
+
x = self.normalize(x)
|
| 293 |
+
|
| 294 |
+
# encoding
|
| 295 |
+
logits, hidden_emb, new_mask = self.encoder(x["melspec_2048"], attention_mask=attention_mask, is_features_only=is_features_only)
|
| 296 |
+
|
| 297 |
+
if return_new_mask:
|
| 298 |
+
return logits, hidden_emb, mask if new_mask is None else new_mask
|
| 299 |
+
else:
|
| 300 |
+
return logits, hidden_emb
|
| 301 |
+
|
| 302 |
+
def get_latent(self, x, layer_ix=12):
|
| 303 |
+
_, hidden_states = self.get_predictions(x)
|
| 304 |
+
emb = hidden_states[layer_ix]
|
| 305 |
+
return emb
|
| 306 |
+
|
| 307 |
+
def compute_nce(self, x, pos, negs):
|
| 308 |
+
neg_is_pos = (pos == negs).all(-1)
|
| 309 |
+
pos = pos.unsqueeze(0)
|
| 310 |
+
targets = torch.cat([pos, negs], dim=0)
|
| 311 |
+
|
| 312 |
+
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
|
| 313 |
+
logits /= 0.1
|
| 314 |
+
if neg_is_pos.any():
|
| 315 |
+
logits[1:][neg_is_pos] = float("-inf")
|
| 316 |
+
logits = logits.transpose(0, 1)
|
| 317 |
+
return logits
|
| 318 |
+
|
| 319 |
+
def get_loss(self, logits, target_tokens, masked_indices):
|
| 320 |
+
losses = {}
|
| 321 |
+
accuracies = {}
|
| 322 |
+
for key in logits.keys():
|
| 323 |
+
if not self.use_rvq_like_target:
|
| 324 |
+
masked_logits = logits[key][tuple(masked_indices.t())]
|
| 325 |
+
masked_tokens = target_tokens[key][tuple(masked_indices.t())]
|
| 326 |
+
else:
|
| 327 |
+
Batch, SeqLen, N_Codebook_x_CodebookSize = logits[key].shape
|
| 328 |
+
Batch, N_Codebook_x_SeqLen = target_tokens[key].shape
|
| 329 |
+
N_Codebook = int(N_Codebook_x_SeqLen // SeqLen)
|
| 330 |
+
target_tokens[key] = rearrange(target_tokens[key], "b (n s) -> b s n", n=N_Codebook) # Batch, SeqLen=750, N_Codebook=4
|
| 331 |
+
masked_logits = logits[key][tuple(masked_indices.t())]
|
| 332 |
+
masked_tokens = target_tokens[key][tuple(masked_indices.t())]
|
| 333 |
+
masked_logits = rearrange(masked_logits, "b (n c) -> (b n) c", n=N_Codebook)
|
| 334 |
+
masked_tokens = rearrange(masked_tokens, "b n -> (b n)", n=N_Codebook)
|
| 335 |
+
|
| 336 |
+
losses[key] = self.loss(masked_logits, masked_tokens)
|
| 337 |
+
accuracies[key] = (
|
| 338 |
+
torch.sum(masked_logits.argmax(-1) == masked_tokens)
|
| 339 |
+
/ masked_tokens.numel()
|
| 340 |
+
)
|
| 341 |
+
return losses, accuracies
|
| 342 |
+
|
| 343 |
+
def get_recon_loss(self, last_hidden_emb, melspec, masked_indices):
|
| 344 |
+
pred_melspec = self.recon_proj(last_hidden_emb[tuple(masked_indices.t())])
|
| 345 |
+
target_melspec = melspec[tuple(masked_indices.t())]
|
| 346 |
+
recon_loss = self.recon_loss(pred_melspec, target_melspec)
|
| 347 |
+
return recon_loss
|
| 348 |
+
|
| 349 |
+
def forward(self, x, attention_mask=None, label=None):
|
| 350 |
+
dtype = x.dtype
|
| 351 |
+
# get target feature tokens
|
| 352 |
+
target_tokens, melspec = self.get_targets(x, label=label)
|
| 353 |
+
|
| 354 |
+
# masking
|
| 355 |
+
x, masked_indices = self.masking(x, attention_mask=attention_mask)
|
| 356 |
+
|
| 357 |
+
# forward
|
| 358 |
+
logits, hidden_emb, masked_indices = self.get_predictions(x, mask=masked_indices, attention_mask=attention_mask, return_new_mask=True)
|
| 359 |
+
|
| 360 |
+
# get loss
|
| 361 |
+
losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
|
| 362 |
+
|
| 363 |
+
if self.recon_loss_ratio:
|
| 364 |
+
losses["recon_loss"] = self.get_recon_loss(hidden_emb[-1], melspec, masked_indices) * self.recon_loss_ratio
|
| 365 |
+
|
| 366 |
+
return logits, hidden_emb, losses, accuracies
|
src/third_party/MuQ/src/muq/muq/modules/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
src/third_party/MuQ/src/muq/muq/modules/conv.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Res2dModule(nn.Module):
|
| 6 |
+
def __init__(self, idim, odim, stride=(2, 2)):
|
| 7 |
+
super(Res2dModule, self).__init__()
|
| 8 |
+
self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
|
| 9 |
+
self.bn1 = nn.BatchNorm2d(odim)
|
| 10 |
+
self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
|
| 11 |
+
self.bn2 = nn.BatchNorm2d(odim)
|
| 12 |
+
self.relu = nn.ReLU()
|
| 13 |
+
|
| 14 |
+
# residual
|
| 15 |
+
self.diff = False
|
| 16 |
+
if (idim != odim) or (stride[0] > 1):
|
| 17 |
+
self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
|
| 18 |
+
self.bn3 = nn.BatchNorm2d(odim)
|
| 19 |
+
self.diff = True
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
|
| 23 |
+
if self.diff:
|
| 24 |
+
x = self.bn3(self.conv3(x))
|
| 25 |
+
out = x + out
|
| 26 |
+
out = self.relu(out)
|
| 27 |
+
return out
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Conv2dSubsampling(nn.Module):
|
| 31 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
idim (int): Input dimension.
|
| 35 |
+
hdim (int): Hidden dimension.
|
| 36 |
+
odim (int): Output dimension.
|
| 37 |
+
strides (list): Sizes of strides.
|
| 38 |
+
n_bands (int): Number of frequency bands.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
|
| 42 |
+
"""Construct an Conv2dSubsampling object."""
|
| 43 |
+
super(Conv2dSubsampling, self).__init__()
|
| 44 |
+
|
| 45 |
+
self.conv = nn.Sequential(
|
| 46 |
+
Res2dModule(idim, hdim, (2, strides[0])),
|
| 47 |
+
Res2dModule(hdim, hdim, (2, strides[1])),
|
| 48 |
+
)
|
| 49 |
+
self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
"""Subsample x.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
x (torch.Tensor): Input tensor (#batch, idim, time).
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 59 |
+
where time' = time // 4.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
if x.dim() == 3:
|
| 63 |
+
x = x.unsqueeze(1) # (b, c, f, t)
|
| 64 |
+
x = self.conv(x)
|
| 65 |
+
x = rearrange(x, "b c f t -> b t (c f)")
|
| 66 |
+
x = self.linear(x)
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
if __name__ == '__main__':
|
| 70 |
+
import torch
|
| 71 |
+
conv_dim, encoder_dim = 512, 1024
|
| 72 |
+
conv = Conv2dSubsampling(
|
| 73 |
+
1, conv_dim, encoder_dim, strides=[2, 1], n_bands=128
|
| 74 |
+
)
|
| 75 |
+
inp = torch.randn((1, 128, 3000))
|
| 76 |
+
out = conv(inp)
|
| 77 |
+
print(out.shape)
|
src/third_party/MuQ/src/muq/muq/modules/features.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchaudio
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MelSTFT:
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
sample_rate=24000,
|
| 10 |
+
n_fft=2048,
|
| 11 |
+
hop_length=240,
|
| 12 |
+
n_mels=128,
|
| 13 |
+
is_db=False,
|
| 14 |
+
):
|
| 15 |
+
super(MelSTFT, self).__init__()
|
| 16 |
+
|
| 17 |
+
# spectrogram
|
| 18 |
+
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
| 19 |
+
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# amplitude to decibel
|
| 23 |
+
self.is_db = is_db
|
| 24 |
+
if is_db:
|
| 25 |
+
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
|
| 26 |
+
|
| 27 |
+
def __call__(self, waveform):
|
| 28 |
+
if self.is_db:
|
| 29 |
+
return self.amplitude_to_db(self.mel_stft(waveform))
|
| 30 |
+
else:
|
| 31 |
+
return self.mel_stft(waveform)
|
| 32 |
+
|
| 33 |
+
def to(self, device):
|
| 34 |
+
self.mel_stft = self.mel_stft.to(device)
|
| 35 |
+
if self.is_db:
|
| 36 |
+
self.amplitude_to_db = self.amplitude_to_db.to(device)
|
| 37 |
+
return self
|
src/third_party/MuQ/src/muq/muq/modules/flash_conformer.py
ADDED
|
@@ -0,0 +1,2114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" PyTorch Wav2Vec2-Conformer model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
+
from torch.nn import CrossEntropyLoss
|
| 26 |
+
from torch.nn import functional as F
|
| 27 |
+
|
| 28 |
+
from transformers.activations import ACT2FN
|
| 29 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
| 30 |
+
from transformers.modeling_outputs import (
|
| 31 |
+
BaseModelOutput,
|
| 32 |
+
CausalLMOutput,
|
| 33 |
+
SequenceClassifierOutput,
|
| 34 |
+
TokenClassifierOutput,
|
| 35 |
+
Wav2Vec2BaseModelOutput,
|
| 36 |
+
XVectorOutput,
|
| 37 |
+
)
|
| 38 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 39 |
+
from transformers.utils import (
|
| 40 |
+
ModelOutput,
|
| 41 |
+
add_code_sample_docstrings,
|
| 42 |
+
add_start_docstrings,
|
| 43 |
+
add_start_docstrings_to_model_forward,
|
| 44 |
+
logging,
|
| 45 |
+
replace_return_docstrings,
|
| 46 |
+
)
|
| 47 |
+
from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
_HIDDEN_STATES_START_POSITION = 2
|
| 54 |
+
|
| 55 |
+
# General docstring
|
| 56 |
+
_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
|
| 57 |
+
|
| 58 |
+
# Base docstring
|
| 59 |
+
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
|
| 60 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
|
| 61 |
+
|
| 62 |
+
# CTC docstring
|
| 63 |
+
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
| 64 |
+
_CTC_EXPECTED_LOSS = 64.21
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 68 |
+
"facebook/wav2vec2-conformer-rel-pos-large",
|
| 69 |
+
# See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
|
| 75 |
+
class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
|
| 76 |
+
"""
|
| 77 |
+
Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
| 81 |
+
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
|
| 82 |
+
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
|
| 83 |
+
projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
|
| 84 |
+
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
|
| 85 |
+
projected quantized states.
|
| 86 |
+
projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
|
| 87 |
+
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
|
| 88 |
+
target vectors for contrastive loss.
|
| 89 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 90 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 91 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
| 92 |
+
|
| 93 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 94 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 95 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 96 |
+
sequence_length)`.
|
| 97 |
+
|
| 98 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 99 |
+
heads.
|
| 100 |
+
contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
| 101 |
+
The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
|
| 102 |
+
diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
| 103 |
+
The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
loss: Optional[torch.FloatTensor] = None
|
| 107 |
+
projected_states: torch.FloatTensor = None
|
| 108 |
+
projected_quantized_states: torch.FloatTensor = None
|
| 109 |
+
codevector_perplexity: torch.FloatTensor = None
|
| 110 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 111 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 112 |
+
contrastive_loss: Optional[torch.FloatTensor] = None
|
| 113 |
+
diversity_loss: Optional[torch.FloatTensor] = None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
|
| 117 |
+
def _compute_mask_indices(
|
| 118 |
+
shape: Tuple[int, int],
|
| 119 |
+
mask_prob: float,
|
| 120 |
+
mask_length: int,
|
| 121 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 122 |
+
min_masks: int = 0,
|
| 123 |
+
) -> np.ndarray:
|
| 124 |
+
"""
|
| 125 |
+
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
|
| 126 |
+
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
|
| 127 |
+
CPU as part of the preprocessing during training.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
|
| 131 |
+
the first element is the batch size and the second element is the length of the axis to span.
|
| 132 |
+
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
|
| 133 |
+
independently generated mask spans of length `mask_length` is computed by
|
| 134 |
+
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
|
| 135 |
+
actual percentage will be smaller.
|
| 136 |
+
mask_length: size of the mask
|
| 137 |
+
min_masks: minimum number of masked spans
|
| 138 |
+
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
|
| 139 |
+
each batch dimension.
|
| 140 |
+
"""
|
| 141 |
+
batch_size, sequence_length = shape
|
| 142 |
+
|
| 143 |
+
if mask_length < 1:
|
| 144 |
+
raise ValueError("`mask_length` has to be bigger than 0.")
|
| 145 |
+
|
| 146 |
+
if mask_length > sequence_length:
|
| 147 |
+
raise ValueError(
|
| 148 |
+
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
|
| 149 |
+
f" and `sequence_length`: {sequence_length}`"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# epsilon is used for probabilistic rounding
|
| 153 |
+
epsilon = np.random.rand(1).item()
|
| 154 |
+
|
| 155 |
+
def compute_num_masked_span(input_length):
|
| 156 |
+
"""Given input length, compute how many spans should be masked"""
|
| 157 |
+
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
| 158 |
+
num_masked_span = max(num_masked_span, min_masks)
|
| 159 |
+
|
| 160 |
+
# make sure num masked span <= sequence_length
|
| 161 |
+
if num_masked_span * mask_length > sequence_length:
|
| 162 |
+
num_masked_span = sequence_length // mask_length
|
| 163 |
+
|
| 164 |
+
# make sure num_masked span is also <= input_length - (mask_length - 1)
|
| 165 |
+
if input_length - (mask_length - 1) < num_masked_span:
|
| 166 |
+
num_masked_span = max(input_length - (mask_length - 1), 0)
|
| 167 |
+
|
| 168 |
+
return num_masked_span
|
| 169 |
+
|
| 170 |
+
# compute number of masked spans in batch
|
| 171 |
+
input_lengths = (
|
| 172 |
+
attention_mask.sum(-1).detach().tolist()
|
| 173 |
+
if attention_mask is not None
|
| 174 |
+
else [sequence_length for _ in range(batch_size)]
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# SpecAugment mask to fill
|
| 178 |
+
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
| 179 |
+
spec_aug_mask_idxs = []
|
| 180 |
+
|
| 181 |
+
max_num_masked_span = compute_num_masked_span(sequence_length)
|
| 182 |
+
|
| 183 |
+
if max_num_masked_span == 0:
|
| 184 |
+
return spec_aug_mask
|
| 185 |
+
|
| 186 |
+
for input_length in input_lengths:
|
| 187 |
+
# compute num of masked spans for this input
|
| 188 |
+
num_masked_span = compute_num_masked_span(input_length)
|
| 189 |
+
|
| 190 |
+
# get random indices to mask
|
| 191 |
+
spec_aug_mask_idx = np.random.choice(
|
| 192 |
+
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# pick first sampled index that will serve as a dummy index to pad vector
|
| 196 |
+
# to ensure same dimension for all batches due to probabilistic rounding
|
| 197 |
+
# Picking first sample just pads those vectors twice.
|
| 198 |
+
if len(spec_aug_mask_idx) == 0:
|
| 199 |
+
# this case can only happen if `input_length` is strictly smaller then
|
| 200 |
+
# `sequence_length` in which case the last token has to be a padding
|
| 201 |
+
# token which we can use as a dummy mask id
|
| 202 |
+
dummy_mask_idx = sequence_length - 1
|
| 203 |
+
else:
|
| 204 |
+
dummy_mask_idx = spec_aug_mask_idx[0]
|
| 205 |
+
|
| 206 |
+
spec_aug_mask_idx = np.concatenate(
|
| 207 |
+
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
| 208 |
+
)
|
| 209 |
+
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
| 210 |
+
|
| 211 |
+
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
| 212 |
+
|
| 213 |
+
# expand masked indices to masked spans
|
| 214 |
+
spec_aug_mask_idxs = np.broadcast_to(
|
| 215 |
+
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
| 216 |
+
)
|
| 217 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
| 218 |
+
|
| 219 |
+
# add offset to the starting indexes so that indexes now create a span
|
| 220 |
+
offsets = np.arange(mask_length)[None, None, :]
|
| 221 |
+
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
| 222 |
+
batch_size, max_num_masked_span * mask_length
|
| 223 |
+
)
|
| 224 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
| 225 |
+
|
| 226 |
+
# ensure that we cannot have indices larger than sequence_length
|
| 227 |
+
if spec_aug_mask_idxs.max() > sequence_length - 1:
|
| 228 |
+
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
|
| 229 |
+
|
| 230 |
+
# scatter indices to mask
|
| 231 |
+
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
| 232 |
+
|
| 233 |
+
return spec_aug_mask
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
|
| 237 |
+
def _sample_negative_indices(
|
| 238 |
+
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
|
| 239 |
+
):
|
| 240 |
+
"""
|
| 241 |
+
Sample `num_negatives` vectors from feature vectors.
|
| 242 |
+
"""
|
| 243 |
+
batch_size, sequence_length = features_shape
|
| 244 |
+
|
| 245 |
+
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
| 246 |
+
sequence_length_range = np.arange(sequence_length)
|
| 247 |
+
|
| 248 |
+
# get `num_negatives` random vector indices from the same utterance
|
| 249 |
+
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
|
| 250 |
+
|
| 251 |
+
mask_time_indices = (
|
| 252 |
+
mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
for batch_idx in range(batch_size):
|
| 256 |
+
high = mask_time_indices[batch_idx].sum() - 1
|
| 257 |
+
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
|
| 258 |
+
|
| 259 |
+
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
|
| 260 |
+
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
|
| 261 |
+
# avoid sampling the same positive vector, but keep the distribution uniform
|
| 262 |
+
sampled_indices[sampled_indices >= feature_indices] += 1
|
| 263 |
+
|
| 264 |
+
# remap to actual indices
|
| 265 |
+
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
|
| 266 |
+
|
| 267 |
+
# correct for batch size
|
| 268 |
+
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
|
| 269 |
+
|
| 270 |
+
return sampled_negative_indices
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 274 |
+
class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
|
| 275 |
+
def __init__(self, config, layer_id=0):
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 278 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 279 |
+
|
| 280 |
+
self.conv = nn.Conv1d(
|
| 281 |
+
self.in_conv_dim,
|
| 282 |
+
self.out_conv_dim,
|
| 283 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 284 |
+
stride=config.conv_stride[layer_id],
|
| 285 |
+
bias=config.conv_bias,
|
| 286 |
+
)
|
| 287 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 288 |
+
|
| 289 |
+
def forward(self, hidden_states):
|
| 290 |
+
hidden_states = self.conv(hidden_states)
|
| 291 |
+
hidden_states = self.activation(hidden_states)
|
| 292 |
+
return hidden_states
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 296 |
+
class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
|
| 297 |
+
def __init__(self, config, layer_id=0):
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 300 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 301 |
+
|
| 302 |
+
self.conv = nn.Conv1d(
|
| 303 |
+
self.in_conv_dim,
|
| 304 |
+
self.out_conv_dim,
|
| 305 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 306 |
+
stride=config.conv_stride[layer_id],
|
| 307 |
+
bias=config.conv_bias,
|
| 308 |
+
)
|
| 309 |
+
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
| 310 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 311 |
+
|
| 312 |
+
def forward(self, hidden_states):
|
| 313 |
+
hidden_states = self.conv(hidden_states)
|
| 314 |
+
|
| 315 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 316 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 317 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 318 |
+
|
| 319 |
+
hidden_states = self.activation(hidden_states)
|
| 320 |
+
return hidden_states
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 324 |
+
class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
|
| 325 |
+
def __init__(self, config, layer_id=0):
|
| 326 |
+
super().__init__()
|
| 327 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 328 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 329 |
+
|
| 330 |
+
self.conv = nn.Conv1d(
|
| 331 |
+
self.in_conv_dim,
|
| 332 |
+
self.out_conv_dim,
|
| 333 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 334 |
+
stride=config.conv_stride[layer_id],
|
| 335 |
+
bias=config.conv_bias,
|
| 336 |
+
)
|
| 337 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 338 |
+
|
| 339 |
+
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
|
| 340 |
+
|
| 341 |
+
def forward(self, hidden_states):
|
| 342 |
+
hidden_states = self.conv(hidden_states)
|
| 343 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 344 |
+
hidden_states = self.activation(hidden_states)
|
| 345 |
+
return hidden_states
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
|
| 349 |
+
class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
|
| 350 |
+
def __init__(self, config):
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.conv = nn.Conv1d(
|
| 353 |
+
config.hidden_size,
|
| 354 |
+
config.hidden_size,
|
| 355 |
+
kernel_size=config.num_conv_pos_embeddings,
|
| 356 |
+
padding=config.num_conv_pos_embeddings // 2,
|
| 357 |
+
groups=config.num_conv_pos_embedding_groups,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
if is_deepspeed_zero3_enabled():
|
| 361 |
+
import deepspeed
|
| 362 |
+
|
| 363 |
+
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
| 364 |
+
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
| 365 |
+
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
| 366 |
+
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
| 367 |
+
else:
|
| 368 |
+
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
| 369 |
+
|
| 370 |
+
self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
|
| 371 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 372 |
+
|
| 373 |
+
def forward(self, hidden_states):
|
| 374 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 375 |
+
|
| 376 |
+
hidden_states = self.conv(hidden_states)
|
| 377 |
+
hidden_states = self.padding(hidden_states)
|
| 378 |
+
hidden_states = self.activation(hidden_states)
|
| 379 |
+
|
| 380 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 381 |
+
return hidden_states
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
|
| 385 |
+
"""Rotary positional embedding
|
| 386 |
+
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
def __init__(self, config):
|
| 390 |
+
super().__init__()
|
| 391 |
+
dim = config.hidden_size // config.num_attention_heads
|
| 392 |
+
base = config.rotary_embedding_base
|
| 393 |
+
|
| 394 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 395 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 396 |
+
self.cached_sequence_length = None
|
| 397 |
+
self.cached_rotary_positional_embedding = None
|
| 398 |
+
|
| 399 |
+
def forward(self, hidden_states):
|
| 400 |
+
sequence_length = hidden_states.shape[1]
|
| 401 |
+
|
| 402 |
+
if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
|
| 403 |
+
return self.cached_rotary_positional_embedding
|
| 404 |
+
|
| 405 |
+
self.cached_sequence_length = sequence_length
|
| 406 |
+
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
|
| 407 |
+
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
|
| 408 |
+
embeddings = torch.cat((freqs, freqs), dim=-1)
|
| 409 |
+
|
| 410 |
+
cos_embeddings = embeddings.cos()[:, None, None, :]
|
| 411 |
+
sin_embeddings = embeddings.sin()[:, None, None, :]
|
| 412 |
+
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
|
| 413 |
+
return self.cached_rotary_positional_embedding
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
|
| 417 |
+
"""Relative positional encoding module."""
|
| 418 |
+
|
| 419 |
+
def __init__(self, config):
|
| 420 |
+
super().__init__()
|
| 421 |
+
self.max_len = config.max_source_positions
|
| 422 |
+
self.d_model = config.hidden_size
|
| 423 |
+
self.pe = None
|
| 424 |
+
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
|
| 425 |
+
|
| 426 |
+
def extend_pe(self, x):
|
| 427 |
+
# Reset the positional encodings
|
| 428 |
+
if self.pe is not None:
|
| 429 |
+
# self.pe contains both positive and negative parts
|
| 430 |
+
# the length of self.pe is 2 * input_len - 1
|
| 431 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
| 432 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 433 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 434 |
+
return
|
| 435 |
+
# Suppose `i` is the position of query vector and `j` is the
|
| 436 |
+
# position of key vector. We use positive relative positions when keys
|
| 437 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
| 438 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
| 439 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
| 440 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 441 |
+
div_term = torch.exp(
|
| 442 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
|
| 443 |
+
)
|
| 444 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
| 445 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
| 446 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
| 447 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
| 448 |
+
|
| 449 |
+
# Reverse the order of positive indices and concat both positive and
|
| 450 |
+
# negative indices. This is used to support the shifting trick
|
| 451 |
+
# as in https://arxiv.org/abs/1901.02860
|
| 452 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
| 453 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
| 454 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
| 455 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 456 |
+
|
| 457 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 458 |
+
self.extend_pe(hidden_states)
|
| 459 |
+
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
|
| 460 |
+
end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
|
| 461 |
+
relative_position_embeddings = self.pe[:, start_idx:end_idx]
|
| 462 |
+
|
| 463 |
+
return relative_position_embeddings
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 467 |
+
class Wav2Vec2ConformerSamePadLayer(nn.Module):
|
| 468 |
+
def __init__(self, num_conv_pos_embeddings):
|
| 469 |
+
super().__init__()
|
| 470 |
+
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
|
| 471 |
+
|
| 472 |
+
def forward(self, hidden_states):
|
| 473 |
+
if self.num_pad_remove > 0:
|
| 474 |
+
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
|
| 475 |
+
return hidden_states
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
|
| 479 |
+
class Wav2Vec2ConformerFeatureEncoder(nn.Module):
|
| 480 |
+
"""Construct the features from raw audio waveform"""
|
| 481 |
+
|
| 482 |
+
def __init__(self, config):
|
| 483 |
+
super().__init__()
|
| 484 |
+
|
| 485 |
+
if config.feat_extract_norm == "group":
|
| 486 |
+
conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
|
| 487 |
+
Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
|
| 488 |
+
for i in range(config.num_feat_extract_layers - 1)
|
| 489 |
+
]
|
| 490 |
+
elif config.feat_extract_norm == "layer":
|
| 491 |
+
conv_layers = [
|
| 492 |
+
Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
|
| 493 |
+
]
|
| 494 |
+
else:
|
| 495 |
+
raise ValueError(
|
| 496 |
+
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
| 497 |
+
)
|
| 498 |
+
self.conv_layers = nn.ModuleList(conv_layers)
|
| 499 |
+
self.gradient_checkpointing = False
|
| 500 |
+
self._requires_grad = True
|
| 501 |
+
|
| 502 |
+
def _freeze_parameters(self):
|
| 503 |
+
for param in self.parameters():
|
| 504 |
+
param.requires_grad = False
|
| 505 |
+
self._requires_grad = False
|
| 506 |
+
|
| 507 |
+
def forward(self, input_values):
|
| 508 |
+
hidden_states = input_values[:, None]
|
| 509 |
+
|
| 510 |
+
# make sure hidden_states require grad for gradient_checkpointing
|
| 511 |
+
if self._requires_grad and self.training:
|
| 512 |
+
hidden_states.requires_grad = True
|
| 513 |
+
|
| 514 |
+
for conv_layer in self.conv_layers:
|
| 515 |
+
if self._requires_grad and self.gradient_checkpointing and self.training:
|
| 516 |
+
|
| 517 |
+
def create_custom_forward(module):
|
| 518 |
+
def custom_forward(*inputs):
|
| 519 |
+
return module(*inputs)
|
| 520 |
+
|
| 521 |
+
return custom_forward
|
| 522 |
+
|
| 523 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 524 |
+
create_custom_forward(conv_layer),
|
| 525 |
+
hidden_states,
|
| 526 |
+
)
|
| 527 |
+
else:
|
| 528 |
+
hidden_states = conv_layer(hidden_states)
|
| 529 |
+
|
| 530 |
+
return hidden_states
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
|
| 534 |
+
class Wav2Vec2ConformerFeatureProjection(nn.Module):
|
| 535 |
+
def __init__(self, config):
|
| 536 |
+
super().__init__()
|
| 537 |
+
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
| 538 |
+
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
| 539 |
+
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
| 540 |
+
|
| 541 |
+
def forward(self, hidden_states):
|
| 542 |
+
# non-projected hidden states are needed for quantization
|
| 543 |
+
norm_hidden_states = self.layer_norm(hidden_states)
|
| 544 |
+
hidden_states = self.projection(norm_hidden_states)
|
| 545 |
+
hidden_states = self.dropout(hidden_states)
|
| 546 |
+
return hidden_states, norm_hidden_states
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
|
| 550 |
+
class Wav2Vec2ConformerFeedForward(nn.Module):
|
| 551 |
+
def __init__(self, config):
|
| 552 |
+
super().__init__()
|
| 553 |
+
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
|
| 554 |
+
|
| 555 |
+
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 556 |
+
if isinstance(config.hidden_act, str):
|
| 557 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 558 |
+
else:
|
| 559 |
+
self.intermediate_act_fn = config.hidden_act
|
| 560 |
+
|
| 561 |
+
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 562 |
+
self.output_dropout = nn.Dropout(config.hidden_dropout)
|
| 563 |
+
|
| 564 |
+
def forward(self, hidden_states):
|
| 565 |
+
hidden_states = self.intermediate_dense(hidden_states)
|
| 566 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 567 |
+
hidden_states = self.intermediate_dropout(hidden_states)
|
| 568 |
+
|
| 569 |
+
hidden_states = self.output_dense(hidden_states)
|
| 570 |
+
hidden_states = self.output_dropout(hidden_states)
|
| 571 |
+
return hidden_states
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
class Wav2Vec2ConformerConvolutionModule(nn.Module):
|
| 575 |
+
"""Convolution block used in the conformer block"""
|
| 576 |
+
|
| 577 |
+
def __init__(self, config):
|
| 578 |
+
super().__init__()
|
| 579 |
+
if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
|
| 580 |
+
raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
|
| 581 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size)
|
| 582 |
+
self.pointwise_conv1 = torch.nn.Conv1d(
|
| 583 |
+
config.hidden_size,
|
| 584 |
+
2 * config.hidden_size,
|
| 585 |
+
kernel_size=1,
|
| 586 |
+
stride=1,
|
| 587 |
+
padding=0,
|
| 588 |
+
bias=False,
|
| 589 |
+
)
|
| 590 |
+
self.glu = torch.nn.GLU(dim=1)
|
| 591 |
+
self.depthwise_conv = torch.nn.Conv1d(
|
| 592 |
+
config.hidden_size,
|
| 593 |
+
config.hidden_size,
|
| 594 |
+
config.conv_depthwise_kernel_size,
|
| 595 |
+
stride=1,
|
| 596 |
+
padding=(config.conv_depthwise_kernel_size - 1) // 2,
|
| 597 |
+
groups=config.hidden_size,
|
| 598 |
+
bias=False,
|
| 599 |
+
)
|
| 600 |
+
self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
|
| 601 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 602 |
+
self.pointwise_conv2 = torch.nn.Conv1d(
|
| 603 |
+
config.hidden_size,
|
| 604 |
+
config.hidden_size,
|
| 605 |
+
kernel_size=1,
|
| 606 |
+
stride=1,
|
| 607 |
+
padding=0,
|
| 608 |
+
bias=False,
|
| 609 |
+
)
|
| 610 |
+
self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
|
| 611 |
+
|
| 612 |
+
def forward(self, hidden_states):
|
| 613 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 614 |
+
# exchange the temporal dimension and the feature dimension
|
| 615 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 616 |
+
|
| 617 |
+
# GLU mechanism
|
| 618 |
+
# => (batch, 2*channel, dim)
|
| 619 |
+
hidden_states = self.pointwise_conv1(hidden_states)
|
| 620 |
+
# => (batch, channel, dim)
|
| 621 |
+
hidden_states = self.glu(hidden_states)
|
| 622 |
+
|
| 623 |
+
# 1D Depthwise Conv
|
| 624 |
+
hidden_states = self.depthwise_conv(hidden_states)
|
| 625 |
+
hidden_states = self.batch_norm(hidden_states)
|
| 626 |
+
hidden_states = self.activation(hidden_states)
|
| 627 |
+
|
| 628 |
+
hidden_states = self.pointwise_conv2(hidden_states)
|
| 629 |
+
hidden_states = self.dropout(hidden_states)
|
| 630 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 631 |
+
return hidden_states
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
class Wav2Vec2ConformerSelfAttention(nn.Module):
|
| 635 |
+
"""Construct an Wav2Vec2ConformerSelfAttention object.
|
| 636 |
+
Can be enhanced with rotary or relative position embeddings.
|
| 637 |
+
"""
|
| 638 |
+
|
| 639 |
+
def __init__(self, config):
|
| 640 |
+
super().__init__()
|
| 641 |
+
|
| 642 |
+
self.head_size = config.hidden_size // config.num_attention_heads
|
| 643 |
+
self.num_heads = config.num_attention_heads
|
| 644 |
+
self.position_embeddings_type = config.position_embeddings_type
|
| 645 |
+
|
| 646 |
+
self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
|
| 647 |
+
self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
|
| 648 |
+
self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
|
| 649 |
+
self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
|
| 650 |
+
|
| 651 |
+
self.dropout = nn.Dropout(p=config.attention_dropout)
|
| 652 |
+
self.dropout_p = config.attention_dropout
|
| 653 |
+
|
| 654 |
+
self.is_causal = config.is_causal
|
| 655 |
+
|
| 656 |
+
if self.position_embeddings_type == "relative":
|
| 657 |
+
# linear transformation for positional encoding
|
| 658 |
+
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
| 659 |
+
# these two learnable bias are used in matrix c and matrix d
|
| 660 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 661 |
+
self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
|
| 662 |
+
self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
|
| 663 |
+
|
| 664 |
+
def forward(
|
| 665 |
+
self,
|
| 666 |
+
hidden_states: torch.Tensor,
|
| 667 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 668 |
+
relative_position_embeddings: Optional[torch.Tensor] = None,
|
| 669 |
+
output_attentions: bool = False,
|
| 670 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 671 |
+
# self-attention mechanism
|
| 672 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 673 |
+
|
| 674 |
+
# make sure query/key states can be != value states
|
| 675 |
+
query_key_states = hidden_states
|
| 676 |
+
value_states = hidden_states
|
| 677 |
+
|
| 678 |
+
if self.position_embeddings_type == "rotary":
|
| 679 |
+
if relative_position_embeddings is None:
|
| 680 |
+
raise ValueError(
|
| 681 |
+
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
|
| 682 |
+
)
|
| 683 |
+
query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
|
| 684 |
+
|
| 685 |
+
# project query_key_states and value_states
|
| 686 |
+
query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
| 687 |
+
key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
| 688 |
+
value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
|
| 689 |
+
|
| 690 |
+
# => (batch, head, time1, d_k)
|
| 691 |
+
query = query.transpose(1, 2)
|
| 692 |
+
key = key.transpose(1, 2)
|
| 693 |
+
value = value.transpose(1, 2)
|
| 694 |
+
|
| 695 |
+
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
|
| 696 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal)
|
| 697 |
+
probs = None
|
| 698 |
+
|
| 699 |
+
# # apply attention_mask if necessary
|
| 700 |
+
# if attention_mask is not None:
|
| 701 |
+
# scores = scores + attention_mask
|
| 702 |
+
|
| 703 |
+
# # => (batch, head, time1, time2)
|
| 704 |
+
# probs = torch.softmax(scores, dim=-1)
|
| 705 |
+
# probs = self.dropout(probs)
|
| 706 |
+
|
| 707 |
+
# # => (batch, head, time1, d_k)
|
| 708 |
+
# hidden_states = torch.matmul(probs, value)
|
| 709 |
+
|
| 710 |
+
# => (batch, time1, hidden_size)
|
| 711 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
|
| 712 |
+
hidden_states = self.linear_out(hidden_states)
|
| 713 |
+
|
| 714 |
+
return hidden_states, probs
|
| 715 |
+
|
| 716 |
+
def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
|
| 717 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 718 |
+
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
|
| 719 |
+
|
| 720 |
+
cos = relative_position_embeddings[0, :sequence_length, ...]
|
| 721 |
+
sin = relative_position_embeddings[1, :sequence_length, ...]
|
| 722 |
+
|
| 723 |
+
# rotate hidden_states with rotary embeddings
|
| 724 |
+
hidden_states = hidden_states.transpose(0, 1)
|
| 725 |
+
rotated_states_begin = hidden_states[..., : self.head_size // 2]
|
| 726 |
+
rotated_states_end = hidden_states[..., self.head_size // 2 :]
|
| 727 |
+
rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
|
| 728 |
+
hidden_states = (hidden_states * cos) + (rotated_states * sin)
|
| 729 |
+
hidden_states = hidden_states.transpose(0, 1)
|
| 730 |
+
|
| 731 |
+
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
|
| 732 |
+
|
| 733 |
+
return hidden_states
|
| 734 |
+
|
| 735 |
+
def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
|
| 736 |
+
# 1. project positional embeddings
|
| 737 |
+
# => (batch, head, 2*time1-1, d_k)
|
| 738 |
+
proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
|
| 739 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.view(
|
| 740 |
+
relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
|
| 741 |
+
)
|
| 742 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
|
| 743 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
|
| 744 |
+
|
| 745 |
+
# 2. Add bias to query
|
| 746 |
+
# => (batch, head, time1, d_k)
|
| 747 |
+
query = query.transpose(1, 2)
|
| 748 |
+
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
|
| 749 |
+
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
|
| 750 |
+
|
| 751 |
+
# 3. attention score: first compute matrix a and matrix c
|
| 752 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 753 |
+
# => (batch, head, time1, time2)
|
| 754 |
+
scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
|
| 755 |
+
|
| 756 |
+
# 4. then compute matrix b and matrix d
|
| 757 |
+
# => (batch, head, time1, 2*time1-1)
|
| 758 |
+
scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
|
| 759 |
+
|
| 760 |
+
# 5. shift matrix b and matrix d
|
| 761 |
+
zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
|
| 762 |
+
scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
|
| 763 |
+
scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
|
| 764 |
+
scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
|
| 765 |
+
scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
|
| 766 |
+
scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
|
| 767 |
+
|
| 768 |
+
# 6. sum matrices
|
| 769 |
+
# => (batch, head, time1, time2)
|
| 770 |
+
scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
|
| 771 |
+
|
| 772 |
+
return scores
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
class Wav2Vec2ConformerEncoderLayer(nn.Module):
|
| 776 |
+
"""Conformer block based on https://arxiv.org/abs/2005.08100."""
|
| 777 |
+
|
| 778 |
+
def __init__(self, config):
|
| 779 |
+
super().__init__()
|
| 780 |
+
embed_dim = config.hidden_size
|
| 781 |
+
dropout = config.attention_dropout
|
| 782 |
+
|
| 783 |
+
# Feed-forward 1
|
| 784 |
+
self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
|
| 785 |
+
self.ffn1 = Wav2Vec2ConformerFeedForward(config)
|
| 786 |
+
|
| 787 |
+
# Self-Attention
|
| 788 |
+
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
|
| 789 |
+
self.self_attn_dropout = torch.nn.Dropout(dropout)
|
| 790 |
+
self.self_attn = Wav2Vec2ConformerSelfAttention(config)
|
| 791 |
+
|
| 792 |
+
# Conformer Convolution
|
| 793 |
+
self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
|
| 794 |
+
|
| 795 |
+
# Feed-forward 2
|
| 796 |
+
self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
|
| 797 |
+
self.ffn2 = Wav2Vec2ConformerFeedForward(config)
|
| 798 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
| 799 |
+
|
| 800 |
+
def forward(
|
| 801 |
+
self,
|
| 802 |
+
hidden_states,
|
| 803 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 804 |
+
relative_position_embeddings: Optional[torch.Tensor] = None,
|
| 805 |
+
output_attentions: bool = False,
|
| 806 |
+
):
|
| 807 |
+
hidden_states = hidden_states
|
| 808 |
+
|
| 809 |
+
# 1. Feed-Forward 1 layer
|
| 810 |
+
residual = hidden_states
|
| 811 |
+
hidden_states = self.ffn1_layer_norm(hidden_states)
|
| 812 |
+
hidden_states = self.ffn1(hidden_states)
|
| 813 |
+
hidden_states = hidden_states * 0.5 + residual
|
| 814 |
+
residual = hidden_states
|
| 815 |
+
|
| 816 |
+
# 2. Self-Attention layer
|
| 817 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 818 |
+
hidden_states, attn_weigts = self.self_attn(
|
| 819 |
+
hidden_states=hidden_states,
|
| 820 |
+
attention_mask=attention_mask,
|
| 821 |
+
relative_position_embeddings=relative_position_embeddings,
|
| 822 |
+
output_attentions=output_attentions,
|
| 823 |
+
)
|
| 824 |
+
hidden_states = self.self_attn_dropout(hidden_states)
|
| 825 |
+
hidden_states = hidden_states + residual
|
| 826 |
+
|
| 827 |
+
# 3. Convolutional Layer
|
| 828 |
+
residual = hidden_states
|
| 829 |
+
hidden_states = self.conv_module(hidden_states)
|
| 830 |
+
hidden_states = residual + hidden_states
|
| 831 |
+
|
| 832 |
+
# 4. Feed-Forward 2 Layer
|
| 833 |
+
residual = hidden_states
|
| 834 |
+
hidden_states = self.ffn2_layer_norm(hidden_states)
|
| 835 |
+
hidden_states = self.ffn2(hidden_states)
|
| 836 |
+
hidden_states = hidden_states * 0.5 + residual
|
| 837 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 838 |
+
|
| 839 |
+
return hidden_states, attn_weigts
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
class Wav2Vec2ConformerEncoder(nn.Module):
|
| 843 |
+
def __init__(self, config, is_causal=False):
|
| 844 |
+
super().__init__()
|
| 845 |
+
config.is_causal = is_causal
|
| 846 |
+
self.config = config
|
| 847 |
+
|
| 848 |
+
if config.position_embeddings_type == "relative":
|
| 849 |
+
self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
|
| 850 |
+
elif config.position_embeddings_type == "rotary":
|
| 851 |
+
self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
|
| 852 |
+
else:
|
| 853 |
+
self.embed_positions = None
|
| 854 |
+
|
| 855 |
+
self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
|
| 856 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 857 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 858 |
+
self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 859 |
+
self.gradient_checkpointing = False
|
| 860 |
+
|
| 861 |
+
def forward(
|
| 862 |
+
self,
|
| 863 |
+
hidden_states,
|
| 864 |
+
attention_mask=None,
|
| 865 |
+
output_attentions=False,
|
| 866 |
+
output_hidden_states=False,
|
| 867 |
+
return_dict=True,
|
| 868 |
+
):
|
| 869 |
+
all_hidden_states = () if output_hidden_states else None
|
| 870 |
+
all_self_attentions = () if output_attentions else None
|
| 871 |
+
|
| 872 |
+
if attention_mask is not None:
|
| 873 |
+
# make sure padded tokens output 0
|
| 874 |
+
hidden_states[~attention_mask] = 0.0
|
| 875 |
+
|
| 876 |
+
# extend attention_mask
|
| 877 |
+
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
|
| 878 |
+
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
|
| 879 |
+
attention_mask = attention_mask.expand(
|
| 880 |
+
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
hidden_states = self.dropout(hidden_states)
|
| 884 |
+
|
| 885 |
+
if self.embed_positions is not None:
|
| 886 |
+
relative_position_embeddings = self.embed_positions(hidden_states)
|
| 887 |
+
else:
|
| 888 |
+
relative_position_embeddings = None
|
| 889 |
+
|
| 890 |
+
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
| 891 |
+
|
| 892 |
+
for i, layer in enumerate(self.layers):
|
| 893 |
+
if output_hidden_states:
|
| 894 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 895 |
+
|
| 896 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 897 |
+
dropout_probability = np.random.uniform(0, 1)
|
| 898 |
+
|
| 899 |
+
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
| 900 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
| 901 |
+
# under deepspeed zero3 all gpus must run in sync
|
| 902 |
+
if self.gradient_checkpointing and self.training:
|
| 903 |
+
# create gradient checkpointing function
|
| 904 |
+
def create_custom_forward(module):
|
| 905 |
+
def custom_forward(*inputs):
|
| 906 |
+
return module(*inputs, output_attentions)
|
| 907 |
+
|
| 908 |
+
return custom_forward
|
| 909 |
+
|
| 910 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 911 |
+
create_custom_forward(layer),
|
| 912 |
+
hidden_states,
|
| 913 |
+
attention_mask,
|
| 914 |
+
relative_position_embeddings,
|
| 915 |
+
)
|
| 916 |
+
else:
|
| 917 |
+
layer_outputs = layer(
|
| 918 |
+
hidden_states,
|
| 919 |
+
attention_mask=attention_mask,
|
| 920 |
+
relative_position_embeddings=relative_position_embeddings,
|
| 921 |
+
output_attentions=output_attentions,
|
| 922 |
+
)
|
| 923 |
+
hidden_states = layer_outputs[0]
|
| 924 |
+
|
| 925 |
+
if skip_the_layer:
|
| 926 |
+
layer_outputs = (None, None)
|
| 927 |
+
|
| 928 |
+
if output_attentions:
|
| 929 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 930 |
+
|
| 931 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 932 |
+
if output_hidden_states:
|
| 933 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 934 |
+
|
| 935 |
+
if not return_dict:
|
| 936 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 937 |
+
return BaseModelOutput(
|
| 938 |
+
last_hidden_state=hidden_states,
|
| 939 |
+
hidden_states=all_hidden_states,
|
| 940 |
+
attentions=all_self_attentions,
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
|
| 945 |
+
class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
|
| 946 |
+
"""
|
| 947 |
+
Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
|
| 948 |
+
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
|
| 949 |
+
"""
|
| 950 |
+
|
| 951 |
+
def __init__(self, config):
|
| 952 |
+
super().__init__()
|
| 953 |
+
self.num_groups = config.num_codevector_groups
|
| 954 |
+
self.num_vars = config.num_codevectors_per_group
|
| 955 |
+
|
| 956 |
+
if config.codevector_dim % self.num_groups != 0:
|
| 957 |
+
raise ValueError(
|
| 958 |
+
f"`config.codevector_dim {config.codevector_dim} must be divisible "
|
| 959 |
+
f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
# storage for codebook variables (codewords)
|
| 963 |
+
self.codevectors = nn.Parameter(
|
| 964 |
+
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
|
| 965 |
+
)
|
| 966 |
+
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
|
| 967 |
+
|
| 968 |
+
# can be decayed for training
|
| 969 |
+
self.temperature = 2
|
| 970 |
+
|
| 971 |
+
@staticmethod
|
| 972 |
+
def _compute_perplexity(probs, mask=None):
|
| 973 |
+
if mask is not None:
|
| 974 |
+
mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
|
| 975 |
+
probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
|
| 976 |
+
marginal_probs = probs.sum(dim=0) / mask.sum()
|
| 977 |
+
else:
|
| 978 |
+
marginal_probs = probs.mean(dim=0)
|
| 979 |
+
|
| 980 |
+
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
|
| 981 |
+
return perplexity
|
| 982 |
+
|
| 983 |
+
def forward(self, hidden_states, mask_time_indices=None):
|
| 984 |
+
batch_size, sequence_length, hidden_size = hidden_states.shape
|
| 985 |
+
|
| 986 |
+
# project to codevector dim
|
| 987 |
+
hidden_states = self.weight_proj(hidden_states)
|
| 988 |
+
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
| 989 |
+
|
| 990 |
+
if self.training:
|
| 991 |
+
# sample code vector probs via gumbel in differentiateable way
|
| 992 |
+
codevector_probs = nn.functional.gumbel_softmax(
|
| 993 |
+
hidden_states.float(), tau=self.temperature, hard=True
|
| 994 |
+
).type_as(hidden_states)
|
| 995 |
+
|
| 996 |
+
# compute perplexity
|
| 997 |
+
codevector_soft_dist = torch.softmax(
|
| 998 |
+
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
|
| 999 |
+
)
|
| 1000 |
+
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
|
| 1001 |
+
else:
|
| 1002 |
+
# take argmax in non-differentiable way
|
| 1003 |
+
# comptute hard codevector distribution (one hot)
|
| 1004 |
+
codevector_idx = hidden_states.argmax(dim=-1)
|
| 1005 |
+
codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
|
| 1006 |
+
-1, codevector_idx.view(-1, 1), 1.0
|
| 1007 |
+
)
|
| 1008 |
+
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
| 1009 |
+
|
| 1010 |
+
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
|
| 1011 |
+
|
| 1012 |
+
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
|
| 1013 |
+
# use probs to retrieve codevectors
|
| 1014 |
+
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
|
| 1015 |
+
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
|
| 1016 |
+
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
|
| 1017 |
+
|
| 1018 |
+
return codevectors, perplexity
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
|
| 1022 |
+
class Wav2Vec2ConformerAdapter(nn.Module):
|
| 1023 |
+
def __init__(self, config):
|
| 1024 |
+
super().__init__()
|
| 1025 |
+
|
| 1026 |
+
# feature dim might need to be down-projected
|
| 1027 |
+
if config.output_hidden_size != config.hidden_size:
|
| 1028 |
+
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
|
| 1029 |
+
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
|
| 1030 |
+
else:
|
| 1031 |
+
self.proj = self.proj_layer_norm = None
|
| 1032 |
+
|
| 1033 |
+
self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
|
| 1034 |
+
self.layerdrop = config.layerdrop
|
| 1035 |
+
|
| 1036 |
+
def forward(self, hidden_states):
|
| 1037 |
+
# down project hidden_states if necessary
|
| 1038 |
+
if self.proj is not None and self.proj_layer_norm is not None:
|
| 1039 |
+
hidden_states = self.proj(hidden_states)
|
| 1040 |
+
hidden_states = self.proj_layer_norm(hidden_states)
|
| 1041 |
+
|
| 1042 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1043 |
+
|
| 1044 |
+
for layer in self.layers:
|
| 1045 |
+
layerdrop_prob = np.random.random()
|
| 1046 |
+
if not self.training or (layerdrop_prob > self.layerdrop):
|
| 1047 |
+
hidden_states = layer(hidden_states)
|
| 1048 |
+
|
| 1049 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1050 |
+
return hidden_states
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 1054 |
+
class Wav2Vec2ConformerAdapterLayer(nn.Module):
|
| 1055 |
+
def __init__(self, config):
|
| 1056 |
+
super().__init__()
|
| 1057 |
+
self.conv = nn.Conv1d(
|
| 1058 |
+
config.output_hidden_size,
|
| 1059 |
+
2 * config.output_hidden_size,
|
| 1060 |
+
config.adapter_kernel_size,
|
| 1061 |
+
stride=config.adapter_stride,
|
| 1062 |
+
padding=1,
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
def forward(self, hidden_states):
|
| 1066 |
+
hidden_states = self.conv(hidden_states)
|
| 1067 |
+
hidden_states = nn.functional.glu(hidden_states, dim=1)
|
| 1068 |
+
|
| 1069 |
+
return hidden_states
|
| 1070 |
+
|
| 1071 |
+
|
| 1072 |
+
class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
|
| 1073 |
+
"""
|
| 1074 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 1075 |
+
models.
|
| 1076 |
+
"""
|
| 1077 |
+
|
| 1078 |
+
config_class = Wav2Vec2ConformerConfig
|
| 1079 |
+
base_model_prefix = "wav2vec2_conformer"
|
| 1080 |
+
main_input_name = "input_values"
|
| 1081 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 1082 |
+
supports_gradient_checkpointing = True
|
| 1083 |
+
|
| 1084 |
+
def _init_weights(self, module):
|
| 1085 |
+
"""Initialize the weights"""
|
| 1086 |
+
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
|
| 1087 |
+
if isinstance(module, Wav2Vec2ConformerForPreTraining):
|
| 1088 |
+
module.project_hid.reset_parameters()
|
| 1089 |
+
module.project_q.reset_parameters()
|
| 1090 |
+
module.project_hid._is_hf_initialized = True
|
| 1091 |
+
module.project_q._is_hf_initialized = True
|
| 1092 |
+
# gumbel softmax requires special init
|
| 1093 |
+
elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
|
| 1094 |
+
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
| 1095 |
+
module.weight_proj.bias.data.zero_()
|
| 1096 |
+
nn.init.uniform_(module.codevectors)
|
| 1097 |
+
elif isinstance(module, Wav2Vec2ConformerSelfAttention):
|
| 1098 |
+
if hasattr(module, "pos_bias_u"):
|
| 1099 |
+
nn.init.xavier_uniform_(module.pos_bias_u)
|
| 1100 |
+
if hasattr(module, "pos_bias_v"):
|
| 1101 |
+
nn.init.xavier_uniform_(module.pos_bias_v)
|
| 1102 |
+
elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
|
| 1103 |
+
nn.init.normal_(
|
| 1104 |
+
module.conv.weight,
|
| 1105 |
+
mean=0,
|
| 1106 |
+
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
|
| 1107 |
+
)
|
| 1108 |
+
nn.init.constant_(module.conv.bias, 0)
|
| 1109 |
+
elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
|
| 1110 |
+
k = math.sqrt(1 / module.projection.in_features)
|
| 1111 |
+
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
| 1112 |
+
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
| 1113 |
+
elif isinstance(module, nn.Linear):
|
| 1114 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1115 |
+
|
| 1116 |
+
if module.bias is not None:
|
| 1117 |
+
module.bias.data.zero_()
|
| 1118 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
| 1119 |
+
module.bias.data.zero_()
|
| 1120 |
+
module.weight.data.fill_(1.0)
|
| 1121 |
+
elif isinstance(module, nn.Conv1d):
|
| 1122 |
+
nn.init.kaiming_normal_(module.weight)
|
| 1123 |
+
|
| 1124 |
+
if module.bias is not None:
|
| 1125 |
+
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
| 1126 |
+
nn.init.uniform_(module.bias, a=-k, b=k)
|
| 1127 |
+
|
| 1128 |
+
def _get_feat_extract_output_lengths(
|
| 1129 |
+
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
|
| 1130 |
+
):
|
| 1131 |
+
"""
|
| 1132 |
+
Computes the output length of the convolutional layers
|
| 1133 |
+
"""
|
| 1134 |
+
|
| 1135 |
+
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
| 1136 |
+
|
| 1137 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
| 1138 |
+
# 1D convolutional layer output length formula taken
|
| 1139 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
| 1140 |
+
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
|
| 1141 |
+
|
| 1142 |
+
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
| 1143 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
| 1144 |
+
|
| 1145 |
+
if add_adapter:
|
| 1146 |
+
for _ in range(self.config.num_adapter_layers):
|
| 1147 |
+
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
| 1148 |
+
|
| 1149 |
+
return input_lengths
|
| 1150 |
+
|
| 1151 |
+
def _get_feature_vector_attention_mask(
|
| 1152 |
+
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
|
| 1153 |
+
):
|
| 1154 |
+
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
| 1155 |
+
# on inference mode.
|
| 1156 |
+
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
|
| 1157 |
+
|
| 1158 |
+
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
| 1159 |
+
output_lengths = output_lengths.to(torch.long)
|
| 1160 |
+
|
| 1161 |
+
batch_size = attention_mask.shape[0]
|
| 1162 |
+
|
| 1163 |
+
attention_mask = torch.zeros(
|
| 1164 |
+
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
| 1165 |
+
)
|
| 1166 |
+
# these two operations makes sure that all values before the output lengths idxs are attended to
|
| 1167 |
+
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
|
| 1168 |
+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
| 1169 |
+
return attention_mask
|
| 1170 |
+
|
| 1171 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 1172 |
+
if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
|
| 1173 |
+
module.gradient_checkpointing = value
|
| 1174 |
+
|
| 1175 |
+
|
| 1176 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
|
| 1177 |
+
Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
|
| 1178 |
+
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
|
| 1179 |
+
Auli.
|
| 1180 |
+
|
| 1181 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1182 |
+
library implements for all its model (such as downloading or saving etc.).
|
| 1183 |
+
|
| 1184 |
+
This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
|
| 1185 |
+
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
|
| 1186 |
+
|
| 1187 |
+
Parameters:
|
| 1188 |
+
config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
|
| 1189 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 1190 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1191 |
+
"""
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
|
| 1195 |
+
Args:
|
| 1196 |
+
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 1197 |
+
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
|
| 1198 |
+
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
|
| 1199 |
+
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
|
| 1200 |
+
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
|
| 1201 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1202 |
+
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
| 1203 |
+
1]`:
|
| 1204 |
+
|
| 1205 |
+
- 1 for tokens that are **not masked**,
|
| 1206 |
+
- 0 for tokens that are **masked**.
|
| 1207 |
+
|
| 1208 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1209 |
+
|
| 1210 |
+
<Tip warning={true}>
|
| 1211 |
+
|
| 1212 |
+
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
|
| 1213 |
+
True`. For all models whose processor has `config.return_attention_mask == False`, such as
|
| 1214 |
+
[wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
|
| 1215 |
+
`attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
|
| 1216 |
+
such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
|
| 1217 |
+
that these models also yield slightly different results depending on whether `input_values` is padded or
|
| 1218 |
+
not.
|
| 1219 |
+
|
| 1220 |
+
</Tip>
|
| 1221 |
+
|
| 1222 |
+
output_attentions (`bool`, *optional*):
|
| 1223 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1224 |
+
tensors for more detail.
|
| 1225 |
+
output_hidden_states (`bool`, *optional*):
|
| 1226 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1227 |
+
more detail.
|
| 1228 |
+
return_dict (`bool`, *optional*):
|
| 1229 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1230 |
+
"""
|
| 1231 |
+
|
| 1232 |
+
|
| 1233 |
+
@add_start_docstrings(
|
| 1234 |
+
"The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1235 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1236 |
+
)
|
| 1237 |
+
class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
|
| 1238 |
+
def __init__(self, config: Wav2Vec2ConformerConfig):
|
| 1239 |
+
super().__init__(config)
|
| 1240 |
+
self.config = config
|
| 1241 |
+
self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
|
| 1242 |
+
self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
|
| 1243 |
+
|
| 1244 |
+
# model only needs masking vector if mask prob is > 0.0
|
| 1245 |
+
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
| 1246 |
+
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
| 1247 |
+
|
| 1248 |
+
self.encoder = Wav2Vec2ConformerEncoder(config)
|
| 1249 |
+
|
| 1250 |
+
self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
|
| 1251 |
+
|
| 1252 |
+
# Initialize weights and apply final processing
|
| 1253 |
+
self.post_init()
|
| 1254 |
+
|
| 1255 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
|
| 1256 |
+
def freeze_feature_encoder(self):
|
| 1257 |
+
"""
|
| 1258 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1259 |
+
not be updated during training.
|
| 1260 |
+
"""
|
| 1261 |
+
self.feature_extractor._freeze_parameters()
|
| 1262 |
+
|
| 1263 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
|
| 1264 |
+
def _mask_hidden_states(
|
| 1265 |
+
self,
|
| 1266 |
+
hidden_states: torch.FloatTensor,
|
| 1267 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1268 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1269 |
+
):
|
| 1270 |
+
"""
|
| 1271 |
+
Masks extracted features along time axis and/or along feature axis according to
|
| 1272 |
+
[SpecAugment](https://arxiv.org/abs/1904.08779).
|
| 1273 |
+
"""
|
| 1274 |
+
|
| 1275 |
+
# `config.apply_spec_augment` can set masking to False
|
| 1276 |
+
if not getattr(self.config, "apply_spec_augment", True):
|
| 1277 |
+
return hidden_states
|
| 1278 |
+
|
| 1279 |
+
# generate indices & apply SpecAugment along time axis
|
| 1280 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 1281 |
+
|
| 1282 |
+
if mask_time_indices is not None:
|
| 1283 |
+
# apply SpecAugment along time axis with given mask_time_indices
|
| 1284 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1285 |
+
elif self.config.mask_time_prob > 0 and self.training:
|
| 1286 |
+
mask_time_indices = _compute_mask_indices(
|
| 1287 |
+
(batch_size, sequence_length),
|
| 1288 |
+
mask_prob=self.config.mask_time_prob,
|
| 1289 |
+
mask_length=self.config.mask_time_length,
|
| 1290 |
+
attention_mask=attention_mask,
|
| 1291 |
+
min_masks=self.config.mask_time_min_masks,
|
| 1292 |
+
)
|
| 1293 |
+
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1294 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1295 |
+
|
| 1296 |
+
if self.config.mask_feature_prob > 0 and self.training:
|
| 1297 |
+
# generate indices & apply SpecAugment along feature axis
|
| 1298 |
+
mask_feature_indices = _compute_mask_indices(
|
| 1299 |
+
(batch_size, hidden_size),
|
| 1300 |
+
mask_prob=self.config.mask_feature_prob,
|
| 1301 |
+
mask_length=self.config.mask_feature_length,
|
| 1302 |
+
min_masks=self.config.mask_feature_min_masks,
|
| 1303 |
+
)
|
| 1304 |
+
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1305 |
+
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
|
| 1306 |
+
hidden_states[mask_feature_indices] = 0
|
| 1307 |
+
|
| 1308 |
+
return hidden_states
|
| 1309 |
+
|
| 1310 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1311 |
+
@add_code_sample_docstrings(
|
| 1312 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1313 |
+
output_type=Wav2Vec2BaseModelOutput,
|
| 1314 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1315 |
+
modality="audio",
|
| 1316 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 1317 |
+
)
|
| 1318 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
|
| 1319 |
+
def forward(
|
| 1320 |
+
self,
|
| 1321 |
+
input_values: Optional[torch.Tensor],
|
| 1322 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1323 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1324 |
+
output_attentions: Optional[bool] = None,
|
| 1325 |
+
output_hidden_states: Optional[bool] = None,
|
| 1326 |
+
return_dict: Optional[bool] = None,
|
| 1327 |
+
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
|
| 1328 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1329 |
+
output_hidden_states = (
|
| 1330 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1331 |
+
)
|
| 1332 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1333 |
+
|
| 1334 |
+
extract_features = self.feature_extractor(input_values)
|
| 1335 |
+
extract_features = extract_features.transpose(1, 2)
|
| 1336 |
+
|
| 1337 |
+
if attention_mask is not None:
|
| 1338 |
+
# compute reduced attention_mask corresponding to feature vectors
|
| 1339 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
| 1340 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
| 1341 |
+
)
|
| 1342 |
+
|
| 1343 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
| 1344 |
+
hidden_states = self._mask_hidden_states(
|
| 1345 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
| 1346 |
+
)
|
| 1347 |
+
|
| 1348 |
+
encoder_outputs = self.encoder(
|
| 1349 |
+
hidden_states,
|
| 1350 |
+
attention_mask=attention_mask,
|
| 1351 |
+
output_attentions=output_attentions,
|
| 1352 |
+
output_hidden_states=output_hidden_states,
|
| 1353 |
+
return_dict=return_dict,
|
| 1354 |
+
)
|
| 1355 |
+
|
| 1356 |
+
hidden_states = encoder_outputs[0]
|
| 1357 |
+
|
| 1358 |
+
if self.adapter is not None:
|
| 1359 |
+
hidden_states = self.adapter(hidden_states)
|
| 1360 |
+
|
| 1361 |
+
if not return_dict:
|
| 1362 |
+
return (hidden_states, extract_features) + encoder_outputs[1:]
|
| 1363 |
+
|
| 1364 |
+
return Wav2Vec2BaseModelOutput(
|
| 1365 |
+
last_hidden_state=hidden_states,
|
| 1366 |
+
extract_features=extract_features,
|
| 1367 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1368 |
+
attentions=encoder_outputs.attentions,
|
| 1369 |
+
)
|
| 1370 |
+
|
| 1371 |
+
|
| 1372 |
+
@add_start_docstrings(
|
| 1373 |
+
"""Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
|
| 1374 |
+
)
|
| 1375 |
+
class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
|
| 1376 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1377 |
+
def __init__(self, config: Wav2Vec2ConformerConfig):
|
| 1378 |
+
super().__init__(config)
|
| 1379 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1380 |
+
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
|
| 1381 |
+
|
| 1382 |
+
self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
|
| 1383 |
+
|
| 1384 |
+
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
| 1385 |
+
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
| 1386 |
+
|
| 1387 |
+
# Initialize weights and apply final processing
|
| 1388 |
+
self.post_init()
|
| 1389 |
+
|
| 1390 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
|
| 1391 |
+
def set_gumbel_temperature(self, temperature: int):
|
| 1392 |
+
"""
|
| 1393 |
+
Set the Gumbel softmax temperature to a given value. Only necessary for training
|
| 1394 |
+
"""
|
| 1395 |
+
self.quantizer.temperature = temperature
|
| 1396 |
+
|
| 1397 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1398 |
+
def freeze_feature_encoder(self):
|
| 1399 |
+
"""
|
| 1400 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1401 |
+
not be updated during training.
|
| 1402 |
+
"""
|
| 1403 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1404 |
+
|
| 1405 |
+
@staticmethod
|
| 1406 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
|
| 1407 |
+
def compute_contrastive_logits(
|
| 1408 |
+
target_features: torch.FloatTensor,
|
| 1409 |
+
negative_features: torch.FloatTensor,
|
| 1410 |
+
predicted_features: torch.FloatTensor,
|
| 1411 |
+
temperature: int = 0.1,
|
| 1412 |
+
):
|
| 1413 |
+
"""
|
| 1414 |
+
Compute logits for contrastive loss based using cosine similarity as the distance measure between
|
| 1415 |
+
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
|
| 1416 |
+
"""
|
| 1417 |
+
target_features = torch.cat([target_features, negative_features], dim=0)
|
| 1418 |
+
|
| 1419 |
+
logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
|
| 1420 |
+
target_features
|
| 1421 |
+
)
|
| 1422 |
+
|
| 1423 |
+
# apply temperature
|
| 1424 |
+
logits = logits / temperature
|
| 1425 |
+
return logits
|
| 1426 |
+
|
| 1427 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1428 |
+
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
| 1429 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
|
| 1430 |
+
def forward(
|
| 1431 |
+
self,
|
| 1432 |
+
input_values: Optional[torch.Tensor],
|
| 1433 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1434 |
+
mask_time_indices: Optional[torch.BoolTensor] = None,
|
| 1435 |
+
sampled_negative_indices: Optional[torch.BoolTensor] = None,
|
| 1436 |
+
output_attentions: Optional[bool] = None,
|
| 1437 |
+
output_hidden_states: Optional[bool] = None,
|
| 1438 |
+
return_dict: Optional[bool] = None,
|
| 1439 |
+
) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
|
| 1440 |
+
r"""
|
| 1441 |
+
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1442 |
+
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
|
| 1443 |
+
masked extracted features in *config.proj_codevector_dim* space.
|
| 1444 |
+
sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
|
| 1445 |
+
Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
|
| 1446 |
+
Required input for pre-training.
|
| 1447 |
+
|
| 1448 |
+
Returns:
|
| 1449 |
+
|
| 1450 |
+
Example:
|
| 1451 |
+
|
| 1452 |
+
```python
|
| 1453 |
+
>>> import torch
|
| 1454 |
+
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
|
| 1455 |
+
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
|
| 1456 |
+
... _compute_mask_indices,
|
| 1457 |
+
... _sample_negative_indices,
|
| 1458 |
+
... )
|
| 1459 |
+
>>> from datasets import load_dataset
|
| 1460 |
+
|
| 1461 |
+
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
| 1462 |
+
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
| 1463 |
+
|
| 1464 |
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 1465 |
+
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
|
| 1466 |
+
|
| 1467 |
+
>>> # compute masked indices
|
| 1468 |
+
>>> batch_size, raw_sequence_length = input_values.shape
|
| 1469 |
+
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
|
| 1470 |
+
>>> mask_time_indices = _compute_mask_indices(
|
| 1471 |
+
... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
|
| 1472 |
+
... )
|
| 1473 |
+
>>> sampled_negative_indices = _sample_negative_indices(
|
| 1474 |
+
... features_shape=(batch_size, sequence_length),
|
| 1475 |
+
... num_negatives=model.config.num_negatives,
|
| 1476 |
+
... mask_time_indices=mask_time_indices,
|
| 1477 |
+
... )
|
| 1478 |
+
>>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
|
| 1479 |
+
>>> sampled_negative_indices = torch.tensor(
|
| 1480 |
+
... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
|
| 1481 |
+
... )
|
| 1482 |
+
|
| 1483 |
+
>>> with torch.no_grad():
|
| 1484 |
+
... outputs = model(input_values, mask_time_indices=mask_time_indices)
|
| 1485 |
+
|
| 1486 |
+
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
| 1487 |
+
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
| 1488 |
+
|
| 1489 |
+
>>> # show that cosine similarity is much higher than random
|
| 1490 |
+
>>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
|
| 1491 |
+
tensor(True)
|
| 1492 |
+
|
| 1493 |
+
>>> # for contrastive loss training model should be put into train mode
|
| 1494 |
+
>>> model = model.train()
|
| 1495 |
+
>>> loss = model(
|
| 1496 |
+
... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
|
| 1497 |
+
... ).loss
|
| 1498 |
+
```"""
|
| 1499 |
+
|
| 1500 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1501 |
+
|
| 1502 |
+
if mask_time_indices is not None:
|
| 1503 |
+
mask_time_indices = mask_time_indices.to(torch.bool)
|
| 1504 |
+
|
| 1505 |
+
outputs = self.wav2vec2_conformer(
|
| 1506 |
+
input_values,
|
| 1507 |
+
attention_mask=attention_mask,
|
| 1508 |
+
output_attentions=output_attentions,
|
| 1509 |
+
output_hidden_states=output_hidden_states,
|
| 1510 |
+
mask_time_indices=mask_time_indices,
|
| 1511 |
+
return_dict=return_dict,
|
| 1512 |
+
)
|
| 1513 |
+
|
| 1514 |
+
# 1. project all transformed features (including masked) to final vq dim
|
| 1515 |
+
transformer_features = self.project_hid(outputs[0])
|
| 1516 |
+
|
| 1517 |
+
# 2. quantize all (unmasked) extracted features and project to final vq dim
|
| 1518 |
+
extract_features = self.dropout_features(outputs[1])
|
| 1519 |
+
|
| 1520 |
+
if attention_mask is not None:
|
| 1521 |
+
# compute reduced attention_mask correponding to feature vectors
|
| 1522 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
| 1523 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
| 1524 |
+
)
|
| 1525 |
+
|
| 1526 |
+
quantized_features, codevector_perplexity = self.quantizer(
|
| 1527 |
+
extract_features, mask_time_indices=mask_time_indices
|
| 1528 |
+
)
|
| 1529 |
+
quantized_features = self.project_q(quantized_features)
|
| 1530 |
+
|
| 1531 |
+
loss = contrastive_loss = diversity_loss = None
|
| 1532 |
+
if sampled_negative_indices is not None:
|
| 1533 |
+
batch_size, sequence_length, hidden_size = quantized_features.shape
|
| 1534 |
+
|
| 1535 |
+
# for training, we sample negatives
|
| 1536 |
+
# 3. sample K negatives (distractors) quantized states for contrastive loss
|
| 1537 |
+
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
|
| 1538 |
+
# sample negative quantized vectors BTC => (BxT)C
|
| 1539 |
+
negative_quantized_features = quantized_features.view(-1, hidden_size)[
|
| 1540 |
+
sampled_negative_indices.long().view(-1)
|
| 1541 |
+
]
|
| 1542 |
+
negative_quantized_features = negative_quantized_features.view(
|
| 1543 |
+
batch_size, sequence_length, -1, hidden_size
|
| 1544 |
+
).permute(2, 0, 1, 3)
|
| 1545 |
+
|
| 1546 |
+
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
|
| 1547 |
+
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
|
| 1548 |
+
logits = self.compute_contrastive_logits(
|
| 1549 |
+
quantized_features[None, :],
|
| 1550 |
+
negative_quantized_features,
|
| 1551 |
+
transformer_features,
|
| 1552 |
+
self.config.contrastive_logits_temperature,
|
| 1553 |
+
)
|
| 1554 |
+
|
| 1555 |
+
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
|
| 1556 |
+
# its cosine similarity will be masked
|
| 1557 |
+
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
|
| 1558 |
+
|
| 1559 |
+
if neg_is_pos.any():
|
| 1560 |
+
logits[1:][neg_is_pos] = float("-inf")
|
| 1561 |
+
|
| 1562 |
+
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
|
| 1563 |
+
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
|
| 1564 |
+
logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
|
| 1565 |
+
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
|
| 1566 |
+
|
| 1567 |
+
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
|
| 1568 |
+
# 7. compute diversity loss: \mathbf{L}_d
|
| 1569 |
+
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
|
| 1570 |
+
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
|
| 1571 |
+
|
| 1572 |
+
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
|
| 1573 |
+
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
|
| 1574 |
+
|
| 1575 |
+
if not return_dict:
|
| 1576 |
+
if loss is not None:
|
| 1577 |
+
return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
| 1578 |
+
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
| 1579 |
+
|
| 1580 |
+
return Wav2Vec2ConformerForPreTrainingOutput(
|
| 1581 |
+
loss=loss,
|
| 1582 |
+
projected_states=transformer_features,
|
| 1583 |
+
projected_quantized_states=quantized_features,
|
| 1584 |
+
codevector_perplexity=codevector_perplexity,
|
| 1585 |
+
hidden_states=outputs.hidden_states,
|
| 1586 |
+
attentions=outputs.attentions,
|
| 1587 |
+
contrastive_loss=contrastive_loss,
|
| 1588 |
+
diversity_loss=diversity_loss,
|
| 1589 |
+
)
|
| 1590 |
+
|
| 1591 |
+
|
| 1592 |
+
@add_start_docstrings(
|
| 1593 |
+
"""Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
|
| 1594 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1595 |
+
)
|
| 1596 |
+
class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
|
| 1597 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1598 |
+
def __init__(self, config):
|
| 1599 |
+
super().__init__(config)
|
| 1600 |
+
|
| 1601 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1602 |
+
self.dropout = nn.Dropout(config.final_dropout)
|
| 1603 |
+
|
| 1604 |
+
if config.vocab_size is None:
|
| 1605 |
+
raise ValueError(
|
| 1606 |
+
f"You are trying to instantiate {self.__class__} with a configuration that "
|
| 1607 |
+
"does not define the vocabulary size of the language model head. Please "
|
| 1608 |
+
"instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
| 1609 |
+
"or define `vocab_size` of your model's configuration."
|
| 1610 |
+
)
|
| 1611 |
+
output_hidden_size = (
|
| 1612 |
+
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
|
| 1613 |
+
)
|
| 1614 |
+
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
| 1615 |
+
|
| 1616 |
+
# Initialize weights and apply final processing
|
| 1617 |
+
self.post_init()
|
| 1618 |
+
|
| 1619 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1620 |
+
def freeze_feature_encoder(self):
|
| 1621 |
+
"""
|
| 1622 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1623 |
+
not be updated during training.
|
| 1624 |
+
"""
|
| 1625 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1626 |
+
|
| 1627 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1628 |
+
@add_code_sample_docstrings(
|
| 1629 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1630 |
+
output_type=CausalLMOutput,
|
| 1631 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1632 |
+
expected_output=_CTC_EXPECTED_OUTPUT,
|
| 1633 |
+
expected_loss=_CTC_EXPECTED_LOSS,
|
| 1634 |
+
)
|
| 1635 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1636 |
+
def forward(
|
| 1637 |
+
self,
|
| 1638 |
+
input_values: Optional[torch.Tensor],
|
| 1639 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1640 |
+
output_attentions: Optional[bool] = None,
|
| 1641 |
+
output_hidden_states: Optional[bool] = None,
|
| 1642 |
+
return_dict: Optional[bool] = None,
|
| 1643 |
+
labels: Optional[torch.Tensor] = None,
|
| 1644 |
+
) -> Union[Tuple, CausalLMOutput]:
|
| 1645 |
+
r"""
|
| 1646 |
+
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
| 1647 |
+
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
| 1648 |
+
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
| 1649 |
+
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
| 1650 |
+
config.vocab_size - 1]`.
|
| 1651 |
+
"""
|
| 1652 |
+
|
| 1653 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1654 |
+
|
| 1655 |
+
outputs = self.wav2vec2_conformer(
|
| 1656 |
+
input_values,
|
| 1657 |
+
attention_mask=attention_mask,
|
| 1658 |
+
output_attentions=output_attentions,
|
| 1659 |
+
output_hidden_states=output_hidden_states,
|
| 1660 |
+
return_dict=return_dict,
|
| 1661 |
+
)
|
| 1662 |
+
|
| 1663 |
+
hidden_states = outputs[0]
|
| 1664 |
+
hidden_states = self.dropout(hidden_states)
|
| 1665 |
+
|
| 1666 |
+
logits = self.lm_head(hidden_states)
|
| 1667 |
+
|
| 1668 |
+
loss = None
|
| 1669 |
+
if labels is not None:
|
| 1670 |
+
if labels.max() >= self.config.vocab_size:
|
| 1671 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
| 1672 |
+
|
| 1673 |
+
# retrieve loss input_lengths from attention_mask
|
| 1674 |
+
attention_mask = (
|
| 1675 |
+
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
| 1676 |
+
)
|
| 1677 |
+
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
| 1678 |
+
|
| 1679 |
+
# assuming that padded tokens are filled with -100
|
| 1680 |
+
# when not being attended to
|
| 1681 |
+
labels_mask = labels >= 0
|
| 1682 |
+
target_lengths = labels_mask.sum(-1)
|
| 1683 |
+
flattened_targets = labels.masked_select(labels_mask)
|
| 1684 |
+
|
| 1685 |
+
# ctc_loss doesn't support fp16
|
| 1686 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
| 1687 |
+
|
| 1688 |
+
with torch.backends.cudnn.flags(enabled=False):
|
| 1689 |
+
loss = nn.functional.ctc_loss(
|
| 1690 |
+
log_probs,
|
| 1691 |
+
flattened_targets,
|
| 1692 |
+
input_lengths,
|
| 1693 |
+
target_lengths,
|
| 1694 |
+
blank=self.config.pad_token_id,
|
| 1695 |
+
reduction=self.config.ctc_loss_reduction,
|
| 1696 |
+
zero_infinity=self.config.ctc_zero_infinity,
|
| 1697 |
+
)
|
| 1698 |
+
|
| 1699 |
+
if not return_dict:
|
| 1700 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1701 |
+
return ((loss,) + output) if loss is not None else output
|
| 1702 |
+
|
| 1703 |
+
return CausalLMOutput(
|
| 1704 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 1705 |
+
)
|
| 1706 |
+
|
| 1707 |
+
|
| 1708 |
+
@add_start_docstrings(
|
| 1709 |
+
"""
|
| 1710 |
+
Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
|
| 1711 |
+
tasks like SUPERB Keyword Spotting.
|
| 1712 |
+
""",
|
| 1713 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1714 |
+
)
|
| 1715 |
+
class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
|
| 1716 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1717 |
+
def __init__(self, config):
|
| 1718 |
+
super().__init__(config)
|
| 1719 |
+
|
| 1720 |
+
if hasattr(config, "add_adapter") and config.add_adapter:
|
| 1721 |
+
raise ValueError(
|
| 1722 |
+
"Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
|
| 1723 |
+
)
|
| 1724 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1725 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1726 |
+
if config.use_weighted_layer_sum:
|
| 1727 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1728 |
+
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
| 1729 |
+
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
| 1730 |
+
|
| 1731 |
+
# Initialize weights and apply final processing
|
| 1732 |
+
self.post_init()
|
| 1733 |
+
|
| 1734 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1735 |
+
def freeze_feature_encoder(self):
|
| 1736 |
+
"""
|
| 1737 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1738 |
+
not be updated during training.
|
| 1739 |
+
"""
|
| 1740 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1741 |
+
|
| 1742 |
+
def freeze_base_model(self):
|
| 1743 |
+
"""
|
| 1744 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1745 |
+
be updated during training. Only the classification head will be updated.
|
| 1746 |
+
"""
|
| 1747 |
+
for param in self.wav2vec2_conformer.parameters():
|
| 1748 |
+
param.requires_grad = False
|
| 1749 |
+
|
| 1750 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1751 |
+
@add_code_sample_docstrings(
|
| 1752 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1753 |
+
output_type=SequenceClassifierOutput,
|
| 1754 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1755 |
+
modality="audio",
|
| 1756 |
+
)
|
| 1757 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
| 1758 |
+
def forward(
|
| 1759 |
+
self,
|
| 1760 |
+
input_values: Optional[torch.Tensor],
|
| 1761 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1762 |
+
output_attentions: Optional[bool] = None,
|
| 1763 |
+
output_hidden_states: Optional[bool] = None,
|
| 1764 |
+
return_dict: Optional[bool] = None,
|
| 1765 |
+
labels: Optional[torch.Tensor] = None,
|
| 1766 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 1767 |
+
r"""
|
| 1768 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1769 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1770 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1771 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1772 |
+
"""
|
| 1773 |
+
|
| 1774 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1775 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 1776 |
+
|
| 1777 |
+
outputs = self.wav2vec2_conformer(
|
| 1778 |
+
input_values,
|
| 1779 |
+
attention_mask=attention_mask,
|
| 1780 |
+
output_attentions=output_attentions,
|
| 1781 |
+
output_hidden_states=output_hidden_states,
|
| 1782 |
+
return_dict=return_dict,
|
| 1783 |
+
)
|
| 1784 |
+
|
| 1785 |
+
if self.config.use_weighted_layer_sum:
|
| 1786 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 1787 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 1788 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 1789 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 1790 |
+
else:
|
| 1791 |
+
hidden_states = outputs[0]
|
| 1792 |
+
|
| 1793 |
+
hidden_states = self.projector(hidden_states)
|
| 1794 |
+
if attention_mask is None:
|
| 1795 |
+
pooled_output = hidden_states.mean(dim=1)
|
| 1796 |
+
else:
|
| 1797 |
+
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
| 1798 |
+
hidden_states[~padding_mask] = 0.0
|
| 1799 |
+
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
| 1800 |
+
|
| 1801 |
+
logits = self.classifier(pooled_output)
|
| 1802 |
+
|
| 1803 |
+
loss = None
|
| 1804 |
+
if labels is not None:
|
| 1805 |
+
loss_fct = CrossEntropyLoss()
|
| 1806 |
+
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
| 1807 |
+
|
| 1808 |
+
if not return_dict:
|
| 1809 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1810 |
+
return ((loss,) + output) if loss is not None else output
|
| 1811 |
+
|
| 1812 |
+
return SequenceClassifierOutput(
|
| 1813 |
+
loss=loss,
|
| 1814 |
+
logits=logits,
|
| 1815 |
+
hidden_states=outputs.hidden_states,
|
| 1816 |
+
attentions=outputs.attentions,
|
| 1817 |
+
)
|
| 1818 |
+
|
| 1819 |
+
|
| 1820 |
+
@add_start_docstrings(
|
| 1821 |
+
"""
|
| 1822 |
+
Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
|
| 1823 |
+
""",
|
| 1824 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1825 |
+
)
|
| 1826 |
+
class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
|
| 1827 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
| 1828 |
+
def __init__(self, config):
|
| 1829 |
+
super().__init__(config)
|
| 1830 |
+
|
| 1831 |
+
if hasattr(config, "add_adapter") and config.add_adapter:
|
| 1832 |
+
raise ValueError(
|
| 1833 |
+
"Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
|
| 1834 |
+
)
|
| 1835 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1836 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1837 |
+
if config.use_weighted_layer_sum:
|
| 1838 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1839 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1840 |
+
self.num_labels = config.num_labels
|
| 1841 |
+
|
| 1842 |
+
self.init_weights()
|
| 1843 |
+
|
| 1844 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1845 |
+
def freeze_feature_encoder(self):
|
| 1846 |
+
"""
|
| 1847 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1848 |
+
not be updated during training.
|
| 1849 |
+
"""
|
| 1850 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1851 |
+
|
| 1852 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
|
| 1853 |
+
def freeze_base_model(self):
|
| 1854 |
+
"""
|
| 1855 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1856 |
+
be updated during training. Only the classification head will be updated.
|
| 1857 |
+
"""
|
| 1858 |
+
for param in self.wav2vec2_conformer.parameters():
|
| 1859 |
+
param.requires_grad = False
|
| 1860 |
+
|
| 1861 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1862 |
+
@add_code_sample_docstrings(
|
| 1863 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1864 |
+
output_type=TokenClassifierOutput,
|
| 1865 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1866 |
+
modality="audio",
|
| 1867 |
+
)
|
| 1868 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
|
| 1869 |
+
def forward(
|
| 1870 |
+
self,
|
| 1871 |
+
input_values: Optional[torch.Tensor],
|
| 1872 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1873 |
+
labels: Optional[torch.Tensor] = None,
|
| 1874 |
+
output_attentions: Optional[bool] = None,
|
| 1875 |
+
output_hidden_states: Optional[bool] = None,
|
| 1876 |
+
return_dict: Optional[bool] = None,
|
| 1877 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
| 1878 |
+
r"""
|
| 1879 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1880 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1881 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1882 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1883 |
+
"""
|
| 1884 |
+
|
| 1885 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1886 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 1887 |
+
|
| 1888 |
+
outputs = self.wav2vec2_conformer(
|
| 1889 |
+
input_values,
|
| 1890 |
+
attention_mask=attention_mask,
|
| 1891 |
+
output_attentions=output_attentions,
|
| 1892 |
+
output_hidden_states=output_hidden_states,
|
| 1893 |
+
return_dict=return_dict,
|
| 1894 |
+
)
|
| 1895 |
+
|
| 1896 |
+
if self.config.use_weighted_layer_sum:
|
| 1897 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 1898 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 1899 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 1900 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 1901 |
+
else:
|
| 1902 |
+
hidden_states = outputs[0]
|
| 1903 |
+
|
| 1904 |
+
logits = self.classifier(hidden_states)
|
| 1905 |
+
|
| 1906 |
+
loss = None
|
| 1907 |
+
if labels is not None:
|
| 1908 |
+
loss_fct = CrossEntropyLoss()
|
| 1909 |
+
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
|
| 1910 |
+
|
| 1911 |
+
if not return_dict:
|
| 1912 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1913 |
+
return output
|
| 1914 |
+
|
| 1915 |
+
return TokenClassifierOutput(
|
| 1916 |
+
loss=loss,
|
| 1917 |
+
logits=logits,
|
| 1918 |
+
hidden_states=outputs.hidden_states,
|
| 1919 |
+
attentions=outputs.attentions,
|
| 1920 |
+
)
|
| 1921 |
+
|
| 1922 |
+
|
| 1923 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
|
| 1924 |
+
class AMSoftmaxLoss(nn.Module):
|
| 1925 |
+
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
|
| 1926 |
+
super(AMSoftmaxLoss, self).__init__()
|
| 1927 |
+
self.scale = scale
|
| 1928 |
+
self.margin = margin
|
| 1929 |
+
self.num_labels = num_labels
|
| 1930 |
+
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
|
| 1931 |
+
self.loss = nn.CrossEntropyLoss()
|
| 1932 |
+
|
| 1933 |
+
def forward(self, hidden_states, labels):
|
| 1934 |
+
labels = labels.flatten()
|
| 1935 |
+
weight = nn.functional.normalize(self.weight, dim=0)
|
| 1936 |
+
hidden_states = nn.functional.normalize(hidden_states, dim=1)
|
| 1937 |
+
cos_theta = torch.mm(hidden_states, weight)
|
| 1938 |
+
psi = cos_theta - self.margin
|
| 1939 |
+
|
| 1940 |
+
onehot = nn.functional.one_hot(labels, self.num_labels)
|
| 1941 |
+
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
|
| 1942 |
+
loss = self.loss(logits, labels)
|
| 1943 |
+
|
| 1944 |
+
return loss
|
| 1945 |
+
|
| 1946 |
+
|
| 1947 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
|
| 1948 |
+
class TDNNLayer(nn.Module):
|
| 1949 |
+
def __init__(self, config, layer_id=0):
|
| 1950 |
+
super().__init__()
|
| 1951 |
+
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
|
| 1952 |
+
self.out_conv_dim = config.tdnn_dim[layer_id]
|
| 1953 |
+
self.kernel_size = config.tdnn_kernel[layer_id]
|
| 1954 |
+
self.dilation = config.tdnn_dilation[layer_id]
|
| 1955 |
+
|
| 1956 |
+
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
|
| 1957 |
+
self.activation = nn.ReLU()
|
| 1958 |
+
|
| 1959 |
+
def forward(self, hidden_states):
|
| 1960 |
+
hidden_states = hidden_states.unsqueeze(1)
|
| 1961 |
+
hidden_states = nn.functional.unfold(
|
| 1962 |
+
hidden_states,
|
| 1963 |
+
(self.kernel_size, self.in_conv_dim),
|
| 1964 |
+
stride=(1, self.in_conv_dim),
|
| 1965 |
+
dilation=(self.dilation, 1),
|
| 1966 |
+
)
|
| 1967 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1968 |
+
hidden_states = self.kernel(hidden_states)
|
| 1969 |
+
|
| 1970 |
+
hidden_states = self.activation(hidden_states)
|
| 1971 |
+
return hidden_states
|
| 1972 |
+
|
| 1973 |
+
|
| 1974 |
+
@add_start_docstrings(
|
| 1975 |
+
"""
|
| 1976 |
+
Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
|
| 1977 |
+
""",
|
| 1978 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1979 |
+
)
|
| 1980 |
+
class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
|
| 1981 |
+
def __init__(self, config):
|
| 1982 |
+
super().__init__(config)
|
| 1983 |
+
|
| 1984 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1985 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1986 |
+
if config.use_weighted_layer_sum:
|
| 1987 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1988 |
+
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
|
| 1989 |
+
|
| 1990 |
+
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
|
| 1991 |
+
self.tdnn = nn.ModuleList(tdnn_layers)
|
| 1992 |
+
|
| 1993 |
+
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
|
| 1994 |
+
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
|
| 1995 |
+
|
| 1996 |
+
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
|
| 1997 |
+
|
| 1998 |
+
self.init_weights()
|
| 1999 |
+
|
| 2000 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 2001 |
+
def freeze_feature_encoder(self):
|
| 2002 |
+
"""
|
| 2003 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 2004 |
+
not be updated during training.
|
| 2005 |
+
"""
|
| 2006 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 2007 |
+
|
| 2008 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
|
| 2009 |
+
def freeze_base_model(self):
|
| 2010 |
+
"""
|
| 2011 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 2012 |
+
be updated during training. Only the classification head will be updated.
|
| 2013 |
+
"""
|
| 2014 |
+
for param in self.wav2vec2_conformer.parameters():
|
| 2015 |
+
param.requires_grad = False
|
| 2016 |
+
|
| 2017 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
|
| 2018 |
+
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
| 2019 |
+
"""
|
| 2020 |
+
Computes the output length of the TDNN layers
|
| 2021 |
+
"""
|
| 2022 |
+
|
| 2023 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
| 2024 |
+
# 1D convolutional layer output length formula taken
|
| 2025 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
| 2026 |
+
return (input_length - kernel_size) // stride + 1
|
| 2027 |
+
|
| 2028 |
+
for kernel_size in self.config.tdnn_kernel:
|
| 2029 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
|
| 2030 |
+
|
| 2031 |
+
return input_lengths
|
| 2032 |
+
|
| 2033 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 2034 |
+
@add_code_sample_docstrings(
|
| 2035 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 2036 |
+
output_type=XVectorOutput,
|
| 2037 |
+
config_class=_CONFIG_FOR_DOC,
|
| 2038 |
+
modality="audio",
|
| 2039 |
+
)
|
| 2040 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
| 2041 |
+
def forward(
|
| 2042 |
+
self,
|
| 2043 |
+
input_values: Optional[torch.Tensor],
|
| 2044 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 2045 |
+
output_attentions: Optional[bool] = None,
|
| 2046 |
+
output_hidden_states: Optional[bool] = None,
|
| 2047 |
+
return_dict: Optional[bool] = None,
|
| 2048 |
+
labels: Optional[torch.Tensor] = None,
|
| 2049 |
+
) -> Union[Tuple, XVectorOutput]:
|
| 2050 |
+
r"""
|
| 2051 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 2052 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 2053 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 2054 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 2055 |
+
"""
|
| 2056 |
+
|
| 2057 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2058 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 2059 |
+
|
| 2060 |
+
outputs = self.wav2vec2_conformer(
|
| 2061 |
+
input_values,
|
| 2062 |
+
attention_mask=attention_mask,
|
| 2063 |
+
output_attentions=output_attentions,
|
| 2064 |
+
output_hidden_states=output_hidden_states,
|
| 2065 |
+
return_dict=return_dict,
|
| 2066 |
+
)
|
| 2067 |
+
|
| 2068 |
+
if self.config.use_weighted_layer_sum:
|
| 2069 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 2070 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 2071 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 2072 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 2073 |
+
else:
|
| 2074 |
+
hidden_states = outputs[0]
|
| 2075 |
+
|
| 2076 |
+
hidden_states = self.projector(hidden_states)
|
| 2077 |
+
|
| 2078 |
+
for tdnn_layer in self.tdnn:
|
| 2079 |
+
hidden_states = tdnn_layer(hidden_states)
|
| 2080 |
+
|
| 2081 |
+
# Statistic Pooling
|
| 2082 |
+
if attention_mask is None:
|
| 2083 |
+
mean_features = hidden_states.mean(dim=1)
|
| 2084 |
+
std_features = hidden_states.std(dim=1)
|
| 2085 |
+
else:
|
| 2086 |
+
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
|
| 2087 |
+
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
|
| 2088 |
+
mean_features = []
|
| 2089 |
+
std_features = []
|
| 2090 |
+
for i, length in enumerate(tdnn_output_lengths):
|
| 2091 |
+
mean_features.append(hidden_states[i, :length].mean(dim=0))
|
| 2092 |
+
std_features.append(hidden_states[i, :length].std(dim=0))
|
| 2093 |
+
mean_features = torch.stack(mean_features)
|
| 2094 |
+
std_features = torch.stack(std_features)
|
| 2095 |
+
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
|
| 2096 |
+
|
| 2097 |
+
output_embeddings = self.feature_extractor(statistic_pooling)
|
| 2098 |
+
logits = self.classifier(output_embeddings)
|
| 2099 |
+
|
| 2100 |
+
loss = None
|
| 2101 |
+
if labels is not None:
|
| 2102 |
+
loss = self.objective(logits, labels)
|
| 2103 |
+
|
| 2104 |
+
if not return_dict:
|
| 2105 |
+
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 2106 |
+
return ((loss,) + output) if loss is not None else output
|
| 2107 |
+
|
| 2108 |
+
return XVectorOutput(
|
| 2109 |
+
loss=loss,
|
| 2110 |
+
logits=logits,
|
| 2111 |
+
embeddings=output_embeddings,
|
| 2112 |
+
hidden_states=outputs.hidden_states,
|
| 2113 |
+
attentions=outputs.attentions,
|
| 2114 |
+
)
|
src/third_party/MuQ/src/muq/muq/modules/random_quantizer.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, einsum
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class RandomProjectionQuantizer(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Random projection and codebook lookup module
|
| 9 |
+
|
| 10 |
+
Some code is borrowed from:
|
| 11 |
+
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
|
| 12 |
+
But I did normalization using pre-computed global mean & variance instead of using layer norm.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
input_dim,
|
| 18 |
+
codebook_dim,
|
| 19 |
+
codebook_size,
|
| 20 |
+
seed=142,
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
# random seed
|
| 25 |
+
torch.manual_seed(seed)
|
| 26 |
+
|
| 27 |
+
# randomly initialized projection
|
| 28 |
+
random_projection = torch.empty(input_dim, codebook_dim)
|
| 29 |
+
nn.init.xavier_normal_(random_projection)
|
| 30 |
+
self.register_buffer("random_projection", random_projection)
|
| 31 |
+
|
| 32 |
+
# randomly initialized codebook
|
| 33 |
+
codebook = torch.empty(codebook_size, codebook_dim)
|
| 34 |
+
nn.init.normal_(codebook)
|
| 35 |
+
self.register_buffer("codebook", codebook)
|
| 36 |
+
|
| 37 |
+
def codebook_lookup(self, x):
|
| 38 |
+
# reshape
|
| 39 |
+
b = x.shape[0]
|
| 40 |
+
x = rearrange(x, "b n e -> (b n) e")
|
| 41 |
+
|
| 42 |
+
# L2 normalization
|
| 43 |
+
normalized_x = nn.functional.normalize(x, dim=1, p=2)
|
| 44 |
+
normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
|
| 45 |
+
|
| 46 |
+
# compute distances
|
| 47 |
+
distances = torch.cdist(normalized_codebook, normalized_x)
|
| 48 |
+
|
| 49 |
+
# get nearest
|
| 50 |
+
nearest_indices = torch.argmin(distances, dim=0)
|
| 51 |
+
|
| 52 |
+
# reshape
|
| 53 |
+
xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
|
| 54 |
+
|
| 55 |
+
return xq
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
# always eval
|
| 60 |
+
self.eval()
|
| 61 |
+
|
| 62 |
+
# random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
|
| 63 |
+
x = einsum("b n d, d e -> b n e", x, self.random_projection)
|
| 64 |
+
|
| 65 |
+
# codebook lookup
|
| 66 |
+
xq = self.codebook_lookup(x)
|
| 67 |
+
|
| 68 |
+
return xq
|
src/third_party/MuQ/src/muq/muq/modules/rvq.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
try:
|
| 10 |
+
from torch.nn.utils import weight_norm
|
| 11 |
+
except:
|
| 12 |
+
try:
|
| 13 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 14 |
+
except:
|
| 15 |
+
from torch.nn.utils.parametrize import weight_norm
|
| 16 |
+
|
| 17 |
+
def WNConv1d(*args, **kwargs):
|
| 18 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class VectorQuantize(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
Implementation of VQ similar to Karpathy's repo:
|
| 24 |
+
https://github.com/karpathy/deep-vector-quantization
|
| 25 |
+
Additionally uses following tricks from Improved VQGAN
|
| 26 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
| 27 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
| 28 |
+
for improved codebook usage
|
| 29 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
| 30 |
+
improves training stability
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 1000, mfcc_clustering=False, n_layer=1):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.codebook_size = codebook_size
|
| 36 |
+
self.codebook_dim = codebook_dim
|
| 37 |
+
self.mfcc_clustering = mfcc_clustering
|
| 38 |
+
|
| 39 |
+
ProjClass = nn.Identity if mfcc_clustering else WNConv1d
|
| 40 |
+
if n_layer==1:
|
| 41 |
+
self.in_proj = ProjClass(input_dim, codebook_dim, kernel_size=1)
|
| 42 |
+
self.out_proj = ProjClass(codebook_dim, input_dim, kernel_size=1)
|
| 43 |
+
elif n_layer >= 2:
|
| 44 |
+
ndim_hidden = 128
|
| 45 |
+
self.in_proj = nn.Sequential(
|
| 46 |
+
ProjClass(input_dim, ndim_hidden, kernel_size=1),
|
| 47 |
+
*[nn.Sequential(nn.ReLU(), ProjClass(ndim_hidden, ndim_hidden, kernel_size=1),) for _ in range(n_layer-2)],
|
| 48 |
+
nn.ReLU(),
|
| 49 |
+
ProjClass(ndim_hidden, codebook_dim, kernel_size=1)
|
| 50 |
+
)
|
| 51 |
+
self.out_proj = nn.Sequential(
|
| 52 |
+
ProjClass(codebook_dim, ndim_hidden, kernel_size=1),
|
| 53 |
+
nn.ReLU(),
|
| 54 |
+
*[nn.Sequential(ProjClass(ndim_hidden, ndim_hidden, kernel_size=1), nn.ReLU()) for _ in range(n_layer-2)],
|
| 55 |
+
ProjClass(ndim_hidden, input_dim, kernel_size=1),
|
| 56 |
+
)
|
| 57 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
| 58 |
+
self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
|
| 59 |
+
self.stale_tolerance = stale_tolerance
|
| 60 |
+
|
| 61 |
+
def forward(self, z):
|
| 62 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
| 63 |
+
the corresponding codebook vectors
|
| 64 |
+
|
| 65 |
+
Parameters
|
| 66 |
+
----------
|
| 67 |
+
z : Tensor[B x D x T]
|
| 68 |
+
|
| 69 |
+
Returns
|
| 70 |
+
-------
|
| 71 |
+
Tensor[B x D x T]
|
| 72 |
+
Quantized continuous representation of input
|
| 73 |
+
Tensor[1]
|
| 74 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 75 |
+
entries
|
| 76 |
+
Tensor[1]
|
| 77 |
+
Codebook loss to update the codebook
|
| 78 |
+
Tensor[B x T]
|
| 79 |
+
Codebook indices (quantized discrete representation of input)
|
| 80 |
+
Tensor[B x D x T]
|
| 81 |
+
Projected latents (continuous representation of input before quantization)
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
| 85 |
+
|
| 86 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
| 87 |
+
z_q, indices = self.decode_latents(z_e)
|
| 88 |
+
|
| 89 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
| 90 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
| 91 |
+
|
| 92 |
+
z_q = (
|
| 93 |
+
z_e + (z_q - z_e).detach()
|
| 94 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
| 95 |
+
|
| 96 |
+
z_q = self.out_proj(z_q)
|
| 97 |
+
|
| 98 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
| 99 |
+
|
| 100 |
+
def embed_code(self, embed_id):
|
| 101 |
+
return F.embedding(embed_id, self.codebook.weight)
|
| 102 |
+
|
| 103 |
+
def decode_code(self, embed_id):
|
| 104 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
| 105 |
+
|
| 106 |
+
def decode_latents(self, latents):
|
| 107 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
| 108 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
| 109 |
+
|
| 110 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
| 111 |
+
encodings = F.normalize(encodings)
|
| 112 |
+
codebook = F.normalize(codebook)
|
| 113 |
+
|
| 114 |
+
# Compute euclidean distance with codebook
|
| 115 |
+
dist = (
|
| 116 |
+
encodings.pow(2).sum(1, keepdim=True)
|
| 117 |
+
- 2 * encodings @ codebook.t()
|
| 118 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 119 |
+
)
|
| 120 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 121 |
+
z_q = self.decode_code(indices)
|
| 122 |
+
|
| 123 |
+
if(self.training):
|
| 124 |
+
onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
|
| 125 |
+
stale_codes = (onehots.sum(0).sum(0) == 0).float()
|
| 126 |
+
self.stale_counter = self.stale_counter * stale_codes + stale_codes
|
| 127 |
+
|
| 128 |
+
# random replace codes that haven't been used for a while
|
| 129 |
+
replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
|
| 130 |
+
if replace_code.sum(-1) > 0:
|
| 131 |
+
print("Replace {} codes".format(replace_code.sum(-1)))
|
| 132 |
+
random_input_idx = torch.randperm(encodings.shape[0])
|
| 133 |
+
random_input = encodings[random_input_idx].view(encodings.shape)
|
| 134 |
+
if random_input.shape[0] < self.codebook_size:
|
| 135 |
+
random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
|
| 136 |
+
random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
|
| 137 |
+
|
| 138 |
+
self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
|
| 139 |
+
self.stale_counter = self.stale_counter * (1 - replace_code)
|
| 140 |
+
|
| 141 |
+
return z_q, indices
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class ResidualVectorQuantize(nn.Module):
|
| 145 |
+
"""
|
| 146 |
+
Introduced in SoundStream: An end2end neural audio codec
|
| 147 |
+
https://arxiv.org/abs/2107.03312
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
input_dim: int = 512,
|
| 153 |
+
n_codebooks: int = 9,
|
| 154 |
+
codebook_size: int = 1024,
|
| 155 |
+
codebook_dim: Union[int, list] = 8,
|
| 156 |
+
quantizer_dropout: float = 0.0,
|
| 157 |
+
stale_tolerance: int = 100,
|
| 158 |
+
use_multi_layer_num:int = 1,
|
| 159 |
+
):
|
| 160 |
+
super().__init__()
|
| 161 |
+
if isinstance(codebook_dim, int):
|
| 162 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
| 163 |
+
|
| 164 |
+
self.n_codebooks = n_codebooks
|
| 165 |
+
self.codebook_dim = codebook_dim
|
| 166 |
+
self.codebook_size = codebook_size
|
| 167 |
+
|
| 168 |
+
self.quantizers = nn.ModuleList(
|
| 169 |
+
[
|
| 170 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance, n_layer=use_multi_layer_num)
|
| 171 |
+
for i in range(n_codebooks)
|
| 172 |
+
]
|
| 173 |
+
)
|
| 174 |
+
self.quantizer_dropout = quantizer_dropout
|
| 175 |
+
|
| 176 |
+
def forward(self, z, n_quantizers: int = None):
|
| 177 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
| 178 |
+
the corresponding codebook vectors
|
| 179 |
+
Parameters
|
| 180 |
+
----------
|
| 181 |
+
z : Tensor[B x D x T]
|
| 182 |
+
n_quantizers : int, optional
|
| 183 |
+
No. of quantizers to use
|
| 184 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
| 185 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
| 186 |
+
when in training mode, and a random number of quantizers is used.
|
| 187 |
+
Returns
|
| 188 |
+
-------
|
| 189 |
+
dict
|
| 190 |
+
A dictionary with the following keys:
|
| 191 |
+
|
| 192 |
+
"z" : Tensor[B x D x T]
|
| 193 |
+
Quantized continuous representation of input
|
| 194 |
+
"codes" : Tensor[B x N x T]
|
| 195 |
+
Codebook indices for each codebook
|
| 196 |
+
(quantized discrete representation of input)
|
| 197 |
+
"latents" : Tensor[B x N*D x T]
|
| 198 |
+
Projected latents (continuous representation of input before quantization)
|
| 199 |
+
"vq/commitment_loss" : Tensor[1]
|
| 200 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 201 |
+
entries
|
| 202 |
+
"vq/codebook_loss" : Tensor[1]
|
| 203 |
+
Codebook loss to update the codebook
|
| 204 |
+
"""
|
| 205 |
+
z_q = 0
|
| 206 |
+
residual = z
|
| 207 |
+
commitment_loss = 0
|
| 208 |
+
codebook_loss = 0
|
| 209 |
+
|
| 210 |
+
codebook_indices = []
|
| 211 |
+
latents = []
|
| 212 |
+
|
| 213 |
+
if n_quantizers is None:
|
| 214 |
+
n_quantizers = self.n_codebooks
|
| 215 |
+
if self.training:
|
| 216 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
| 217 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
| 218 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
| 219 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 220 |
+
n_quantizers = n_quantizers.to(z.device)
|
| 221 |
+
else:
|
| 222 |
+
n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1
|
| 223 |
+
n_quantizers = n_quantizers.to(z.device)
|
| 224 |
+
|
| 225 |
+
for i, quantizer in enumerate(self.quantizers):
|
| 226 |
+
# if self.training is False and i >= n_quantizers:
|
| 227 |
+
# break
|
| 228 |
+
|
| 229 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
| 230 |
+
residual
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Create mask to apply quantizer dropout
|
| 234 |
+
mask = (
|
| 235 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
| 236 |
+
)
|
| 237 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
| 238 |
+
residual = residual - z_q_i
|
| 239 |
+
|
| 240 |
+
# Sum losses
|
| 241 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
| 242 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
| 243 |
+
|
| 244 |
+
codebook_indices.append(indices_i)
|
| 245 |
+
latents.append(z_e_i)
|
| 246 |
+
|
| 247 |
+
codes = torch.stack(codebook_indices, dim=1)
|
| 248 |
+
latents = torch.cat(latents, dim=1)
|
| 249 |
+
|
| 250 |
+
encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
|
| 251 |
+
|
| 252 |
+
return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
|
| 253 |
+
|
| 254 |
+
def get_loss(self, x, quantized_prompt_embeds, commitment_loss, codebook_loss):
|
| 255 |
+
final_loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
|
| 256 |
+
return final_loss
|
| 257 |
+
|
| 258 |
+
def from_codes(self, codes: torch.Tensor):
|
| 259 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
| 260 |
+
Parameters
|
| 261 |
+
----------
|
| 262 |
+
codes : Tensor[B x N x T]
|
| 263 |
+
Quantized discrete representation of input
|
| 264 |
+
Returns
|
| 265 |
+
-------
|
| 266 |
+
Tensor[B x D x T]
|
| 267 |
+
Quantized continuous representation of input
|
| 268 |
+
"""
|
| 269 |
+
z_q = 0.0
|
| 270 |
+
z_p = []
|
| 271 |
+
n_codebooks = codes.shape[1]
|
| 272 |
+
for i in range(n_codebooks):
|
| 273 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
| 274 |
+
z_p.append(z_p_i)
|
| 275 |
+
|
| 276 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 277 |
+
z_q = z_q + z_q_i
|
| 278 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
| 279 |
+
|
| 280 |
+
def from_latents(self, latents: torch.Tensor):
|
| 281 |
+
"""Given the unquantized latents, reconstruct the
|
| 282 |
+
continuous representation after quantization.
|
| 283 |
+
|
| 284 |
+
Parameters
|
| 285 |
+
----------
|
| 286 |
+
latents : Tensor[B x N x T]
|
| 287 |
+
Continuous representation of input after projection
|
| 288 |
+
|
| 289 |
+
Returns
|
| 290 |
+
-------
|
| 291 |
+
Tensor[B x D x T]
|
| 292 |
+
Quantized representation of full-projected space
|
| 293 |
+
Tensor[B x D x T]
|
| 294 |
+
Quantized representation of latent space
|
| 295 |
+
"""
|
| 296 |
+
z_q = 0
|
| 297 |
+
z_p = []
|
| 298 |
+
codes = []
|
| 299 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
| 300 |
+
|
| 301 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
| 302 |
+
0
|
| 303 |
+
]
|
| 304 |
+
for i in range(n_codebooks):
|
| 305 |
+
j, k = dims[i], dims[i + 1]
|
| 306 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
| 307 |
+
z_p.append(z_p_i)
|
| 308 |
+
codes.append(codes_i)
|
| 309 |
+
|
| 310 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 311 |
+
z_q = z_q + z_q_i
|
| 312 |
+
|
| 313 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
| 314 |
+
|
src/third_party/MuQ/src/muq/muq/muq.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
from .models.muq_model import MuQModel
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 7 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class MuQConfig:
|
| 11 |
+
label_rate:int = field(default=25)
|
| 12 |
+
num_codebooks:int = field(default=1)
|
| 13 |
+
codebook_dim:int = field(default=16)
|
| 14 |
+
codebook_size:int = field(default=4096)
|
| 15 |
+
features:List[str] = field(default_factory=lambda:["melspec_2048"])
|
| 16 |
+
hop_length:int = field(default=240)
|
| 17 |
+
n_mels:int = field(default=128)
|
| 18 |
+
conv_dim:int = field(default=512)
|
| 19 |
+
encoder_dim:int = field(default=1024)
|
| 20 |
+
encoder_depth:int = field(default=12)
|
| 21 |
+
mask_hop:float = field(default=0.4)
|
| 22 |
+
mask_prob:float = field(default=0.6)
|
| 23 |
+
is_flash:bool = field(default=False)
|
| 24 |
+
stat:Optional[dict] = field(default_factory=dict)
|
| 25 |
+
w2v2_config:Optional[dict] = field(default_factory=dict)
|
| 26 |
+
use_rvq_target:bool = field(default=False)
|
| 27 |
+
use_vq_target:bool = field(default=False)
|
| 28 |
+
use_encodec_target:bool = field(default=False)
|
| 29 |
+
rvq_ckpt_path: Optional[str] = field(default=None)
|
| 30 |
+
recon_loss_ratio: Optional[float] = field(default=None)
|
| 31 |
+
resume_checkpoint: Optional[str] = None
|
| 32 |
+
rvq_n_codebooks:int = field(default=8)
|
| 33 |
+
rvq_multi_layer_num:int = field(default=1)
|
| 34 |
+
|
| 35 |
+
class MuQ(nn.Module, PyTorchModelHubMixin):
|
| 36 |
+
def __init__(self, config: MuQConfig):
|
| 37 |
+
super().__init__()
|
| 38 |
+
if isinstance(config, dict):
|
| 39 |
+
config = MuQConfig(**config)
|
| 40 |
+
self.config = config
|
| 41 |
+
self.model = MuQModel(
|
| 42 |
+
num_codebooks=config.num_codebooks,
|
| 43 |
+
codebook_dim=config.codebook_dim,
|
| 44 |
+
codebook_size=config.codebook_size,
|
| 45 |
+
features=config.features,
|
| 46 |
+
hop_length=config.hop_length,
|
| 47 |
+
n_mels=config.n_mels,
|
| 48 |
+
conv_dim=config.conv_dim,
|
| 49 |
+
encoder_dim=config.encoder_dim,
|
| 50 |
+
encoder_depth=config.encoder_depth,
|
| 51 |
+
mask_hop=config.mask_hop,
|
| 52 |
+
mask_prob=config.mask_prob,
|
| 53 |
+
is_flash=config.is_flash,
|
| 54 |
+
stat=config.stat,
|
| 55 |
+
w2v2_config=config.w2v2_config,
|
| 56 |
+
use_rvq_target=config.use_rvq_target,
|
| 57 |
+
use_vq_target=config.use_vq_target,
|
| 58 |
+
use_encodec_target=config.use_encodec_target,
|
| 59 |
+
rvq_ckpt_path=config.rvq_ckpt_path,
|
| 60 |
+
recon_loss_ratio=config.recon_loss_ratio,
|
| 61 |
+
label_rate=config.label_rate,
|
| 62 |
+
rvq_n_codebooks=config.rvq_n_codebooks,
|
| 63 |
+
rvq_multi_layer_num=config.rvq_multi_layer_num,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def forward(self, x, attention_mask:Optional[torch.Tensor]=None, output_hidden_states:bool=True) ->BaseModelOutput:
|
| 67 |
+
"""
|
| 68 |
+
Forward pass through the MuQ model and extract features.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
x (torch.Tensor): Input waveform tensor of shape (batch_size, time).
|
| 72 |
+
attention_mask (torch.Tensor, optional): Mask to avoid performing attention on padding token indices.
|
| 73 |
+
Default is None.
|
| 74 |
+
output_hidden_states (bool, optional): Whether to return all hidden states or only the last one.
|
| 75 |
+
Default is False.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
BaseModelOutput: An object containing the last hidden state and optionally all hidden states.
|
| 79 |
+
- last_hidden_state (torch.Tensor): The last hidden state of the model, i.e. extracted MuQ features, of shape (batch_size, sequence_length, hidden_size).
|
| 80 |
+
- hidden_states (tuple(torch.Tensor), optional): A tuple containing all hidden states produced by the model,
|
| 81 |
+
each of shape (batch_size, sequence_length, hidden_size). Only returned if output_hidden_states is True.
|
| 82 |
+
"""
|
| 83 |
+
_, hidden_states = self.model.get_predictions(x, attention_mask=attention_mask, is_features_only=True)
|
| 84 |
+
last_hidden_state = hidden_states[-1]
|
| 85 |
+
if not output_hidden_states:
|
| 86 |
+
return BaseModelOutput(last_hidden_state=last_hidden_state)
|
| 87 |
+
return BaseModelOutput(
|
| 88 |
+
last_hidden_state=last_hidden_state,
|
| 89 |
+
hidden_states=hidden_states
|
| 90 |
+
)
|
src/third_party/MuQ/src/muq/muq_mulan/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .muq_mulan import MuQMuLan, MuQMuLanConfig, MuLanConfig, ModalModelConfig, TextTransformerConfig, AudioTransformerConfig
|
src/third_party/MuQ/src/muq/muq_mulan/models/__init__.py
ADDED
|
File without changes
|