Spaces:
Runtime error
Runtime error
SpeechT5 upload
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- SpeechT5 +0 -1
- SpeechT5/CODE_OF_CONDUCT.md +9 -0
- SpeechT5/LICENSE +21 -0
- SpeechT5/README.md +267 -0
- SpeechT5/SECURITY.md +41 -0
- SpeechT5/Speech2C/README.md +145 -0
- SpeechT5/Speech2C/speech2c/__init__.py +1 -0
- SpeechT5/Speech2C/speech2c/config/base_100h.yaml +93 -0
- SpeechT5/Speech2C/speech2c/config/base_10h.yaml +104 -0
- SpeechT5/Speech2C/speech2c/config/speech2c_base_librispeech.yaml +100 -0
- SpeechT5/Speech2C/speech2c/criterions/__init__.py +10 -0
- SpeechT5/Speech2C/speech2c/criterions/ctc_ce.py +404 -0
- SpeechT5/Speech2C/speech2c/criterions/speech2c_criterion.py +261 -0
- SpeechT5/Speech2C/speech2c/data/speech2c_dataset.py +145 -0
- SpeechT5/Speech2C/speech2c/models/modules/ctc_prefix_score.py +93 -0
- SpeechT5/Speech2C/speech2c/models/modules/multihead_attention.py +341 -0
- SpeechT5/Speech2C/speech2c/models/modules/relative_pos_enc.py +35 -0
- SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder.py +485 -0
- SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder_layer.py +215 -0
- SpeechT5/Speech2C/speech2c/models/modules/transformer_encoder.py +278 -0
- SpeechT5/Speech2C/speech2c/models/speech2c.py +321 -0
- SpeechT5/Speech2C/speech2c/models/speech2c_asr.py +276 -0
- SpeechT5/Speech2C/speech2c/models/t5_transformer_lm.py +25 -0
- SpeechT5/Speech2C/speech2c/squence_generator.py +1028 -0
- SpeechT5/Speech2C/speech2c/tasks/speech2c_pretraining.py +91 -0
- SpeechT5/Speech2S/README.md +64 -0
- SpeechT5/Speech2S/speech2s/__init__.py +1 -0
- SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_base_100h.yaml +101 -0
- SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_100h.yaml +102 -0
- SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_960h.yaml +100 -0
- SpeechT5/Speech2S/speech2s/config/pretrain/speechut_base_librispeech.yaml +153 -0
- SpeechT5/Speech2S/speech2s/config/pretrain/speechut_large_librilight.yaml +159 -0
- SpeechT5/Speech2S/speech2s/criterions/__init__.py +9 -0
- SpeechT5/Speech2S/speech2s/criterions/ctc_ce.py +414 -0
- SpeechT5/Speech2S/speech2s/criterions/speechut_criterion.py +384 -0
- SpeechT5/Speech2S/speech2s/data/concat_dataset.py +129 -0
- SpeechT5/Speech2S/speech2s/data/hubert_dataset.py +597 -0
- SpeechT5/Speech2S/speech2s/data/language_trible_dataset.py +669 -0
- SpeechT5/Speech2S/speech2s/data/load_langpair_dataset.py +172 -0
- SpeechT5/Speech2S/speech2s/data/multimodal_corpus_dataset.py +368 -0
- SpeechT5/Speech2S/speech2s/models/__init__.py +0 -0
- SpeechT5/Speech2S/speech2s/models/speechut.py +785 -0
- SpeechT5/Speech2S/speech2s/models/speechut_asr.py +165 -0
- SpeechT5/Speech2S/speech2s/models/speechut_st.py +221 -0
- SpeechT5/Speech2S/speech2s/models/t5_transformer_lm.py +25 -0
- SpeechT5/Speech2S/speech2s/modules/__init__.py +27 -0
- SpeechT5/Speech2S/speech2s/modules/ctc_prefix_score.py +93 -0
- SpeechT5/Speech2S/speech2s/modules/learned_positional_embedding.py +69 -0
- SpeechT5/Speech2S/speech2s/modules/multihead_attention.py +346 -0
- SpeechT5/Speech2S/speech2s/modules/relative_pos_enc.py +33 -0
SpeechT5
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
Subproject commit 8b5ade783571e63450aaa5507444150dcb08fa94
|
|
|
|
SpeechT5/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Microsoft Open Source Code of Conduct
|
2 |
+
|
3 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
4 |
+
|
5 |
+
Resources:
|
6 |
+
|
7 |
+
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
8 |
+
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
9 |
+
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
SpeechT5/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Microsoft Corporation.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE
|
SpeechT5/README.md
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SpeechT5
|
2 |
+
|
3 |
+
Unified-modal speech-text pre-training for spoken language processing:
|
4 |
+
|
5 |
+
> [**SpeechT5**](https://arxiv.org/abs/2110.07205) (```ACL 2022```): **SpeechT5: Unified-Modal Encoder-Decoder Pre-training for Spoken Language Processing**
|
6 |
+
|
7 |
+
> [**Speech2C**](https://arxiv.org/abs/2203.17113) (```INTERSPEECH 2022```): **Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data**
|
8 |
+
|
9 |
+
> [**YiTrans**](https://arxiv.org/abs/2206.05777) (```IWSLT 2022```): **The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task**
|
10 |
+
|
11 |
+
> [**SpeechUT**](https://arxiv.org/abs/2210.03730) (```EMNLP 2022```): **SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training**
|
12 |
+
|
13 |
+
> [**SpeechLM**](https://arxiv.org/abs/2209.15329) (```Arxiv 2022```): **SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data**
|
14 |
+
|
15 |
+
> [**Speech2S**](https://arxiv.org/abs/2210.17027) (```ICASSP 2023```): **Joint Pre-Training with Speech and Bilingual Text for Direct Speech to Speech Translation**
|
16 |
+
|
17 |
+
> [**Prosody-SpeechT5**](https://ieeexplore.ieee.org/document/10096530/) (```ICASSP 2023```): **Prosody-aware SpeechT5 for Expressive Neural TTS**
|
18 |
+
|
19 |
+
> [**VATLM**](https://arxiv.org/abs/2211.11275) (```IEEE Transactions on Multimedia```): **VATLM: Visual-Audio-Text Pre-Training with Unified Masked Prediction for Speech Representation Learning**
|
20 |
+
|
21 |
+
> [**VALL-E X**](https://arxiv.org/abs/2303.03926) (```Arxiv 2023```): **Speak Foreign Languages with Your Own Voice: Cross-Lingual Neural Codec Language Modeling**
|
22 |
+
|
23 |
+
> [**VioLA**](https://arxiv.org/abs/2305.16107) (```Arxiv 2023```): **VioLA: Unified Codec Language Models for Speech Recognition, Synthesis, and Translation**
|
24 |
+
|
25 |
+
<!-- Model introductions, evaluation results, and model inference instructions are located in the corresponding folders. The source code is [https://github.com/microsoft/SpeechT5/tree/main/ModelName]. -->
|
26 |
+
|
27 |
+
|
28 |
+
## Update
|
29 |
+
|
30 |
+
- May, 2023: VioLA [**Arxiv**](https://arxiv.org/abs/2305.16107).
|
31 |
+
- May, 2023: [**VATLM**](https://arxiv.org/abs/2211.11275) was accepted by IEEE Transactions on Multimedia.
|
32 |
+
- March, 2023: VALL-E X [**Arxiv**](https://arxiv.org/abs/2303.03926) and [**Demo**](https://aka.ms/vallex).
|
33 |
+
- February, 2023: [**Speech2S**](https://arxiv.org/abs/2210.17027) and [**Prosody-SpeechT5**](https://arxiv.org/abs/2211.11275) were accepted by ICASSP 2023.
|
34 |
+
- [HuggingFace Integration] February, 2023: [**SpeechT5**](https://aclanthology.org/2022.acl-long.393/) models are on [**HuggingFace**](https://huggingface.co/blog/speecht5).
|
35 |
+
- [Model Release] November, 2022: [**VATLM**](https://github.com/microsoft/SpeechT5/tree/main/VATLM) models are released.
|
36 |
+
- November, 2022: VATLM [**Arxiv**](https://arxiv.org/abs/2211.11275).
|
37 |
+
- November, 2022: Speech2S [**Arxiv**](https://arxiv.org/abs/2210.17027).
|
38 |
+
- [Model Release] October, 2022: [**SpeechUT**](https://github.com/microsoft/SpeechT5/tree/main/SpeechUT) models are released.
|
39 |
+
- October, 2022: [**SpeechUT**](https://arxiv.org/abs/2210.03730) was accepted by EMNLP 2022.
|
40 |
+
- [Model Release] October, 2022: [**SpeechLM**](https://github.com/microsoft/SpeechT5/tree/main/SpeechLM) models are released.
|
41 |
+
- September, 2022: SpeechLM [**Arxiv**](https://arxiv.org/abs/2209.15329).
|
42 |
+
- [Evaluation] June, 2022: The end-to-end ST system [**YiTrans**](https://arxiv.org/abs/2206.05777) achieved top results on [**IWSLT 2022**](https://iwslt.org/2022/offline) shared tasks.
|
43 |
+
- June, 2022: [**Speech2C**](https://www.isca-speech.org/archive/interspeech_2022/ao22_interspeech.html) was accepted by InterSpeech 2022.
|
44 |
+
- [Model Release] May, 2022: [**Speech2C**](https://github.com/microsoft/SpeechT5/tree/main/Speech2C) models are released.
|
45 |
+
- [Model Release] April, 2022: [**SpeechT5**](https://github.com/microsoft/SpeechT5/tree/main/SpeechT5) models are released.
|
46 |
+
- March, 2022: Speech2C [**Arxiv**](https://arxiv.org/abs/2203.17113).
|
47 |
+
- February, 2022: [**SpeechT5**](https://aclanthology.org/2022.acl-long.393/) was accepted by ACL 2022.
|
48 |
+
- October, 2021: SpeechT5 [**Arxiv**](https://arxiv.org/abs/2110.07205).
|
49 |
+
|
50 |
+
|
51 |
+
## Pre-Trained Models
|
52 |
+
|
53 |
+
|
54 |
+
| Model | Pre-training Dataset | Fine-tuning Dataset | Model |
|
55 |
+
| :------: | :----------------------------------------------: | :-----------------: | :-----: |
|
56 |
+
| SpeechT5 Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [LibriSpeech LM Dataset](https://www.openslr.org/11/) | - | [HuggingFace](https://huggingface.co/ajyy/SpeechT5/resolve/main/speecht5_base.pt)<br /> [Google Drive](https://drive.google.com/file/d/1Sq00uZ1pw6Z4OUaqhOWzQEJxIVWgAO5U/view?usp=sharing) |
|
57 |
+
| SpeechT5 Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [LibriSpeech LM Dataset](https://www.openslr.org/11/) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [HuggingFace](https://huggingface.co/ajyy/SpeechT5/resolve/main/speecht5_base_asr.pt)<br /> [Google Drive](https://drive.google.com/file/d/1qLKJ81JPWOGf1MHfjSmgtZyqqTqgI6kT/view?usp=sharing) |
|
58 |
+
| SpeechT5 Large | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [LibriSpeech LM Dataset](https://www.openslr.org/11/) | - | [Google Drive](https://drive.google.com/file/d/1M79b1jetSPOVxWVMIX-y0URvDjNskZKp/view?usp=sharing) |
|
59 |
+
| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | - | [Google Drive](https://drive.google.com/file/d/1nGZ0LWEwlLq2pz7o805YALsMr9irV0Za/view?usp=sharing) |
|
60 |
+
| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [10 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1nWSAc-33LmcDQHzH8IjXVJsuk0JZTWgN/view?usp=sharing) |
|
61 |
+
| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1LwbQ5Y3tKZoK3s1ayLQgsfLTFnmkKNZs/view?usp=sharing) |
|
62 |
+
| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Google drive](https://drive.google.com/file/d/1iJvhSGghNrMT-wAY1nwVu2YaYuTy1pxx/view?usp=sharing) |
|
63 |
+
| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google drive](https://drive.google.com/file/d/1mH3N7iKMWYk3rSBJErQPYf3x5ugqDq5x/view?usp=sharing) |
|
64 |
+
| SpeechLM-H Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Google drive](https://drive.google.com/file/d/1eblW8U8f9t-NTuCNRrNHwr-8BeLAUAmQ/view?usp=sharing) |
|
65 |
+
| SpeechLM-H Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google drive](https://drive.google.com/file/d/1vXyO5DolbiWiTYZ6pkkKQsu2wJetaPlv/view?usp=sharing) |
|
66 |
+
| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-De CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/finetune_covost/checkpoint_ende.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
67 |
+
| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Ca CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/finetune_covost/checkpoint_enca.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
68 |
+
| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Ar CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/finetune_covost/checkpoint_enar.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
69 |
+
| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Tr CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/finetune_covost/checkpoint_entr.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
70 |
+
| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | - | [Google drive](https://drive.google.com/file/d/1QjLIgTJKIylVIp5hUkfSjGPtz8Xo7Lky/view?usp=sharing) |
|
71 |
+
| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [960 hrs LibriSpeech](http://www.openslr.org/12) | [Google drive](https://drive.google.com/file/d/1YZQDVv096o8Opt0RBnkRiZXYPRDqKZnP/view?usp=sharing) |
|
72 |
+
| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-De CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1qYygNWSc11TQbBI1OzC4ChlR-dNh8t9S/view?usp=sharing) |
|
73 |
+
| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Ca CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/162U88mwso2aVfzzPkEM2nP_vwTpcb57T/view?usp=sharing) |
|
74 |
+
| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Ar CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1lbTSRXewEeb2t45URunD6EiJcbniyjWW/view?usp=sharing) |
|
75 |
+
| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Tr CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1Er4I_jHS175pQQph223yKtiiLQ378VvH/view?usp=sharing) |
|
76 |
+
| SpeechUT Base (ASR) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4asr_32gpu_1accum/checkpoint_298_400000.pt?sv=2020-04-08&st=2023-03-08T01%3A39%3A48Z&se=2024-03-09T01%3A39%3A00Z&sr=b&sp=r&sig=l3gJS1D%2BJfLfNfS3z33WjmSMGrOBJ63CvqGGewC6WeU%3D)|
|
77 |
+
| SpeechUT Base (ASR) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/speechut_base_asr100h_checkpoint_best.pt?sv=2020-04-08&st=2023-03-08T01%3A41%3A22Z&se=2024-03-09T01%3A41%3A00Z&sr=b&sp=r&sig=%2B9lpGrqtZXa%2F6n1uZT%2Biey54ky31bYKSJytgfnBbbN4%3D)|
|
78 |
+
| SpeechUT Large (ASR) | [60k hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/large_speechut4asr_32gpu_4accum/checkpoint_22_400k.pt?sv=2020-04-08&st=2023-03-08T01%3A42%3A10Z&se=2024-03-09T01%3A42%3A00Z&sr=b&sp=r&sig=TZNcsHQAqapyj%2BAvpHtl749kZy9flTkWm8P5L4W26qs%3D)|
|
79 |
+
| SpeechUT Large (ASR) | [60k hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [960 hrs LibriSpeech](http://www.openslr.org/12) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/speechut_large_asr960h_checkpoint_best.pt?sv=2020-04-08&st=2023-03-08T01%3A43%3A02Z&se=2024-03-09T01%3A43%3A00Z&sr=b&sp=r&sig=PmO%2BgSAMXRgMC7GfpS4c%2BrDPsfJGekqUzD5AJm7RrYU%3D)|
|
80 |
+
| SpeechUT Base (En-De) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [408 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [4.6M Text](https://www.statmt.org/wmt16/) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4ende_32gpu_1accum/checkpoint_217_400000.pt?sv=2020-04-08&st=2023-03-08T01%3A43%3A47Z&se=2024-03-09T01%3A43%3A00Z&sr=b&sp=r&sig=XDEesMdGQ027j7YtpSql1kZtwgfv39gruOuWYlKlJ7w%3D)|
|
81 |
+
| SpeechUT Base (En-De) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [408 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [4.6M Text](https://www.statmt.org/wmt16/) | [En-De MuST-C v1](https://ict.fbk.eu/must-c/) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4ende_32gpu_1accum/fineutne_ende_checkpoint_avg.pt?sv=2020-04-08&st=2023-03-08T01%3A44%3A15Z&se=2024-03-09T01%3A44%3A00Z&sr=b&sp=r&sig=8dcenahRg46EJdwiHUalVBJgKra6JoSN7tUxdLAwzOM%3D)|
|
82 |
+
| SpeechUT Base (En-Es) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [504 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [15M Text](https://www.statmt.org/wmt13/) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4enes_32gpu_1accum/checkpoint_204_400000.pt?sv=2020-04-08&st=2023-03-08T01%3A48%3A16Z&se=2024-03-09T01%3A48%3A00Z&sr=b&sp=r&sig=hWoCM0y0SGZTD4CznC%2F5CejFczkqDYTOaISmlhCAYAU%3D)|
|
83 |
+
| SpeechUT Base (En-Es) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [504 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [15M Text](https://www.statmt.org/wmt13/) | [En-Es MuST-C v1](https://ict.fbk.eu/must-c/) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4enes_32gpu_1accum/fineutne_enes_checkpoint_avg.pt?sv=2020-04-08&st=2023-03-08T01%3A48%3A41Z&se=2024-03-09T01%3A48%3A00Z&sr=b&sp=r&sig=KGfzgKfKkDVQI0JxxnS%2BsYdBQzhUqFLQAVYG0OSGBtk%3D)|
|
84 |
+
| SpeechUT Base (En-Fr) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [492 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [40M Text](https://www.statmt.org/wmt14/) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4enfr_32gpu_1accum/checkpoint_297_600000.pt?sv=2020-04-08&st=2023-03-08T01%3A49%3A09Z&se=2024-03-09T01%3A49%3A00Z&sr=b&sp=r&sig=1eqpXMLCjWpfyd7AiOHGzfk%2B8ZYqWwVWdHk1GqXgoeg%3D)|
|
85 |
+
| SpeechUT Base (En-Fr) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [492 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [40M Text](https://www.statmt.org/wmt14/) | [En-Fr MuST-C v1](https://ict.fbk.eu/must-c/) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4enfr_32gpu_1accum/fineutne_enfr_checkpoint.pt?sv=2020-04-08&st=2023-03-08T01%3A49%3A34Z&se=2024-03-09T01%3A49%3A00Z&sr=b&sp=r&sig=i3jMAqvL1Vp7DRjACAbrdoQKhlv2Cmi40%2F14SJ6%2BoiU%3D)|
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
## SpeechT5 Introduction
|
90 |
+
|
91 |
+
Motivated by the success of T5 (Text-To-Text Transfer Transformer) in pre-trained natural language processing models, we propose a unified-modal SpeechT5 framework that explores the encoder-decoder pre-training for self-supervised speech/text representation learning.
|
92 |
+
The SpeechT5 framework consists of a shared encoder-decoder network and six modal-specific (speech/text) pre/post-nets.
|
93 |
+
After preprocessing the input speech/text through the pre-nets, the shared encoder-decoder network models the sequence-to-sequence transformation, and then the post-nets generate the output in the speech/text modality based on the output of the decoder.
|
94 |
+
|
95 |
+
<img src="SpeechT5/speecht5_framework.png" alt="se" width="1000" />
|
96 |
+
|
97 |
+
Leveraging large-scale unlabeled speech and text data, we pre-train SpeechT5 to learn a unified-modal representation, hoping to improve the modeling capability for both speech and text.
|
98 |
+
To align the textual and speech information into this unified semantic space, we propose a cross-modal vector quantization approach that randomly mixes up speech/text states with latent units as the interface between encoder and decoder.
|
99 |
+
Extensive evaluations show the superiority of the proposed SpeechT5 framework on a wide variety of spoken language processing tasks, including automatic speech recognition, speech synthesis, speech translation, voice conversion, speech enhancement, and speaker identification.
|
100 |
+
|
101 |
+
<!--
|
102 |
+
Model introductions, evaluation results, and model inference instructions are located in the corresponding folders. The source code is here [https://github.com/microsoft/SpeechT5/tree/main/SpeechT5].
|
103 |
+
-->
|
104 |
+
|
105 |
+
## SpeechT5 Downstream Task Performance
|
106 |
+
|
107 |
+
We evaluate our models on typical spoken language processing tasks, including automatic speech recognition, text to speech, speech to text translation, voice conversion, speech enhancement, and speaker identification.
|
108 |
+
|
109 |
+
### Automatic Speech Recognition
|
110 |
+
|
111 |
+
Evaluation on the [LibriSpeech](http://www.openslr.org/12)
|
112 |
+
|
113 |
+
| Model |LM | dev-clean | dev-other | test-clean | test-other |
|
114 |
+
| ------------- |------------- | ------| ----- | ----| ----|
|
115 |
+
| wav2vec2.0 Base | - | 6.1 | 13.5 | 6.1 | 13.3 |
|
116 |
+
| HuBERT Base | - | 5.5 | 13.1 | 5.8 | 13.3 |
|
117 |
+
| Baseline (w/o CTC) | - | 5.8 | 12.3 | 6.2 | 12.3 |
|
118 |
+
| Baseline | - | 4.9 | 11.7 | 5.0 | 11.9 |
|
119 |
+
| SpeechT5 (w/o CTC) | - | 5.4 | 10.7 | 5.8 | 10.7 |
|
120 |
+
| **SpeechT5** | - | **4.3** | **10.3** | **4.4** | **10.4** |
|
121 |
+
| DiscreteBERT | 4-gram | 4.0 |10.9 |4.5 |12.1 |
|
122 |
+
| wav2vec 2.0 Base | 4-gram | 2.7 |7.9 |3.4 |8.0 |
|
123 |
+
| HuBERT Base | 4-gram | 2.7 |7.8 |3.4 |8.1 |
|
124 |
+
| wav2vec 2.0 Base | Transf. | 2.2 |6.3 |2.6 |6.3 |
|
125 |
+
| Baseline | Transf. | 2.3 |6.3 |2.5 |6.3 |
|
126 |
+
| **SpeechT5** | Transf. | **2.1** |**5.5** |**2.4** |**5.8** |
|
127 |
+
|
128 |
+
### Text-to-Speech
|
129 |
+
|
130 |
+
Evaluation on the [LibriTTS](http://www.openslr.org/60/)
|
131 |
+
|
132 |
+
|
133 |
+
| Model | Naturalness | MOS | CMOS |
|
134 |
+
| ------------- |------------ | ------ | ----- |
|
135 |
+
| Ground Truth | - | 3.87 | - |
|
136 |
+
| Baseline | 2.76 | 3.56 | 0 |
|
137 |
+
| **SpeechT5** | 2.91 | **3.65** | **+0.290** |
|
138 |
+
|
139 |
+
### Speech Translation
|
140 |
+
|
141 |
+
Evaluation on the [MUST-C v1](https://ict.fbk.eu/must-c/)
|
142 |
+
|
143 |
+
| Model | EN-DE | EN-FR |
|
144 |
+
| ------------- |------------ | ------ |
|
145 |
+
| Fairseq ST | 22.70 | 32.90 |
|
146 |
+
| ESPnet ST | 22.91 | 32.69 |
|
147 |
+
| Adapter Tuning| 24.63 | 34.98 |
|
148 |
+
| Baseline | 23.43 | 33.76 |
|
149 |
+
| SpeechT5 (w/o initializing decoder) | 24.44 | 34.5 |
|
150 |
+
| **SpeechT5** | **25.18** | **35.30** |
|
151 |
+
|
152 |
+
|
153 |
+
### Voice Conversion
|
154 |
+
|
155 |
+
Evaluation on the [CMU Arctic](http://www.festvox.org/cmu_arctic/)
|
156 |
+
|
157 |
+
|
158 |
+
| Model | WER | WER | MCD | MCD |
|
159 |
+
| ------------- | ------ | ----- | ---- | ----|
|
160 |
+
| | bdl to slt | clb to slt | bdl to slt | clb to slt |
|
161 |
+
| VTN w/ ASR | 11.1 | 10.9 | 6.5 | 6.11 |
|
162 |
+
| VTN w/ TTS | 7.6 | 9.1 | 6.33 | 13.3 |
|
163 |
+
| Many-to-many VTN | - | - | 6.13 | 5.97 |
|
164 |
+
| Baseline | 21.5 | 10.8 | 6.26 | 6.16 |
|
165 |
+
| **SpeechT5** | **7.8** | **6.4** | **5.93**| **5.87** |
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
### Speech Enhancement
|
170 |
+
|
171 |
+
Evaluation on the [WSJ0 Hipster AmbientMixtures (WHAM!)](http://wham.whisper.ai/)
|
172 |
+
|
173 |
+
|
174 |
+
| Model | WER |
|
175 |
+
| ------------- |------------ |
|
176 |
+
| Ground Truth Speech | 3.2 |
|
177 |
+
| Noisy Speech | 76.1 |
|
178 |
+
| Baseline | 10.9 |
|
179 |
+
| **SpeechT5** | **8.9** |
|
180 |
+
|
181 |
+
|
182 |
+
### Speaker Identification
|
183 |
+
|
184 |
+
Evaluation on the [VoxCeleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html)
|
185 |
+
|
186 |
+
| Model | Acc |
|
187 |
+
| ------------- |------------ |
|
188 |
+
| SUPERB, wav2vec 2.0 Base | 75.18% |
|
189 |
+
| SUPERB, HuBERT Base | 81.42% |
|
190 |
+
| SUPERB, HuBERT Large | 90.33% |
|
191 |
+
| SpeechNet, single task | 86.00% |
|
192 |
+
| SpeechNet, multi-task with TTS | 87.90% |
|
193 |
+
| Thin ResNet-34 | 89.00% |
|
194 |
+
| Baseline | 91.92% |
|
195 |
+
| **SpeechT5** | **96.49%** |
|
196 |
+
|
197 |
+
## License
|
198 |
+
|
199 |
+
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
|
200 |
+
Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) and [ESPnet](https://github.com/espnet/espnet) projects.
|
201 |
+
|
202 |
+
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
|
203 |
+
|
204 |
+
### Reference
|
205 |
+
|
206 |
+
If you find our work is useful in your research, please cite the following paper:
|
207 |
+
|
208 |
+
```bibtex
|
209 |
+
@article{Ao2021SpeechT5,
|
210 |
+
title = {SpeechT5: Unified-Modal Encoder-Decoder Pre-training for Spoken Language Processing},
|
211 |
+
author = {Junyi Ao and Rui Wang and Long Zhou and Chengyi Wang and Shuo Ren and Yu Wu and Shujie Liu and Tom Ko and Qing Li and Yu Zhang and Zhihua Wei and Yao Qian and Jinyu Li and Furu Wei},
|
212 |
+
eprint={2110.07205},
|
213 |
+
archivePrefix={arXiv},
|
214 |
+
primaryClass={eess.AS},
|
215 |
+
year={2021}
|
216 |
+
}
|
217 |
+
```
|
218 |
+
|
219 |
+
```bibtex
|
220 |
+
@article{Ao2022Speech2C,
|
221 |
+
title = {Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data},
|
222 |
+
author = {Junyi Ao and Ziqiang Zhang and Long Zhou and Shujie Liu and Haizhou Li and Tom Ko and Lirong Dai and Jinyu Li and Yao Qian and Furu Wei},
|
223 |
+
eprint={2203.17113},
|
224 |
+
archivePrefix={arXiv},
|
225 |
+
primaryClass={cs.SD},
|
226 |
+
year={2022}
|
227 |
+
}
|
228 |
+
```
|
229 |
+
|
230 |
+
```bibtex
|
231 |
+
@article{Zhang2022Yitrans,
|
232 |
+
title = {The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task},
|
233 |
+
author = {Zhang, Ziqiang and Ao, Junyi and Zhou, Long and Liu, Shujie and Wei, Furu and Li, Jinyu},
|
234 |
+
eprint={2206.05777},
|
235 |
+
archivePrefix={arXiv},
|
236 |
+
primaryClass={cs.CL},
|
237 |
+
year={2022}
|
238 |
+
}
|
239 |
+
```
|
240 |
+
|
241 |
+
```bibtex
|
242 |
+
@article{zhang2022speechut,
|
243 |
+
title = {SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training},
|
244 |
+
author = {Zhang, Ziqiang and Zhou, Long and Ao, Junyi and Liu, Shujie and Dai, Lirong and Li, Jinyu and Wei, Furu},
|
245 |
+
eprint={2210.03730},
|
246 |
+
archivePrefix={arXiv},
|
247 |
+
primaryClass={cs.CL},
|
248 |
+
year={2022}
|
249 |
+
}
|
250 |
+
```
|
251 |
+
|
252 |
+
```bibtex
|
253 |
+
@article{zhang2022speechlm,
|
254 |
+
title = {SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data},
|
255 |
+
author = {Zhang, Ziqiang and Chen, Sanyuan and Zhou, Long and Wu, Yu and Ren, Shuo and Liu, Shujie and Yao, Zhuoyuan and Gong, Xun and Dai, Lirong and Li, Jinyu and Wei, Furu},
|
256 |
+
eprint={2209.15329},
|
257 |
+
archivePrefix={arXiv},
|
258 |
+
primaryClass={cs.CL},
|
259 |
+
year={2022}
|
260 |
+
}
|
261 |
+
```
|
262 |
+
|
263 |
+
### Contact Information
|
264 |
+
|
265 |
+
For help or issues using SpeechT5 models, please submit a GitHub issue.
|
266 |
+
|
267 |
+
For other communications related to SpeechT5, please contact Long Zhou (`lozhou@microsoft.com`).
|
SpeechT5/SECURITY.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.7 BLOCK -->
|
2 |
+
|
3 |
+
## Security
|
4 |
+
|
5 |
+
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
6 |
+
|
7 |
+
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
|
8 |
+
|
9 |
+
## Reporting Security Issues
|
10 |
+
|
11 |
+
**Please do not report security vulnerabilities through public GitHub issues.**
|
12 |
+
|
13 |
+
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
|
14 |
+
|
15 |
+
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
|
16 |
+
|
17 |
+
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
|
18 |
+
|
19 |
+
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
20 |
+
|
21 |
+
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
22 |
+
* Full paths of source file(s) related to the manifestation of the issue
|
23 |
+
* The location of the affected source code (tag/branch/commit or direct URL)
|
24 |
+
* Any special configuration required to reproduce the issue
|
25 |
+
* Step-by-step instructions to reproduce the issue
|
26 |
+
* Proof-of-concept or exploit code (if possible)
|
27 |
+
* Impact of the issue, including how an attacker might exploit the issue
|
28 |
+
|
29 |
+
This information will help us triage your report more quickly.
|
30 |
+
|
31 |
+
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
|
32 |
+
|
33 |
+
## Preferred Languages
|
34 |
+
|
35 |
+
We prefer all communications to be in English.
|
36 |
+
|
37 |
+
## Policy
|
38 |
+
|
39 |
+
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
|
40 |
+
|
41 |
+
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
SpeechT5/Speech2C/README.md
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Speech2C
|
2 |
+
|
3 |
+
> [**Speech2C**](https://arxiv.org/abs/2203.17113) (```INTERSPEECH 2022```): **Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data**
|
4 |
+
|
5 |
+
## Pre-Trained and Fine-tuned Models
|
6 |
+
|
7 |
+
| Model | Pre-training Dataset | Fine-tuning Dataset | Model |
|
8 |
+
| :------: | :----------------------------------------------: | :-----------------: | :-----: |
|
9 |
+
| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | - | [Google Drive](https://drive.google.com/file/d/1nGZ0LWEwlLq2pz7o805YALsMr9irV0Za/view?usp=sharing) |
|
10 |
+
| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [10 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1nWSAc-33LmcDQHzH8IjXVJsuk0JZTWgN/view?usp=sharing) |
|
11 |
+
| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1LwbQ5Y3tKZoK3s1ayLQgsfLTFnmkKNZs/view?usp=sharing) |
|
12 |
+
|
13 |
+
|
14 |
+
## Language Model and Vocabulary
|
15 |
+
| Model | Dataset | Model | Vocabulary |
|
16 |
+
| :------: | :------: | :---: | :--------: |
|
17 |
+
| LM | [LibriSpeech LM Dataset](https://www.openslr.org/11/) | [Model](https://drive.google.com/file/d/1UDCcNJT1DlquSRw0iRAXH6GHlf6zK6-8/view?usp=sharing) | [Vocabulary](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt) |
|
18 |
+
|
19 |
+
## Setup
|
20 |
+
```
|
21 |
+
git submodule update --init Speech2C/fairseq
|
22 |
+
cd Speech2C/
|
23 |
+
pip install --editable fairseq/
|
24 |
+
```
|
25 |
+
|
26 |
+
## Data Preparation
|
27 |
+
Please follow the steps of data preparation for HuBERT in [here](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#data-preparation).
|
28 |
+
|
29 |
+
## Pre-Training
|
30 |
+
```
|
31 |
+
DATA_DIR=
|
32 |
+
LABEL_DIR=
|
33 |
+
FAIRSEQ_PATH=
|
34 |
+
|
35 |
+
python ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \
|
36 |
+
--config-dir speech2c/config \
|
37 |
+
--config-name speech2c_base_librispeech \
|
38 |
+
task.data=${DATA_DIR} task.label_dir=${LABEL_DIR} task.labels='["km"]' \
|
39 |
+
model.label_rate=50 common.user_dir=SpeechT5/Speech2C/speech2c \
|
40 |
+
```
|
41 |
+
|
42 |
+
## Finetune
|
43 |
+
|
44 |
+
```
|
45 |
+
DATA_DIR=
|
46 |
+
LABEL_DIR=
|
47 |
+
FAIRSEQ_PATH=
|
48 |
+
W2V_PATH=
|
49 |
+
CONFIG_NAME=
|
50 |
+
|
51 |
+
python ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \
|
52 |
+
--config-dir speech2c/config \
|
53 |
+
--config-name ${CONFIG_NAME} \
|
54 |
+
task.data=${DATA_DIR} task.label_dir=${LABEL_DIR} \
|
55 |
+
model.w2v_path=${W2V_PATH} common.user_dir=SpeechT5/Speech2C/speech2c \
|
56 |
+
```
|
57 |
+
|
58 |
+
## Inference
|
59 |
+
Note that joint CTC and decoder inference is only supported when the batch size is 1.
|
60 |
+
|
61 |
+
```
|
62 |
+
FAIRSEQ_PATH=
|
63 |
+
DATA_DIR=
|
64 |
+
LABEL_DIR=
|
65 |
+
BEAM_SIZE=
|
66 |
+
CTC_WEIGHT=
|
67 |
+
TEST_SET=
|
68 |
+
CHECKPOINT_PATH=
|
69 |
+
W2V_PATH=
|
70 |
+
|
71 |
+
|
72 |
+
python ${FAIRSEQ_PATH}/fairseq_cli/generate.py ${DATA_DIR} \
|
73 |
+
--label-dir ${LABEL_DIR} \
|
74 |
+
--path ${CHECKPOINT_PATH} \
|
75 |
+
--user-dir SpeechT5/Speech2C/speech2c \
|
76 |
+
--model-overrides "{'w2v_path': '${W2V_PATH}'}" \
|
77 |
+
--gen-subset ${TEST_SET} \
|
78 |
+
--task speech2c_pretraining \
|
79 |
+
--post-process letter \
|
80 |
+
--add-decoder \
|
81 |
+
--labels '["ltr"]' \
|
82 |
+
--fine-tuning \
|
83 |
+
--scoring wer \
|
84 |
+
--max-len-a 0 \
|
85 |
+
--max-len-b 620 \
|
86 |
+
--pad-audio \
|
87 |
+
--random-crop \
|
88 |
+
--ctc-weight ${CTC_WEIGHT} \
|
89 |
+
--max-tokens 8000000 \
|
90 |
+
--beam ${BEAM_SIZE} \
|
91 |
+
--single-target \
|
92 |
+
```
|
93 |
+
|
94 |
+
## Results on Librispeech
|
95 |
+
|
96 |
+
### Evaluation on the [LibriSpeech](http://www.openslr.org/12) 10hr subset
|
97 |
+
|
98 |
+
| Model |LM | test-clean | test-other |
|
99 |
+
| ------------- |------------- | ----| ----|
|
100 |
+
| wav2vec2.0 Base | - | 11.1 | 17.6 |
|
101 |
+
| HuBERT Base | - | 10.1 | 16.8 |
|
102 |
+
| **Speech2C** | - | **7.8** | **13.1** |
|
103 |
+
| wav2vec 2.0 Base | 4-gram | 4.3 |9.5 |
|
104 |
+
| wav2vec 2.0 Base | Transf. |3.2 |7.8 |
|
105 |
+
| HuBERT Base | 4-gram |4.3 |9.4 |
|
106 |
+
| **Speech2C** | **Transf.** | **3.1** | **7.0** |
|
107 |
+
|
108 |
+
### Evaluation on the [LibriSpeech](http://www.openslr.org/12) 100hr subset
|
109 |
+
|
110 |
+
| Model |LM | test-clean | test-other |
|
111 |
+
| ------------- |------------- | ----| ----|
|
112 |
+
| wav2vec2.0 Base | - | 6.1 | 13.3 |
|
113 |
+
| wav2vec2.0 Large | - | 4.7 | 9.0 |
|
114 |
+
| HuBERT Base | - | 6.3 | 13.2 |
|
115 |
+
| SpeechT5 | - | 4.4 | 10.4 |
|
116 |
+
| Baseline | - | 5.0 | 11.9 |
|
117 |
+
| **Speech2C** | - | **4.3** |**9.0** |
|
118 |
+
| wav2vec 2.0 Base | 4-gram | 3.4 |8.0 |
|
119 |
+
| wav2vec 2.0 Base | Transf. | 2.6 | 6.3 |
|
120 |
+
| HuBERT Base | 4-gram | 3.4 |8.1 |
|
121 |
+
| SpeechT5 | Transf. | 2.4 |5.8 |
|
122 |
+
| Baseline | Transf. | 2.5 |6.3 |
|
123 |
+
| **Speech2C** | **Transf.** | **2.4** |**5.2** |
|
124 |
+
|
125 |
+
## License
|
126 |
+
|
127 |
+
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
|
128 |
+
Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq).
|
129 |
+
|
130 |
+
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
|
131 |
+
|
132 |
+
## Reference
|
133 |
+
|
134 |
+
If you find our work is useful in your research, please cite the following paper:
|
135 |
+
|
136 |
+
```bibtex
|
137 |
+
@article{Ao2022Speech2C,
|
138 |
+
title = {Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data},
|
139 |
+
author = {Junyi Ao and Ziqiang Zhang and Long Zhou and Shujie Liu and Haizhou Li and Tom Ko and Lirong Dai and Jinyu Li and Yao Qian and Furu Wei},
|
140 |
+
eprint={2203.17113},
|
141 |
+
archivePrefix={arXiv},
|
142 |
+
primaryClass={cs.SD},
|
143 |
+
year={2022}
|
144 |
+
}
|
145 |
+
```
|
SpeechT5/Speech2C/speech2c/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import data, tasks, criterions, models # noqa
|
SpeechT5/Speech2C/speech2c/config/base_100h.yaml
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
no_epoch_checkpoints: true
|
12 |
+
best_checkpoint_metric: dec_accuracy
|
13 |
+
maximize_best_checkpoint_metric: true
|
14 |
+
|
15 |
+
distributed_training:
|
16 |
+
ddp_backend: c10d
|
17 |
+
find_unused_parameters: true
|
18 |
+
distributed_world_size: 1
|
19 |
+
distributed_port: 29671
|
20 |
+
nprocs_per_node: 8
|
21 |
+
|
22 |
+
task:
|
23 |
+
_name: speech2c_pretraining
|
24 |
+
data: ???
|
25 |
+
fine_tuning: true
|
26 |
+
label_dir: ???
|
27 |
+
normalize: false # must be consistent with pre-training
|
28 |
+
labels: ["ltr"]
|
29 |
+
single_target: true
|
30 |
+
add_decoder: true
|
31 |
+
pad_audio: true
|
32 |
+
random_crop: false
|
33 |
+
|
34 |
+
dataset:
|
35 |
+
num_workers: 6
|
36 |
+
max_tokens: 3200000
|
37 |
+
skip_invalid_size_inputs_valid_test: true
|
38 |
+
train_subset: train_100h
|
39 |
+
valid_subset: dev_other
|
40 |
+
|
41 |
+
criterion:
|
42 |
+
_name: ctc_ce
|
43 |
+
zero_infinity: true
|
44 |
+
|
45 |
+
optimization:
|
46 |
+
max_update: 80000
|
47 |
+
lr: [0.00004]
|
48 |
+
sentence_avg: true
|
49 |
+
update_freq: [1]
|
50 |
+
|
51 |
+
optimizer:
|
52 |
+
_name: adam
|
53 |
+
adam_betas: (0.9,0.98)
|
54 |
+
adam_eps: 1e-08
|
55 |
+
|
56 |
+
lr_scheduler:
|
57 |
+
_name: tri_stage
|
58 |
+
phase_ratio: [0.1, 0.4, 0.5]
|
59 |
+
final_lr_scale: 0.05
|
60 |
+
|
61 |
+
model:
|
62 |
+
_name: speech2c_ctc
|
63 |
+
w2v_path: ???
|
64 |
+
apply_mask: true
|
65 |
+
mask_prob: 0.65
|
66 |
+
mask_channel_prob: 0.5
|
67 |
+
mask_channel_length: 64
|
68 |
+
layerdrop: 0.1
|
69 |
+
decoder_layerdrop: 0.1
|
70 |
+
activation_dropout: 0.1
|
71 |
+
feature_grad_mult: 0.0
|
72 |
+
freeze_finetune_updates: 25000
|
73 |
+
|
74 |
+
hydra:
|
75 |
+
job:
|
76 |
+
config:
|
77 |
+
override_dirname:
|
78 |
+
kv_sep: '-'
|
79 |
+
item_sep: '__'
|
80 |
+
exclude_keys:
|
81 |
+
- run
|
82 |
+
- task.data
|
83 |
+
- task.label_dir
|
84 |
+
- model.w2v_path
|
85 |
+
- dataset.train_subset
|
86 |
+
- dataset.valid_subset
|
87 |
+
- criterion.wer_kenlm_model
|
88 |
+
- criterion.wer_lexicon
|
89 |
+
run:
|
90 |
+
dir: ???
|
91 |
+
sweep:
|
92 |
+
dir: ???
|
93 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
SpeechT5/Speech2C/speech2c/config/base_10h.yaml
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_interval: 5
|
12 |
+
keep_interval_updates: 1
|
13 |
+
no_epoch_checkpoints: true
|
14 |
+
best_checkpoint_metric: dec_accuracy
|
15 |
+
maximize_best_checkpoint_metric: true
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: c10d
|
19 |
+
find_unused_parameters: true
|
20 |
+
distributed_world_size: 1
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: speech2c_pretraining
|
26 |
+
data: ???
|
27 |
+
fine_tuning: true
|
28 |
+
label_dir: ???
|
29 |
+
normalize: false # must be consistent with pre-training
|
30 |
+
labels: ["ltr"]
|
31 |
+
single_target: true
|
32 |
+
add_decoder: true
|
33 |
+
pad_audio: true
|
34 |
+
random_crop: false
|
35 |
+
|
36 |
+
dataset:
|
37 |
+
num_workers: 6
|
38 |
+
max_tokens: 3200000
|
39 |
+
skip_invalid_size_inputs_valid_test: true
|
40 |
+
validate_after_updates: ${model.freeze_finetune_updates}
|
41 |
+
validate_interval: 5
|
42 |
+
train_subset: train_10h
|
43 |
+
valid_subset: dev_other
|
44 |
+
|
45 |
+
criterion:
|
46 |
+
_name: ctc_ce
|
47 |
+
zero_infinity: true
|
48 |
+
|
49 |
+
optimization:
|
50 |
+
max_update: 25000
|
51 |
+
lr: [2e-5]
|
52 |
+
sentence_avg: true
|
53 |
+
update_freq: [1]
|
54 |
+
|
55 |
+
optimizer:
|
56 |
+
_name: adam
|
57 |
+
adam_betas: (0.9,0.98)
|
58 |
+
adam_eps: 1e-08
|
59 |
+
|
60 |
+
lr_scheduler:
|
61 |
+
_name: tri_stage
|
62 |
+
phase_ratio: [0.1, 0.4, 0.5]
|
63 |
+
final_lr_scale: 0.05
|
64 |
+
|
65 |
+
model:
|
66 |
+
_name: speech2c_ctc
|
67 |
+
w2v_path: ???
|
68 |
+
apply_mask: true
|
69 |
+
mask_selection: static
|
70 |
+
mask_length: 10
|
71 |
+
mask_other: 0
|
72 |
+
mask_prob: 0.75
|
73 |
+
mask_channel_selection: static
|
74 |
+
mask_channel_length: 64
|
75 |
+
mask_channel_other: 0
|
76 |
+
mask_channel_prob: 0.5
|
77 |
+
layerdrop: 0.1
|
78 |
+
decoder_layerdrop: 0.1
|
79 |
+
dropout: 0.0
|
80 |
+
activation_dropout: 0.1
|
81 |
+
attention_dropout: 0.0
|
82 |
+
feature_grad_mult: 0.0
|
83 |
+
freeze_finetune_updates: 10000
|
84 |
+
|
85 |
+
hydra:
|
86 |
+
job:
|
87 |
+
config:
|
88 |
+
override_dirname:
|
89 |
+
kv_sep: '-'
|
90 |
+
item_sep: '__'
|
91 |
+
exclude_keys:
|
92 |
+
- run
|
93 |
+
- task.data
|
94 |
+
- task.label_dir
|
95 |
+
- model.w2v_path
|
96 |
+
- dataset.train_subset
|
97 |
+
- dataset.valid_subset
|
98 |
+
- criterion.wer_kenlm_model
|
99 |
+
- criterion.wer_lexicon
|
100 |
+
run:
|
101 |
+
dir: ???
|
102 |
+
sweep:
|
103 |
+
dir: ???
|
104 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
SpeechT5/Speech2C/speech2c/config/speech2c_base_librispeech.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
tensorboard_logdir: tblog
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_interval_updates: 25000
|
12 |
+
keep_interval_updates: 1
|
13 |
+
no_epoch_checkpoints: true
|
14 |
+
|
15 |
+
|
16 |
+
distributed_training:
|
17 |
+
ddp_backend: no_c10d
|
18 |
+
distributed_backend: 'nccl'
|
19 |
+
distributed_world_size: 32
|
20 |
+
distributed_port: 29671
|
21 |
+
nprocs_per_node: 8
|
22 |
+
find_unused_parameters: true
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: speech2c_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ???
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 16000
|
31 |
+
max_sample_size: 250000
|
32 |
+
min_sample_size: 32000
|
33 |
+
pad_audio: false
|
34 |
+
random_crop: true
|
35 |
+
normalize: false # must be consistent with extractor
|
36 |
+
add_decoder: true
|
37 |
+
|
38 |
+
dataset:
|
39 |
+
num_workers: 6
|
40 |
+
max_tokens: 1400000
|
41 |
+
skip_invalid_size_inputs_valid_test: true
|
42 |
+
validate_interval: 5
|
43 |
+
validate_interval_updates: 10000
|
44 |
+
|
45 |
+
criterion:
|
46 |
+
_name: speech2c
|
47 |
+
pred_masked_weight: 1.0
|
48 |
+
pred_nomask_weight: 0.0
|
49 |
+
loss_weights: [10,]
|
50 |
+
|
51 |
+
optimization:
|
52 |
+
max_update: 400000
|
53 |
+
lr: [0.0005]
|
54 |
+
clip_norm: 10.0
|
55 |
+
|
56 |
+
optimizer:
|
57 |
+
_name: adam
|
58 |
+
adam_betas: (0.9,0.98)
|
59 |
+
adam_eps: 1e-06
|
60 |
+
weight_decay: 0.01
|
61 |
+
|
62 |
+
lr_scheduler:
|
63 |
+
_name: polynomial_decay
|
64 |
+
warmup_updates: 32000
|
65 |
+
|
66 |
+
model:
|
67 |
+
_name: speech2c
|
68 |
+
label_rate: ???
|
69 |
+
skip_masked: false
|
70 |
+
skip_nomask: false
|
71 |
+
mask_prob: 0.80
|
72 |
+
extractor_mode: default
|
73 |
+
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
74 |
+
final_dim: 256
|
75 |
+
encoder_layerdrop: 0.05
|
76 |
+
dropout_input: 0.1
|
77 |
+
dropout_features: 0.1
|
78 |
+
dropout: 0.1
|
79 |
+
attention_dropout: 0.1
|
80 |
+
feature_grad_mult: 0.1
|
81 |
+
untie_final_proj: true
|
82 |
+
activation_dropout: 0.0
|
83 |
+
use_rel_pos_enc: true
|
84 |
+
decoder_dict_size: -1
|
85 |
+
|
86 |
+
hydra:
|
87 |
+
job:
|
88 |
+
config:
|
89 |
+
override_dirname:
|
90 |
+
kv_sep: '-'
|
91 |
+
item_sep: '__'
|
92 |
+
exclude_keys:
|
93 |
+
- run
|
94 |
+
- task.data
|
95 |
+
- task.label_dir
|
96 |
+
run:
|
97 |
+
dir: ???
|
98 |
+
sweep:
|
99 |
+
dir: ???
|
100 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
SpeechT5/Speech2C/speech2c/criterions/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
for file in os.listdir(os.path.dirname(__file__)):
|
6 |
+
if file.endswith(".py") and not file.startswith("_"):
|
7 |
+
criterion_name = file[: file.find(".py")]
|
8 |
+
importlib.import_module(
|
9 |
+
"speech2c.criterions." + criterion_name
|
10 |
+
)
|
SpeechT5/Speech2C/speech2c/criterions/ctc_ce.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
from argparse import Namespace
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from omegaconf import II
|
14 |
+
from typing import Optional
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from fairseq import metrics, utils
|
19 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
20 |
+
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
|
21 |
+
from fairseq.dataclass import FairseqDataclass
|
22 |
+
from fairseq.data.data_utils import post_process
|
23 |
+
from fairseq.tasks import FairseqTask
|
24 |
+
from fairseq.logging.meters import safe_round
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class CtcCeCriterionConfig(FairseqDataclass):
|
29 |
+
zero_infinity: bool = field(
|
30 |
+
default=False,
|
31 |
+
metadata={"help": "zero inf loss when source length <= target length"},
|
32 |
+
)
|
33 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
34 |
+
post_process: str = field(
|
35 |
+
default="letter",
|
36 |
+
metadata={
|
37 |
+
"help": "how to post process predictions into words. can be letter, "
|
38 |
+
"wordpiece, BPE symbols, etc. "
|
39 |
+
"See fairseq.data.data_utils.post_process() for full list of options"
|
40 |
+
},
|
41 |
+
)
|
42 |
+
wer_kenlm_model: Optional[str] = field(
|
43 |
+
default=None,
|
44 |
+
metadata={
|
45 |
+
"help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
|
46 |
+
},
|
47 |
+
)
|
48 |
+
wer_lexicon: Optional[str] = field(
|
49 |
+
default=None,
|
50 |
+
metadata={"help": "lexicon to use with wer_kenlm_model"},
|
51 |
+
)
|
52 |
+
wer_lm_weight: float = field(
|
53 |
+
default=2.0,
|
54 |
+
metadata={"help": "lm weight to use with wer_kenlm_model"},
|
55 |
+
)
|
56 |
+
wer_word_score: float = field(
|
57 |
+
default=-1.0,
|
58 |
+
metadata={"help": "lm word score to use with wer_kenlm_model"},
|
59 |
+
)
|
60 |
+
|
61 |
+
wer_args: Optional[str] = field(
|
62 |
+
default=None,
|
63 |
+
metadata={
|
64 |
+
"help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
|
65 |
+
},
|
66 |
+
)
|
67 |
+
|
68 |
+
dec_weight: float = field(
|
69 |
+
default=0.5,
|
70 |
+
metadata={"help": "weights for decoder CE Loss, loss will be ((1 - dec_weight) * hubert_loss + dec_weight * CE_Loss)"},
|
71 |
+
)
|
72 |
+
report_accuracy: bool = field(
|
73 |
+
default=True,
|
74 |
+
metadata={"help": "report decoder accuracy metric"},
|
75 |
+
)
|
76 |
+
ignore_prefix_size: int = field(
|
77 |
+
default=0,
|
78 |
+
metadata={"help": "Ignore first N tokens"},
|
79 |
+
)
|
80 |
+
label_smoothing: float = field(
|
81 |
+
default=0.1,
|
82 |
+
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
@register_criterion("ctc_ce", dataclass=CtcCeCriterionConfig)
|
87 |
+
class CtcCeCriterion(FairseqCriterion):
|
88 |
+
def __init__(self, cfg: CtcCeCriterionConfig, task: FairseqTask):
|
89 |
+
super().__init__(task)
|
90 |
+
self.blank_idx = (
|
91 |
+
task.target_dictionary.index(task.blank_symbol)
|
92 |
+
if hasattr(task, "blank_symbol")
|
93 |
+
else 0
|
94 |
+
)
|
95 |
+
self.pad_idx = task.target_dictionary.pad()
|
96 |
+
self.eos_idx = task.target_dictionary.eos()
|
97 |
+
self.post_process = cfg.post_process
|
98 |
+
|
99 |
+
if cfg.wer_args is not None:
|
100 |
+
(
|
101 |
+
cfg.wer_kenlm_model,
|
102 |
+
cfg.wer_lexicon,
|
103 |
+
cfg.wer_lm_weight,
|
104 |
+
cfg.wer_word_score,
|
105 |
+
) = eval(cfg.wer_args)
|
106 |
+
|
107 |
+
if cfg.wer_kenlm_model is not None:
|
108 |
+
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
109 |
+
|
110 |
+
dec_args = Namespace()
|
111 |
+
dec_args.nbest = 1
|
112 |
+
dec_args.criterion = "ctc"
|
113 |
+
dec_args.kenlm_model = cfg.wer_kenlm_model
|
114 |
+
dec_args.lexicon = cfg.wer_lexicon
|
115 |
+
dec_args.beam = 50
|
116 |
+
dec_args.beam_size_token = min(50, len(task.target_dictionary))
|
117 |
+
dec_args.beam_threshold = min(50, len(task.target_dictionary))
|
118 |
+
dec_args.lm_weight = cfg.wer_lm_weight
|
119 |
+
dec_args.word_score = cfg.wer_word_score
|
120 |
+
dec_args.unk_weight = -math.inf
|
121 |
+
dec_args.sil_weight = 0
|
122 |
+
|
123 |
+
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
|
124 |
+
else:
|
125 |
+
self.w2l_decoder = None
|
126 |
+
|
127 |
+
self.zero_infinity = cfg.zero_infinity
|
128 |
+
self.sentence_avg = cfg.sentence_avg
|
129 |
+
|
130 |
+
self.dec_weight = cfg.dec_weight
|
131 |
+
self.report_accuracy = cfg.report_accuracy
|
132 |
+
self.ignore_prefix_size = cfg.ignore_prefix_size
|
133 |
+
self.eps = cfg.label_smoothing
|
134 |
+
|
135 |
+
def forward(self, model, sample, reduce=True):
|
136 |
+
net_output = model(**sample["net_input"])
|
137 |
+
lprobs = model.get_normalized_probs(
|
138 |
+
net_output, log_probs=True
|
139 |
+
).contiguous() # (T, B, C) from the encoder
|
140 |
+
|
141 |
+
if "src_lengths" in sample["net_input"]:
|
142 |
+
input_lengths = sample["net_input"]["src_lengths"]
|
143 |
+
else:
|
144 |
+
if net_output["padding_mask"] is not None:
|
145 |
+
non_padding_mask = ~net_output["padding_mask"]
|
146 |
+
input_lengths = non_padding_mask.long().sum(-1)
|
147 |
+
else:
|
148 |
+
input_lengths = lprobs.new_full(
|
149 |
+
(lprobs.size(1),), lprobs.size(0), dtype=torch.long
|
150 |
+
)
|
151 |
+
|
152 |
+
pad_mask = (sample["target"] != self.pad_idx) & (
|
153 |
+
sample["target"] != self.eos_idx
|
154 |
+
)
|
155 |
+
targets_flat = sample["target"].masked_select(pad_mask)
|
156 |
+
if "target_lengths" in sample:
|
157 |
+
target_lengths = sample["target_lengths"]
|
158 |
+
else:
|
159 |
+
target_lengths = pad_mask.sum(-1)
|
160 |
+
|
161 |
+
with torch.backends.cudnn.flags(enabled=False):
|
162 |
+
loss = F.ctc_loss(
|
163 |
+
lprobs,
|
164 |
+
targets_flat,
|
165 |
+
input_lengths,
|
166 |
+
target_lengths,
|
167 |
+
blank=self.blank_idx,
|
168 |
+
reduction="sum",
|
169 |
+
zero_infinity=self.zero_infinity,
|
170 |
+
)
|
171 |
+
|
172 |
+
ntokens = (
|
173 |
+
sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
|
174 |
+
)
|
175 |
+
|
176 |
+
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
|
177 |
+
|
178 |
+
logging_output = {}
|
179 |
+
if "decoder_target" in sample:
|
180 |
+
dec_sample_size = sample["target"].size(0) if self.sentence_avg else sample["dec_ntokens"]
|
181 |
+
dec_loss, dec_nll_loss = self.compute_ce_loss(model, net_output["decoder_out"], sample, reduce=reduce)
|
182 |
+
logging_output["ctc_loss"] = loss.item()
|
183 |
+
loss = (1 - self.dec_weight) * loss + (self.dec_weight * dec_loss * sample_size / dec_sample_size)
|
184 |
+
logging_output["dec_loss"] = dec_loss.item()
|
185 |
+
logging_output["dec_nll_loss"] = dec_nll_loss.item()
|
186 |
+
logging_output["dec_sample_size"] = dec_sample_size
|
187 |
+
|
188 |
+
if self.report_accuracy:
|
189 |
+
n_correct, total = self.compute_accuracy(model, net_output["decoder_out"], sample)
|
190 |
+
logging_output["dec_n_correct"] = utils.item(n_correct.data)
|
191 |
+
logging_output["total"] = utils.item(total.data)
|
192 |
+
|
193 |
+
logging_output = {
|
194 |
+
"loss": utils.item(loss.data), # * sample['ntokens'],
|
195 |
+
"ntokens": ntokens,
|
196 |
+
"nsentences": sample["id"].numel(),
|
197 |
+
"sample_size": sample_size,
|
198 |
+
**logging_output,
|
199 |
+
}
|
200 |
+
|
201 |
+
if not model.training:
|
202 |
+
import editdistance
|
203 |
+
|
204 |
+
with torch.no_grad():
|
205 |
+
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
|
206 |
+
|
207 |
+
c_err = 0
|
208 |
+
c_len = 0
|
209 |
+
w_errs = 0
|
210 |
+
w_len = 0
|
211 |
+
wv_errs = 0
|
212 |
+
for lp, t, inp_l in zip(
|
213 |
+
lprobs_t,
|
214 |
+
sample["target_label"]
|
215 |
+
if "target_label" in sample
|
216 |
+
else sample["target"],
|
217 |
+
input_lengths,
|
218 |
+
):
|
219 |
+
lp = lp[:inp_l].unsqueeze(0)
|
220 |
+
|
221 |
+
decoded = None
|
222 |
+
if self.w2l_decoder is not None:
|
223 |
+
decoded = self.w2l_decoder.decode(lp)
|
224 |
+
if len(decoded) < 1:
|
225 |
+
decoded = None
|
226 |
+
else:
|
227 |
+
decoded = decoded[0]
|
228 |
+
if len(decoded) < 1:
|
229 |
+
decoded = None
|
230 |
+
else:
|
231 |
+
decoded = decoded[0]
|
232 |
+
|
233 |
+
p = (t != self.task.target_dictionary.pad()) & (
|
234 |
+
t != self.task.target_dictionary.eos()
|
235 |
+
)
|
236 |
+
targ = t[p]
|
237 |
+
targ_units = self.task.target_dictionary.string(targ)
|
238 |
+
targ_units_arr = targ.tolist()
|
239 |
+
|
240 |
+
toks = lp.argmax(dim=-1).unique_consecutive()
|
241 |
+
pred_units_arr = toks[toks != self.blank_idx].tolist()
|
242 |
+
|
243 |
+
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
|
244 |
+
c_len += len(targ_units_arr)
|
245 |
+
|
246 |
+
targ_words = post_process(targ_units, self.post_process).split()
|
247 |
+
|
248 |
+
pred_units = self.task.target_dictionary.string(pred_units_arr)
|
249 |
+
pred_words_raw = post_process(pred_units, self.post_process).split()
|
250 |
+
|
251 |
+
if decoded is not None and "words" in decoded:
|
252 |
+
pred_words = decoded["words"]
|
253 |
+
w_errs += editdistance.eval(pred_words, targ_words)
|
254 |
+
wv_errs += editdistance.eval(pred_words_raw, targ_words)
|
255 |
+
else:
|
256 |
+
dist = editdistance.eval(pred_words_raw, targ_words)
|
257 |
+
w_errs += dist
|
258 |
+
wv_errs += dist
|
259 |
+
|
260 |
+
w_len += len(targ_words)
|
261 |
+
|
262 |
+
logging_output["wv_errors"] = wv_errs
|
263 |
+
logging_output["w_errors"] = w_errs
|
264 |
+
logging_output["w_total"] = w_len
|
265 |
+
logging_output["c_errors"] = c_err
|
266 |
+
logging_output["c_total"] = c_len
|
267 |
+
|
268 |
+
return loss, sample_size, logging_output
|
269 |
+
|
270 |
+
def compute_ce_loss(self, model, net_output, sample, reduce=True):
|
271 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
272 |
+
loss, nll_loss = label_smoothed_nll_loss(
|
273 |
+
lprobs,
|
274 |
+
target,
|
275 |
+
self.eps,
|
276 |
+
ignore_index=self.pad_idx,
|
277 |
+
reduce=reduce,
|
278 |
+
)
|
279 |
+
return loss, nll_loss
|
280 |
+
|
281 |
+
def compute_accuracy(self, model, net_output, sample):
|
282 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
283 |
+
mask = target.ne(self.pad_idx)
|
284 |
+
n_correct = torch.sum(
|
285 |
+
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
286 |
+
)
|
287 |
+
total = torch.sum(mask)
|
288 |
+
return n_correct, total
|
289 |
+
|
290 |
+
def get_lprobs_and_target(self, model, net_output, sample):
|
291 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
292 |
+
target = sample["decoder_target"]
|
293 |
+
if self.ignore_prefix_size > 0:
|
294 |
+
if getattr(lprobs, "batch_first", False):
|
295 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
296 |
+
target = target[:, self.ignore_prefix_size :].contiguous()
|
297 |
+
else:
|
298 |
+
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
|
299 |
+
target = target[self.ignore_prefix_size :, :].contiguous()
|
300 |
+
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
|
301 |
+
|
302 |
+
|
303 |
+
@staticmethod
|
304 |
+
def reduce_metrics(logging_outputs) -> None:
|
305 |
+
"""Aggregate logging outputs from data parallel training."""
|
306 |
+
|
307 |
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
308 |
+
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
309 |
+
nsentences = utils.item(
|
310 |
+
sum(log.get("nsentences", 0) for log in logging_outputs)
|
311 |
+
)
|
312 |
+
sample_size = utils.item(
|
313 |
+
sum(log.get("sample_size", 0) for log in logging_outputs)
|
314 |
+
)
|
315 |
+
|
316 |
+
metrics.log_scalar(
|
317 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
318 |
+
)
|
319 |
+
metrics.log_scalar("ntokens", ntokens)
|
320 |
+
metrics.log_scalar("nsentences", nsentences)
|
321 |
+
if sample_size != ntokens:
|
322 |
+
metrics.log_scalar(
|
323 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
324 |
+
)
|
325 |
+
|
326 |
+
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
|
327 |
+
metrics.log_scalar("_c_errors", c_errors)
|
328 |
+
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
|
329 |
+
metrics.log_scalar("_c_total", c_total)
|
330 |
+
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
|
331 |
+
metrics.log_scalar("_w_errors", w_errors)
|
332 |
+
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
|
333 |
+
metrics.log_scalar("_wv_errors", wv_errors)
|
334 |
+
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
|
335 |
+
metrics.log_scalar("_w_total", w_total)
|
336 |
+
|
337 |
+
if c_total > 0:
|
338 |
+
metrics.log_derived(
|
339 |
+
"uer",
|
340 |
+
lambda meters: safe_round(
|
341 |
+
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
|
342 |
+
)
|
343 |
+
if meters["_c_total"].sum > 0
|
344 |
+
else float("nan"),
|
345 |
+
)
|
346 |
+
if w_total > 0:
|
347 |
+
metrics.log_derived(
|
348 |
+
"wer",
|
349 |
+
lambda meters: safe_round(
|
350 |
+
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
351 |
+
)
|
352 |
+
if meters["_w_total"].sum > 0
|
353 |
+
else float("nan"),
|
354 |
+
)
|
355 |
+
metrics.log_derived(
|
356 |
+
"raw_wer",
|
357 |
+
lambda meters: safe_round(
|
358 |
+
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
359 |
+
)
|
360 |
+
if meters["_w_total"].sum > 0
|
361 |
+
else float("nan"),
|
362 |
+
)
|
363 |
+
|
364 |
+
if "dec_loss" in logging_outputs[0]:
|
365 |
+
ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
|
366 |
+
dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
|
367 |
+
dec_nll_loss_sum = sum(log.get("dec_nll_loss", 0) for log in logging_outputs)
|
368 |
+
dec_sample_size = sum(log.get("dec_sample_size", 0) for log in logging_outputs)
|
369 |
+
metrics.log_scalar(
|
370 |
+
"dec_loss", dec_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
|
371 |
+
)
|
372 |
+
metrics.log_scalar(
|
373 |
+
"ctc_loss", ctc_loss_sum / sample_size / math.log(2), sample_size, round=3
|
374 |
+
)
|
375 |
+
metrics.log_scalar(
|
376 |
+
"dec_nll_loss", dec_nll_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
|
377 |
+
)
|
378 |
+
metrics.log_derived(
|
379 |
+
"dec_ppl", lambda meters: utils.get_perplexity(meters["dec_nll_loss"].avg)
|
380 |
+
)
|
381 |
+
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
382 |
+
if total > 0:
|
383 |
+
metrics.log_scalar("total", total)
|
384 |
+
n_correct = utils.item(
|
385 |
+
sum(log.get("dec_n_correct", 0) for log in logging_outputs)
|
386 |
+
)
|
387 |
+
metrics.log_scalar("dec_n_correct", n_correct)
|
388 |
+
metrics.log_derived(
|
389 |
+
"dec_accuracy",
|
390 |
+
lambda meters: round(
|
391 |
+
meters["dec_n_correct"].sum * 100.0 / meters["total"].sum, 3
|
392 |
+
)
|
393 |
+
if meters["total"].sum > 0
|
394 |
+
else float("nan"),
|
395 |
+
)
|
396 |
+
|
397 |
+
@staticmethod
|
398 |
+
def logging_outputs_can_be_summed() -> bool:
|
399 |
+
"""
|
400 |
+
Whether the logging outputs returned by `forward` can be summed
|
401 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
402 |
+
to True will improves distributed training speed.
|
403 |
+
"""
|
404 |
+
return True
|
SpeechT5/Speech2C/speech2c/criterions/speech2c_criterion.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import re
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from fairseq import metrics, utils
|
17 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
18 |
+
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
|
19 |
+
from fairseq.criterions.hubert_criterion import HubertCriterionConfig
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class Speech2cCriterionConfig(HubertCriterionConfig):
|
23 |
+
dec_weight: float = field(
|
24 |
+
default=1.0,
|
25 |
+
metadata={"help": "weights for decoder CE Loss, loss will be (hubert_loss + dec_weight * CE_Loss)"},
|
26 |
+
)
|
27 |
+
report_accuracy: bool = field(
|
28 |
+
default=True,
|
29 |
+
metadata={"help": "report decoder accuracy metric"},
|
30 |
+
)
|
31 |
+
ignore_prefix_size: int = field(
|
32 |
+
default=0,
|
33 |
+
metadata={"help": "Ignore first N tokens"},
|
34 |
+
)
|
35 |
+
label_smoothing: float = field(
|
36 |
+
default=0.0,
|
37 |
+
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
@register_criterion("speech2c", dataclass=Speech2cCriterionConfig)
|
42 |
+
class Speech2cCriterion(FairseqCriterion):
|
43 |
+
def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None, dec_weight=1.0, report_accuracy=False, ignore_prefix_size=0, label_smoothing=0.0):
|
44 |
+
super().__init__(task)
|
45 |
+
self.pred_masked_weight = pred_masked_weight
|
46 |
+
self.pred_nomask_weight = pred_nomask_weight
|
47 |
+
self.loss_weights = loss_weights
|
48 |
+
self.log_keys = [] if log_keys is None else log_keys
|
49 |
+
self.dec_weight = dec_weight
|
50 |
+
self.report_accuracy = report_accuracy
|
51 |
+
self.ignore_prefix_size = ignore_prefix_size
|
52 |
+
self.eps = label_smoothing
|
53 |
+
self.padding_idx = task.dictionaries[0].pad()
|
54 |
+
|
55 |
+
def forward(self, model, sample, reduce=True, log_pred=False):
|
56 |
+
"""Compute the loss for the given sample.
|
57 |
+
Returns a tuple with three elements:
|
58 |
+
1) the loss
|
59 |
+
2) the sample size, which is used as the denominator for the gradient
|
60 |
+
3) logging outputs to display while training
|
61 |
+
"""
|
62 |
+
net_output = model(target_list=sample["target_list"], **sample["net_input"])
|
63 |
+
loss = 0.0
|
64 |
+
sample_size = 0
|
65 |
+
logging_output = {}
|
66 |
+
reduction = "sum" if reduce else "none"
|
67 |
+
|
68 |
+
loss_m_list = []
|
69 |
+
logp_m_list = model.get_logits(net_output, True)
|
70 |
+
targ_m_list = model.get_targets(net_output, True)
|
71 |
+
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
|
72 |
+
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
|
73 |
+
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
|
74 |
+
loss_m_list.append(loss_m)
|
75 |
+
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
|
76 |
+
if self.pred_masked_weight > 0:
|
77 |
+
loss += self.pred_masked_weight * sum(loss_m_list)
|
78 |
+
sample_size += targ_m_list[0].numel()
|
79 |
+
|
80 |
+
loss_u_list = []
|
81 |
+
logp_u_list = model.get_logits(net_output, False)
|
82 |
+
targ_u_list = model.get_targets(net_output, False)
|
83 |
+
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
|
84 |
+
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
|
85 |
+
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
|
86 |
+
loss_u_list.append(loss_u)
|
87 |
+
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
|
88 |
+
if self.pred_nomask_weight > 0:
|
89 |
+
loss += self.pred_nomask_weight * sum(loss_u_list)
|
90 |
+
sample_size += targ_u_list[0].numel()
|
91 |
+
|
92 |
+
if self.loss_weights is not None:
|
93 |
+
assert hasattr(model, "get_extra_losses")
|
94 |
+
extra_losses, names = model.get_extra_losses(net_output)
|
95 |
+
if torch.is_tensor(extra_losses):
|
96 |
+
extra_losses = [extra_losses]
|
97 |
+
names = [names]
|
98 |
+
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
99 |
+
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
100 |
+
assert len(extra_losses) == len(
|
101 |
+
self.loss_weights
|
102 |
+
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
103 |
+
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
104 |
+
if coef != 0 and p is not None:
|
105 |
+
p = coef * p.float() * sample_size
|
106 |
+
loss += p
|
107 |
+
logging_output[f"loss_{n}"] = p.item()
|
108 |
+
|
109 |
+
if "decoder_target" in sample:
|
110 |
+
dec_sample_size = sample["dec_ntokens"]
|
111 |
+
dec_loss, dec_nll_loss = self.compute_ce_loss(model, net_output["decoder_out"], sample, reduce=reduce)
|
112 |
+
loss = loss + (self.dec_weight * dec_loss * sample_size / dec_sample_size)
|
113 |
+
logging_output["dec_loss"] = dec_loss.item()
|
114 |
+
logging_output["dec_nll_loss"] = dec_nll_loss.item()
|
115 |
+
logging_output["dec_sample_size"] = dec_sample_size
|
116 |
+
|
117 |
+
if self.report_accuracy:
|
118 |
+
n_correct, total = self.compute_accuracy(model, net_output["decoder_out"], sample)
|
119 |
+
logging_output["dec_n_correct"] = utils.item(n_correct.data)
|
120 |
+
logging_output["total"] = utils.item(total.data)
|
121 |
+
|
122 |
+
logging_output = {
|
123 |
+
"loss": loss.item() if reduce else loss,
|
124 |
+
"ntokens": sample_size,
|
125 |
+
"nsentences": sample["id"].numel(),
|
126 |
+
"sample_size": sample_size,
|
127 |
+
**logging_output,
|
128 |
+
}
|
129 |
+
|
130 |
+
for lk in self.log_keys:
|
131 |
+
if lk in net_output:
|
132 |
+
logging_output[lk] = float((net_output[lk]))
|
133 |
+
|
134 |
+
def compute_correct(logits):
|
135 |
+
if logits.numel() == 0:
|
136 |
+
return 0, 0
|
137 |
+
else:
|
138 |
+
assert logits.dim() > 1, logits.shape
|
139 |
+
max = logits.argmax(-1) == 0
|
140 |
+
min = logits.argmin(-1) == 0
|
141 |
+
both = max & min
|
142 |
+
corr = max.long().sum().item() - both.long().sum().item()
|
143 |
+
count = max.numel()
|
144 |
+
return corr, count
|
145 |
+
|
146 |
+
with torch.no_grad():
|
147 |
+
for i, logp_m in enumerate(logp_m_list):
|
148 |
+
corr_m, count_m = compute_correct(logp_m)
|
149 |
+
logging_output[f"correct_m_{i}"] = corr_m
|
150 |
+
logging_output[f"count_m_{i}"] = count_m
|
151 |
+
|
152 |
+
for i, logp_u in enumerate(logp_u_list):
|
153 |
+
corr_u, count_u = compute_correct(logp_u)
|
154 |
+
logging_output[f"correct_u_{i}"] = corr_u
|
155 |
+
logging_output[f"count_u_{i}"] = count_u
|
156 |
+
|
157 |
+
return loss, sample_size, logging_output
|
158 |
+
|
159 |
+
def compute_ce_loss(self, model, net_output, sample, reduce=True):
|
160 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
161 |
+
loss, nll_loss = label_smoothed_nll_loss(
|
162 |
+
lprobs,
|
163 |
+
target,
|
164 |
+
self.eps,
|
165 |
+
ignore_index=self.padding_idx,
|
166 |
+
reduce=reduce,
|
167 |
+
)
|
168 |
+
return loss, nll_loss
|
169 |
+
|
170 |
+
def compute_accuracy(self, model, net_output, sample):
|
171 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
172 |
+
mask = target.ne(self.padding_idx)
|
173 |
+
n_correct = torch.sum(
|
174 |
+
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
175 |
+
)
|
176 |
+
total = torch.sum(mask)
|
177 |
+
return n_correct, total
|
178 |
+
|
179 |
+
def get_lprobs_and_target(self, model, net_output, sample):
|
180 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
181 |
+
target = sample["decoder_target"]
|
182 |
+
if self.ignore_prefix_size > 0:
|
183 |
+
if getattr(lprobs, "batch_first", False):
|
184 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
185 |
+
target = target[:, self.ignore_prefix_size :].contiguous()
|
186 |
+
else:
|
187 |
+
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
|
188 |
+
target = target[self.ignore_prefix_size :, :].contiguous()
|
189 |
+
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
|
190 |
+
|
191 |
+
@staticmethod
|
192 |
+
def reduce_metrics(logging_outputs) -> None:
|
193 |
+
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
|
194 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
195 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
196 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
197 |
+
|
198 |
+
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
|
199 |
+
if sample_size != ntokens:
|
200 |
+
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
|
201 |
+
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
|
202 |
+
else:
|
203 |
+
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
|
204 |
+
|
205 |
+
counts = {}
|
206 |
+
for lk in logging_outputs[0].keys():
|
207 |
+
if lk.startswith("count_"):
|
208 |
+
val = sum(log[lk] for log in logging_outputs)
|
209 |
+
metrics.log_scalar(lk, val)
|
210 |
+
counts[lk] = val
|
211 |
+
|
212 |
+
for lk in logging_outputs[0].keys():
|
213 |
+
if lk.startswith("loss_"):
|
214 |
+
val = sum(log[lk] for log in logging_outputs)
|
215 |
+
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
|
216 |
+
elif lk.startswith("correct_"):
|
217 |
+
val = sum(log[lk] for log in logging_outputs)
|
218 |
+
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
|
219 |
+
|
220 |
+
if "dec_loss" in logging_outputs[0]:
|
221 |
+
dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
|
222 |
+
dec_nll_loss_sum = sum(log.get("dec_nll_loss", 0) for log in logging_outputs)
|
223 |
+
dec_sample_size = sum(log.get("dec_sample_size", 0) for log in logging_outputs)
|
224 |
+
metrics.log_scalar(
|
225 |
+
"dec_loss", dec_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
|
226 |
+
)
|
227 |
+
metrics.log_scalar(
|
228 |
+
"dec_nll_loss", dec_nll_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
|
229 |
+
)
|
230 |
+
metrics.log_derived(
|
231 |
+
"dec_ppl", lambda meters: utils.get_perplexity(meters["dec_nll_loss"].avg)
|
232 |
+
)
|
233 |
+
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
234 |
+
if total > 0:
|
235 |
+
metrics.log_scalar("total", total)
|
236 |
+
n_correct = utils.item(
|
237 |
+
sum(log.get("dec_n_correct", 0) for log in logging_outputs)
|
238 |
+
)
|
239 |
+
metrics.log_scalar("dec_n_correct", n_correct)
|
240 |
+
metrics.log_derived(
|
241 |
+
"dec_accuracy",
|
242 |
+
lambda meters: round(
|
243 |
+
meters["dec_n_correct"].sum * 100.0 / meters["total"].sum, 3
|
244 |
+
)
|
245 |
+
if meters["total"].sum > 0
|
246 |
+
else float("nan"),
|
247 |
+
)
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def aggregate_logging_outputs(logging_outputs):
|
251 |
+
"""Aggregate logging outputs from data parallel training."""
|
252 |
+
raise NotImplementedError()
|
253 |
+
|
254 |
+
@staticmethod
|
255 |
+
def logging_outputs_can_be_summed() -> bool:
|
256 |
+
"""
|
257 |
+
Whether the logging outputs returned by `forward` can be summed
|
258 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
259 |
+
to True will improves distributed training speed.
|
260 |
+
"""
|
261 |
+
return False
|
SpeechT5/Speech2C/speech2c/data/speech2c_dataset.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import logging
|
11 |
+
from typing import Any, List, Optional, Union
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from fairseq.data import data_utils, Dictionary
|
15 |
+
from fairseq.data.audio.hubert_dataset import HubertDataset
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class Speech2cDataset(HubertDataset):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
manifest_path: str,
|
23 |
+
sample_rate: float,
|
24 |
+
label_paths: List[str],
|
25 |
+
label_rates: Union[List[float], float], # -1 for sequence labels
|
26 |
+
pad_list: List[str],
|
27 |
+
eos_list: List[str],
|
28 |
+
label_processors: Optional[List[Any]] = None,
|
29 |
+
max_keep_sample_size: Optional[int] = None,
|
30 |
+
min_keep_sample_size: Optional[int] = None,
|
31 |
+
max_sample_size: Optional[int] = None,
|
32 |
+
shuffle: bool = True,
|
33 |
+
pad_audio: bool = False,
|
34 |
+
normalize: bool = False,
|
35 |
+
store_labels: bool = True,
|
36 |
+
random_crop: bool = False,
|
37 |
+
single_target: bool = False,
|
38 |
+
tgt_dict: Optional[Dictionary] = None,
|
39 |
+
add_decoder: bool = False,
|
40 |
+
fine_tuning: bool = False,
|
41 |
+
):
|
42 |
+
super().__init__(
|
43 |
+
manifest_path,
|
44 |
+
sample_rate,
|
45 |
+
label_paths,
|
46 |
+
label_rates,
|
47 |
+
pad_list,
|
48 |
+
eos_list,
|
49 |
+
label_processors,
|
50 |
+
max_keep_sample_size,
|
51 |
+
min_keep_sample_size,
|
52 |
+
max_sample_size,
|
53 |
+
shuffle,
|
54 |
+
pad_audio,
|
55 |
+
normalize,
|
56 |
+
store_labels,
|
57 |
+
random_crop,
|
58 |
+
single_target
|
59 |
+
)
|
60 |
+
|
61 |
+
self.tgt_dict = tgt_dict
|
62 |
+
self.add_decoder = add_decoder
|
63 |
+
self.fine_tuning = fine_tuning
|
64 |
+
|
65 |
+
def collater(self, samples):
|
66 |
+
# target = max(sizes) -> random_crop not used
|
67 |
+
# target = max_sample_size -> random_crop used for long
|
68 |
+
samples = [s for s in samples if s["source"] is not None]
|
69 |
+
if len(samples) == 0:
|
70 |
+
return {}
|
71 |
+
|
72 |
+
audios = [s["source"] for s in samples]
|
73 |
+
audio_sizes = [len(s) for s in audios]
|
74 |
+
if self.pad_audio:
|
75 |
+
audio_size = min(max(audio_sizes), self.max_sample_size)
|
76 |
+
else:
|
77 |
+
audio_size = min(min(audio_sizes), self.max_sample_size)
|
78 |
+
collated_audios, padding_mask, audio_starts = self.collater_audio(
|
79 |
+
audios, audio_size
|
80 |
+
)
|
81 |
+
|
82 |
+
targets_by_label = [
|
83 |
+
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
|
84 |
+
]
|
85 |
+
targets_list, lengths_list, ntokens_list = self.collater_label(
|
86 |
+
targets_by_label, audio_size, audio_starts
|
87 |
+
)
|
88 |
+
|
89 |
+
if self.add_decoder:
|
90 |
+
if self.fine_tuning:
|
91 |
+
decoder_label = [
|
92 |
+
torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
|
93 |
+
for i in range(targets_list[0].size(0))
|
94 |
+
]
|
95 |
+
else:
|
96 |
+
decoder_label = [
|
97 |
+
torch.cat((targets_list[0][i, :lengths_list[0][i]].unique_consecutive(), torch.tensor([self.tgt_dict.eos()])), 0).long()
|
98 |
+
for i in range(targets_list[0].size(0))
|
99 |
+
]
|
100 |
+
dec_ntokens = sum(x.size(0) for x in decoder_label)
|
101 |
+
decoder_target = data_utils.collate_tokens(
|
102 |
+
decoder_label,
|
103 |
+
self.tgt_dict.pad(),
|
104 |
+
self.tgt_dict.eos(),
|
105 |
+
left_pad=False,
|
106 |
+
move_eos_to_beginning=False,
|
107 |
+
)
|
108 |
+
decoder_target_lengths = torch.tensor(
|
109 |
+
[x.size(0) for x in decoder_label], dtype=torch.long
|
110 |
+
)
|
111 |
+
prev_output_tokens = data_utils.collate_tokens(
|
112 |
+
decoder_label,
|
113 |
+
self.tgt_dict.pad(),
|
114 |
+
self.tgt_dict.eos(),
|
115 |
+
left_pad=False,
|
116 |
+
move_eos_to_beginning=True,
|
117 |
+
)
|
118 |
+
net_input = {
|
119 |
+
"source": collated_audios,
|
120 |
+
"padding_mask": padding_mask,
|
121 |
+
"prev_output_tokens": prev_output_tokens,
|
122 |
+
}
|
123 |
+
batch = {
|
124 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
125 |
+
"net_input": net_input,
|
126 |
+
"decoder_target": decoder_target,
|
127 |
+
"decoder_target_lengths": decoder_target_lengths,
|
128 |
+
"dec_ntokens": dec_ntokens,
|
129 |
+
}
|
130 |
+
else:
|
131 |
+
net_input = {"source": collated_audios, "padding_mask": padding_mask}
|
132 |
+
batch = {
|
133 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
134 |
+
"net_input": net_input,
|
135 |
+
}
|
136 |
+
|
137 |
+
if self.single_target:
|
138 |
+
batch["target_lengths"] = lengths_list[0]
|
139 |
+
batch["ntokens"] = ntokens_list[0]
|
140 |
+
batch["target"] = targets_list[0]
|
141 |
+
else:
|
142 |
+
batch["target_lengths_list"] = lengths_list
|
143 |
+
batch["ntokens_list"] = ntokens_list
|
144 |
+
batch["target_list"] = targets_list
|
145 |
+
return batch
|
SpeechT5/Speech2C/speech2c/models/modules/ctc_prefix_score.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import six
|
8 |
+
|
9 |
+
|
10 |
+
class CTCPrefixScore(object):
|
11 |
+
"""Compute CTC label sequence scores
|
12 |
+
which is based on Algorithm 2 in WATANABE et al.
|
13 |
+
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
14 |
+
but extended to efficiently compute the probablities of multiple labels
|
15 |
+
simultaneously
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, x, blank, eos, xp):
|
19 |
+
self.xp = xp
|
20 |
+
self.logzero = -10000000000.0
|
21 |
+
self.blank = blank
|
22 |
+
self.eos = eos
|
23 |
+
self.input_length = len(x)
|
24 |
+
self.x = x
|
25 |
+
|
26 |
+
def initial_state(self):
|
27 |
+
"""Obtain an initial CTC state
|
28 |
+
:return: CTC state
|
29 |
+
"""
|
30 |
+
# initial CTC state is made of a frame x 2 tensor that corresponds to
|
31 |
+
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
|
32 |
+
# superscripts n and b (non-blank and blank), respectively.
|
33 |
+
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
|
34 |
+
r[0, 1] = self.x[0, self.blank]
|
35 |
+
for i in six.moves.range(1, self.input_length):
|
36 |
+
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
|
37 |
+
return r
|
38 |
+
|
39 |
+
def __call__(self, y, cs, r_prev):
|
40 |
+
"""Compute CTC prefix scores for next labels
|
41 |
+
:param y : prefix label sequence
|
42 |
+
:param cs : array of next labels
|
43 |
+
:param r_prev: previous CTC state
|
44 |
+
:return ctc_scores, ctc_states
|
45 |
+
"""
|
46 |
+
# initialize CTC states
|
47 |
+
output_length = len(y) - 1 # ignore sos
|
48 |
+
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
|
49 |
+
# that corresponds to r_t^n(h) and r_t^b(h).
|
50 |
+
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
|
51 |
+
xs = self.x[:, cs]
|
52 |
+
if output_length == 0:
|
53 |
+
r[0, 0] = xs[0]
|
54 |
+
r[0, 1] = self.logzero
|
55 |
+
else:
|
56 |
+
r[output_length - 1] = self.logzero
|
57 |
+
|
58 |
+
# prepare forward probabilities for the last label
|
59 |
+
r_sum = self.xp.logaddexp(
|
60 |
+
r_prev[:, 0], r_prev[:, 1]
|
61 |
+
) # log(r_t^n(g) + r_t^b(g))
|
62 |
+
last = y[-1]
|
63 |
+
if output_length > 0 and last in cs:
|
64 |
+
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
|
65 |
+
for i in six.moves.range(len(cs)):
|
66 |
+
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
|
67 |
+
else:
|
68 |
+
log_phi = r_sum
|
69 |
+
|
70 |
+
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
|
71 |
+
# and log prefix probabilities log(psi)
|
72 |
+
start = max(output_length, 1)
|
73 |
+
log_psi = r[start - 1, 0]
|
74 |
+
for t in six.moves.range(start, self.input_length):
|
75 |
+
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
|
76 |
+
r[t, 1] = (
|
77 |
+
self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
|
78 |
+
)
|
79 |
+
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
|
80 |
+
|
81 |
+
# get P(...eos|X) that ends with the prefix itself
|
82 |
+
eos_pos = self.xp.where(cs == self.eos)[0]
|
83 |
+
if len(eos_pos) > 0:
|
84 |
+
log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
|
85 |
+
|
86 |
+
# exclude blank probs
|
87 |
+
blank_pos = self.xp.where(cs == self.blank)[0]
|
88 |
+
if len(blank_pos) > 0:
|
89 |
+
log_psi[blank_pos] = self.logzero
|
90 |
+
|
91 |
+
# return the log prefix probability and CTC states, where the label axis
|
92 |
+
# of the CTC states is moved to the first axis to slice it easily
|
93 |
+
return log_psi, self.xp.rollaxis(r, 2)
|
SpeechT5/Speech2C/speech2c/models/modules/multihead_attention.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
from typing import Dict, Optional, Tuple
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from fairseq import utils
|
15 |
+
from torch import Tensor
|
16 |
+
|
17 |
+
from fairseq.modules import MultiheadAttention as FairseqMultiheadAttention
|
18 |
+
|
19 |
+
|
20 |
+
class MultiheadAttention(FairseqMultiheadAttention):
|
21 |
+
"""Multi-headed attention.
|
22 |
+
|
23 |
+
See "Attention Is All You Need" for more details.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
embed_dim,
|
29 |
+
num_heads,
|
30 |
+
kdim=None,
|
31 |
+
vdim=None,
|
32 |
+
dropout=0.0,
|
33 |
+
bias=True,
|
34 |
+
add_bias_kv=False,
|
35 |
+
add_zero_attn=False,
|
36 |
+
self_attention=False,
|
37 |
+
encoder_decoder_attention=False,
|
38 |
+
q_noise=0.0,
|
39 |
+
qn_block_size=8,
|
40 |
+
):
|
41 |
+
super().__init__(
|
42 |
+
embed_dim,
|
43 |
+
num_heads,
|
44 |
+
kdim,
|
45 |
+
vdim,
|
46 |
+
dropout,
|
47 |
+
bias,
|
48 |
+
add_bias_kv,
|
49 |
+
add_zero_attn,
|
50 |
+
self_attention,
|
51 |
+
encoder_decoder_attention,
|
52 |
+
q_noise,
|
53 |
+
qn_block_size,
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
query,
|
59 |
+
key: Optional[Tensor],
|
60 |
+
value: Optional[Tensor],
|
61 |
+
key_padding_mask: Optional[Tensor] = None,
|
62 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
63 |
+
need_weights: bool = True,
|
64 |
+
static_kv: bool = False,
|
65 |
+
attn_mask: Optional[Tensor] = None,
|
66 |
+
before_softmax: bool = False,
|
67 |
+
need_head_weights: bool = False,
|
68 |
+
position_bias: Optional[Tensor] = None,
|
69 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
70 |
+
"""Input shape: Time x Batch x Channel
|
71 |
+
|
72 |
+
Args:
|
73 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
74 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
75 |
+
padding elements are indicated by 1s.
|
76 |
+
need_weights (bool, optional): return the attention weights,
|
77 |
+
averaged over heads (default: False).
|
78 |
+
attn_mask (ByteTensor, optional): typically used to
|
79 |
+
implement causal attention, where the mask prevents the
|
80 |
+
attention from looking forward in time (default: None).
|
81 |
+
before_softmax (bool, optional): return the raw attention
|
82 |
+
weights and values before the attention softmax.
|
83 |
+
need_head_weights (bool, optional): return the attention
|
84 |
+
weights for each head. Implies *need_weights*. Default:
|
85 |
+
return the average attention weights over all heads.
|
86 |
+
"""
|
87 |
+
if need_head_weights:
|
88 |
+
need_weights = True
|
89 |
+
|
90 |
+
is_tpu = query.device.type == "xla"
|
91 |
+
|
92 |
+
tgt_len, bsz, embed_dim = query.size()
|
93 |
+
src_len = tgt_len
|
94 |
+
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
|
95 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
96 |
+
if key is not None:
|
97 |
+
src_len, key_bsz, _ = key.size()
|
98 |
+
if not torch.jit.is_scripting():
|
99 |
+
assert key_bsz == bsz
|
100 |
+
assert value is not None
|
101 |
+
assert src_len, bsz == value.shape[:2]
|
102 |
+
|
103 |
+
if (
|
104 |
+
not self.onnx_trace
|
105 |
+
and not is_tpu # don't use PyTorch version on TPUs
|
106 |
+
and incremental_state is None
|
107 |
+
and not static_kv
|
108 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
109 |
+
# treats bias in linear module as method.
|
110 |
+
and not torch.jit.is_scripting()
|
111 |
+
and position_bias is None
|
112 |
+
):
|
113 |
+
assert key is not None and value is not None
|
114 |
+
return F.multi_head_attention_forward(
|
115 |
+
query,
|
116 |
+
key,
|
117 |
+
value,
|
118 |
+
self.embed_dim,
|
119 |
+
self.num_heads,
|
120 |
+
torch.empty([0]),
|
121 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
122 |
+
self.bias_k,
|
123 |
+
self.bias_v,
|
124 |
+
self.add_zero_attn,
|
125 |
+
self.dropout_module.p,
|
126 |
+
self.out_proj.weight,
|
127 |
+
self.out_proj.bias,
|
128 |
+
self.training or self.dropout_module.apply_during_inference,
|
129 |
+
key_padding_mask,
|
130 |
+
need_weights,
|
131 |
+
attn_mask,
|
132 |
+
use_separate_proj_weight=True,
|
133 |
+
q_proj_weight=self.q_proj.weight,
|
134 |
+
k_proj_weight=self.k_proj.weight,
|
135 |
+
v_proj_weight=self.v_proj.weight,
|
136 |
+
)
|
137 |
+
|
138 |
+
if incremental_state is not None:
|
139 |
+
saved_state = self._get_input_buffer(incremental_state)
|
140 |
+
if saved_state is not None and "prev_key" in saved_state:
|
141 |
+
# previous time steps are cached - no need to recompute
|
142 |
+
# key and value if they are static
|
143 |
+
if static_kv:
|
144 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
145 |
+
key = value = None
|
146 |
+
else:
|
147 |
+
saved_state = None
|
148 |
+
|
149 |
+
if self.self_attention:
|
150 |
+
q = self.q_proj(query)
|
151 |
+
k = self.k_proj(query)
|
152 |
+
v = self.v_proj(query)
|
153 |
+
elif self.encoder_decoder_attention:
|
154 |
+
# encoder-decoder attention
|
155 |
+
q = self.q_proj(query)
|
156 |
+
if key is None:
|
157 |
+
assert value is None
|
158 |
+
k = v = None
|
159 |
+
else:
|
160 |
+
k = self.k_proj(key)
|
161 |
+
v = self.v_proj(key)
|
162 |
+
|
163 |
+
else:
|
164 |
+
assert key is not None and value is not None
|
165 |
+
q = self.q_proj(query)
|
166 |
+
k = self.k_proj(key)
|
167 |
+
v = self.v_proj(value)
|
168 |
+
q *= self.scaling
|
169 |
+
|
170 |
+
if self.bias_k is not None:
|
171 |
+
assert self.bias_v is not None
|
172 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
173 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
174 |
+
if attn_mask is not None:
|
175 |
+
attn_mask = torch.cat(
|
176 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
177 |
+
)
|
178 |
+
if key_padding_mask is not None:
|
179 |
+
key_padding_mask = torch.cat(
|
180 |
+
[
|
181 |
+
key_padding_mask,
|
182 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
183 |
+
],
|
184 |
+
dim=1,
|
185 |
+
)
|
186 |
+
|
187 |
+
q = (
|
188 |
+
q.contiguous()
|
189 |
+
.view(tgt_len, bsz * self.num_heads, self.head_dim)
|
190 |
+
.transpose(0, 1)
|
191 |
+
)
|
192 |
+
if k is not None:
|
193 |
+
k = (
|
194 |
+
k.contiguous()
|
195 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
196 |
+
.transpose(0, 1)
|
197 |
+
)
|
198 |
+
if v is not None:
|
199 |
+
v = (
|
200 |
+
v.contiguous()
|
201 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
202 |
+
.transpose(0, 1)
|
203 |
+
)
|
204 |
+
|
205 |
+
if saved_state is not None:
|
206 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
207 |
+
if "prev_key" in saved_state:
|
208 |
+
_prev_key = saved_state["prev_key"]
|
209 |
+
assert _prev_key is not None
|
210 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
211 |
+
if static_kv:
|
212 |
+
k = prev_key
|
213 |
+
else:
|
214 |
+
assert k is not None
|
215 |
+
k = torch.cat([prev_key, k], dim=1)
|
216 |
+
src_len = k.size(1)
|
217 |
+
if "prev_value" in saved_state:
|
218 |
+
_prev_value = saved_state["prev_value"]
|
219 |
+
assert _prev_value is not None
|
220 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
221 |
+
if static_kv:
|
222 |
+
v = prev_value
|
223 |
+
else:
|
224 |
+
assert v is not None
|
225 |
+
v = torch.cat([prev_value, v], dim=1)
|
226 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
227 |
+
if "prev_key_padding_mask" in saved_state:
|
228 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
229 |
+
assert k is not None and v is not None
|
230 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
231 |
+
key_padding_mask=key_padding_mask,
|
232 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
233 |
+
batch_size=bsz,
|
234 |
+
src_len=k.size(1),
|
235 |
+
static_kv=static_kv,
|
236 |
+
)
|
237 |
+
|
238 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
239 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
240 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
241 |
+
# In this branch incremental_state is never None
|
242 |
+
assert incremental_state is not None
|
243 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
244 |
+
assert k is not None
|
245 |
+
assert k.size(1) == src_len
|
246 |
+
|
247 |
+
# This is part of a workaround to get around fork/join parallelism
|
248 |
+
# not supporting Optional types.
|
249 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
250 |
+
key_padding_mask = None
|
251 |
+
|
252 |
+
if key_padding_mask is not None:
|
253 |
+
assert key_padding_mask.size(0) == bsz
|
254 |
+
assert key_padding_mask.size(1) == src_len
|
255 |
+
|
256 |
+
if self.add_zero_attn:
|
257 |
+
assert v is not None
|
258 |
+
src_len += 1
|
259 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
260 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
261 |
+
if attn_mask is not None:
|
262 |
+
attn_mask = torch.cat(
|
263 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
264 |
+
)
|
265 |
+
if key_padding_mask is not None:
|
266 |
+
key_padding_mask = torch.cat(
|
267 |
+
[
|
268 |
+
key_padding_mask,
|
269 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
270 |
+
key_padding_mask
|
271 |
+
),
|
272 |
+
],
|
273 |
+
dim=1,
|
274 |
+
)
|
275 |
+
|
276 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
277 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
278 |
+
|
279 |
+
if position_bias is not None: ## first order
|
280 |
+
## position_bias: [241, 241, 64]
|
281 |
+
#print ("attn_weights: ", attn_weights.size()) # [492, 241, 241]
|
282 |
+
reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64]
|
283 |
+
#print ("reshape_q: ", reshape_q.size())
|
284 |
+
B = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
|
285 |
+
#print ("B: ", B.size()) ## [241, 492, 241]
|
286 |
+
#B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1))
|
287 |
+
B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1))
|
288 |
+
#print ("B 2: ", B.size())
|
289 |
+
attn_weights += B
|
290 |
+
|
291 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
292 |
+
|
293 |
+
if attn_mask is not None:
|
294 |
+
attn_mask = attn_mask.unsqueeze(0)
|
295 |
+
if self.onnx_trace:
|
296 |
+
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
297 |
+
attn_weights += attn_mask
|
298 |
+
|
299 |
+
if key_padding_mask is not None:
|
300 |
+
# don't attend to padding symbols
|
301 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
302 |
+
if not is_tpu:
|
303 |
+
attn_weights = attn_weights.masked_fill(
|
304 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
305 |
+
float("-inf"),
|
306 |
+
)
|
307 |
+
else:
|
308 |
+
attn_weights = attn_weights.transpose(0, 2)
|
309 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
310 |
+
attn_weights = attn_weights.transpose(0, 2)
|
311 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
312 |
+
|
313 |
+
if before_softmax:
|
314 |
+
return attn_weights, v
|
315 |
+
|
316 |
+
attn_weights_float = utils.softmax(
|
317 |
+
attn_weights, dim=-1, onnx_trace=self.onnx_trace
|
318 |
+
)
|
319 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
320 |
+
attn_probs = self.dropout_module(attn_weights)
|
321 |
+
|
322 |
+
assert v is not None
|
323 |
+
attn = torch.bmm(attn_probs, v)
|
324 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
325 |
+
if self.onnx_trace and attn.size(1) == 1:
|
326 |
+
# when ONNX tracing a single decoder step (sequence length == 1)
|
327 |
+
# the transpose is a no-op copy before view, thus unnecessary
|
328 |
+
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
|
329 |
+
else:
|
330 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
331 |
+
attn = self.out_proj(attn)
|
332 |
+
attn_weights: Optional[Tensor] = None
|
333 |
+
if need_weights:
|
334 |
+
attn_weights = attn_weights_float.view(
|
335 |
+
bsz, self.num_heads, tgt_len, src_len
|
336 |
+
).transpose(1, 0)
|
337 |
+
if not need_head_weights:
|
338 |
+
# average attention weights over heads
|
339 |
+
attn_weights = attn_weights.mean(dim=0)
|
340 |
+
|
341 |
+
return attn, attn_weights
|
SpeechT5/Speech2C/speech2c/models/modules/relative_pos_enc.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
class RelativePositionalEncoding(torch.nn.Module):
|
13 |
+
def __init__(self, d_model, maxlen=1000, embed_v=False):
|
14 |
+
super(RelativePositionalEncoding, self).__init__()
|
15 |
+
|
16 |
+
self.d_model = d_model
|
17 |
+
self.maxlen = maxlen
|
18 |
+
self.pe_k = torch.nn.Embedding(2*maxlen, d_model)
|
19 |
+
if embed_v:
|
20 |
+
self.pe_v = torch.nn.Embedding(2*maxlen, d_model)
|
21 |
+
self.embed_v = embed_v
|
22 |
+
|
23 |
+
|
24 |
+
def forward(self, pos_seq, incremental_state=None):
|
25 |
+
pos_seq[pos_seq < -self.maxlen] = -self.maxlen
|
26 |
+
pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1
|
27 |
+
pos_seq = pos_seq + self.maxlen
|
28 |
+
|
29 |
+
if incremental_state is not None:
|
30 |
+
pos_seq = pos_seq[-1:]
|
31 |
+
|
32 |
+
if self.embed_v:
|
33 |
+
return self.pe_k(pos_seq), self.pe_v(pos_seq)
|
34 |
+
else:
|
35 |
+
return self.pe_k(pos_seq), None
|
SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
from typing import Any, Dict, List, Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from fairseq import utils
|
16 |
+
from fairseq.distributed import fsdp_wrap
|
17 |
+
from fairseq.models import FairseqIncrementalDecoder
|
18 |
+
from fairseq.models.transformer import TransformerConfig
|
19 |
+
from fairseq.models.transformer.transformer_decoder import module_name_fordropout, Linear
|
20 |
+
from fairseq.modules import (
|
21 |
+
AdaptiveSoftmax,
|
22 |
+
BaseLayer,
|
23 |
+
FairseqDropout,
|
24 |
+
LayerDropModuleList,
|
25 |
+
LayerNorm,
|
26 |
+
PositionalEmbedding,
|
27 |
+
SinusoidalPositionalEmbedding,
|
28 |
+
)
|
29 |
+
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
30 |
+
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
|
31 |
+
from torch import Tensor
|
32 |
+
|
33 |
+
|
34 |
+
from speech2c.models.modules.transformer_decoder_layer import TransformerDecoderLayerBase
|
35 |
+
from speech2c.models.modules.relative_pos_enc import RelativePositionalEncoding
|
36 |
+
|
37 |
+
|
38 |
+
class TransformerDecoderBase(FairseqIncrementalDecoder):
|
39 |
+
"""
|
40 |
+
Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer
|
41 |
+
is a :class:`TransformerDecoderLayer`.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
args (argparse.Namespace): parsed command-line arguments
|
45 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
46 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
47 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
48 |
+
(default: False).
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
cfg,
|
54 |
+
dictionary,
|
55 |
+
embed_tokens,
|
56 |
+
no_encoder_attn=False,
|
57 |
+
output_projection=None,
|
58 |
+
use_rel_pos_enc=False,
|
59 |
+
):
|
60 |
+
self.cfg = cfg
|
61 |
+
super().__init__(dictionary)
|
62 |
+
self.register_buffer("version", torch.Tensor([3]))
|
63 |
+
self._future_mask = torch.empty(0)
|
64 |
+
|
65 |
+
self.dropout_module = FairseqDropout(
|
66 |
+
cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
|
67 |
+
)
|
68 |
+
self.decoder_layerdrop = cfg.decoder.layerdrop
|
69 |
+
self.share_input_output_embed = cfg.share_decoder_input_output_embed
|
70 |
+
|
71 |
+
input_embed_dim = embed_tokens.embedding_dim
|
72 |
+
embed_dim = cfg.decoder.embed_dim
|
73 |
+
self.embed_dim = embed_dim
|
74 |
+
self.output_embed_dim = cfg.decoder.output_dim
|
75 |
+
|
76 |
+
self.padding_idx = embed_tokens.padding_idx
|
77 |
+
self.max_target_positions = cfg.max_target_positions
|
78 |
+
|
79 |
+
self.embed_tokens = embed_tokens
|
80 |
+
|
81 |
+
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
|
82 |
+
|
83 |
+
if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
|
84 |
+
self.quant_noise = apply_quant_noise_(
|
85 |
+
nn.Linear(embed_dim, embed_dim, bias=False),
|
86 |
+
cfg.quant_noise.pq,
|
87 |
+
cfg.quant_noise.pq_block_size,
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
self.quant_noise = None
|
91 |
+
|
92 |
+
self.project_in_dim = (
|
93 |
+
Linear(input_embed_dim, embed_dim, bias=False)
|
94 |
+
if embed_dim != input_embed_dim
|
95 |
+
else None
|
96 |
+
)
|
97 |
+
self.embed_positions = (
|
98 |
+
PositionalEmbedding(
|
99 |
+
self.max_target_positions,
|
100 |
+
embed_dim,
|
101 |
+
self.padding_idx,
|
102 |
+
learned=cfg.decoder.learned_pos,
|
103 |
+
)
|
104 |
+
if not cfg.no_token_positional_embeddings
|
105 |
+
else None
|
106 |
+
)
|
107 |
+
if cfg.layernorm_embedding:
|
108 |
+
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
|
109 |
+
else:
|
110 |
+
self.layernorm_embedding = None
|
111 |
+
|
112 |
+
self.cross_self_attention = cfg.cross_self_attention
|
113 |
+
|
114 |
+
self.use_rel_pos_enc = use_rel_pos_enc
|
115 |
+
if self.decoder_layerdrop > 0.0:
|
116 |
+
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
|
117 |
+
else:
|
118 |
+
self.layers = nn.ModuleList([])
|
119 |
+
self.layers.extend(
|
120 |
+
[
|
121 |
+
self.build_decoder_layer(cfg, no_encoder_attn)
|
122 |
+
for _ in range(cfg.decoder.layers)
|
123 |
+
]
|
124 |
+
)
|
125 |
+
self.num_layers = len(self.layers)
|
126 |
+
|
127 |
+
if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm:
|
128 |
+
self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
|
129 |
+
else:
|
130 |
+
self.layer_norm = None
|
131 |
+
|
132 |
+
self.project_out_dim = (
|
133 |
+
Linear(embed_dim, self.output_embed_dim, bias=False)
|
134 |
+
if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights
|
135 |
+
else None
|
136 |
+
)
|
137 |
+
|
138 |
+
self.adaptive_softmax = None
|
139 |
+
self.output_projection = output_projection
|
140 |
+
if self.output_projection is None:
|
141 |
+
self.build_output_projection(cfg, dictionary, embed_tokens)
|
142 |
+
|
143 |
+
if self.use_rel_pos_enc:
|
144 |
+
self.pos_emb = RelativePositionalEncoding(self.embed_dim // cfg.decoder.attention_heads, 24)
|
145 |
+
|
146 |
+
def build_output_projection(self, cfg, dictionary, embed_tokens):
|
147 |
+
if cfg.adaptive_softmax_cutoff is not None:
|
148 |
+
self.adaptive_softmax = AdaptiveSoftmax(
|
149 |
+
len(dictionary),
|
150 |
+
self.output_embed_dim,
|
151 |
+
utils.eval_str_list(cfg.adaptive_softmax_cutoff, type=int),
|
152 |
+
dropout=cfg.adaptive_softmax_dropout,
|
153 |
+
adaptive_inputs=embed_tokens if cfg.tie_adaptive_weights else None,
|
154 |
+
factor=cfg.adaptive_softmax_factor,
|
155 |
+
tie_proj=cfg.tie_adaptive_proj,
|
156 |
+
)
|
157 |
+
elif self.share_input_output_embed:
|
158 |
+
self.output_projection = nn.Linear(
|
159 |
+
self.embed_tokens.weight.shape[1],
|
160 |
+
self.embed_tokens.weight.shape[0],
|
161 |
+
bias=False,
|
162 |
+
)
|
163 |
+
self.output_projection.weight = self.embed_tokens.weight
|
164 |
+
else:
|
165 |
+
self.output_projection = nn.Linear(
|
166 |
+
self.output_embed_dim, len(dictionary), bias=False
|
167 |
+
)
|
168 |
+
nn.init.normal_(
|
169 |
+
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
|
170 |
+
)
|
171 |
+
num_base_layers = cfg.base_layers
|
172 |
+
for i in range(num_base_layers):
|
173 |
+
self.layers.insert(
|
174 |
+
((i + 1) * cfg.decoder.layers) // (num_base_layers + 1),
|
175 |
+
BaseLayer(cfg),
|
176 |
+
)
|
177 |
+
|
178 |
+
def build_decoder_layer(self, cfg, no_encoder_attn=False):
|
179 |
+
layer = TransformerDecoderLayerBase(cfg, no_encoder_attn, has_relative_attention_bias=self.use_rel_pos_enc)
|
180 |
+
checkpoint = cfg.checkpoint_activations
|
181 |
+
if checkpoint:
|
182 |
+
offload_to_cpu = cfg.offload_activations
|
183 |
+
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
|
184 |
+
# if we are checkpointing, enforce that FSDP always wraps the
|
185 |
+
# checkpointed layer, regardless of layer size
|
186 |
+
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
|
187 |
+
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
|
188 |
+
return layer
|
189 |
+
|
190 |
+
def forward(
|
191 |
+
self,
|
192 |
+
prev_output_tokens,
|
193 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
194 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
195 |
+
features_only: bool = False,
|
196 |
+
full_context_alignment: bool = False,
|
197 |
+
alignment_layer: Optional[int] = None,
|
198 |
+
alignment_heads: Optional[int] = None,
|
199 |
+
src_lengths: Optional[Any] = None,
|
200 |
+
return_all_hiddens: bool = False,
|
201 |
+
):
|
202 |
+
"""
|
203 |
+
Args:
|
204 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
205 |
+
`(batch, tgt_len)`, for teacher forcing
|
206 |
+
encoder_out (optional): output from the encoder, used for
|
207 |
+
encoder-side attention, should be of size T x B x C
|
208 |
+
incremental_state (dict): dictionary used for storing state during
|
209 |
+
:ref:`Incremental decoding`
|
210 |
+
features_only (bool, optional): only return features without
|
211 |
+
applying output layer (default: False).
|
212 |
+
full_context_alignment (bool, optional): don't apply
|
213 |
+
auto-regressive mask to self-attention (default: False).
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
tuple:
|
217 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
218 |
+
- a dictionary with any model-specific outputs
|
219 |
+
"""
|
220 |
+
|
221 |
+
x, extra = self.extract_features(
|
222 |
+
prev_output_tokens,
|
223 |
+
encoder_out=encoder_out,
|
224 |
+
incremental_state=incremental_state,
|
225 |
+
full_context_alignment=full_context_alignment,
|
226 |
+
alignment_layer=alignment_layer,
|
227 |
+
alignment_heads=alignment_heads,
|
228 |
+
)
|
229 |
+
|
230 |
+
if not features_only:
|
231 |
+
x = self.output_layer(x)
|
232 |
+
return x, extra
|
233 |
+
|
234 |
+
def extract_features_scriptable(
|
235 |
+
self,
|
236 |
+
prev_output_tokens,
|
237 |
+
encoder_out: Optional[Dict[str, List[Tensor]]],
|
238 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
239 |
+
full_context_alignment: bool = False,
|
240 |
+
alignment_layer: Optional[int] = None,
|
241 |
+
alignment_heads: Optional[int] = None,
|
242 |
+
):
|
243 |
+
"""
|
244 |
+
Similar to *forward* but only return features.
|
245 |
+
|
246 |
+
Includes several features from "Jointly Learning to Align and
|
247 |
+
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
248 |
+
|
249 |
+
Args:
|
250 |
+
full_context_alignment (bool, optional): don't apply
|
251 |
+
auto-regressive mask to self-attention (default: False).
|
252 |
+
alignment_layer (int, optional): return mean alignment over
|
253 |
+
heads at this layer (default: last layer).
|
254 |
+
alignment_heads (int, optional): only average alignment over
|
255 |
+
this many heads (default: all heads).
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
tuple:
|
259 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
260 |
+
- a dictionary with any model-specific outputs
|
261 |
+
"""
|
262 |
+
bs, slen = prev_output_tokens.size()
|
263 |
+
if alignment_layer is None:
|
264 |
+
alignment_layer = self.num_layers - 1
|
265 |
+
|
266 |
+
enc: Optional[Tensor] = None
|
267 |
+
padding_mask: Optional[Tensor] = None
|
268 |
+
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
|
269 |
+
enc = encoder_out["encoder_out"][0]
|
270 |
+
assert (
|
271 |
+
enc.size()[1] == bs
|
272 |
+
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
|
273 |
+
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
|
274 |
+
padding_mask = encoder_out["encoder_padding_mask"][0]
|
275 |
+
|
276 |
+
# embed positions
|
277 |
+
positions = None
|
278 |
+
if self.embed_positions is not None:
|
279 |
+
positions = self.embed_positions(
|
280 |
+
prev_output_tokens, incremental_state=incremental_state
|
281 |
+
)
|
282 |
+
|
283 |
+
if incremental_state is not None:
|
284 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
285 |
+
if positions is not None:
|
286 |
+
positions = positions[:, -1:]
|
287 |
+
|
288 |
+
# embed tokens and positions
|
289 |
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
290 |
+
|
291 |
+
if self.quant_noise is not None:
|
292 |
+
x = self.quant_noise(x)
|
293 |
+
|
294 |
+
if self.project_in_dim is not None:
|
295 |
+
x = self.project_in_dim(x)
|
296 |
+
|
297 |
+
if positions is not None:
|
298 |
+
x += positions
|
299 |
+
|
300 |
+
if self.layernorm_embedding is not None:
|
301 |
+
x = self.layernorm_embedding(x)
|
302 |
+
|
303 |
+
x = self.dropout_module(x)
|
304 |
+
|
305 |
+
# B x T x C -> T x B x C
|
306 |
+
x = x.transpose(0, 1)
|
307 |
+
if self.use_rel_pos_enc:
|
308 |
+
pos_seq = torch.arange(0, slen).long().to(x.device)
|
309 |
+
pos_seq = pos_seq[:, None] - pos_seq[None, :]
|
310 |
+
pos_k, _ = self.pos_emb(pos_seq, incremental_state)
|
311 |
+
else:
|
312 |
+
pos_k = None
|
313 |
+
|
314 |
+
self_attn_padding_mask: Optional[Tensor] = None
|
315 |
+
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
|
316 |
+
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
|
317 |
+
|
318 |
+
# decoder layers
|
319 |
+
attn: Optional[Tensor] = None
|
320 |
+
inner_states: List[Optional[Tensor]] = [x]
|
321 |
+
for idx, layer in enumerate(self.layers):
|
322 |
+
if incremental_state is None and not full_context_alignment:
|
323 |
+
self_attn_mask = self.buffered_future_mask(x)
|
324 |
+
else:
|
325 |
+
self_attn_mask = None
|
326 |
+
|
327 |
+
x, layer_attn, _ = layer(
|
328 |
+
x,
|
329 |
+
enc,
|
330 |
+
padding_mask,
|
331 |
+
incremental_state,
|
332 |
+
self_attn_mask=self_attn_mask,
|
333 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
334 |
+
need_attn=bool((idx == alignment_layer)),
|
335 |
+
need_head_weights=bool((idx == alignment_layer)),
|
336 |
+
pos_bias=pos_k,
|
337 |
+
)
|
338 |
+
inner_states.append(x)
|
339 |
+
if layer_attn is not None and idx == alignment_layer:
|
340 |
+
attn = layer_attn.float().to(x)
|
341 |
+
|
342 |
+
if attn is not None:
|
343 |
+
if alignment_heads is not None:
|
344 |
+
attn = attn[:alignment_heads]
|
345 |
+
|
346 |
+
# average probabilities over heads
|
347 |
+
attn = attn.mean(dim=0)
|
348 |
+
|
349 |
+
if self.layer_norm is not None:
|
350 |
+
x = self.layer_norm(x)
|
351 |
+
|
352 |
+
# T x B x C -> B x T x C
|
353 |
+
x = x.transpose(0, 1)
|
354 |
+
|
355 |
+
if self.project_out_dim is not None:
|
356 |
+
x = self.project_out_dim(x)
|
357 |
+
|
358 |
+
return x, {"attn": [attn], "inner_states": inner_states}
|
359 |
+
|
360 |
+
def output_layer(self, features):
|
361 |
+
"""Project features to the vocabulary size."""
|
362 |
+
if self.adaptive_softmax is None:
|
363 |
+
# project back to size of vocabulary
|
364 |
+
return self.output_projection(features)
|
365 |
+
else:
|
366 |
+
return features
|
367 |
+
|
368 |
+
def max_positions(self):
|
369 |
+
"""Maximum output length supported by the decoder."""
|
370 |
+
if self.embed_positions is None:
|
371 |
+
return self.max_target_positions
|
372 |
+
return min(self.max_target_positions, self.embed_positions.max_positions)
|
373 |
+
|
374 |
+
def buffered_future_mask(self, tensor):
|
375 |
+
dim = tensor.size(0)
|
376 |
+
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
|
377 |
+
if (
|
378 |
+
self._future_mask.size(0) == 0
|
379 |
+
or (not self._future_mask.device == tensor.device)
|
380 |
+
or self._future_mask.size(0) < dim
|
381 |
+
):
|
382 |
+
self._future_mask = torch.triu(
|
383 |
+
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
|
384 |
+
)
|
385 |
+
self._future_mask = self._future_mask.to(tensor)
|
386 |
+
return self._future_mask[:dim, :dim]
|
387 |
+
|
388 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
389 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
390 |
+
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
|
391 |
+
weights_key = "{}.embed_positions.weights".format(name)
|
392 |
+
if weights_key in state_dict:
|
393 |
+
del state_dict[weights_key]
|
394 |
+
state_dict[
|
395 |
+
"{}.embed_positions._float_tensor".format(name)
|
396 |
+
] = torch.FloatTensor(1)
|
397 |
+
|
398 |
+
if f"{name}.output_projection.weight" not in state_dict:
|
399 |
+
if self.share_input_output_embed:
|
400 |
+
embed_out_key = f"{name}.embed_tokens.weight"
|
401 |
+
else:
|
402 |
+
embed_out_key = f"{name}.embed_out"
|
403 |
+
if embed_out_key in state_dict:
|
404 |
+
state_dict[f"{name}.output_projection.weight"] = state_dict[
|
405 |
+
embed_out_key
|
406 |
+
]
|
407 |
+
if not self.share_input_output_embed:
|
408 |
+
del state_dict[embed_out_key]
|
409 |
+
|
410 |
+
for i in range(self.num_layers):
|
411 |
+
# update layer norms
|
412 |
+
layer_norm_map = {
|
413 |
+
"0": "self_attn_layer_norm",
|
414 |
+
"1": "encoder_attn_layer_norm",
|
415 |
+
"2": "final_layer_norm",
|
416 |
+
}
|
417 |
+
for old, new in layer_norm_map.items():
|
418 |
+
for m in ("weight", "bias"):
|
419 |
+
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
|
420 |
+
if k in state_dict:
|
421 |
+
state_dict[
|
422 |
+
"{}.layers.{}.{}.{}".format(name, i, new, m)
|
423 |
+
] = state_dict[k]
|
424 |
+
del state_dict[k]
|
425 |
+
|
426 |
+
version_key = "{}.version".format(name)
|
427 |
+
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
|
428 |
+
# earlier checkpoints did not normalize after the stack of layers
|
429 |
+
self.layer_norm = None
|
430 |
+
self.normalize = False
|
431 |
+
state_dict[version_key] = torch.Tensor([1])
|
432 |
+
|
433 |
+
return state_dict
|
434 |
+
|
435 |
+
|
436 |
+
class TransformerDecoder(TransformerDecoderBase):
|
437 |
+
def __init__(
|
438 |
+
self,
|
439 |
+
args,
|
440 |
+
dictionary,
|
441 |
+
embed_tokens,
|
442 |
+
no_encoder_attn=False,
|
443 |
+
output_projection=None,
|
444 |
+
):
|
445 |
+
self.args = args
|
446 |
+
super().__init__(
|
447 |
+
TransformerConfig.from_namespace(args),
|
448 |
+
dictionary,
|
449 |
+
embed_tokens,
|
450 |
+
no_encoder_attn=no_encoder_attn,
|
451 |
+
output_projection=output_projection,
|
452 |
+
use_rel_pos_enc=args.use_rel_pos_enc,
|
453 |
+
)
|
454 |
+
|
455 |
+
def build_output_projection(self, args, dictionary, embed_tokens):
|
456 |
+
super().build_output_projection(
|
457 |
+
TransformerConfig.from_namespace(args), dictionary, embed_tokens
|
458 |
+
)
|
459 |
+
|
460 |
+
def build_decoder_layer(self, args, no_encoder_attn=False):
|
461 |
+
return super().build_decoder_layer(
|
462 |
+
TransformerConfig.from_namespace(args), no_encoder_attn=no_encoder_attn
|
463 |
+
)
|
464 |
+
|
465 |
+
class TransformerDecoderScriptable(TransformerDecoder):
|
466 |
+
def extract_features(
|
467 |
+
self,
|
468 |
+
prev_output_tokens,
|
469 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
470 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
471 |
+
full_context_alignment: bool = False,
|
472 |
+
alignment_layer: Optional[int] = None,
|
473 |
+
alignment_heads: Optional[int] = None,
|
474 |
+
):
|
475 |
+
# call scriptable method from parent class
|
476 |
+
x, _ = self.extract_features_scriptable(
|
477 |
+
prev_output_tokens,
|
478 |
+
encoder_out,
|
479 |
+
incremental_state,
|
480 |
+
full_context_alignment,
|
481 |
+
alignment_layer,
|
482 |
+
alignment_heads,
|
483 |
+
)
|
484 |
+
return x, None
|
485 |
+
|
SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder_layer.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
from typing import Dict, List, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import Tensor
|
14 |
+
from fairseq.modules.transformer_layer import TransformerDecoderLayerBase as FairseqTransformerDecoderLayerBase
|
15 |
+
from fairseq.modules import LayerNorm
|
16 |
+
|
17 |
+
from speech2c.models.modules.multihead_attention import MultiheadAttention
|
18 |
+
|
19 |
+
|
20 |
+
class TransformerDecoderLayerBase(FairseqTransformerDecoderLayerBase):
|
21 |
+
"""Decoder layer block.
|
22 |
+
|
23 |
+
In the original paper each operation (multi-head attention, encoder
|
24 |
+
attention or FFN) is postprocessed with: `dropout -> add residual ->
|
25 |
+
layernorm`. In the tensor2tensor code they suggest that learning is more
|
26 |
+
robust when preprocessing each layer with layernorm and postprocessing with:
|
27 |
+
`dropout -> add residual`. We default to the approach in the paper, but the
|
28 |
+
tensor2tensor approach can be enabled by setting
|
29 |
+
*cfg.decoder.normalize_before* to ``True``.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
args (argparse.Namespace): parsed command-line arguments
|
33 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
34 |
+
(default: False).
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, has_relative_attention_bias=False
|
39 |
+
):
|
40 |
+
super().__init__(
|
41 |
+
cfg,
|
42 |
+
no_encoder_attn,
|
43 |
+
add_bias_kv,
|
44 |
+
add_zero_attn,
|
45 |
+
)
|
46 |
+
|
47 |
+
if has_relative_attention_bias:
|
48 |
+
self.norm_k = LayerNorm(self.embed_dim // cfg.decoder.attention_heads)
|
49 |
+
|
50 |
+
def build_self_attention(
|
51 |
+
self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False
|
52 |
+
):
|
53 |
+
return MultiheadAttention(
|
54 |
+
embed_dim,
|
55 |
+
cfg.decoder.attention_heads,
|
56 |
+
dropout=cfg.attention_dropout,
|
57 |
+
add_bias_kv=add_bias_kv,
|
58 |
+
add_zero_attn=add_zero_attn,
|
59 |
+
self_attention=not cfg.cross_self_attention,
|
60 |
+
q_noise=self.quant_noise,
|
61 |
+
qn_block_size=self.quant_noise_block_size,
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(
|
65 |
+
self,
|
66 |
+
x,
|
67 |
+
encoder_out: Optional[torch.Tensor] = None,
|
68 |
+
encoder_padding_mask: Optional[torch.Tensor] = None,
|
69 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
70 |
+
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
|
71 |
+
prev_attn_state: Optional[List[torch.Tensor]] = None,
|
72 |
+
self_attn_mask: Optional[torch.Tensor] = None,
|
73 |
+
self_attn_padding_mask: Optional[torch.Tensor] = None,
|
74 |
+
need_attn: bool = False,
|
75 |
+
need_head_weights: bool = False,
|
76 |
+
pos_bias=None,
|
77 |
+
):
|
78 |
+
"""
|
79 |
+
Args:
|
80 |
+
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
81 |
+
encoder_padding_mask (ByteTensor, optional): binary
|
82 |
+
ByteTensor of shape `(batch, src_len)` where padding
|
83 |
+
elements are indicated by ``1``.
|
84 |
+
need_attn (bool, optional): return attention weights
|
85 |
+
need_head_weights (bool, optional): return attention weights
|
86 |
+
for each head (default: return average over heads).
|
87 |
+
Returns:
|
88 |
+
encoded output of shape `(seq_len, batch, embed_dim)`
|
89 |
+
"""
|
90 |
+
if need_head_weights:
|
91 |
+
need_attn = True
|
92 |
+
|
93 |
+
residual = x
|
94 |
+
if self.normalize_before:
|
95 |
+
x = self.self_attn_layer_norm(x)
|
96 |
+
if pos_bias is not None:
|
97 |
+
pos_bias = self.norm_k(pos_bias)
|
98 |
+
if prev_self_attn_state is not None:
|
99 |
+
prev_key, prev_value = prev_self_attn_state[:2]
|
100 |
+
saved_state: Dict[str, Optional[Tensor]] = {
|
101 |
+
"prev_key": prev_key,
|
102 |
+
"prev_value": prev_value,
|
103 |
+
}
|
104 |
+
if len(prev_self_attn_state) >= 3:
|
105 |
+
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
|
106 |
+
assert incremental_state is not None
|
107 |
+
self.self_attn._set_input_buffer(incremental_state, saved_state)
|
108 |
+
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
|
109 |
+
if self.cross_self_attention and not (
|
110 |
+
incremental_state is not None
|
111 |
+
and _self_attn_input_buffer is not None
|
112 |
+
and "prev_key" in _self_attn_input_buffer
|
113 |
+
):
|
114 |
+
if self_attn_mask is not None:
|
115 |
+
assert encoder_out is not None
|
116 |
+
self_attn_mask = torch.cat(
|
117 |
+
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
|
118 |
+
)
|
119 |
+
if self_attn_padding_mask is not None:
|
120 |
+
if encoder_padding_mask is None:
|
121 |
+
assert encoder_out is not None
|
122 |
+
encoder_padding_mask = self_attn_padding_mask.new_zeros(
|
123 |
+
encoder_out.size(1), encoder_out.size(0)
|
124 |
+
)
|
125 |
+
self_attn_padding_mask = torch.cat(
|
126 |
+
(encoder_padding_mask, self_attn_padding_mask), dim=1
|
127 |
+
)
|
128 |
+
assert encoder_out is not None
|
129 |
+
y = torch.cat((encoder_out, x), dim=0)
|
130 |
+
else:
|
131 |
+
y = x
|
132 |
+
|
133 |
+
x, attn = self.self_attn(
|
134 |
+
query=x,
|
135 |
+
key=y,
|
136 |
+
value=y,
|
137 |
+
key_padding_mask=self_attn_padding_mask,
|
138 |
+
incremental_state=incremental_state,
|
139 |
+
need_weights=False,
|
140 |
+
attn_mask=self_attn_mask,
|
141 |
+
position_bias=pos_bias,
|
142 |
+
)
|
143 |
+
if self.c_attn is not None:
|
144 |
+
tgt_len, bsz = x.size(0), x.size(1)
|
145 |
+
x = x.view(tgt_len, bsz, self.nh, self.head_dim)
|
146 |
+
x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
|
147 |
+
x = x.reshape(tgt_len, bsz, self.embed_dim)
|
148 |
+
if self.attn_ln is not None:
|
149 |
+
x = self.attn_ln(x)
|
150 |
+
x = self.dropout_module(x)
|
151 |
+
x = self.residual_connection(x, residual)
|
152 |
+
if not self.normalize_before:
|
153 |
+
x = self.self_attn_layer_norm(x)
|
154 |
+
|
155 |
+
if self.encoder_attn is not None and encoder_out is not None:
|
156 |
+
residual = x
|
157 |
+
if self.normalize_before:
|
158 |
+
x = self.encoder_attn_layer_norm(x)
|
159 |
+
if prev_attn_state is not None:
|
160 |
+
prev_key, prev_value = prev_attn_state[:2]
|
161 |
+
saved_state: Dict[str, Optional[Tensor]] = {
|
162 |
+
"prev_key": prev_key,
|
163 |
+
"prev_value": prev_value,
|
164 |
+
}
|
165 |
+
if len(prev_attn_state) >= 3:
|
166 |
+
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
|
167 |
+
assert incremental_state is not None
|
168 |
+
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
|
169 |
+
|
170 |
+
x, attn = self.encoder_attn(
|
171 |
+
query=x,
|
172 |
+
key=encoder_out,
|
173 |
+
value=encoder_out,
|
174 |
+
key_padding_mask=encoder_padding_mask,
|
175 |
+
incremental_state=incremental_state,
|
176 |
+
static_kv=True,
|
177 |
+
need_weights=need_attn or (not self.training and self.need_attn),
|
178 |
+
need_head_weights=need_head_weights,
|
179 |
+
)
|
180 |
+
x = self.dropout_module(x)
|
181 |
+
x = self.residual_connection(x, residual)
|
182 |
+
if not self.normalize_before:
|
183 |
+
x = self.encoder_attn_layer_norm(x)
|
184 |
+
|
185 |
+
residual = x
|
186 |
+
if self.normalize_before:
|
187 |
+
x = self.final_layer_norm(x)
|
188 |
+
|
189 |
+
x = self.activation_fn(self.fc1(x))
|
190 |
+
x = self.activation_dropout_module(x)
|
191 |
+
if self.ffn_layernorm is not None:
|
192 |
+
x = self.ffn_layernorm(x)
|
193 |
+
x = self.fc2(x)
|
194 |
+
x = self.dropout_module(x)
|
195 |
+
if self.w_resid is not None:
|
196 |
+
residual = torch.mul(self.w_resid, residual)
|
197 |
+
x = self.residual_connection(x, residual)
|
198 |
+
if not self.normalize_before:
|
199 |
+
x = self.final_layer_norm(x)
|
200 |
+
if self.onnx_trace and incremental_state is not None:
|
201 |
+
saved_state = self.self_attn._get_input_buffer(incremental_state)
|
202 |
+
assert saved_state is not None
|
203 |
+
if self_attn_padding_mask is not None:
|
204 |
+
self_attn_state = [
|
205 |
+
saved_state["prev_key"],
|
206 |
+
saved_state["prev_value"],
|
207 |
+
saved_state["prev_key_padding_mask"],
|
208 |
+
]
|
209 |
+
else:
|
210 |
+
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
|
211 |
+
return x, attn, self_attn_state
|
212 |
+
return x, attn, None
|
213 |
+
|
214 |
+
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
|
215 |
+
self.need_attn = need_attn
|
SpeechT5/Speech2C/speech2c/models/modules/transformer_encoder.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from fairseq import utils
|
17 |
+
from fairseq.dataclass import ChoiceEnum
|
18 |
+
from fairseq.modules import (
|
19 |
+
LayerNorm,
|
20 |
+
MultiheadAttention,
|
21 |
+
SamePad,
|
22 |
+
)
|
23 |
+
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
24 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
25 |
+
from fairseq.utils import index_put
|
26 |
+
from fairseq.distributed import fsdp_wrap
|
27 |
+
from fairseq.models.wav2vec.utils import pad_to_multiple
|
28 |
+
from fairseq.models.wav2vec.wav2vec2 import TransformerEncoder as W2vTransformerEncoder
|
29 |
+
|
30 |
+
from speech2c.models.modules.relative_pos_enc import RelativePositionalEncoding
|
31 |
+
from speech2c.models.modules.multihead_attention import MultiheadAttention
|
32 |
+
|
33 |
+
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
|
34 |
+
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
|
35 |
+
|
36 |
+
|
37 |
+
class TransformerEncoder(W2vTransformerEncoder):
|
38 |
+
def __init__(self, args):
|
39 |
+
super().__init__(args)
|
40 |
+
|
41 |
+
self.dropout = args.dropout
|
42 |
+
self.embedding_dim = args.encoder_embed_dim
|
43 |
+
self.required_seq_len_multiple = args.required_seq_len_multiple
|
44 |
+
self.use_rel_pos_enc = getattr(args, "use_rel_pos_enc", False)
|
45 |
+
|
46 |
+
self.pos_conv = nn.Conv1d(
|
47 |
+
self.embedding_dim,
|
48 |
+
self.embedding_dim,
|
49 |
+
kernel_size=args.conv_pos,
|
50 |
+
padding=args.conv_pos // 2,
|
51 |
+
groups=args.conv_pos_groups,
|
52 |
+
)
|
53 |
+
dropout = 0
|
54 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
55 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
56 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
57 |
+
|
58 |
+
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
59 |
+
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
60 |
+
|
61 |
+
layers = []
|
62 |
+
for _ in range(args.encoder_layers):
|
63 |
+
layer = TransformerSentenceEncoderLayer(
|
64 |
+
embedding_dim=self.embedding_dim,
|
65 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
66 |
+
num_attention_heads=args.encoder_attention_heads,
|
67 |
+
dropout=self.dropout,
|
68 |
+
attention_dropout=args.attention_dropout,
|
69 |
+
activation_dropout=args.activation_dropout,
|
70 |
+
activation_fn=args.activation_fn,
|
71 |
+
layer_norm_first=args.layer_norm_first,
|
72 |
+
has_relative_attention_bias=self.use_rel_pos_enc,
|
73 |
+
)
|
74 |
+
if args.checkpoint_activations:
|
75 |
+
layer = fsdp_wrap(layer)
|
76 |
+
layer = checkpoint_wrapper(layer)
|
77 |
+
layers.append(layer)
|
78 |
+
self.layers = nn.ModuleList(layers)
|
79 |
+
|
80 |
+
self.layer_norm_first = args.layer_norm_first
|
81 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
82 |
+
self.layerdrop = args.encoder_layerdrop
|
83 |
+
if self.use_rel_pos_enc:
|
84 |
+
self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim // args.encoder_attention_heads, 160)
|
85 |
+
|
86 |
+
|
87 |
+
self.apply(init_bert_params)
|
88 |
+
|
89 |
+
def forward(self, x, padding_mask=None, layer=None):
|
90 |
+
x, layer_results = self.extract_features(x, padding_mask, layer)
|
91 |
+
|
92 |
+
if self.layer_norm_first and layer is None:
|
93 |
+
x = self.layer_norm(x)
|
94 |
+
|
95 |
+
return x, layer_results
|
96 |
+
|
97 |
+
def extract_features(self, x, padding_mask=None, tgt_layer=None):
|
98 |
+
|
99 |
+
if padding_mask is not None:
|
100 |
+
x = index_put(x, padding_mask, 0)
|
101 |
+
|
102 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
103 |
+
x_conv = x_conv.transpose(1, 2)
|
104 |
+
x = x + x_conv
|
105 |
+
|
106 |
+
if not self.layer_norm_first:
|
107 |
+
x = self.layer_norm(x)
|
108 |
+
|
109 |
+
# pad to the sequence length dimension
|
110 |
+
x, pad_length = pad_to_multiple(
|
111 |
+
x, self.required_seq_len_multiple, dim=-2, value=0
|
112 |
+
)
|
113 |
+
if pad_length > 0 and padding_mask is None:
|
114 |
+
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
115 |
+
padding_mask[:, -pad_length:] = True
|
116 |
+
else:
|
117 |
+
padding_mask, _ = pad_to_multiple(
|
118 |
+
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
|
119 |
+
)
|
120 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
121 |
+
|
122 |
+
# B x T x C -> T x B x C
|
123 |
+
x = x.transpose(0, 1)
|
124 |
+
|
125 |
+
if self.use_rel_pos_enc:
|
126 |
+
x_len = x.shape[0]
|
127 |
+
pos_seq = torch.arange(0, x_len).long().to(x.device)
|
128 |
+
pos_seq = pos_seq[:, None] - pos_seq[None, :]
|
129 |
+
pos_k, pos_v = self.pos_emb(pos_seq)
|
130 |
+
else:
|
131 |
+
pos_k = None
|
132 |
+
|
133 |
+
layer_results = []
|
134 |
+
r = None
|
135 |
+
for i, layer in enumerate(self.layers):
|
136 |
+
dropout_probability = np.random.random()
|
137 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
138 |
+
x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_k)
|
139 |
+
if tgt_layer is not None:
|
140 |
+
# unpad if needed
|
141 |
+
if pad_length > 0:
|
142 |
+
layer_results.append(
|
143 |
+
(
|
144 |
+
x[:-pad_length],
|
145 |
+
z[:, :-pad_length, :-pad_length]
|
146 |
+
if z is not None
|
147 |
+
else z,
|
148 |
+
)
|
149 |
+
)
|
150 |
+
else:
|
151 |
+
layer_results.append((x, z))
|
152 |
+
if i == tgt_layer:
|
153 |
+
r = x
|
154 |
+
break
|
155 |
+
|
156 |
+
if r is not None:
|
157 |
+
x = r
|
158 |
+
|
159 |
+
# T x B x C -> B x T x C
|
160 |
+
x = x.transpose(0, 1)
|
161 |
+
# undo paddding
|
162 |
+
if pad_length > 0:
|
163 |
+
x = x[:, :-pad_length]
|
164 |
+
|
165 |
+
return x, layer_results
|
166 |
+
|
167 |
+
|
168 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
169 |
+
"""
|
170 |
+
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
171 |
+
models.
|
172 |
+
"""
|
173 |
+
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
embedding_dim: float = 768,
|
177 |
+
ffn_embedding_dim: float = 3072,
|
178 |
+
num_attention_heads: float = 8,
|
179 |
+
dropout: float = 0.1,
|
180 |
+
attention_dropout: float = 0.1,
|
181 |
+
activation_dropout: float = 0.1,
|
182 |
+
activation_fn: str = "relu",
|
183 |
+
layer_norm_first: bool = False,
|
184 |
+
has_relative_attention_bias: bool = False,
|
185 |
+
) -> None:
|
186 |
+
|
187 |
+
super().__init__()
|
188 |
+
# Initialize parameters
|
189 |
+
self.embedding_dim = embedding_dim
|
190 |
+
self.dropout = dropout
|
191 |
+
self.activation_dropout = activation_dropout
|
192 |
+
|
193 |
+
# Initialize blocks
|
194 |
+
self.activation_fn = utils.get_activation_fn(activation_fn)
|
195 |
+
self.self_attn = MultiheadAttention(
|
196 |
+
self.embedding_dim,
|
197 |
+
num_attention_heads,
|
198 |
+
dropout=attention_dropout,
|
199 |
+
self_attention=True,
|
200 |
+
)
|
201 |
+
|
202 |
+
self.dropout1 = nn.Dropout(dropout)
|
203 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
204 |
+
self.dropout3 = nn.Dropout(dropout)
|
205 |
+
|
206 |
+
self.layer_norm_first = layer_norm_first
|
207 |
+
|
208 |
+
# layer norm associated with the self attention layer
|
209 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
210 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
211 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
212 |
+
|
213 |
+
# layer norm associated with the position wise feed-forward NN
|
214 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
215 |
+
|
216 |
+
if has_relative_attention_bias:
|
217 |
+
self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads)
|
218 |
+
|
219 |
+
def forward(
|
220 |
+
self,
|
221 |
+
x: torch.Tensor,
|
222 |
+
self_attn_mask: torch.Tensor = None,
|
223 |
+
self_attn_padding_mask: torch.Tensor = None,
|
224 |
+
need_weights: bool = False,
|
225 |
+
att_args=None,
|
226 |
+
pos_bias=None,
|
227 |
+
):
|
228 |
+
"""
|
229 |
+
LayerNorm is applied either before or after the self-attention/ffn
|
230 |
+
modules similar to the original Transformer imlementation.
|
231 |
+
"""
|
232 |
+
residual = x
|
233 |
+
|
234 |
+
if self.layer_norm_first:
|
235 |
+
x = self.self_attn_layer_norm(x)
|
236 |
+
if pos_bias is not None:
|
237 |
+
pos_bias = self.norm_k(pos_bias)
|
238 |
+
x, attn = self.self_attn(
|
239 |
+
query=x,
|
240 |
+
key=x,
|
241 |
+
value=x,
|
242 |
+
key_padding_mask=self_attn_padding_mask,
|
243 |
+
attn_mask=self_attn_mask,
|
244 |
+
position_bias=pos_bias,
|
245 |
+
)
|
246 |
+
x = self.dropout1(x)
|
247 |
+
x = residual + x
|
248 |
+
|
249 |
+
residual = x
|
250 |
+
x = self.final_layer_norm(x)
|
251 |
+
x = self.activation_fn(self.fc1(x))
|
252 |
+
x = self.dropout2(x)
|
253 |
+
x = self.fc2(x)
|
254 |
+
x = self.dropout3(x)
|
255 |
+
x = residual + x
|
256 |
+
else:
|
257 |
+
x, attn = self.self_attn(
|
258 |
+
query=x,
|
259 |
+
key=x,
|
260 |
+
value=x,
|
261 |
+
key_padding_mask=self_attn_padding_mask,
|
262 |
+
position_bias=pos_bias,
|
263 |
+
)
|
264 |
+
|
265 |
+
x = self.dropout1(x)
|
266 |
+
x = residual + x
|
267 |
+
|
268 |
+
x = self.self_attn_layer_norm(x)
|
269 |
+
|
270 |
+
residual = x
|
271 |
+
x = self.activation_fn(self.fc1(x))
|
272 |
+
x = self.dropout2(x)
|
273 |
+
x = self.fc2(x)
|
274 |
+
x = self.dropout3(x)
|
275 |
+
x = residual + x
|
276 |
+
x = self.final_layer_norm(x)
|
277 |
+
|
278 |
+
return x, attn
|
SpeechT5/Speech2C/speech2c/models/speech2c.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import copy
|
12 |
+
import contextlib
|
13 |
+
from typing import Dict, List, Optional, Tuple
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from dataclasses import dataclass, field
|
17 |
+
from fairseq.data.dictionary import Dictionary
|
18 |
+
from fairseq.models import register_model
|
19 |
+
from fairseq.models.hubert import HubertConfig, HubertModel
|
20 |
+
from fairseq.models.transformer import Embedding
|
21 |
+
from torch import Tensor
|
22 |
+
from speech2c.tasks.speech2c_pretraining import (
|
23 |
+
Speech2cPretrainingConfig,
|
24 |
+
Speech2cPretrainingTask,
|
25 |
+
)
|
26 |
+
|
27 |
+
from speech2c.models.modules.transformer_decoder import TransformerDecoderScriptable
|
28 |
+
from speech2c.models.modules.transformer_encoder import TransformerEncoder
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class Speech2cConfig(HubertConfig):
|
35 |
+
use_rel_pos_enc: bool = field(
|
36 |
+
default=False,
|
37 |
+
metadata={"help": "whether to use relative positional encoding"},
|
38 |
+
)
|
39 |
+
|
40 |
+
# decoder
|
41 |
+
decoder_layers: int = field(
|
42 |
+
default=6, metadata={"help": "num decoder layers in the transformer"}
|
43 |
+
)
|
44 |
+
decoder_embed_dim: int = field(
|
45 |
+
default=768, metadata={"help": "decoder embedding dimension"}
|
46 |
+
)
|
47 |
+
decoder_ffn_embed_dim: int = field(
|
48 |
+
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
|
49 |
+
)
|
50 |
+
decoder_attention_heads: int = field(
|
51 |
+
default=12, metadata={"help": "num decoder attention heads"}
|
52 |
+
)
|
53 |
+
decoder_normalize_before: bool = field(
|
54 |
+
default=False,
|
55 |
+
metadata={"help": "apply layernorm before each decoder block"},
|
56 |
+
)
|
57 |
+
decoder_layerdrop: float = field(
|
58 |
+
default=0.0,
|
59 |
+
metadata={"help": "probability of dropping a tarnsformer layer"},
|
60 |
+
)
|
61 |
+
share_decoder_input_output_embed: bool = field(
|
62 |
+
default=False,
|
63 |
+
metadata={"help": "share decoder input and output embeddings"},
|
64 |
+
)
|
65 |
+
decoder_output_dim: int = field(
|
66 |
+
default=768, metadata={"help": "decoder output dimension"}
|
67 |
+
)
|
68 |
+
max_target_positions: int = field(
|
69 |
+
default=3000, metadata={"help": "max target position"}
|
70 |
+
)
|
71 |
+
no_scale_embedding: bool = field(
|
72 |
+
default=False,
|
73 |
+
metadata={"help": "not scale embedding"},
|
74 |
+
)
|
75 |
+
adaptive_input: bool = field(
|
76 |
+
default=False,
|
77 |
+
metadata={"help": "adaptive input"},
|
78 |
+
)
|
79 |
+
quant_noise_pq: int = field(
|
80 |
+
default=0, metadata={"help": "quant noise pq"}
|
81 |
+
)
|
82 |
+
decoder_learned_pos: bool = field(
|
83 |
+
default=False,
|
84 |
+
metadata={"help": "decoder learnable positional embedding"},
|
85 |
+
)
|
86 |
+
no_token_positional_embeddings: bool = field(
|
87 |
+
default=False,
|
88 |
+
metadata={"help": "no token positional embeddings"},
|
89 |
+
)
|
90 |
+
decoder_dict_size: int = field(
|
91 |
+
default=-1,
|
92 |
+
metadata={"help": "decoder dictionary dimension, only used for fine-tuning"},
|
93 |
+
)
|
94 |
+
|
95 |
+
# FP16 optimization
|
96 |
+
required_seq_len_multiple: int = field(
|
97 |
+
default=1,
|
98 |
+
metadata={
|
99 |
+
"help": "pad the input to encoder such that the sequence length is divisible by multiple"
|
100 |
+
},
|
101 |
+
)
|
102 |
+
crop_seq_to_multiple: int = field(
|
103 |
+
default=1,
|
104 |
+
metadata={
|
105 |
+
"help": "crop convolutional feature extractor output such that the sequence length is divisible by multiple"
|
106 |
+
},
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
@register_model("speech2c", dataclass=Speech2cConfig)
|
111 |
+
class Speech2cModel(HubertModel):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
cfg: Speech2cConfig,
|
115 |
+
task_cfg: Speech2cPretrainingConfig,
|
116 |
+
dictionaries: List[Dictionary],
|
117 |
+
) -> None:
|
118 |
+
super().__init__(cfg, task_cfg, dictionaries)
|
119 |
+
logger.info(f"Speech2cModel Config: {cfg}")
|
120 |
+
|
121 |
+
self.encoder = TransformerEncoder(cfg)
|
122 |
+
|
123 |
+
self.add_decoder = task_cfg.add_decoder
|
124 |
+
if task_cfg.add_decoder:
|
125 |
+
def build_embedding(dictionary, embed_dim):
|
126 |
+
num_embeddings = len(dictionary)
|
127 |
+
padding_idx = dictionary.pad()
|
128 |
+
return Embedding(num_embeddings, embed_dim, padding_idx)
|
129 |
+
|
130 |
+
# To make sure that the decoder dict size is the same as the fine-tuning tgt_dict size
|
131 |
+
cut_dictionary = copy.deepcopy(dictionaries[0])
|
132 |
+
if cfg.decoder_dict_size != -1:
|
133 |
+
cut_dictionary.symbols = cut_dictionary.symbols[:cfg.decoder_dict_size]
|
134 |
+
|
135 |
+
decoder_embed_tokens = build_embedding(
|
136 |
+
cut_dictionary, cfg.decoder_embed_dim
|
137 |
+
)
|
138 |
+
|
139 |
+
self.decoder = TransformerDecoderScriptable(cfg, cut_dictionary, decoder_embed_tokens)
|
140 |
+
|
141 |
+
|
142 |
+
@classmethod
|
143 |
+
def build_model(cls, cfg: Speech2cConfig, task: Speech2cPretrainingTask):
|
144 |
+
"""Build a new model instance."""
|
145 |
+
|
146 |
+
model = Speech2cModel(cfg, task.cfg, task.dictionaries)
|
147 |
+
return model
|
148 |
+
|
149 |
+
def get_normalized_probs(
|
150 |
+
self,
|
151 |
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
152 |
+
log_probs: bool,
|
153 |
+
sample: Optional[Dict[str, Tensor]] = None,
|
154 |
+
):
|
155 |
+
# net_output['encoder_out'] is a (B, T, D) tensor
|
156 |
+
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
157 |
+
lprobs.batch_first = True
|
158 |
+
return lprobs
|
159 |
+
|
160 |
+
def forward(
|
161 |
+
self,
|
162 |
+
source: torch.Tensor,
|
163 |
+
target_list: Optional[List[torch.Tensor]] = None,
|
164 |
+
padding_mask: Optional[torch.Tensor] = None,
|
165 |
+
mask: bool = True,
|
166 |
+
features_only: bool = False,
|
167 |
+
output_layer: Optional[int] = None,
|
168 |
+
prev_output_tokens: Optional[torch.Tensor] = None,
|
169 |
+
) -> Dict[str, torch.Tensor]:
|
170 |
+
"""output layer is 1-based"""
|
171 |
+
features = self.forward_features(source)
|
172 |
+
if target_list is not None:
|
173 |
+
features, target_list = self.forward_targets(features, target_list)
|
174 |
+
|
175 |
+
features_pen = features.float().pow(2).mean()
|
176 |
+
|
177 |
+
features = features.transpose(1, 2)
|
178 |
+
features = self.layer_norm(features)
|
179 |
+
unmasked_features = features.clone()
|
180 |
+
|
181 |
+
if padding_mask is not None:
|
182 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
183 |
+
|
184 |
+
if self.post_extract_proj is not None:
|
185 |
+
features = self.post_extract_proj(features)
|
186 |
+
|
187 |
+
features = self.dropout_input(features)
|
188 |
+
unmasked_features = self.dropout_features(unmasked_features)
|
189 |
+
|
190 |
+
if mask:
|
191 |
+
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
|
192 |
+
else:
|
193 |
+
x = features
|
194 |
+
mask_indices = None
|
195 |
+
|
196 |
+
# feature: (B, T, D), float
|
197 |
+
# target: (B, T), long
|
198 |
+
# x: (B, T, D), float
|
199 |
+
# padding_mask: (B, T), bool
|
200 |
+
# mask_indices: (B, T), bool
|
201 |
+
x, _ = self.encoder(
|
202 |
+
x,
|
203 |
+
padding_mask=padding_mask,
|
204 |
+
layer=None if output_layer is None else output_layer - 1,
|
205 |
+
)
|
206 |
+
|
207 |
+
if features_only:
|
208 |
+
return {"x": x, "padding_mask": padding_mask, "features": features}
|
209 |
+
|
210 |
+
def compute_pred(proj_x, target, label_embs):
|
211 |
+
# compute logits for the i-th label set
|
212 |
+
y = torch.index_select(label_embs, 0, target.long())
|
213 |
+
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
|
214 |
+
if self.target_glu:
|
215 |
+
y = self.target_glu(y)
|
216 |
+
negs = self.target_glu(negs)
|
217 |
+
# proj_x: (S, D)
|
218 |
+
# y: (S, D)
|
219 |
+
# negs: (Neg, S, D)
|
220 |
+
return self.compute_nce(proj_x, y, negs)
|
221 |
+
|
222 |
+
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
|
223 |
+
|
224 |
+
if not self.skip_masked:
|
225 |
+
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
226 |
+
proj_x_m = self.final_proj(x[masked_indices])
|
227 |
+
if self.untie_final_proj:
|
228 |
+
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
|
229 |
+
else:
|
230 |
+
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
|
231 |
+
logit_m_list = [
|
232 |
+
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
|
233 |
+
for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
|
234 |
+
]
|
235 |
+
else:
|
236 |
+
logit_m_list = [None for _ in target_list]
|
237 |
+
|
238 |
+
if not self.skip_nomask:
|
239 |
+
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
240 |
+
proj_x_u = self.final_proj(x[nomask_indices])
|
241 |
+
if self.untie_final_proj:
|
242 |
+
proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
|
243 |
+
else:
|
244 |
+
proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
|
245 |
+
|
246 |
+
logit_u_list = [
|
247 |
+
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
|
248 |
+
for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))
|
249 |
+
]
|
250 |
+
else:
|
251 |
+
logit_u_list = [None for _ in target_list]
|
252 |
+
|
253 |
+
result = {
|
254 |
+
"logit_m_list": logit_m_list,
|
255 |
+
"logit_u_list": logit_u_list,
|
256 |
+
"padding_mask": padding_mask,
|
257 |
+
"features_pen": features_pen,
|
258 |
+
}
|
259 |
+
if self.add_decoder:
|
260 |
+
encoder_out = {
|
261 |
+
"encoder_out": [x.transpose(0, 1)], # T x B x C
|
262 |
+
"encoder_padding_mask": [padding_mask], # B x T
|
263 |
+
}
|
264 |
+
assert prev_output_tokens is not None
|
265 |
+
decoder_out = self.decoder(
|
266 |
+
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
|
267 |
+
)
|
268 |
+
result['decoder_out'] = decoder_out
|
269 |
+
return result
|
270 |
+
|
271 |
+
def forward_torchscript(self, net_input: Dict[str, Tensor]):
|
272 |
+
"""A TorchScript-compatible version of forward.
|
273 |
+
Encoders which use additional arguments may want to override
|
274 |
+
this method for TorchScript compatibility.
|
275 |
+
"""
|
276 |
+
res = self.forward(
|
277 |
+
net_input["source"],
|
278 |
+
padding_mask=net_input["padding_mask"],
|
279 |
+
mask=False,
|
280 |
+
features_only=True
|
281 |
+
)
|
282 |
+
|
283 |
+
encoder_out = {
|
284 |
+
"encoder_out": [res["x"].transpose(0, 1)], # T x B x C
|
285 |
+
"encoder_padding_mask": [res["padding_mask"]], # B x T
|
286 |
+
}
|
287 |
+
return encoder_out
|
288 |
+
|
289 |
+
def extract_features(
|
290 |
+
self,
|
291 |
+
source: torch.Tensor,
|
292 |
+
padding_mask: Optional[torch.Tensor] = None,
|
293 |
+
mask: bool = False,
|
294 |
+
ret_conv: bool = False,
|
295 |
+
output_layer: Optional[int] = None,
|
296 |
+
prev_output_tokens: Optional[torch.Tensor] = None,
|
297 |
+
ft: bool = True,
|
298 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
299 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
300 |
+
res = self.forward(
|
301 |
+
source,
|
302 |
+
padding_mask=padding_mask,
|
303 |
+
mask=mask,
|
304 |
+
features_only=True,
|
305 |
+
output_layer=output_layer,
|
306 |
+
)
|
307 |
+
|
308 |
+
feature = res["features"] if ret_conv else res["x"]
|
309 |
+
if self.add_decoder:
|
310 |
+
encoder_out = {
|
311 |
+
"encoder_out": [feature.transpose(0, 1)], # T x B x C
|
312 |
+
"encoder_padding_mask": [res["padding_mask"]], # B x T
|
313 |
+
}
|
314 |
+
assert prev_output_tokens is not None
|
315 |
+
decoder_out = self.decoder(
|
316 |
+
prev_output_tokens=prev_output_tokens,
|
317 |
+
encoder_out=encoder_out,
|
318 |
+
)
|
319 |
+
else:
|
320 |
+
decoder_out = None
|
321 |
+
return feature, res["padding_mask"], decoder_out
|
SpeechT5/Speech2C/speech2c/models/speech2c_asr.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
from argparse import Namespace
|
11 |
+
from omegaconf import II
|
12 |
+
|
13 |
+
import torch.nn as nn
|
14 |
+
from dataclasses import dataclass, field
|
15 |
+
from fairseq import checkpoint_utils, tasks, utils
|
16 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
17 |
+
from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
|
18 |
+
from fairseq.models.hubert.hubert_asr import HubertAsrConfig, Linear
|
19 |
+
from fairseq.tasks import FairseqTask
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class Speech2cAsrConfig(HubertAsrConfig):
|
24 |
+
# for decoder
|
25 |
+
decoder_layerdrop: float = field(
|
26 |
+
default=0.0,
|
27 |
+
metadata={"help": "probability of dropping a decoder layer in hubert"},
|
28 |
+
)
|
29 |
+
|
30 |
+
add_decoder: bool = II("task.add_decoder")
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class Speech2cCtcConfig(Speech2cAsrConfig):
|
34 |
+
pass
|
35 |
+
|
36 |
+
|
37 |
+
@register_model("speech2c_ctc", dataclass=Speech2cCtcConfig)
|
38 |
+
class Speech2cCtc(BaseFairseqModel):
|
39 |
+
def __init__(self, cfg: Speech2cCtcConfig, w2v_encoder: BaseFairseqModel):
|
40 |
+
super().__init__()
|
41 |
+
self.cfg = cfg
|
42 |
+
self.w2v_encoder = w2v_encoder
|
43 |
+
|
44 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
45 |
+
super().upgrade_state_dict_named(state_dict, name)
|
46 |
+
return state_dict
|
47 |
+
|
48 |
+
@classmethod
|
49 |
+
def build_model(cls, cfg: Speech2cCtcConfig, task: FairseqTask):
|
50 |
+
"""Build a new model instance."""
|
51 |
+
w2v_encoder = Speech2cEncoder(cfg, task.target_dictionary)
|
52 |
+
return cls(cfg, w2v_encoder)
|
53 |
+
|
54 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
55 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
56 |
+
if "encoder_out" not in net_output:
|
57 |
+
return self.w2v_encoder.get_normalized_probs_decoder(net_output, log_probs, sample)
|
58 |
+
|
59 |
+
if "encoder_out_for_ctc" in net_output:
|
60 |
+
logits = net_output["encoder_out_for_ctc"]
|
61 |
+
else:
|
62 |
+
logits = net_output["encoder_out"]
|
63 |
+
|
64 |
+
if isinstance(logits, list):
|
65 |
+
logits = logits[0]
|
66 |
+
|
67 |
+
if log_probs:
|
68 |
+
return utils.log_softmax(logits.float(), dim=-1)
|
69 |
+
else:
|
70 |
+
return utils.softmax(logits.float(), dim=-1)
|
71 |
+
|
72 |
+
def get_logits(self, net_output):
|
73 |
+
logits = net_output["encoder_out"]
|
74 |
+
padding = net_output["encoder_padding_mask"]
|
75 |
+
if padding is not None and padding.any():
|
76 |
+
padding = padding.T
|
77 |
+
logits[padding][..., 0] = 0
|
78 |
+
logits[padding][..., 1:] = float("-inf")
|
79 |
+
|
80 |
+
return logits
|
81 |
+
|
82 |
+
def forward(self, **kwargs):
|
83 |
+
x = self.w2v_encoder(**kwargs)
|
84 |
+
return x
|
85 |
+
|
86 |
+
@property
|
87 |
+
def encoder(self):
|
88 |
+
return self.w2v_encoder
|
89 |
+
|
90 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
91 |
+
return self.encoder.reorder_encoder_out(encoder_out, new_order)
|
92 |
+
|
93 |
+
@property
|
94 |
+
def decoder(self):
|
95 |
+
return self.w2v_encoder.w2v_model.decoder
|
96 |
+
|
97 |
+
|
98 |
+
class Speech2cEncoder(FairseqEncoder):
|
99 |
+
def __init__(self, cfg: Speech2cAsrConfig, tgt_dict=None):
|
100 |
+
self.apply_mask = cfg.apply_mask
|
101 |
+
|
102 |
+
arg_overrides = {
|
103 |
+
"dropout": cfg.dropout,
|
104 |
+
"activation_dropout": cfg.activation_dropout,
|
105 |
+
"dropout_input": cfg.dropout_input,
|
106 |
+
"attention_dropout": cfg.attention_dropout,
|
107 |
+
"mask_length": cfg.mask_length,
|
108 |
+
"mask_prob": cfg.mask_prob,
|
109 |
+
"mask_selection": cfg.mask_selection,
|
110 |
+
"mask_other": cfg.mask_other,
|
111 |
+
"no_mask_overlap": cfg.no_mask_overlap,
|
112 |
+
"mask_channel_length": cfg.mask_channel_length,
|
113 |
+
"mask_channel_prob": cfg.mask_channel_prob,
|
114 |
+
"mask_channel_selection": cfg.mask_channel_selection,
|
115 |
+
"mask_channel_other": cfg.mask_channel_other,
|
116 |
+
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
|
117 |
+
"encoder_layerdrop": cfg.layerdrop,
|
118 |
+
"decoder_layerdrop": cfg.decoder_layerdrop,
|
119 |
+
"feature_grad_mult": cfg.feature_grad_mult,
|
120 |
+
"decoder_dict_size": len(tgt_dict) if cfg.add_decoder else -1,
|
121 |
+
}
|
122 |
+
|
123 |
+
if cfg.w2v_args is None:
|
124 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
|
125 |
+
w2v_args = state.get("cfg", None)
|
126 |
+
if w2v_args is None:
|
127 |
+
w2v_args = convert_namespace_to_omegaconf(state["args"])
|
128 |
+
cfg.w2v_args = w2v_args
|
129 |
+
else:
|
130 |
+
state = None
|
131 |
+
w2v_args = cfg.w2v_args
|
132 |
+
if isinstance(w2v_args, Namespace):
|
133 |
+
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
|
134 |
+
|
135 |
+
assert cfg.normalize == w2v_args.task.normalize, (
|
136 |
+
"Fine-tuning works best when data normalization is the same. "
|
137 |
+
"Please check that --normalize is set or unset for "
|
138 |
+
"both pre-training and here"
|
139 |
+
)
|
140 |
+
|
141 |
+
w2v_args.task.data = cfg.data
|
142 |
+
w2v_args.task.add_decoder = cfg.add_decoder
|
143 |
+
task = tasks.setup_task(w2v_args.task)
|
144 |
+
if state is not None and "task_state" in state:
|
145 |
+
# This will load the stored "dictionaries" object
|
146 |
+
task.load_state_dict(state["task_state"])
|
147 |
+
model = task.build_model(w2v_args.model)
|
148 |
+
|
149 |
+
if state is not None and not cfg.no_pretrained_weights:
|
150 |
+
if "decoder.embed_tokens.weight" in state["model"]:
|
151 |
+
del state["model"]["decoder.embed_tokens.weight"]
|
152 |
+
if "decoder.output_projection.weight" in state["model"]:
|
153 |
+
del state["model"]["decoder.output_projection.weight"]
|
154 |
+
# set strict=False because we omit some modules
|
155 |
+
model.load_state_dict(state["model"], strict=False)
|
156 |
+
|
157 |
+
model.remove_pretraining_modules()
|
158 |
+
|
159 |
+
super().__init__(task.source_dictionary)
|
160 |
+
|
161 |
+
d = model.mask_emb.size(0)
|
162 |
+
|
163 |
+
self.w2v_model = model
|
164 |
+
|
165 |
+
self.final_dropout = nn.Dropout(cfg.final_dropout)
|
166 |
+
self.freeze_finetune_updates = cfg.freeze_finetune_updates
|
167 |
+
self.num_updates = 0
|
168 |
+
|
169 |
+
if tgt_dict is not None:
|
170 |
+
self.proj = Linear(d, len(tgt_dict))
|
171 |
+
elif getattr(cfg, "decoder_embed_dim", d) != d:
|
172 |
+
self.proj = Linear(d, cfg.decoder_embed_dim)
|
173 |
+
else:
|
174 |
+
self.proj = None
|
175 |
+
|
176 |
+
def set_num_updates(self, num_updates):
|
177 |
+
"""Set the number of parameters updates."""
|
178 |
+
super().set_num_updates(num_updates)
|
179 |
+
self.num_updates = num_updates
|
180 |
+
|
181 |
+
def forward(self, source, padding_mask, prev_output_tokens=None, tbc=True, **kwargs):
|
182 |
+
|
183 |
+
ft = self.freeze_finetune_updates <= self.num_updates
|
184 |
+
w2v_args = {
|
185 |
+
"source": source,
|
186 |
+
"padding_mask": padding_mask,
|
187 |
+
"mask": self.apply_mask and self.training,
|
188 |
+
"prev_output_tokens": prev_output_tokens,
|
189 |
+
"ft": ft,
|
190 |
+
}
|
191 |
+
|
192 |
+
x, padding_mask, decoder_out = self.w2v_model.extract_features(**w2v_args)
|
193 |
+
|
194 |
+
if tbc:
|
195 |
+
# B x T x C -> T x B x C
|
196 |
+
x = x.transpose(0, 1)
|
197 |
+
|
198 |
+
x = self.final_dropout(x)
|
199 |
+
|
200 |
+
if self.proj:
|
201 |
+
x = self.proj(x)
|
202 |
+
|
203 |
+
return {
|
204 |
+
"encoder_out": x, # T x B x C
|
205 |
+
"encoder_padding_mask": padding_mask, # B x T
|
206 |
+
"padding_mask": padding_mask,
|
207 |
+
"decoder_out": decoder_out,
|
208 |
+
}
|
209 |
+
|
210 |
+
def get_normalized_probs_decoder(self, net_output, log_probs, sample=None):
|
211 |
+
# net_output['encoder_out'] is a (B, T, D) tensor
|
212 |
+
return self.w2v_model.get_normalized_probs(net_output, log_probs, sample)
|
213 |
+
|
214 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
215 |
+
if encoder_out["encoder_out"] is not None:
|
216 |
+
if isinstance(encoder_out["encoder_out"], list):
|
217 |
+
encoder_out["encoder_out"] = (
|
218 |
+
[] if len(encoder_out["encoder_out"]) == 0
|
219 |
+
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
encoder_out["encoder_out"] = encoder_out[
|
223 |
+
"encoder_out"
|
224 |
+
].index_select(1, new_order)
|
225 |
+
if encoder_out["encoder_padding_mask"] is not None:
|
226 |
+
if isinstance(encoder_out["encoder_padding_mask"], list):
|
227 |
+
encoder_out["encoder_padding_mask"] = (
|
228 |
+
[] if len(encoder_out["encoder_padding_mask"]) == 0
|
229 |
+
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
encoder_out["encoder_padding_mask"] = encoder_out[
|
233 |
+
"encoder_padding_mask"
|
234 |
+
].index_select(0, new_order)
|
235 |
+
if "decoder_out" in encoder_out and encoder_out["decoder_out"] is not None:
|
236 |
+
if isinstance(encoder_out["decoder_out"], list):
|
237 |
+
encoder_out["decoder_out"] = (
|
238 |
+
[] if len(encoder_out["decoder_out"]) == 0
|
239 |
+
else [x.index_select(0, new_order) for x in encoder_out["decoder_out"]]
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
encoder_out["decoder_out"] = encoder_out[
|
243 |
+
"decoder_out"
|
244 |
+
].index_select(0, new_order)
|
245 |
+
if "encoder_out_for_ctc" in encoder_out and encoder_out["encoder_out_for_ctc"] is not None:
|
246 |
+
if isinstance(encoder_out["encoder_out_for_ctc"], list):
|
247 |
+
encoder_out["encoder_out_for_ctc"] = (
|
248 |
+
[] if len(encoder_out["encoder_out_for_ctc"]) == 0
|
249 |
+
else [x.index_select(1, new_order) for x in encoder_out["encoder_out_for_ctc"]]
|
250 |
+
)
|
251 |
+
else:
|
252 |
+
encoder_out["encoder_out_for_ctc"] = encoder_out[
|
253 |
+
"encoder_out_for_ctc"
|
254 |
+
].index_select(1, new_order)
|
255 |
+
|
256 |
+
return encoder_out
|
257 |
+
|
258 |
+
def forward_torchscript(self, net_input):
|
259 |
+
"""A TorchScript-compatible version of forward.
|
260 |
+
Encoders which use additional arguments may want to override
|
261 |
+
this method for TorchScript compatibility.
|
262 |
+
"""
|
263 |
+
encoder_out = self.w2v_model.forward_torchscript(net_input)
|
264 |
+
|
265 |
+
assert self.proj is not None
|
266 |
+
encoder_out['encoder_out_for_ctc'] = [self.proj(encoder_out['encoder_out'][0])]
|
267 |
+
|
268 |
+
return encoder_out
|
269 |
+
|
270 |
+
def max_positions(self):
|
271 |
+
"""Maximum input length supported by the encoder."""
|
272 |
+
return None
|
273 |
+
|
274 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
275 |
+
return state_dict
|
276 |
+
|
SpeechT5/Speech2C/speech2c/models/t5_transformer_lm.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
from fairseq.models import (
|
11 |
+
register_model_architecture,
|
12 |
+
)
|
13 |
+
from fairseq.models.transformer_lm import base_lm_architecture
|
14 |
+
|
15 |
+
|
16 |
+
@register_model_architecture(model_name="transformer_lm", arch_name="transformer_lm_t5")
|
17 |
+
def transformer_lm_t5(args):
|
18 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280)
|
19 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6144)
|
20 |
+
args.decoder_layers = getattr(args, "decoder_layers", 20)
|
21 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
22 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
23 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
24 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
25 |
+
base_lm_architecture(args)
|
SpeechT5/Speech2C/speech2c/squence_generator.py
ADDED
@@ -0,0 +1,1028 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
from typing import Dict, List, Optional
|
12 |
+
import sys
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from fairseq import search, utils
|
17 |
+
from fairseq.data import data_utils
|
18 |
+
from fairseq.models import FairseqIncrementalDecoder
|
19 |
+
from torch import Tensor
|
20 |
+
from fairseq.ngram_repeat_block import NGramRepeatBlock
|
21 |
+
from speech2c.models.modules.ctc_prefix_score import CTCPrefixScore
|
22 |
+
import numpy
|
23 |
+
|
24 |
+
|
25 |
+
CTC_SCORING_RATIO = 7.0
|
26 |
+
|
27 |
+
class SequenceGenerator(nn.Module):
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
models,
|
31 |
+
tgt_dict,
|
32 |
+
beam_size=1,
|
33 |
+
max_len_a=0,
|
34 |
+
max_len_b=200,
|
35 |
+
max_len=0,
|
36 |
+
min_len=1,
|
37 |
+
normalize_scores=True,
|
38 |
+
len_penalty=1.0,
|
39 |
+
unk_penalty=0.0,
|
40 |
+
temperature=1.0,
|
41 |
+
match_source_len=False,
|
42 |
+
no_repeat_ngram_size=0,
|
43 |
+
search_strategy=None,
|
44 |
+
eos=None,
|
45 |
+
symbols_to_strip_from_output=None,
|
46 |
+
lm_model=None,
|
47 |
+
lm_weight=1.0,
|
48 |
+
ctc_weight=0.0,
|
49 |
+
):
|
50 |
+
"""Generates translations of a given source sentence.
|
51 |
+
Args:
|
52 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models,
|
53 |
+
currently support fairseq.models.TransformerModel for scripting
|
54 |
+
beam_size (int, optional): beam width (default: 1)
|
55 |
+
max_len_a/b (int, optional): generate sequences of maximum length
|
56 |
+
ax + b, where x is the source length
|
57 |
+
max_len (int, optional): the maximum length of the generated output
|
58 |
+
(not including end-of-sentence)
|
59 |
+
min_len (int, optional): the minimum length of the generated output
|
60 |
+
(not including end-of-sentence)
|
61 |
+
normalize_scores (bool, optional): normalize scores by the length
|
62 |
+
of the output (default: True)
|
63 |
+
len_penalty (float, optional): length penalty, where <1.0 favors
|
64 |
+
shorter, >1.0 favors longer sentences (default: 1.0)
|
65 |
+
unk_penalty (float, optional): unknown word penalty, where <0
|
66 |
+
produces more unks, >0 produces fewer (default: 0.0)
|
67 |
+
temperature (float, optional): temperature, where values
|
68 |
+
>1.0 produce more uniform samples and values <1.0 produce
|
69 |
+
sharper samples (default: 1.0)
|
70 |
+
match_source_len (bool, optional): outputs should match the source
|
71 |
+
length (default: False)
|
72 |
+
"""
|
73 |
+
super().__init__()
|
74 |
+
if isinstance(models, EnsembleModel):
|
75 |
+
self.model = models
|
76 |
+
else:
|
77 |
+
self.model = EnsembleModel(models)
|
78 |
+
self.tgt_dict = tgt_dict
|
79 |
+
self.pad = tgt_dict.pad()
|
80 |
+
self.unk = tgt_dict.unk()
|
81 |
+
self.eos = tgt_dict.eos() if eos is None else eos
|
82 |
+
self.blank = self.tgt_dict.index("<s>")
|
83 |
+
self.symbols_to_strip_from_output = (
|
84 |
+
symbols_to_strip_from_output.union({self.eos})
|
85 |
+
if symbols_to_strip_from_output is not None
|
86 |
+
else {self.eos}
|
87 |
+
)
|
88 |
+
self.vocab_size = len(tgt_dict)
|
89 |
+
self.beam_size = beam_size
|
90 |
+
# the max beam size is the dictionary size - 1, since we never select pad
|
91 |
+
self.beam_size = min(beam_size, self.vocab_size - 1)
|
92 |
+
self.max_len_a = max_len_a
|
93 |
+
self.max_len_b = max_len_b
|
94 |
+
self.min_len = min_len
|
95 |
+
self.max_len = max_len or self.model.max_decoder_positions()
|
96 |
+
|
97 |
+
self.normalize_scores = normalize_scores
|
98 |
+
self.len_penalty = len_penalty
|
99 |
+
self.unk_penalty = unk_penalty
|
100 |
+
self.temperature = temperature
|
101 |
+
self.match_source_len = match_source_len
|
102 |
+
|
103 |
+
if no_repeat_ngram_size > 0:
|
104 |
+
self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
|
105 |
+
else:
|
106 |
+
self.repeat_ngram_blocker = None
|
107 |
+
|
108 |
+
assert temperature > 0, "--temperature must be greater than 0"
|
109 |
+
|
110 |
+
self.search = (
|
111 |
+
search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
|
112 |
+
)
|
113 |
+
# We only need to set src_lengths in LengthConstrainedBeamSearch.
|
114 |
+
# As a module attribute, setting it would break in multithread
|
115 |
+
# settings when the model is shared.
|
116 |
+
self.should_set_src_lengths = (
|
117 |
+
hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
|
118 |
+
)
|
119 |
+
|
120 |
+
self.model.eval()
|
121 |
+
|
122 |
+
self.lm_model = lm_model
|
123 |
+
self.lm_weight = lm_weight
|
124 |
+
self.ctc_weight = ctc_weight
|
125 |
+
if self.lm_model is not None:
|
126 |
+
self.lm_model.eval()
|
127 |
+
|
128 |
+
def cuda(self):
|
129 |
+
self.model.cuda()
|
130 |
+
return self
|
131 |
+
|
132 |
+
@torch.no_grad()
|
133 |
+
def forward(
|
134 |
+
self,
|
135 |
+
sample: Dict[str, Dict[str, Tensor]],
|
136 |
+
prefix_tokens: Optional[Tensor] = None,
|
137 |
+
bos_token: Optional[int] = None,
|
138 |
+
):
|
139 |
+
"""Generate a batch of translations.
|
140 |
+
Args:
|
141 |
+
sample (dict): batch
|
142 |
+
prefix_tokens (torch.LongTensor, optional): force decoder to begin
|
143 |
+
with these tokens
|
144 |
+
bos_token (int, optional): beginning of sentence token
|
145 |
+
(default: self.eos)
|
146 |
+
"""
|
147 |
+
return self._generate(sample, prefix_tokens, bos_token=bos_token)
|
148 |
+
|
149 |
+
# TODO(myleott): unused, deprecate after pytorch-translate migration
|
150 |
+
def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
|
151 |
+
"""Iterate over a batched dataset and yield individual translations.
|
152 |
+
Args:
|
153 |
+
cuda (bool, optional): use GPU for generation
|
154 |
+
timer (StopwatchMeter, optional): time generations
|
155 |
+
"""
|
156 |
+
for sample in data_itr:
|
157 |
+
s = utils.move_to_cuda(sample) if cuda else sample
|
158 |
+
if "net_input" not in s:
|
159 |
+
continue
|
160 |
+
input = s["net_input"]
|
161 |
+
# model.forward normally channels prev_output_tokens into the decoder
|
162 |
+
# separately, but SequenceGenerator directly calls model.encoder
|
163 |
+
encoder_input = {
|
164 |
+
k: v for k, v in input.items() if k != "prev_output_tokens"
|
165 |
+
}
|
166 |
+
if timer is not None:
|
167 |
+
timer.start()
|
168 |
+
with torch.no_grad():
|
169 |
+
hypos = self.generate(encoder_input)
|
170 |
+
if timer is not None:
|
171 |
+
timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
|
172 |
+
for i, id in enumerate(s["id"].data):
|
173 |
+
# remove padding
|
174 |
+
src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
|
175 |
+
ref = (
|
176 |
+
utils.strip_pad(s["target"].data[i, :], self.pad)
|
177 |
+
if s["target"] is not None
|
178 |
+
else None
|
179 |
+
)
|
180 |
+
yield id, src, ref, hypos[i]
|
181 |
+
|
182 |
+
@torch.no_grad()
|
183 |
+
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
|
184 |
+
"""Generate translations. Match the api of other fairseq generators.
|
185 |
+
Args:
|
186 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models
|
187 |
+
sample (dict): batch
|
188 |
+
prefix_tokens (torch.LongTensor, optional): force decoder to begin
|
189 |
+
with these tokens
|
190 |
+
constraints (torch.LongTensor, optional): force decoder to include
|
191 |
+
the list of constraints
|
192 |
+
bos_token (int, optional): beginning of sentence token
|
193 |
+
(default: self.eos)
|
194 |
+
"""
|
195 |
+
return self._generate(sample, **kwargs)
|
196 |
+
|
197 |
+
def _generate(
|
198 |
+
self,
|
199 |
+
sample: Dict[str, Dict[str, Tensor]],
|
200 |
+
prefix_tokens: Optional[Tensor] = None,
|
201 |
+
constraints: Optional[Tensor] = None,
|
202 |
+
bos_token: Optional[int] = None,
|
203 |
+
):
|
204 |
+
incremental_states = torch.jit.annotate(
|
205 |
+
List[Dict[str, Dict[str, Optional[Tensor]]]],
|
206 |
+
[
|
207 |
+
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
|
208 |
+
for i in range(self.model.models_size)
|
209 |
+
],
|
210 |
+
)
|
211 |
+
net_input = sample["net_input"]
|
212 |
+
|
213 |
+
if "src_tokens" in net_input:
|
214 |
+
src_tokens = net_input["src_tokens"]
|
215 |
+
# length of the source text being the character length except EndOfSentence and pad
|
216 |
+
src_lengths = (
|
217 |
+
(src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
|
218 |
+
)
|
219 |
+
elif "source" in net_input:
|
220 |
+
src_tokens = net_input["source"]
|
221 |
+
src_lengths = (
|
222 |
+
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
|
223 |
+
if net_input["padding_mask"] is not None
|
224 |
+
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
225 |
+
)
|
226 |
+
elif "features" in net_input:
|
227 |
+
src_tokens = net_input["features"]
|
228 |
+
src_lengths = (
|
229 |
+
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
|
230 |
+
if net_input["padding_mask"] is not None
|
231 |
+
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
232 |
+
)
|
233 |
+
else:
|
234 |
+
raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys()))
|
235 |
+
|
236 |
+
# bsz: total number of sentences in beam
|
237 |
+
# Note that src_tokens may have more than 2 dimensions (i.e. audio features)
|
238 |
+
bsz, src_len = src_tokens.size()[:2]
|
239 |
+
beam_size = self.beam_size
|
240 |
+
|
241 |
+
if constraints is not None and not self.search.supports_constraints:
|
242 |
+
raise NotImplementedError(
|
243 |
+
"Target-side constraints were provided, but search method doesn't support them"
|
244 |
+
)
|
245 |
+
|
246 |
+
# Initialize constraints, when active
|
247 |
+
self.search.init_constraints(constraints, beam_size)
|
248 |
+
|
249 |
+
max_len: int = -1
|
250 |
+
if self.match_source_len:
|
251 |
+
max_len = src_lengths.max().item()
|
252 |
+
else:
|
253 |
+
max_len = min(
|
254 |
+
int(self.max_len_a * src_len + self.max_len_b),
|
255 |
+
self.max_len - 1,
|
256 |
+
)
|
257 |
+
assert (
|
258 |
+
self.min_len <= max_len
|
259 |
+
), "min_len cannot be larger than max_len, please adjust these!"
|
260 |
+
# compute the encoder output for each beam
|
261 |
+
with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
|
262 |
+
encoder_outs = self.model.forward_encoder(net_input)
|
263 |
+
|
264 |
+
# Get CTC lprobs and prep ctc_scorer
|
265 |
+
if self.ctc_weight > 0:
|
266 |
+
ctc_lprobs = self.model.models[0].get_normalized_probs(
|
267 |
+
encoder_outs[0], log_probs=True
|
268 |
+
).contiguous().transpose(0, 1) # (B, T, C) from the encoder
|
269 |
+
|
270 |
+
hyp = {}
|
271 |
+
ctc_prefix_score = CTCPrefixScore(ctc_lprobs[0].detach().cpu().numpy(), self.blank, self.eos, numpy)
|
272 |
+
hyp["ctc_state_prev"] = ctc_prefix_score.initial_state()
|
273 |
+
hyp["ctc_score_prev"] = 0.0
|
274 |
+
ctc_beam = min(ctc_lprobs.shape[-1], int(beam_size * CTC_SCORING_RATIO))
|
275 |
+
ctc_hyps = {str(self.eos): hyp}
|
276 |
+
|
277 |
+
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
|
278 |
+
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
|
279 |
+
new_order = new_order.to(src_tokens.device).long()
|
280 |
+
encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
|
281 |
+
# ensure encoder_outs is a List.
|
282 |
+
assert encoder_outs is not None
|
283 |
+
|
284 |
+
# initialize buffers
|
285 |
+
scores = (
|
286 |
+
torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
|
287 |
+
) # +1 for eos; pad is never chosen for scoring
|
288 |
+
tokens = (
|
289 |
+
torch.zeros(bsz * beam_size, max_len + 2)
|
290 |
+
.to(src_tokens)
|
291 |
+
.long()
|
292 |
+
.fill_(self.pad)
|
293 |
+
) # +2 for eos and pad
|
294 |
+
tokens[:, 0] = self.eos if bos_token is None else bos_token
|
295 |
+
attn: Optional[Tensor] = None
|
296 |
+
|
297 |
+
# A list that indicates candidates that should be ignored.
|
298 |
+
# For example, suppose we're sampling and have already finalized 2/5
|
299 |
+
# samples. Then cands_to_ignore would mark 2 positions as being ignored,
|
300 |
+
# so that we only finalize the remaining 3 samples.
|
301 |
+
cands_to_ignore = (
|
302 |
+
torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
|
303 |
+
) # forward and backward-compatible False mask
|
304 |
+
|
305 |
+
# list of completed sentences
|
306 |
+
finalized = torch.jit.annotate(
|
307 |
+
List[List[Dict[str, Tensor]]],
|
308 |
+
[torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
|
309 |
+
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
|
310 |
+
|
311 |
+
# a boolean array indicating if the sentence at the index is finished or not
|
312 |
+
finished = [False for i in range(bsz)]
|
313 |
+
num_remaining_sent = bsz # number of sentences remaining
|
314 |
+
|
315 |
+
# number of candidate hypos per step
|
316 |
+
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
|
317 |
+
|
318 |
+
# offset arrays for converting between different indexing schemes
|
319 |
+
bbsz_offsets = (
|
320 |
+
(torch.arange(0, bsz) * beam_size)
|
321 |
+
.unsqueeze(1)
|
322 |
+
.type_as(tokens)
|
323 |
+
.to(src_tokens.device)
|
324 |
+
)
|
325 |
+
cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device)
|
326 |
+
|
327 |
+
reorder_state: Optional[Tensor] = None
|
328 |
+
batch_idxs: Optional[Tensor] = None
|
329 |
+
|
330 |
+
original_batch_idxs: Optional[Tensor] = None
|
331 |
+
if "id" in sample and isinstance(sample["id"], Tensor):
|
332 |
+
original_batch_idxs = sample["id"]
|
333 |
+
else:
|
334 |
+
original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
|
335 |
+
|
336 |
+
for step in range(max_len + 1): # one extra step for EOS marker
|
337 |
+
# reorder decoder internal states based on the prev choice of beams
|
338 |
+
if reorder_state is not None:
|
339 |
+
if batch_idxs is not None:
|
340 |
+
# update beam indices to take into account removed sentences
|
341 |
+
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
|
342 |
+
batch_idxs
|
343 |
+
)
|
344 |
+
reorder_state.view(-1, beam_size).add_(
|
345 |
+
corr.unsqueeze(-1) * beam_size
|
346 |
+
)
|
347 |
+
original_batch_idxs = original_batch_idxs[batch_idxs]
|
348 |
+
self.model.reorder_incremental_state(incremental_states, reorder_state)
|
349 |
+
encoder_outs = self.model.reorder_encoder_out(
|
350 |
+
encoder_outs, reorder_state
|
351 |
+
)
|
352 |
+
with torch.autograd.profiler.record_function("EnsembleModel: forward_decoder"):
|
353 |
+
lprobs, avg_attn_scores = self.model.forward_decoder(
|
354 |
+
tokens[:, : step + 1],
|
355 |
+
encoder_outs,
|
356 |
+
incremental_states,
|
357 |
+
self.temperature,
|
358 |
+
)
|
359 |
+
|
360 |
+
if self.ctc_weight > 0 and step != 0:
|
361 |
+
# lprobs[:, self.blank] = -math.inf # never select blank
|
362 |
+
ctc_lprobs = lprobs.clone()
|
363 |
+
ctc_lprobs[:, self.blank] = -math.inf # never select blank
|
364 |
+
_, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1)
|
365 |
+
for b in range(tokens.size(0)):
|
366 |
+
hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist())
|
367 |
+
ctc_scores, ctc_states = ctc_prefix_score(
|
368 |
+
tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"]
|
369 |
+
)
|
370 |
+
lprobs[b] = lprobs[b]
|
371 |
+
lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy(
|
372 |
+
ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"]
|
373 |
+
).to(device="cuda")
|
374 |
+
for j in range(len(local_best_ids[b])):
|
375 |
+
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {}
|
376 |
+
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j]
|
377 |
+
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j]
|
378 |
+
|
379 |
+
elif self.ctc_weight > 0 and step == 0:
|
380 |
+
ctc_lprobs = lprobs.clone()
|
381 |
+
ctc_lprobs[:, self.blank] = -math.inf # never select blank
|
382 |
+
_, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1)
|
383 |
+
for b in range(tokens.size(0)):
|
384 |
+
hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist())
|
385 |
+
ctc_scores, ctc_states = ctc_prefix_score(
|
386 |
+
tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"]
|
387 |
+
)
|
388 |
+
lprobs[b] = lprobs[b]
|
389 |
+
lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy(
|
390 |
+
ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"]
|
391 |
+
).to(device="cuda")
|
392 |
+
for j in range(len(local_best_ids[b])):
|
393 |
+
if b == 0:
|
394 |
+
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {}
|
395 |
+
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j]
|
396 |
+
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j]
|
397 |
+
|
398 |
+
if self.lm_model is not None:
|
399 |
+
lm_out = self.lm_model(tokens[:, : step + 1])
|
400 |
+
probs = self.lm_model.get_normalized_probs(
|
401 |
+
lm_out, log_probs=True, sample=None
|
402 |
+
)
|
403 |
+
probs = probs[:, -1, :] * self.lm_weight
|
404 |
+
lprobs += probs
|
405 |
+
# handle prefix tokens (possibly with different lengths)
|
406 |
+
if (
|
407 |
+
prefix_tokens is not None
|
408 |
+
and step < prefix_tokens.size(1)
|
409 |
+
and step < max_len
|
410 |
+
):
|
411 |
+
lprobs, tokens, scores = self._prefix_tokens(
|
412 |
+
step, lprobs, scores, tokens, prefix_tokens, beam_size
|
413 |
+
)
|
414 |
+
elif step < self.min_len:
|
415 |
+
# minimum length constraint (does not apply if using prefix_tokens)
|
416 |
+
lprobs[:, self.eos] = -math.inf
|
417 |
+
|
418 |
+
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
|
419 |
+
|
420 |
+
lprobs[:, self.pad] = -math.inf # never select pad
|
421 |
+
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
|
422 |
+
lprobs[:, self.blank] = -math.inf # never select blank
|
423 |
+
|
424 |
+
# handle max length constraint
|
425 |
+
if step >= max_len:
|
426 |
+
lprobs[:, : self.eos] = -math.inf
|
427 |
+
lprobs[:, self.eos + 1 :] = -math.inf
|
428 |
+
|
429 |
+
# Record attention scores, only support avg_attn_scores is a Tensor
|
430 |
+
if avg_attn_scores is not None:
|
431 |
+
if attn is None:
|
432 |
+
attn = torch.empty(
|
433 |
+
bsz * beam_size, avg_attn_scores.size(1), max_len + 2
|
434 |
+
).to(scores)
|
435 |
+
attn[:, :, step + 1].copy_(avg_attn_scores)
|
436 |
+
|
437 |
+
scores = scores.type_as(lprobs)
|
438 |
+
eos_bbsz_idx = torch.empty(0).to(
|
439 |
+
tokens
|
440 |
+
) # indices of hypothesis ending with eos (finished sentences)
|
441 |
+
eos_scores = torch.empty(0).to(
|
442 |
+
scores
|
443 |
+
) # scores of hypothesis ending with eos (finished sentences)
|
444 |
+
|
445 |
+
if self.should_set_src_lengths:
|
446 |
+
self.search.set_src_lengths(src_lengths)
|
447 |
+
|
448 |
+
if self.repeat_ngram_blocker is not None:
|
449 |
+
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
|
450 |
+
|
451 |
+
# Shape: (batch, cand_size)
|
452 |
+
cand_scores, cand_indices, cand_beams = self.search.step(
|
453 |
+
step,
|
454 |
+
lprobs.view(bsz, -1, self.vocab_size),
|
455 |
+
scores.view(bsz, beam_size, -1)[:, :, :step],
|
456 |
+
tokens[:, : step + 1],
|
457 |
+
original_batch_idxs,
|
458 |
+
)
|
459 |
+
|
460 |
+
# cand_bbsz_idx contains beam indices for the top candidate
|
461 |
+
# hypotheses, with a range of values: [0, bsz*beam_size),
|
462 |
+
# and dimensions: [bsz, cand_size]
|
463 |
+
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
464 |
+
|
465 |
+
# finalize hypotheses that end in eos
|
466 |
+
# Shape of eos_mask: (batch size, beam size)
|
467 |
+
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
|
468 |
+
eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
|
469 |
+
|
470 |
+
# only consider eos when it's among the top beam_size indices
|
471 |
+
# Now we know what beam item(s) to finish
|
472 |
+
# Shape: 1d list of absolute-numbered
|
473 |
+
eos_bbsz_idx = torch.masked_select(
|
474 |
+
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
|
475 |
+
)
|
476 |
+
|
477 |
+
finalized_sents: List[int] = []
|
478 |
+
if eos_bbsz_idx.numel() > 0:
|
479 |
+
eos_scores = torch.masked_select(
|
480 |
+
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
|
481 |
+
)
|
482 |
+
|
483 |
+
finalized_sents = self.finalize_hypos(
|
484 |
+
step,
|
485 |
+
eos_bbsz_idx,
|
486 |
+
eos_scores,
|
487 |
+
tokens,
|
488 |
+
scores,
|
489 |
+
finalized,
|
490 |
+
finished,
|
491 |
+
beam_size,
|
492 |
+
attn,
|
493 |
+
src_lengths,
|
494 |
+
max_len,
|
495 |
+
)
|
496 |
+
num_remaining_sent -= len(finalized_sents)
|
497 |
+
|
498 |
+
assert num_remaining_sent >= 0
|
499 |
+
if num_remaining_sent == 0:
|
500 |
+
break
|
501 |
+
if self.search.stop_on_max_len and step >= max_len:
|
502 |
+
break
|
503 |
+
assert step < max_len, f"{step} < {max_len}"
|
504 |
+
|
505 |
+
# Remove finalized sentences (ones for which {beam_size}
|
506 |
+
# finished hypotheses have been generated) from the batch.
|
507 |
+
if len(finalized_sents) > 0:
|
508 |
+
new_bsz = bsz - len(finalized_sents)
|
509 |
+
|
510 |
+
# construct batch_idxs which holds indices of batches to keep for the next pass
|
511 |
+
batch_mask = torch.ones(
|
512 |
+
bsz, dtype=torch.bool, device=cand_indices.device
|
513 |
+
)
|
514 |
+
batch_mask[finalized_sents] = False
|
515 |
+
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
|
516 |
+
batch_idxs = torch.arange(
|
517 |
+
bsz, device=cand_indices.device
|
518 |
+
).masked_select(batch_mask)
|
519 |
+
|
520 |
+
# Choose the subset of the hypothesized constraints that will continue
|
521 |
+
self.search.prune_sentences(batch_idxs)
|
522 |
+
|
523 |
+
eos_mask = eos_mask[batch_idxs]
|
524 |
+
cand_beams = cand_beams[batch_idxs]
|
525 |
+
bbsz_offsets.resize_(new_bsz, 1)
|
526 |
+
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
527 |
+
cand_scores = cand_scores[batch_idxs]
|
528 |
+
cand_indices = cand_indices[batch_idxs]
|
529 |
+
|
530 |
+
if prefix_tokens is not None:
|
531 |
+
prefix_tokens = prefix_tokens[batch_idxs]
|
532 |
+
src_lengths = src_lengths[batch_idxs]
|
533 |
+
cands_to_ignore = cands_to_ignore[batch_idxs]
|
534 |
+
|
535 |
+
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
536 |
+
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
537 |
+
if attn is not None:
|
538 |
+
attn = attn.view(bsz, -1)[batch_idxs].view(
|
539 |
+
new_bsz * beam_size, attn.size(1), -1
|
540 |
+
)
|
541 |
+
bsz = new_bsz
|
542 |
+
else:
|
543 |
+
batch_idxs = None
|
544 |
+
|
545 |
+
# Set active_mask so that values > cand_size indicate eos hypos
|
546 |
+
# and values < cand_size indicate candidate active hypos.
|
547 |
+
# After, the min values per row are the top candidate active hypos
|
548 |
+
|
549 |
+
# Rewrite the operator since the element wise or is not supported in torchscript.
|
550 |
+
|
551 |
+
eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
|
552 |
+
active_mask = torch.add(
|
553 |
+
eos_mask.type_as(cand_offsets) * cand_size,
|
554 |
+
cand_offsets[: eos_mask.size(1)],
|
555 |
+
)
|
556 |
+
|
557 |
+
# get the top beam_size active hypotheses, which are just
|
558 |
+
# the hypos with the smallest values in active_mask.
|
559 |
+
# {active_hypos} indicates which {beam_size} hypotheses
|
560 |
+
# from the list of {2 * beam_size} candidates were
|
561 |
+
# selected. Shapes: (batch size, beam size)
|
562 |
+
new_cands_to_ignore, active_hypos = torch.topk(
|
563 |
+
active_mask, k=beam_size, dim=1, largest=False
|
564 |
+
)
|
565 |
+
|
566 |
+
# update cands_to_ignore to ignore any finalized hypos.
|
567 |
+
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
|
568 |
+
# Make sure there is at least one active item for each sentence in the batch.
|
569 |
+
assert (~cands_to_ignore).any(dim=1).all()
|
570 |
+
|
571 |
+
# update cands_to_ignore to ignore any finalized hypos
|
572 |
+
|
573 |
+
# {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
|
574 |
+
# can be selected more than once).
|
575 |
+
active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
|
576 |
+
active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
|
577 |
+
|
578 |
+
active_bbsz_idx = active_bbsz_idx.view(-1)
|
579 |
+
active_scores = active_scores.view(-1)
|
580 |
+
|
581 |
+
# copy tokens and scores for active hypotheses
|
582 |
+
|
583 |
+
# Set the tokens for each beam (can select the same row more than once)
|
584 |
+
tokens[:, : step + 1] = torch.index_select(
|
585 |
+
tokens[:, : step + 1], dim=0, index=active_bbsz_idx
|
586 |
+
)
|
587 |
+
# Select the next token for each of them
|
588 |
+
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
|
589 |
+
cand_indices, dim=1, index=active_hypos
|
590 |
+
)
|
591 |
+
if step > 0:
|
592 |
+
scores[:, :step] = torch.index_select(
|
593 |
+
scores[:, :step], dim=0, index=active_bbsz_idx
|
594 |
+
)
|
595 |
+
scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
|
596 |
+
cand_scores, dim=1, index=active_hypos
|
597 |
+
)
|
598 |
+
|
599 |
+
# Update constraints based on which candidates were selected for the next beam
|
600 |
+
self.search.update_constraints(active_hypos)
|
601 |
+
|
602 |
+
# copy attention for active hypotheses
|
603 |
+
if attn is not None:
|
604 |
+
attn[:, :, : step + 2] = torch.index_select(
|
605 |
+
attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
|
606 |
+
)
|
607 |
+
|
608 |
+
# reorder incremental state in decoder
|
609 |
+
reorder_state = active_bbsz_idx
|
610 |
+
|
611 |
+
# sort by score descending
|
612 |
+
for sent in range(len(finalized)):
|
613 |
+
scores = torch.tensor(
|
614 |
+
[float(elem["score"].item()) for elem in finalized[sent]]
|
615 |
+
)
|
616 |
+
_, sorted_scores_indices = torch.sort(scores, descending=True)
|
617 |
+
finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
|
618 |
+
finalized[sent] = torch.jit.annotate(
|
619 |
+
List[Dict[str, Tensor]], finalized[sent]
|
620 |
+
)
|
621 |
+
return finalized
|
622 |
+
|
623 |
+
def _prefix_tokens(
|
624 |
+
self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
|
625 |
+
):
|
626 |
+
"""Handle prefix tokens"""
|
627 |
+
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
|
628 |
+
prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
|
629 |
+
prefix_mask = prefix_toks.ne(self.pad)
|
630 |
+
lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1
|
631 |
+
lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
|
632 |
+
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
|
633 |
+
)
|
634 |
+
# if prefix includes eos, then we should make sure tokens and
|
635 |
+
# scores are the same across all beams
|
636 |
+
eos_mask = prefix_toks.eq(self.eos)
|
637 |
+
if eos_mask.any():
|
638 |
+
# validate that the first beam matches the prefix
|
639 |
+
first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
|
640 |
+
:, 0, 1 : step + 1
|
641 |
+
]
|
642 |
+
eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
|
643 |
+
target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
|
644 |
+
assert (first_beam == target_prefix).all()
|
645 |
+
|
646 |
+
# copy tokens, scores and lprobs from the first beam to all beams
|
647 |
+
tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
|
648 |
+
scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
|
649 |
+
lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
|
650 |
+
return lprobs, tokens, scores
|
651 |
+
|
652 |
+
def replicate_first_beam(self, tensor, mask, beam_size: int):
|
653 |
+
tensor = tensor.view(-1, beam_size, tensor.size(-1))
|
654 |
+
tensor[mask] = tensor[mask][:, :1, :]
|
655 |
+
return tensor.view(-1, tensor.size(-1))
|
656 |
+
|
657 |
+
def finalize_hypos(
|
658 |
+
self,
|
659 |
+
step: int,
|
660 |
+
bbsz_idx,
|
661 |
+
eos_scores,
|
662 |
+
tokens,
|
663 |
+
scores,
|
664 |
+
finalized: List[List[Dict[str, Tensor]]],
|
665 |
+
finished: List[bool],
|
666 |
+
beam_size: int,
|
667 |
+
attn: Optional[Tensor],
|
668 |
+
src_lengths,
|
669 |
+
max_len: int,
|
670 |
+
):
|
671 |
+
"""Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
|
672 |
+
A sentence is finalized when {beam_size} finished items have been collected for it.
|
673 |
+
Returns number of sentences (not beam items) being finalized.
|
674 |
+
These will be removed from the batch and not processed further.
|
675 |
+
Args:
|
676 |
+
bbsz_idx (Tensor):
|
677 |
+
"""
|
678 |
+
assert bbsz_idx.numel() == eos_scores.numel()
|
679 |
+
|
680 |
+
# clone relevant token and attention tensors.
|
681 |
+
# tokens is (batch * beam, max_len). So the index_select
|
682 |
+
# gets the newly EOS rows, then selects cols 1..{step + 2}
|
683 |
+
tokens_clone = tokens.index_select(0, bbsz_idx)[
|
684 |
+
:, 1 : step + 2
|
685 |
+
] # skip the first index, which is EOS
|
686 |
+
|
687 |
+
tokens_clone[:, step] = self.eos
|
688 |
+
attn_clone = (
|
689 |
+
attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
|
690 |
+
if attn is not None
|
691 |
+
else None
|
692 |
+
)
|
693 |
+
|
694 |
+
# compute scores per token position
|
695 |
+
pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
|
696 |
+
pos_scores[:, step] = eos_scores
|
697 |
+
# convert from cumulative to per-position scores
|
698 |
+
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
|
699 |
+
|
700 |
+
# normalize sentence-level scores
|
701 |
+
if self.normalize_scores:
|
702 |
+
eos_scores /= (step + 1) ** self.len_penalty
|
703 |
+
|
704 |
+
# cum_unfin records which sentences in the batch are finished.
|
705 |
+
# It helps match indexing between (a) the original sentences
|
706 |
+
# in the batch and (b) the current, possibly-reduced set of
|
707 |
+
# sentences.
|
708 |
+
cum_unfin: List[int] = []
|
709 |
+
prev = 0
|
710 |
+
for f in finished:
|
711 |
+
if f:
|
712 |
+
prev += 1
|
713 |
+
else:
|
714 |
+
cum_unfin.append(prev)
|
715 |
+
cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx)
|
716 |
+
|
717 |
+
unfin_idx = bbsz_idx // beam_size
|
718 |
+
sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx)
|
719 |
+
|
720 |
+
# Create a set of "{sent}{unfin_idx}", where
|
721 |
+
# "unfin_idx" is the index in the current (possibly reduced)
|
722 |
+
# list of sentences, and "sent" is the index in the original,
|
723 |
+
# unreduced batch
|
724 |
+
# For every finished beam item
|
725 |
+
# sentence index in the current (possibly reduced) batch
|
726 |
+
seen = (sent << 32) + unfin_idx
|
727 |
+
unique_seen: List[int] = torch.unique(seen).tolist()
|
728 |
+
|
729 |
+
if self.match_source_len:
|
730 |
+
condition = step > torch.index_select(src_lengths, 0, unfin_idx)
|
731 |
+
eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores)
|
732 |
+
sent_list: List[int] = sent.tolist()
|
733 |
+
for i in range(bbsz_idx.size()[0]):
|
734 |
+
# An input sentence (among those in a batch) is finished when
|
735 |
+
# beam_size hypotheses have been collected for it
|
736 |
+
if len(finalized[sent_list[i]]) < beam_size:
|
737 |
+
if attn_clone is not None:
|
738 |
+
# remove padding tokens from attn scores
|
739 |
+
hypo_attn = attn_clone[i]
|
740 |
+
else:
|
741 |
+
hypo_attn = torch.empty(0)
|
742 |
+
|
743 |
+
finalized[sent_list[i]].append(
|
744 |
+
{
|
745 |
+
"tokens": tokens_clone[i],
|
746 |
+
"score": eos_scores[i],
|
747 |
+
"attention": hypo_attn, # src_len x tgt_len
|
748 |
+
"alignment": torch.empty(0),
|
749 |
+
"positional_scores": pos_scores[i],
|
750 |
+
}
|
751 |
+
)
|
752 |
+
|
753 |
+
newly_finished: List[int] = []
|
754 |
+
for unique_s in unique_seen:
|
755 |
+
# check termination conditions for this sentence
|
756 |
+
unique_sent: int = unique_s >> 32
|
757 |
+
unique_unfin_idx: int = unique_s - (unique_sent << 32)
|
758 |
+
|
759 |
+
if not finished[unique_sent] and self.is_finished(
|
760 |
+
step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size
|
761 |
+
):
|
762 |
+
finished[unique_sent] = True
|
763 |
+
newly_finished.append(unique_unfin_idx)
|
764 |
+
|
765 |
+
return newly_finished
|
766 |
+
|
767 |
+
def is_finished(
|
768 |
+
self,
|
769 |
+
step: int,
|
770 |
+
unfin_idx: int,
|
771 |
+
max_len: int,
|
772 |
+
finalized_sent_len: int,
|
773 |
+
beam_size: int,
|
774 |
+
):
|
775 |
+
"""
|
776 |
+
Check whether decoding for a sentence is finished, which
|
777 |
+
occurs when the list of finalized sentences has reached the
|
778 |
+
beam size, or when we reach the maximum length.
|
779 |
+
"""
|
780 |
+
assert finalized_sent_len <= beam_size
|
781 |
+
if finalized_sent_len == beam_size or step == max_len:
|
782 |
+
return True
|
783 |
+
return False
|
784 |
+
|
785 |
+
|
786 |
+
class EnsembleModel(nn.Module):
|
787 |
+
"""A wrapper around an ensemble of models."""
|
788 |
+
|
789 |
+
def __init__(self, models):
|
790 |
+
super().__init__()
|
791 |
+
self.models_size = len(models)
|
792 |
+
# method '__len__' is not supported in ModuleList for torch script
|
793 |
+
self.single_model = models[0]
|
794 |
+
self.models = nn.ModuleList(models)
|
795 |
+
|
796 |
+
self.has_incremental: bool = False
|
797 |
+
if all(
|
798 |
+
hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder)
|
799 |
+
for m in models
|
800 |
+
):
|
801 |
+
self.has_incremental = True
|
802 |
+
|
803 |
+
def forward(self):
|
804 |
+
pass
|
805 |
+
|
806 |
+
def has_encoder(self):
|
807 |
+
return hasattr(self.single_model, "encoder")
|
808 |
+
|
809 |
+
def has_incremental_states(self):
|
810 |
+
return self.has_incremental
|
811 |
+
|
812 |
+
def max_decoder_positions(self):
|
813 |
+
return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize])
|
814 |
+
|
815 |
+
@torch.jit.export
|
816 |
+
def forward_encoder(self, net_input: Dict[str, Tensor]):
|
817 |
+
if not self.has_encoder():
|
818 |
+
return None
|
819 |
+
return [model.encoder.forward_torchscript(net_input) for model in self.models]
|
820 |
+
|
821 |
+
@torch.jit.export
|
822 |
+
def forward_decoder(
|
823 |
+
self,
|
824 |
+
tokens,
|
825 |
+
encoder_outs: List[Dict[str, List[Tensor]]],
|
826 |
+
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
827 |
+
temperature: float = 1.0,
|
828 |
+
):
|
829 |
+
log_probs = []
|
830 |
+
avg_attn: Optional[Tensor] = None
|
831 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None
|
832 |
+
for i, model in enumerate(self.models):
|
833 |
+
if self.has_encoder():
|
834 |
+
encoder_out = encoder_outs[i]
|
835 |
+
# decode each model
|
836 |
+
if self.has_incremental_states():
|
837 |
+
decoder_out = model.decoder.forward(
|
838 |
+
tokens,
|
839 |
+
encoder_out=encoder_out,
|
840 |
+
incremental_state=incremental_states[i],
|
841 |
+
)
|
842 |
+
else:
|
843 |
+
if hasattr(model, "decoder"):
|
844 |
+
decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out)
|
845 |
+
else:
|
846 |
+
decoder_out = model.forward(tokens)
|
847 |
+
|
848 |
+
attn: Optional[Tensor] = None
|
849 |
+
decoder_len = len(decoder_out)
|
850 |
+
if decoder_len > 1 and decoder_out[1] is not None:
|
851 |
+
if isinstance(decoder_out[1], Tensor):
|
852 |
+
attn = decoder_out[1]
|
853 |
+
else:
|
854 |
+
attn_holder = decoder_out[1]["attn"]
|
855 |
+
if isinstance(attn_holder, Tensor):
|
856 |
+
attn = attn_holder
|
857 |
+
elif attn_holder is not None:
|
858 |
+
attn = attn_holder[0]
|
859 |
+
if attn is not None:
|
860 |
+
attn = attn[:, -1, :]
|
861 |
+
|
862 |
+
decoder_out_tuple = (
|
863 |
+
decoder_out[0][:, -1:, :].div_(temperature),
|
864 |
+
None if decoder_len <= 1 else decoder_out[1],
|
865 |
+
)
|
866 |
+
probs = model.get_normalized_probs(
|
867 |
+
decoder_out_tuple, log_probs=True, sample=None
|
868 |
+
)
|
869 |
+
probs = probs[:, -1, :]
|
870 |
+
if self.models_size == 1:
|
871 |
+
return probs, attn
|
872 |
+
|
873 |
+
log_probs.append(probs)
|
874 |
+
if attn is not None:
|
875 |
+
if avg_attn is None:
|
876 |
+
avg_attn = attn
|
877 |
+
else:
|
878 |
+
avg_attn.add_(attn)
|
879 |
+
|
880 |
+
avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
|
881 |
+
self.models_size
|
882 |
+
)
|
883 |
+
|
884 |
+
if avg_attn is not None:
|
885 |
+
avg_attn.div_(self.models_size)
|
886 |
+
return avg_probs, avg_attn
|
887 |
+
|
888 |
+
@torch.jit.export
|
889 |
+
def reorder_encoder_out(
|
890 |
+
self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order
|
891 |
+
):
|
892 |
+
"""
|
893 |
+
Reorder encoder output according to *new_order*.
|
894 |
+
Args:
|
895 |
+
encoder_out: output from the ``forward()`` method
|
896 |
+
new_order (LongTensor): desired order
|
897 |
+
Returns:
|
898 |
+
*encoder_out* rearranged according to *new_order*
|
899 |
+
"""
|
900 |
+
new_outs: List[Dict[str, List[Tensor]]] = []
|
901 |
+
if not self.has_encoder():
|
902 |
+
return new_outs
|
903 |
+
for i, model in enumerate(self.models):
|
904 |
+
assert encoder_outs is not None
|
905 |
+
new_outs.append(
|
906 |
+
model.encoder.reorder_encoder_out(encoder_outs[i], new_order)
|
907 |
+
)
|
908 |
+
return new_outs
|
909 |
+
|
910 |
+
@torch.jit.export
|
911 |
+
def reorder_incremental_state(
|
912 |
+
self,
|
913 |
+
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
914 |
+
new_order,
|
915 |
+
):
|
916 |
+
if not self.has_incremental_states():
|
917 |
+
return
|
918 |
+
for i, model in enumerate(self.models):
|
919 |
+
model.decoder.reorder_incremental_state_scripting(
|
920 |
+
incremental_states[i], new_order
|
921 |
+
)
|
922 |
+
|
923 |
+
|
924 |
+
class SequenceGeneratorWithAlignment(SequenceGenerator):
|
925 |
+
def __init__(
|
926 |
+
self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs
|
927 |
+
):
|
928 |
+
"""Generates translations of a given source sentence.
|
929 |
+
Produces alignments following "Jointly Learning to Align and
|
930 |
+
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
931 |
+
Args:
|
932 |
+
left_pad_target (bool, optional): Whether or not the
|
933 |
+
hypothesis should be left padded or not when they are
|
934 |
+
teacher forced for generating alignments.
|
935 |
+
"""
|
936 |
+
super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs)
|
937 |
+
self.left_pad_target = left_pad_target
|
938 |
+
|
939 |
+
if print_alignment == "hard":
|
940 |
+
self.extract_alignment = utils.extract_hard_alignment
|
941 |
+
elif print_alignment == "soft":
|
942 |
+
self.extract_alignment = utils.extract_soft_alignment
|
943 |
+
|
944 |
+
@torch.no_grad()
|
945 |
+
def generate(self, models, sample, **kwargs):
|
946 |
+
finalized = super()._generate(sample, **kwargs)
|
947 |
+
|
948 |
+
src_tokens = sample["net_input"]["src_tokens"]
|
949 |
+
bsz = src_tokens.shape[0]
|
950 |
+
beam_size = self.beam_size
|
951 |
+
(
|
952 |
+
src_tokens,
|
953 |
+
src_lengths,
|
954 |
+
prev_output_tokens,
|
955 |
+
tgt_tokens,
|
956 |
+
) = self._prepare_batch_for_alignment(sample, finalized)
|
957 |
+
if any(getattr(m, "full_context_alignment", False) for m in self.model.models):
|
958 |
+
attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens)
|
959 |
+
else:
|
960 |
+
attn = [
|
961 |
+
finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0)
|
962 |
+
for i in range(bsz * beam_size)
|
963 |
+
]
|
964 |
+
|
965 |
+
if src_tokens.device != "cpu":
|
966 |
+
src_tokens = src_tokens.to("cpu")
|
967 |
+
tgt_tokens = tgt_tokens.to("cpu")
|
968 |
+
attn = [i.to("cpu") for i in attn]
|
969 |
+
|
970 |
+
# Process the attn matrix to extract hard alignments.
|
971 |
+
for i in range(bsz * beam_size):
|
972 |
+
alignment = self.extract_alignment(
|
973 |
+
attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos
|
974 |
+
)
|
975 |
+
finalized[i // beam_size][i % beam_size]["alignment"] = alignment
|
976 |
+
return finalized
|
977 |
+
|
978 |
+
def _prepare_batch_for_alignment(self, sample, hypothesis):
|
979 |
+
src_tokens = sample["net_input"]["src_tokens"]
|
980 |
+
bsz = src_tokens.shape[0]
|
981 |
+
src_tokens = (
|
982 |
+
src_tokens[:, None, :]
|
983 |
+
.expand(-1, self.beam_size, -1)
|
984 |
+
.contiguous()
|
985 |
+
.view(bsz * self.beam_size, -1)
|
986 |
+
)
|
987 |
+
src_lengths = sample["net_input"]["src_lengths"]
|
988 |
+
src_lengths = (
|
989 |
+
src_lengths[:, None]
|
990 |
+
.expand(-1, self.beam_size)
|
991 |
+
.contiguous()
|
992 |
+
.view(bsz * self.beam_size)
|
993 |
+
)
|
994 |
+
prev_output_tokens = data_utils.collate_tokens(
|
995 |
+
[beam["tokens"] for example in hypothesis for beam in example],
|
996 |
+
self.pad,
|
997 |
+
self.eos,
|
998 |
+
self.left_pad_target,
|
999 |
+
move_eos_to_beginning=True,
|
1000 |
+
)
|
1001 |
+
tgt_tokens = data_utils.collate_tokens(
|
1002 |
+
[beam["tokens"] for example in hypothesis for beam in example],
|
1003 |
+
self.pad,
|
1004 |
+
self.eos,
|
1005 |
+
self.left_pad_target,
|
1006 |
+
move_eos_to_beginning=False,
|
1007 |
+
)
|
1008 |
+
return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
|
1009 |
+
|
1010 |
+
|
1011 |
+
class EnsembleModelWithAlignment(EnsembleModel):
|
1012 |
+
"""A wrapper around an ensemble of models."""
|
1013 |
+
|
1014 |
+
def __init__(self, models):
|
1015 |
+
super().__init__(models)
|
1016 |
+
|
1017 |
+
def forward_align(self, src_tokens, src_lengths, prev_output_tokens):
|
1018 |
+
avg_attn = None
|
1019 |
+
for model in self.models:
|
1020 |
+
decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
|
1021 |
+
attn = decoder_out[1]["attn"][0]
|
1022 |
+
if avg_attn is None:
|
1023 |
+
avg_attn = attn
|
1024 |
+
else:
|
1025 |
+
avg_attn.add_(attn)
|
1026 |
+
if len(self.models) > 1:
|
1027 |
+
avg_attn.div_(len(self.models))
|
1028 |
+
return avg_attn
|
SpeechT5/Speech2C/speech2c/tasks/speech2c_pretraining.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import logging
|
11 |
+
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from fairseq.data import Dictionary
|
14 |
+
from fairseq.tasks import register_task
|
15 |
+
from fairseq.tasks.hubert_pretraining import HubertPretrainingConfig, HubertPretrainingTask, LabelEncoder
|
16 |
+
from speech2c.data.speech2c_dataset import Speech2cDataset
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class Speech2cPretrainingConfig(HubertPretrainingConfig):
|
23 |
+
add_decoder: bool = field(
|
24 |
+
default=False,
|
25 |
+
metadata={"help": "whether to add decoder for CE Loss on code"},
|
26 |
+
)
|
27 |
+
|
28 |
+
# For inference
|
29 |
+
ctc_weight: float = field(
|
30 |
+
default=0.0,
|
31 |
+
metadata={"help": "ctc weight during inference"},
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
@register_task("speech2c_pretraining", dataclass=Speech2cPretrainingConfig)
|
36 |
+
class Speech2cPretrainingTask(HubertPretrainingTask):
|
37 |
+
|
38 |
+
cfg: Speech2cPretrainingConfig
|
39 |
+
|
40 |
+
def load_dictionaries(self):
|
41 |
+
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
|
42 |
+
dictionaries = [Dictionary.load(f"{label_dir}/dict.{label}.txt") for label in self.cfg.labels]
|
43 |
+
return dictionaries[0] if self.cfg.fine_tuning else dictionaries
|
44 |
+
|
45 |
+
def load_dataset(self, split: str, **kwargs) -> None:
|
46 |
+
manifest = f"{self.cfg.data}/{split}.tsv"
|
47 |
+
dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries
|
48 |
+
pad_list = [dict.pad() for dict in dicts]
|
49 |
+
eos_list = [dict.eos() for dict in dicts]
|
50 |
+
procs = [LabelEncoder(dict) for dict in dicts]
|
51 |
+
paths = [
|
52 |
+
f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels
|
53 |
+
]
|
54 |
+
|
55 |
+
# hubert v1: pad_audio=True, random_crop=False;
|
56 |
+
self.datasets[split] = Speech2cDataset(
|
57 |
+
manifest,
|
58 |
+
sample_rate=self.cfg.sample_rate,
|
59 |
+
label_paths=paths,
|
60 |
+
label_rates=self.cfg.label_rate,
|
61 |
+
pad_list=pad_list,
|
62 |
+
eos_list=eos_list,
|
63 |
+
label_processors=procs,
|
64 |
+
max_keep_sample_size=self.cfg.max_keep_size,
|
65 |
+
min_keep_sample_size=self.cfg.min_sample_size,
|
66 |
+
max_sample_size=self.cfg.max_sample_size,
|
67 |
+
pad_audio=self.cfg.pad_audio,
|
68 |
+
normalize=self.cfg.normalize,
|
69 |
+
store_labels=False,
|
70 |
+
random_crop=self.cfg.random_crop,
|
71 |
+
single_target=self.cfg.single_target,
|
72 |
+
tgt_dict=dicts[0],
|
73 |
+
add_decoder=self.cfg.add_decoder,
|
74 |
+
fine_tuning=self.cfg.fine_tuning,
|
75 |
+
)
|
76 |
+
|
77 |
+
def build_generator(
|
78 |
+
self,
|
79 |
+
models,
|
80 |
+
args,
|
81 |
+
seq_gen_cls=None,
|
82 |
+
extra_gen_cls_kwargs=None,
|
83 |
+
):
|
84 |
+
from speech2c.squence_generator import SequenceGenerator
|
85 |
+
extra_gen_cls_kwargs = {
|
86 |
+
"ctc_weight": self.cfg.ctc_weight,
|
87 |
+
**extra_gen_cls_kwargs
|
88 |
+
}
|
89 |
+
return super().build_generator(
|
90 |
+
models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs
|
91 |
+
)
|
SpeechT5/Speech2S/README.md
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Speech2S
|
2 |
+
<!--**Pre-trained models for speech related tasks**-->
|
3 |
+
|
4 |
+
[**Joint Pre-Training with Speech and Bilingual Text for Direct Speech to Speech Translation**](https://arxiv.org/abs/2210.17027)
|
5 |
+
|
6 |
+
|
7 |
+
- (Updating) Nov. 2022: release the code and models
|
8 |
+
- Nov. 2022: release preprint in [arXiv](https://arxiv.org/abs/2210.17027)
|
9 |
+
|
10 |
+
## Pre-Trained and Fine-tuned Models
|
11 |
+
|
12 |
+
| Model | Pre-training Dataset | Fine-tuning Dataset | Model |
|
13 |
+
| :------: | :----------------------------------------------: | :-----------------: | :-----: |
|
14 |
+
| Speech2S_enes | Voxpopuli_en_v2 | - | [Google Drive](https://drive.google.com/file/d/1TYypFiEKoCixUro8FTTG23bRZYwAxhkX/view?usp=share_link) |
|
15 |
+
| Speech2S_enes | Voxpopuli_en_v2 | Voxpopuli_s2s | [Google Drive](https://drive.google.com/file/d/11RxeKznSrHcoP_KK9A1VgwRt3fNh_U_C/view?usp=share_link) |
|
16 |
+
| Speech2S_esen | Voxpopuli_es_v2 | - | [Google Drive](https://drive.google.com/file/d/1NoC7W-UtQZ-ugIptF1ex0ZlGJncsT1S4/view?usp=share_link) |
|
17 |
+
| Speech2S_esen | Voxpopuli_es_v2 | Voxpopuli_s2s | [Google Drive](https://drive.google.com/file/d/1eNcKw4ZWGmcABWXJxlf6MKocmiPrKSkH/view?usp=share_link) |
|
18 |
+
|
19 |
+
|
20 |
+
## Setup
|
21 |
+
```
|
22 |
+
cd Speech2S/speech2s
|
23 |
+
pip install --editable fairseq/
|
24 |
+
```
|
25 |
+
|
26 |
+
## Data Preparation
|
27 |
+
Please follow the steps of data preparation for S2ST in [here](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md).
|
28 |
+
|
29 |
+
## Pre-Training
|
30 |
+
```
|
31 |
+
cd speech2s/stpretrain_scripts
|
32 |
+
base_sc2c_enes.sh
|
33 |
+
```
|
34 |
+
## Finetune
|
35 |
+
```
|
36 |
+
cd speech2s/stpretrain_scripts
|
37 |
+
finetune_enes.sh
|
38 |
+
```
|
39 |
+
## Inference
|
40 |
+
```
|
41 |
+
cd speech2s/stpretrain_scripts
|
42 |
+
inference_ed.sh
|
43 |
+
```
|
44 |
+
## Results on Voxpopuli and Covst
|
45 |
+
|
46 |
+
|
47 |
+
## License
|
48 |
+
|
49 |
+
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
|
50 |
+
Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq).
|
51 |
+
|
52 |
+
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
|
53 |
+
|
54 |
+
## Reference
|
55 |
+
|
56 |
+
If you find our work is useful in your research, please cite the following paper:
|
57 |
+
```bibtex
|
58 |
+
@article{wei2022joint,
|
59 |
+
title={Joint Pre-Training with Speech and Bilingual Text for Direct Speech to Speech Translation},
|
60 |
+
author={Wei, Kun and Zhou, Long and Zhang, Ziqiang and Chen, Liping and Liu, Shujie and He, Lei and Li, Jinyu and Wei, Furu},
|
61 |
+
journal={arXiv preprint arXiv:2210.17027},
|
62 |
+
year={2022}
|
63 |
+
}
|
64 |
+
```
|
SpeechT5/Speech2S/speech2s/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import data, tasks, criterions, models
|
SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_base_100h.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 100
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_interval: 1
|
12 |
+
keep_last_epochs: 1
|
13 |
+
keep_best_checkpoints: 5
|
14 |
+
best_checkpoint_metric: dec_accuracy
|
15 |
+
maximize_best_checkpoint_metric: true
|
16 |
+
restore_file: checkpoint_last.pt
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: legacy_ddp
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 1
|
22 |
+
distributed_port: -1
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: joint_sc2t_pretraining
|
27 |
+
data: ???
|
28 |
+
fine_tuning: true
|
29 |
+
label_dir: ???
|
30 |
+
normalize: false # must be consistent with pre-training
|
31 |
+
labels: ["ltr"]
|
32 |
+
store_labels: true
|
33 |
+
single_target: true
|
34 |
+
add_decoder_target: true
|
35 |
+
pad_audio: false
|
36 |
+
random_crop: true
|
37 |
+
hubert_tokenizer: "none"
|
38 |
+
sp_path: None
|
39 |
+
|
40 |
+
dataset:
|
41 |
+
num_workers: 0
|
42 |
+
max_tokens: 1300000
|
43 |
+
skip_invalid_size_inputs_valid_test: true
|
44 |
+
train_subset: train_100
|
45 |
+
valid_subset: dev_other
|
46 |
+
required_batch_size_multiple: 1
|
47 |
+
|
48 |
+
criterion:
|
49 |
+
_name: ctc_ce
|
50 |
+
zero_infinity: true
|
51 |
+
|
52 |
+
optimization:
|
53 |
+
max_update: 40000
|
54 |
+
lr: [0.00001]
|
55 |
+
sentence_avg: true
|
56 |
+
update_freq: [2]
|
57 |
+
|
58 |
+
optimizer:
|
59 |
+
_name: adam
|
60 |
+
adam_betas: (0.9,0.98)
|
61 |
+
adam_eps: 1e-08
|
62 |
+
weight_decay: 0.0
|
63 |
+
|
64 |
+
lr_scheduler:
|
65 |
+
_name: tri_stage
|
66 |
+
phase_ratio: [0.1, 0.4, 0.5]
|
67 |
+
final_lr_scale: 0.05
|
68 |
+
|
69 |
+
model:
|
70 |
+
_name: speechut_asr
|
71 |
+
w2v_path: ???
|
72 |
+
apply_mask: true
|
73 |
+
mask_prob: 0.65
|
74 |
+
mask_channel_prob: 0.5
|
75 |
+
mask_channel_length: 64
|
76 |
+
layerdrop: 0.1
|
77 |
+
activation_dropout: 0.1
|
78 |
+
feature_grad_mult: 0.0
|
79 |
+
freeze_finetune_updates: 0
|
80 |
+
add_decoder: true
|
81 |
+
|
82 |
+
hydra:
|
83 |
+
job:
|
84 |
+
config:
|
85 |
+
override_dirname:
|
86 |
+
kv_sep: '-'
|
87 |
+
item_sep: '__'
|
88 |
+
exclude_keys:
|
89 |
+
- run
|
90 |
+
- task.data
|
91 |
+
- task.label_dir
|
92 |
+
- model.w2v_path
|
93 |
+
- dataset.train_subset
|
94 |
+
- dataset.valid_subset
|
95 |
+
- criterion.wer_kenlm_model
|
96 |
+
- criterion.wer_lexicon
|
97 |
+
run:
|
98 |
+
dir: ???
|
99 |
+
sweep:
|
100 |
+
dir: ???
|
101 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_100h.yaml
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 100
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_interval: 1
|
12 |
+
keep_last_epochs: 5
|
13 |
+
keep_best_checkpoints: 5
|
14 |
+
best_checkpoint_metric: dec_accuracy
|
15 |
+
maximize_best_checkpoint_metric: true
|
16 |
+
restore_file: checkpoint_last.pt
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: legacy_ddp
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 16
|
22 |
+
distributed_port: -1
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: joint_sc2t_pretraining
|
27 |
+
data: ???
|
28 |
+
fine_tuning: true
|
29 |
+
label_dir: ???
|
30 |
+
normalize: true # must be consistent with pre-training
|
31 |
+
labels: ["ltr"]
|
32 |
+
store_labels: true
|
33 |
+
single_target: true
|
34 |
+
add_decoder_target: true
|
35 |
+
pad_audio: false
|
36 |
+
random_crop: true
|
37 |
+
hubert_tokenizer: "none"
|
38 |
+
sp_path: None
|
39 |
+
|
40 |
+
dataset:
|
41 |
+
num_workers: 0
|
42 |
+
max_tokens: 1300000
|
43 |
+
skip_invalid_size_inputs_valid_test: true
|
44 |
+
train_subset: train_100
|
45 |
+
valid_subset: dev_other
|
46 |
+
required_batch_size_multiple: 1
|
47 |
+
|
48 |
+
criterion:
|
49 |
+
_name: ctc_ce
|
50 |
+
zero_infinity: true
|
51 |
+
|
52 |
+
optimization:
|
53 |
+
max_update: 40000
|
54 |
+
lr: [0.00001]
|
55 |
+
sentence_avg: true
|
56 |
+
update_freq: [2]
|
57 |
+
|
58 |
+
optimizer:
|
59 |
+
_name: adam
|
60 |
+
adam_betas: (0.9,0.98)
|
61 |
+
adam_eps: 1e-08
|
62 |
+
weight_decay: 0.0
|
63 |
+
|
64 |
+
lr_scheduler:
|
65 |
+
_name: tri_stage
|
66 |
+
phase_ratio: [0.1, 0.4, 0.5]
|
67 |
+
final_lr_scale: 0.05
|
68 |
+
|
69 |
+
model:
|
70 |
+
_name: speechut_asr
|
71 |
+
w2v_path: ???
|
72 |
+
apply_mask: true
|
73 |
+
mask_prob: 0.5
|
74 |
+
mask_channel_prob: 0.5
|
75 |
+
mask_channel_length: 64
|
76 |
+
layerdrop: 0.0
|
77 |
+
activation_dropout: 0.1
|
78 |
+
attention_dropout: 0.1
|
79 |
+
feature_grad_mult: 0.0
|
80 |
+
freeze_finetune_updates: 0
|
81 |
+
add_decoder: true
|
82 |
+
|
83 |
+
hydra:
|
84 |
+
job:
|
85 |
+
config:
|
86 |
+
override_dirname:
|
87 |
+
kv_sep: '-'
|
88 |
+
item_sep: '__'
|
89 |
+
exclude_keys:
|
90 |
+
- run
|
91 |
+
- task.data
|
92 |
+
- task.label_dir
|
93 |
+
- model.w2v_path
|
94 |
+
- dataset.train_subset
|
95 |
+
- dataset.valid_subset
|
96 |
+
- criterion.wer_kenlm_model
|
97 |
+
- criterion.wer_lexicon
|
98 |
+
run:
|
99 |
+
dir: ???
|
100 |
+
sweep:
|
101 |
+
dir: ???
|
102 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_960h.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 100
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
|
9 |
+
checkpoint:
|
10 |
+
save_interval: 1
|
11 |
+
keep_last_epochs: 5
|
12 |
+
keep_best_checkpoints: 5
|
13 |
+
best_checkpoint_metric: dec_accuracy
|
14 |
+
maximize_best_checkpoint_metric: true
|
15 |
+
restore_file: checkpoint_last.pt
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: legacy_ddp
|
19 |
+
find_unused_parameters: true
|
20 |
+
distributed_world_size: 24
|
21 |
+
distributed_port: -1
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: joint_sc2t_pretraining
|
26 |
+
data: ???
|
27 |
+
fine_tuning: true
|
28 |
+
label_dir: ???
|
29 |
+
normalize: true # must be consistent with pre-training
|
30 |
+
labels: ["ltr"]
|
31 |
+
store_labels: true
|
32 |
+
single_target: true
|
33 |
+
add_decoder_target: true
|
34 |
+
pad_audio: false
|
35 |
+
random_crop: true
|
36 |
+
hubert_tokenizer: "none"
|
37 |
+
sp_path: None
|
38 |
+
|
39 |
+
dataset:
|
40 |
+
num_workers: 0
|
41 |
+
max_tokens: 1300000
|
42 |
+
skip_invalid_size_inputs_valid_test: true
|
43 |
+
train_subset: train_960
|
44 |
+
valid_subset: dev_other
|
45 |
+
required_batch_size_multiple: 1
|
46 |
+
|
47 |
+
criterion:
|
48 |
+
_name: ctc_ce
|
49 |
+
zero_infinity: true
|
50 |
+
|
51 |
+
optimization:
|
52 |
+
max_update: 40000
|
53 |
+
lr: [0.00001]
|
54 |
+
sentence_avg: true
|
55 |
+
update_freq: [2]
|
56 |
+
|
57 |
+
optimizer:
|
58 |
+
_name: adam
|
59 |
+
adam_betas: (0.9,0.98)
|
60 |
+
adam_eps: 1e-08
|
61 |
+
weight_decay: 0.0
|
62 |
+
|
63 |
+
lr_scheduler:
|
64 |
+
_name: tri_stage
|
65 |
+
phase_ratio: [0.1, 0.4, 0.5]
|
66 |
+
final_lr_scale: 0.05
|
67 |
+
|
68 |
+
model:
|
69 |
+
_name: speechut_asr
|
70 |
+
w2v_path: ???
|
71 |
+
apply_mask: true
|
72 |
+
mask_prob: 0.5
|
73 |
+
mask_channel_prob: 0.25
|
74 |
+
mask_channel_length: 64
|
75 |
+
layerdrop: 0.0
|
76 |
+
activation_dropout: 0.1
|
77 |
+
feature_grad_mult: 0.0
|
78 |
+
freeze_finetune_updates: 0
|
79 |
+
add_decoder: true
|
80 |
+
|
81 |
+
hydra:
|
82 |
+
job:
|
83 |
+
config:
|
84 |
+
override_dirname:
|
85 |
+
kv_sep: '-'
|
86 |
+
item_sep: '__'
|
87 |
+
exclude_keys:
|
88 |
+
- run
|
89 |
+
- task.data
|
90 |
+
- task.label_dir
|
91 |
+
- model.w2v_path
|
92 |
+
- dataset.train_subset
|
93 |
+
- dataset.valid_subset
|
94 |
+
- criterion.wer_kenlm_model
|
95 |
+
- criterion.wer_lexicon
|
96 |
+
run:
|
97 |
+
dir: ???
|
98 |
+
sweep:
|
99 |
+
dir: ???
|
100 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
SpeechT5/Speech2S/speech2s/config/pretrain/speechut_base_librispeech.yaml
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
tensorboard_logdir: tblog
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_dir: ???
|
12 |
+
save_interval: 4
|
13 |
+
keep_last_epochs: 4
|
14 |
+
save_interval_updates: 50000
|
15 |
+
keep_interval_updates: -1
|
16 |
+
keep_interval_updates_pattern: 50000
|
17 |
+
# no_epoch_checkpoints: true
|
18 |
+
|
19 |
+
distributed_training:
|
20 |
+
ddp_backend: no_c10d
|
21 |
+
distributed_backend: 'nccl'
|
22 |
+
distributed_port: -1
|
23 |
+
distributed_world_size: 32
|
24 |
+
nprocs_per_node: 8
|
25 |
+
find_unused_parameters: true
|
26 |
+
|
27 |
+
task:
|
28 |
+
_name: joint_sc2t_pretraining
|
29 |
+
data: ???
|
30 |
+
label_dir: ???
|
31 |
+
labels: ???
|
32 |
+
label_rate: ${model.label_rate}
|
33 |
+
store_labels: true
|
34 |
+
sample_rate: 16000
|
35 |
+
max_sample_size: 250000
|
36 |
+
min_sample_size: 32000
|
37 |
+
pad_audio: false
|
38 |
+
random_crop: true
|
39 |
+
normalize: false # must be consistent with extractor
|
40 |
+
add_decoder_target: true
|
41 |
+
text_cfg:
|
42 |
+
seed: ${common.seed}
|
43 |
+
text_data: ???
|
44 |
+
data_config: config.yaml
|
45 |
+
sample_break_mode: eos
|
46 |
+
tokens_per_sample: 1024
|
47 |
+
shorten_method: "random_crop"
|
48 |
+
text_maxtokens_ratio: 1.5
|
49 |
+
|
50 |
+
dataset:
|
51 |
+
num_workers: 6
|
52 |
+
max_tokens: 1400000
|
53 |
+
skip_invalid_size_inputs_valid_test: true
|
54 |
+
validate_interval: ${checkpoint.save_interval}
|
55 |
+
validate_interval_updates: ${checkpoint.save_interval_updates}
|
56 |
+
required_batch_size_multiple: 1
|
57 |
+
|
58 |
+
criterion:
|
59 |
+
_name: speechut_criterion
|
60 |
+
pred_masked_weight: 1.0
|
61 |
+
pred_nomask_weight: 0.0
|
62 |
+
loss_weights: [10,]
|
63 |
+
label_smoothing: 0.1
|
64 |
+
u2t_ed_weight: 0.1
|
65 |
+
u2t_ctc_weight: 0.1
|
66 |
+
text_mum_weight: 0.5
|
67 |
+
|
68 |
+
optimization:
|
69 |
+
max_update: 400000
|
70 |
+
lr: [0.0005]
|
71 |
+
clip_norm: 10.0
|
72 |
+
|
73 |
+
optimizer:
|
74 |
+
_name: adam
|
75 |
+
adam_betas: (0.9,0.98)
|
76 |
+
adam_eps: 1e-06
|
77 |
+
weight_decay: 0.01
|
78 |
+
|
79 |
+
lr_scheduler:
|
80 |
+
_name: polynomial_decay
|
81 |
+
warmup_updates: 32000
|
82 |
+
|
83 |
+
model:
|
84 |
+
_name: speechut
|
85 |
+
label_rate: ???
|
86 |
+
skip_masked: false
|
87 |
+
skip_nomask: false
|
88 |
+
mask_prob: 0.80
|
89 |
+
extractor_mode: default
|
90 |
+
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
91 |
+
final_dim: 256
|
92 |
+
activation_fn: "gelu"
|
93 |
+
encoder_layers: 6
|
94 |
+
encoder_attention_heads: 8
|
95 |
+
encoder_layerdrop: 0.0
|
96 |
+
dropout_input: 0.1
|
97 |
+
dropout_features: 0.1
|
98 |
+
dropout: 0.1
|
99 |
+
attention_dropout: 0.1
|
100 |
+
feature_grad_mult: 0.1
|
101 |
+
untie_final_proj: true
|
102 |
+
activation_dropout: 0.0
|
103 |
+
use_rel_pos_enc: true
|
104 |
+
add_unit_encoder: true
|
105 |
+
add_text_ctc: true
|
106 |
+
mask_u2t: false
|
107 |
+
mix_with_unit: true
|
108 |
+
add_decoder: true
|
109 |
+
reset_decoder_embedding_config: true
|
110 |
+
text_transformer:
|
111 |
+
activation_fn: ${model.activation_fn}
|
112 |
+
dropout: ${model.dropout}
|
113 |
+
attention_dropout: ${model.attention_dropout}
|
114 |
+
activation_dropout: ${model.activation_dropout}
|
115 |
+
max_source_positions: 3000
|
116 |
+
max_target_positions: 3000
|
117 |
+
no_scale_embedding: true
|
118 |
+
layernorm_embedding: true
|
119 |
+
no_token_positional_embeddings: false
|
120 |
+
share_decoder_input_output_embed: false
|
121 |
+
encoder:
|
122 |
+
embed_dim: 768
|
123 |
+
ffn_embed_dim: 3072
|
124 |
+
layers: 6
|
125 |
+
attention_heads: 8
|
126 |
+
normalize_before: false
|
127 |
+
learned_pos: true
|
128 |
+
layerdrop: ${model.encoder_layerdrop}
|
129 |
+
decoder:
|
130 |
+
layerdrop: 0.1
|
131 |
+
embed_dim: 768
|
132 |
+
ffn_embed_dim: 3072
|
133 |
+
layers: 6
|
134 |
+
attention_heads: 12
|
135 |
+
normalize_before: false
|
136 |
+
learned_pos: false
|
137 |
+
output_dim: 768
|
138 |
+
|
139 |
+
hydra:
|
140 |
+
job:
|
141 |
+
config:
|
142 |
+
override_dirname:
|
143 |
+
kv_sep: '-'
|
144 |
+
item_sep: '__'
|
145 |
+
exclude_keys:
|
146 |
+
- run
|
147 |
+
- task.data
|
148 |
+
- task.label_dir
|
149 |
+
run:
|
150 |
+
dir: ???
|
151 |
+
sweep:
|
152 |
+
dir: ???
|
153 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
SpeechT5/Speech2S/speech2s/config/pretrain/speechut_large_librilight.yaml
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
fp16_scale_tolerance: 0.1 # alleviate fp16 overflow issue
|
6 |
+
log_format: json
|
7 |
+
log_interval: 200
|
8 |
+
seed: 1234
|
9 |
+
tensorboard_logdir: tblog
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_dir: ???
|
13 |
+
save_interval: 1
|
14 |
+
keep_last_epochs: 4
|
15 |
+
save_interval_updates: 10000
|
16 |
+
keep_interval_updates: -1
|
17 |
+
keep_interval_updates_pattern: 10000
|
18 |
+
# no_epoch_checkpoints: true
|
19 |
+
|
20 |
+
distributed_training:
|
21 |
+
ddp_backend: no_c10d
|
22 |
+
distributed_backend: 'nccl'
|
23 |
+
distributed_port: -1
|
24 |
+
distributed_world_size: 128
|
25 |
+
nprocs_per_node: 8
|
26 |
+
find_unused_parameters: true
|
27 |
+
|
28 |
+
task:
|
29 |
+
_name: joint_sc2t_pretraining
|
30 |
+
data: ???
|
31 |
+
label_dir: ???
|
32 |
+
labels: ???
|
33 |
+
label_rate: ${model.label_rate}
|
34 |
+
store_labels: true
|
35 |
+
sample_rate: 16000
|
36 |
+
max_sample_size: 250000
|
37 |
+
min_sample_size: 32000
|
38 |
+
pad_audio: false
|
39 |
+
random_crop: true
|
40 |
+
normalize: true # must be consistent with extractor
|
41 |
+
add_decoder_target: true
|
42 |
+
text_cfg:
|
43 |
+
seed: ${common.seed}
|
44 |
+
text_data: ???
|
45 |
+
data_config: config.yaml
|
46 |
+
sample_break_mode: eos
|
47 |
+
tokens_per_sample: 1024
|
48 |
+
shorten_method: "random_crop"
|
49 |
+
text_maxtokens_ratio: 1.4
|
50 |
+
|
51 |
+
dataset:
|
52 |
+
num_workers: 6
|
53 |
+
max_tokens: 900000
|
54 |
+
skip_invalid_size_inputs_valid_test: true
|
55 |
+
validate_interval: ${checkpoint.save_interval}
|
56 |
+
validate_interval_updates: ${checkpoint.save_interval_updates}
|
57 |
+
required_batch_size_multiple: 2
|
58 |
+
|
59 |
+
criterion:
|
60 |
+
_name: speechut_criterion
|
61 |
+
pred_masked_weight: 1.0
|
62 |
+
pred_nomask_weight: 0.0
|
63 |
+
loss_weights: [10,]
|
64 |
+
label_smoothing: 0.1
|
65 |
+
u2t_ed_weight: 0.1
|
66 |
+
u2t_ctc_weight: 0.1
|
67 |
+
text_mum_weight: 0.5
|
68 |
+
|
69 |
+
optimization:
|
70 |
+
max_update: 400000
|
71 |
+
lr: [0.0005]
|
72 |
+
clip_norm: 1.0
|
73 |
+
|
74 |
+
optimizer:
|
75 |
+
_name: adam
|
76 |
+
adam_betas: (0.9,0.98)
|
77 |
+
adam_eps: 1e-06
|
78 |
+
weight_decay: 0.01
|
79 |
+
|
80 |
+
lr_scheduler:
|
81 |
+
_name: polynomial_decay
|
82 |
+
warmup_updates: 32000
|
83 |
+
end_learning_rate: 0.00015 # for future longger pre-training, e.g. 600K step
|
84 |
+
|
85 |
+
model:
|
86 |
+
_name: speechut
|
87 |
+
label_rate: ???
|
88 |
+
encoder_embed_dim: 1024
|
89 |
+
encoder_ffn_embed_dim: 4096
|
90 |
+
skip_masked: false
|
91 |
+
skip_nomask: false
|
92 |
+
mask_prob: 0.80
|
93 |
+
extractor_mode: layer_norm
|
94 |
+
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
95 |
+
final_dim: 768
|
96 |
+
activation_fn: "gelu"
|
97 |
+
encoder_layers: 12
|
98 |
+
encoder_attention_heads: 16
|
99 |
+
encoder_layerdrop: 0.0
|
100 |
+
dropout_input: 0.0
|
101 |
+
dropout_features: 0.0
|
102 |
+
dropout: 0.0
|
103 |
+
attention_dropout: 0.0
|
104 |
+
layer_norm_first: true
|
105 |
+
feature_grad_mult: 1.0
|
106 |
+
untie_final_proj: true
|
107 |
+
activation_dropout: 0.0
|
108 |
+
use_rel_pos_enc: true
|
109 |
+
add_unit_encoder: true
|
110 |
+
add_text_ctc: true
|
111 |
+
mask_u2t: false
|
112 |
+
mix_with_unit: true
|
113 |
+
add_decoder: true
|
114 |
+
reset_decoder_embedding_config: true
|
115 |
+
scaling_for_att: 32 # alleviate fp16 overflow issue
|
116 |
+
text_transformer:
|
117 |
+
activation_fn: ${model.activation_fn}
|
118 |
+
dropout: ${model.dropout}
|
119 |
+
attention_dropout: ${model.attention_dropout}
|
120 |
+
activation_dropout: ${model.activation_dropout}
|
121 |
+
max_source_positions: 3000
|
122 |
+
max_target_positions: 3000
|
123 |
+
no_scale_embedding: true
|
124 |
+
layernorm_embedding: true
|
125 |
+
no_token_positional_embeddings: true
|
126 |
+
share_decoder_input_output_embed: false
|
127 |
+
encoder:
|
128 |
+
embed_dim: 1024
|
129 |
+
ffn_embed_dim: 4096
|
130 |
+
layers: 12
|
131 |
+
attention_heads: 16
|
132 |
+
normalize_before: false
|
133 |
+
learned_pos: true
|
134 |
+
layerdrop: ${model.encoder_layerdrop}
|
135 |
+
decoder:
|
136 |
+
layerdrop: 0.1
|
137 |
+
embed_dim: 768
|
138 |
+
ffn_embed_dim: 3072
|
139 |
+
layers: 6
|
140 |
+
attention_heads: 12
|
141 |
+
normalize_before: false
|
142 |
+
learned_pos: false
|
143 |
+
output_dim: 768
|
144 |
+
|
145 |
+
hydra:
|
146 |
+
job:
|
147 |
+
config:
|
148 |
+
override_dirname:
|
149 |
+
kv_sep: '-'
|
150 |
+
item_sep: '__'
|
151 |
+
exclude_keys:
|
152 |
+
- run
|
153 |
+
- task.data
|
154 |
+
- task.label_dir
|
155 |
+
run:
|
156 |
+
dir: ???
|
157 |
+
sweep:
|
158 |
+
dir: ???
|
159 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
SpeechT5/Speech2S/speech2s/criterions/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
for file in os.listdir(os.path.dirname(__file__)):
|
5 |
+
if file.endswith(".py") and not file.startswith("_"):
|
6 |
+
criterion_name = file[: file.find(".py")]
|
7 |
+
importlib.import_module(
|
8 |
+
"speechut.criterions." + criterion_name
|
9 |
+
)
|
SpeechT5/Speech2S/speech2s/criterions/ctc_ce.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ----------------------------------------------------------------------------
|
2 |
+
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
|
4 |
+
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
|
5 |
+
#
|
6 |
+
# Copyright (c) 2022 Microsoft
|
7 |
+
# Licensed under The MIT License [see LICENSE for details]
|
8 |
+
# ----------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
from argparse import Namespace
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from omegaconf import II
|
14 |
+
from typing import Optional
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from fairseq import metrics, utils
|
19 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
20 |
+
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
|
21 |
+
from fairseq.dataclass import FairseqDataclass
|
22 |
+
from fairseq.data.data_utils import post_process
|
23 |
+
from fairseq.tasks import FairseqTask
|
24 |
+
from fairseq.logging.meters import safe_round
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class CtcCeCriterionConfig(FairseqDataclass):
|
29 |
+
zero_infinity: bool = field(
|
30 |
+
default=False,
|
31 |
+
metadata={"help": "zero inf loss when source length <= target length"},
|
32 |
+
)
|
33 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
34 |
+
post_process: str = field(
|
35 |
+
default="letter",
|
36 |
+
metadata={
|
37 |
+
"help": "how to post process predictions into words. can be letter, "
|
38 |
+
"wordpiece, BPE symbols, etc. "
|
39 |
+
"See fairseq.data.data_utils.post_process() for full list of options"
|
40 |
+
},
|
41 |
+
)
|
42 |
+
wer_kenlm_model: Optional[str] = field(
|
43 |
+
default=None,
|
44 |
+
metadata={
|
45 |
+
"help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
|
46 |
+
},
|
47 |
+
)
|
48 |
+
wer_lexicon: Optional[str] = field(
|
49 |
+
default=None,
|
50 |
+
metadata={"help": "lexicon to use with wer_kenlm_model"},
|
51 |
+
)
|
52 |
+
wer_lm_weight: float = field(
|
53 |
+
default=2.0,
|
54 |
+
metadata={"help": "lm weight to use with wer_kenlm_model"},
|
55 |
+
)
|
56 |
+
wer_word_score: float = field(
|
57 |
+
default=-1.0,
|
58 |
+
metadata={"help": "lm word score to use with wer_kenlm_model"},
|
59 |
+
)
|
60 |
+
|
61 |
+
wer_args: Optional[str] = field(
|
62 |
+
default=None,
|
63 |
+
metadata={
|
64 |
+
"help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
|
65 |
+
},
|
66 |
+
)
|
67 |
+
|
68 |
+
dec_weight: float = field(
|
69 |
+
default=0.5,
|
70 |
+
metadata={"help": "weights for decoder CE Loss, loss will be ((1 - dec_weight) * hubert_loss + dec_weight * CE_Loss)"},
|
71 |
+
)
|
72 |
+
report_accuracy: bool = field(
|
73 |
+
default=True,
|
74 |
+
metadata={"help": "report decoder accuracy metric"},
|
75 |
+
)
|
76 |
+
ignore_prefix_size: int = field(
|
77 |
+
default=0,
|
78 |
+
metadata={"help": "Ignore first N tokens"},
|
79 |
+
)
|
80 |
+
label_smoothing: float = field(
|
81 |
+
default=0.1,
|
82 |
+
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
@register_criterion("ctc_ce", dataclass=CtcCeCriterionConfig)
|
87 |
+
class CtcCeCriterion(FairseqCriterion):
|
88 |
+
def __init__(self, cfg: CtcCeCriterionConfig, task: FairseqTask):
|
89 |
+
super().__init__(task)
|
90 |
+
self.blank_idx = (
|
91 |
+
task.target_dictionary.index(task.blank_symbol)
|
92 |
+
if hasattr(task, "blank_symbol")
|
93 |
+
else 0
|
94 |
+
)
|
95 |
+
self.pad_idx = task.target_dictionary.pad()
|
96 |
+
self.eos_idx = task.target_dictionary.eos()
|
97 |
+
self.post_process = cfg.post_process
|
98 |
+
|
99 |
+
if cfg.wer_args is not None:
|
100 |
+
(
|
101 |
+
cfg.wer_kenlm_model,
|
102 |
+
cfg.wer_lexicon,
|
103 |
+
cfg.wer_lm_weight,
|
104 |
+
cfg.wer_word_score,
|
105 |
+
) = eval(cfg.wer_args)
|
106 |
+
|
107 |
+
if cfg.wer_kenlm_model is not None:
|
108 |
+
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
109 |
+
|
110 |
+
dec_args = Namespace()
|
111 |
+
dec_args.nbest = 1
|
112 |
+
dec_args.criterion = "ctc"
|
113 |
+
dec_args.kenlm_model = cfg.wer_kenlm_model
|
114 |
+
dec_args.lexicon = cfg.wer_lexicon
|
115 |
+
dec_args.beam = 50
|
116 |
+
dec_args.beam_size_token = min(50, len(task.target_dictionary))
|
117 |
+
dec_args.beam_threshold = min(50, len(task.target_dictionary))
|
118 |
+
dec_args.lm_weight = cfg.wer_lm_weight
|
119 |
+
dec_args.word_score = cfg.wer_word_score
|
120 |
+
dec_args.unk_weight = -math.inf
|
121 |
+
dec_args.sil_weight = 0
|
122 |
+
|
123 |
+
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
|
124 |
+
else:
|
125 |
+
self.w2l_decoder = None
|
126 |
+
|
127 |
+
self.zero_infinity = cfg.zero_infinity
|
128 |
+
self.sentence_avg = cfg.sentence_avg
|
129 |
+
|
130 |
+
self.dec_weight = cfg.dec_weight
|
131 |
+
self.report_accuracy = cfg.report_accuracy
|
132 |
+
self.ignore_prefix_size = cfg.ignore_prefix_size
|
133 |
+
self.eps = cfg.label_smoothing
|
134 |
+
|
135 |
+
def forward(self, model, sample, reduce=True):
|
136 |
+
net_output = model(**sample["net_input"])
|
137 |
+
lprobs = model.get_normalized_probs(
|
138 |
+
net_output, log_probs=True
|
139 |
+
).contiguous() # (T, B, C) from the encoder
|
140 |
+
|
141 |
+
if "src_lengths" in sample["net_input"]:
|
142 |
+
input_lengths = sample["net_input"]["src_lengths"]
|
143 |
+
else:
|
144 |
+
if net_output["padding_mask"] is not None:
|
145 |
+
non_padding_mask = ~net_output["padding_mask"]
|
146 |
+
input_lengths = non_padding_mask.long().sum(-1)
|
147 |
+
else:
|
148 |
+
input_lengths = lprobs.new_full(
|
149 |
+
(lprobs.size(1),), lprobs.size(0), dtype=torch.long
|
150 |
+
)
|
151 |
+
|
152 |
+
pad_mask = (sample["target"] != self.pad_idx) & (
|
153 |
+
sample["target"] != self.eos_idx
|
154 |
+
)
|
155 |
+
targets_flat = sample["target"].masked_select(pad_mask)
|
156 |
+
if "target_lengths" in sample:
|
157 |
+
target_lengths = sample["target_lengths"]
|
158 |
+
else:
|
159 |
+
target_lengths = pad_mask.sum(-1)
|
160 |
+
|
161 |
+
with torch.backends.cudnn.flags(enabled=False):
|
162 |
+
loss = F.ctc_loss(
|
163 |
+
lprobs,
|
164 |
+
targets_flat,
|
165 |
+
input_lengths,
|
166 |
+
target_lengths,
|
167 |
+
blank=self.blank_idx,
|
168 |
+
reduction="sum",
|
169 |
+
zero_infinity=self.zero_infinity,
|
170 |
+
)
|
171 |
+
|
172 |
+
ntokens = (
|
173 |
+
sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
|
174 |
+
)
|
175 |
+
|
176 |
+
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
|
177 |
+
|
178 |
+
logging_output = {}
|
179 |
+
if "decoder_target" in sample:
|
180 |
+
if net_output["decoder_out"] is not None:
|
181 |
+
dec_sample_size = sample["target"].size(0) if self.sentence_avg else sample["dec_ntokens"]
|
182 |
+
dec_loss, dec_nll_loss = self.compute_ce_loss(model, net_output["decoder_out"], sample, reduce=reduce)
|
183 |
+
logging_output["ctc_loss"] = loss.item()
|
184 |
+
loss = (1 - self.dec_weight) * loss + (self.dec_weight * dec_loss * sample_size / dec_sample_size)
|
185 |
+
logging_output["dec_loss"] = dec_loss.item()
|
186 |
+
logging_output["dec_nll_loss"] = dec_nll_loss.item()
|
187 |
+
logging_output["dec_sample_size"] = dec_sample_size
|
188 |
+
|
189 |
+
if self.report_accuracy:
|
190 |
+
n_correct, total = self.compute_accuracy(model, net_output["decoder_out"], sample)
|
191 |
+
logging_output["dec_n_correct"] = utils.item(n_correct.data)
|
192 |
+
logging_output["total"] = utils.item(total.data)
|
193 |
+
else:
|
194 |
+
logging_output["ctc_loss"] = loss.item()
|
195 |
+
loss = (1 - self.dec_weight) * loss
|
196 |
+
logging_output["dec_loss"] = 0
|
197 |
+
logging_output["dec_nll_loss"] = 0
|
198 |
+
logging_output["dec_sample_size"] = 1
|
199 |
+
if self.report_accuracy:
|
200 |
+
logging_output["dec_n_correct"] = 0
|
201 |
+
logging_output["total"] = 1
|
202 |
+
|
203 |
+
logging_output = {
|
204 |
+
"loss": utils.item(loss.data), # * sample['ntokens'],
|
205 |
+
"ntokens": ntokens,
|
206 |
+
"nsentences": sample["id"].numel(),
|
207 |
+
"sample_size": sample_size,
|
208 |
+
**logging_output,
|
209 |
+
}
|
210 |
+
|
211 |
+
if not model.training and self.dec_weight < 1.0:
|
212 |
+
import editdistance
|
213 |
+
|
214 |
+
with torch.no_grad():
|
215 |
+
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
|
216 |
+
|
217 |
+
c_err = 0
|
218 |
+
c_len = 0
|
219 |
+
w_errs = 0
|
220 |
+
w_len = 0
|
221 |
+
wv_errs = 0
|
222 |
+
for lp, t, inp_l in zip(
|
223 |
+
lprobs_t,
|
224 |
+
sample["target_label"]
|
225 |
+
if "target_label" in sample
|
226 |
+
else sample["target"],
|
227 |
+
input_lengths,
|
228 |
+
):
|
229 |
+
lp = lp[:inp_l].unsqueeze(0)
|
230 |
+
|
231 |
+
decoded = None
|
232 |
+
if self.w2l_decoder is not None:
|
233 |
+
decoded = self.w2l_decoder.decode(lp)
|
234 |
+
if len(decoded) < 1:
|
235 |
+
decoded = None
|
236 |
+
else:
|
237 |
+
decoded = decoded[0]
|
238 |
+
if len(decoded) < 1:
|
239 |
+
decoded = None
|
240 |
+
else:
|
241 |
+
decoded = decoded[0]
|
242 |
+
|
243 |
+
p = (t != self.task.target_dictionary.pad()) & (
|
244 |
+
t != self.task.target_dictionary.eos()
|
245 |
+
)
|
246 |
+
targ = t[p]
|
247 |
+
targ_units = self.task.target_dictionary.string(targ)
|
248 |
+
targ_units_arr = targ.tolist()
|
249 |
+
|
250 |
+
toks = lp.argmax(dim=-1).unique_consecutive()
|
251 |
+
pred_units_arr = toks[toks != self.blank_idx].tolist()
|
252 |
+
|
253 |
+
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
|
254 |
+
c_len += len(targ_units_arr)
|
255 |
+
|
256 |
+
targ_words = post_process(targ_units, self.post_process).split()
|
257 |
+
|
258 |
+
pred_units = self.task.target_dictionary.string(pred_units_arr)
|
259 |
+
pred_words_raw = post_process(pred_units, self.post_process).split()
|
260 |
+
|
261 |
+
if decoded is not None and "words" in decoded:
|
262 |
+
pred_words = decoded["words"]
|
263 |
+
w_errs += editdistance.eval(pred_words, targ_words)
|
264 |
+
wv_errs += editdistance.eval(pred_words_raw, targ_words)
|
265 |
+
else:
|
266 |
+
dist = editdistance.eval(pred_words_raw, targ_words)
|
267 |
+
w_errs += dist
|
268 |
+
wv_errs += dist
|
269 |
+
|
270 |
+
w_len += len(targ_words)
|
271 |
+
|
272 |
+
logging_output["wv_errors"] = wv_errs
|
273 |
+
logging_output["w_errors"] = w_errs
|
274 |
+
logging_output["w_total"] = w_len
|
275 |
+
logging_output["c_errors"] = c_err
|
276 |
+
logging_output["c_total"] = c_len
|
277 |
+
|
278 |
+
return loss, sample_size, logging_output
|
279 |
+
|
280 |
+
def compute_ce_loss(self, model, net_output, sample, reduce=True):
|
281 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
282 |
+
loss, nll_loss = label_smoothed_nll_loss(
|
283 |
+
lprobs,
|
284 |
+
target,
|
285 |
+
self.eps,
|
286 |
+
ignore_index=self.pad_idx,
|
287 |
+
reduce=reduce,
|
288 |
+
)
|
289 |
+
return loss, nll_loss
|
290 |
+
|
291 |
+
def compute_accuracy(self, model, net_output, sample):
|
292 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
293 |
+
mask = target.ne(self.pad_idx)
|
294 |
+
n_correct = torch.sum(
|
295 |
+
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
296 |
+
)
|
297 |
+
total = torch.sum(mask)
|
298 |
+
return n_correct, total
|
299 |
+
|
300 |
+
def get_lprobs_and_target(self, model, net_output, sample):
|
301 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
302 |
+
target = sample["decoder_target"]
|
303 |
+
if self.ignore_prefix_size > 0:
|
304 |
+
if getattr(lprobs, "batch_first", False):
|
305 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
306 |
+
target = target[:, self.ignore_prefix_size :].contiguous()
|
307 |
+
else:
|
308 |
+
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
|
309 |
+
target = target[self.ignore_prefix_size :, :].contiguous()
|
310 |
+
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
|
311 |
+
|
312 |
+
|
313 |
+
@staticmethod
|
314 |
+
def reduce_metrics(logging_outputs) -> None:
|
315 |
+
"""Aggregate logging outputs from data parallel training."""
|
316 |
+
|
317 |
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
318 |
+
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
319 |
+
nsentences = utils.item(
|
320 |
+
sum(log.get("nsentences", 0) for log in logging_outputs)
|
321 |
+
)
|
322 |
+
sample_size = utils.item(
|
323 |
+
sum(log.get("sample_size", 0) for log in logging_outputs)
|
324 |
+
)
|
325 |
+
|
326 |
+
metrics.log_scalar(
|
327 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
328 |
+
)
|
329 |
+
metrics.log_scalar("ntokens", ntokens)
|
330 |
+
metrics.log_scalar("nsentences", nsentences)
|
331 |
+
if sample_size != ntokens:
|
332 |
+
metrics.log_scalar(
|
333 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
334 |
+
)
|
335 |
+
|
336 |
+
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
|
337 |
+
metrics.log_scalar("_c_errors", c_errors)
|
338 |
+
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
|
339 |
+
metrics.log_scalar("_c_total", c_total)
|
340 |
+
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
|
341 |
+
metrics.log_scalar("_w_errors", w_errors)
|
342 |
+
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
|
343 |
+
metrics.log_scalar("_wv_errors", wv_errors)
|
344 |
+
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
|
345 |
+
metrics.log_scalar("_w_total", w_total)
|
346 |
+
|
347 |
+
if c_total > 0:
|
348 |
+
metrics.log_derived(
|
349 |
+
"uer",
|
350 |
+
lambda meters: safe_round(
|
351 |
+
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
|
352 |
+
)
|
353 |
+
if meters["_c_total"].sum > 0
|
354 |
+
else float("nan"),
|
355 |
+
)
|
356 |
+
if w_total > 0:
|
357 |
+
metrics.log_derived(
|
358 |
+
"wer",
|
359 |
+
lambda meters: safe_round(
|
360 |
+
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
361 |
+
)
|
362 |
+
if meters["_w_total"].sum > 0
|
363 |
+
else float("nan"),
|
364 |
+
)
|
365 |
+
metrics.log_derived(
|
366 |
+
"raw_wer",
|
367 |
+
lambda meters: safe_round(
|
368 |
+
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
369 |
+
)
|
370 |
+
if meters["_w_total"].sum > 0
|
371 |
+
else float("nan"),
|
372 |
+
)
|
373 |
+
|
374 |
+
if "dec_loss" in logging_outputs[0]:
|
375 |
+
ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
|
376 |
+
dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
|
377 |
+
dec_nll_loss_sum = sum(log.get("dec_nll_loss", 0) for log in logging_outputs)
|
378 |
+
dec_sample_size = sum(log.get("dec_sample_size", 0) for log in logging_outputs)
|
379 |
+
metrics.log_scalar(
|
380 |
+
"dec_loss", dec_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
|
381 |
+
)
|
382 |
+
metrics.log_scalar(
|
383 |
+
"ctc_loss", ctc_loss_sum / sample_size / math.log(2), sample_size, round=3
|
384 |
+
)
|
385 |
+
metrics.log_scalar(
|
386 |
+
"dec_nll_loss", dec_nll_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
|
387 |
+
)
|
388 |
+
metrics.log_derived(
|
389 |
+
"dec_ppl", lambda meters: utils.get_perplexity(meters["dec_nll_loss"].avg)
|
390 |
+
)
|
391 |
+
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
392 |
+
if total > 0:
|
393 |
+
metrics.log_scalar("total", total)
|
394 |
+
n_correct = utils.item(
|
395 |
+
sum(log.get("dec_n_correct", 0) for log in logging_outputs)
|
396 |
+
)
|
397 |
+
metrics.log_scalar("dec_n_correct", n_correct)
|
398 |
+
metrics.log_derived(
|
399 |
+
"dec_accuracy",
|
400 |
+
lambda meters: round(
|
401 |
+
meters["dec_n_correct"].sum * 100.0 / meters["total"].sum, 3
|
402 |
+
)
|
403 |
+
if meters["total"].sum > 0
|
404 |
+
else float("nan"),
|
405 |
+
)
|
406 |
+
|
407 |
+
@staticmethod
|
408 |
+
def logging_outputs_can_be_summed() -> bool:
|
409 |
+
"""
|
410 |
+
Whether the logging outputs returned by `forward` can be summed
|
411 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
412 |
+
to True will improves distributed training speed.
|
413 |
+
"""
|
414 |
+
return True
|
SpeechT5/Speech2S/speech2s/criterions/speechut_criterion.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ----------------------------------------------------------------------------
|
2 |
+
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
|
4 |
+
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
|
5 |
+
#
|
6 |
+
# Copyright (c) 2022 Microsoft
|
7 |
+
# Licensed under The MIT License [see LICENSE for details]
|
8 |
+
# ----------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import math
|
12 |
+
import re
|
13 |
+
from dataclasses import dataclass, field
|
14 |
+
from typing import List, Optional
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from fairseq import metrics, utils
|
20 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
21 |
+
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
|
22 |
+
from fairseq.dataclass import FairseqDataclass
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class SpeechUTCriterionConfig(FairseqDataclass):
|
28 |
+
pred_masked_weight: float = field(
|
29 |
+
default=1.0,
|
30 |
+
metadata={"help": "weight for predictive loss for masked frames"},
|
31 |
+
)
|
32 |
+
pred_nomask_weight: float = field(
|
33 |
+
default=0.0,
|
34 |
+
metadata={"help": "weight for predictive loss for unmasked frames"},
|
35 |
+
)
|
36 |
+
loss_weights: Optional[List[float]] = field(
|
37 |
+
default=None,
|
38 |
+
metadata={"help": "weights for additional loss terms (not first one)"},
|
39 |
+
)
|
40 |
+
log_keys: List[str] = field(
|
41 |
+
default_factory=lambda: [],
|
42 |
+
metadata={"help": "output keys to log"},
|
43 |
+
)
|
44 |
+
u2t_ed_weight: float = field(
|
45 |
+
default=0.1,
|
46 |
+
metadata={"help": "weights for text ED Loss, loss will be (hubert_loss + text_mum_weight * MUM_Loss + u2t_ed_weight * CE_Loss + u2t_ctc_weight * CTC_loss)"},
|
47 |
+
)
|
48 |
+
u2t_ctc_weight: float = field(
|
49 |
+
default=0.0,
|
50 |
+
metadata={"help": "weights for text ED Loss, loss will be (hubert_loss + text_mum_weight * MUM_Loss + u2t_ed_weight * CE_Loss + u2t_ctc_weight * CTC_loss)"},
|
51 |
+
)
|
52 |
+
text_mum_weight: float = field(
|
53 |
+
default=0.0,
|
54 |
+
metadata={"help": "masked unit modeling weight from the text end"},
|
55 |
+
)
|
56 |
+
report_accuracy: bool = field(
|
57 |
+
default=True,
|
58 |
+
metadata={"help": "report decoder accuracy metric"},
|
59 |
+
)
|
60 |
+
ignore_prefix_size: int = field(
|
61 |
+
default=0,
|
62 |
+
metadata={"help": "Ignore first N tokens"},
|
63 |
+
)
|
64 |
+
label_smoothing: float = field(
|
65 |
+
default=0.0,
|
66 |
+
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
67 |
+
)
|
68 |
+
no_ctc_blank: bool = field(
|
69 |
+
default=False,
|
70 |
+
metadata={"help": "mask out the blank of ctc, only when dec_loss_type=ctc"},
|
71 |
+
)
|
72 |
+
label_smoothing: float = field(
|
73 |
+
default=0.0,
|
74 |
+
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
75 |
+
)
|
76 |
+
|
77 |
+
@register_criterion("speechut_criterion", dataclass=SpeechUTCriterionConfig)
|
78 |
+
class SpeechUTCriterion(FairseqCriterion):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
task,
|
82 |
+
pred_masked_weight,
|
83 |
+
pred_nomask_weight,
|
84 |
+
loss_weights=None,
|
85 |
+
log_keys=None,
|
86 |
+
u2t_ed_weight=0.1,
|
87 |
+
u2t_ctc_weight=0,
|
88 |
+
text_mum_weight=0,
|
89 |
+
report_accuracy=False,
|
90 |
+
ignore_prefix_size=0,
|
91 |
+
label_smoothing=0,
|
92 |
+
no_ctc_blank=False,
|
93 |
+
):
|
94 |
+
super().__init__(task)
|
95 |
+
self.pred_masked_weight = pred_masked_weight
|
96 |
+
self.pred_nomask_weight = pred_nomask_weight
|
97 |
+
self.loss_weights = loss_weights
|
98 |
+
self.log_keys = [] if log_keys is None else log_keys
|
99 |
+
self.u2t_ed_weight = u2t_ed_weight
|
100 |
+
self.u2t_ctc_weight = u2t_ctc_weight
|
101 |
+
self.text_mum_weight = text_mum_weight
|
102 |
+
self.report_accuracy = report_accuracy
|
103 |
+
self.ignore_prefix_size = ignore_prefix_size
|
104 |
+
self.eps = label_smoothing
|
105 |
+
self.no_ctc_blank = no_ctc_blank
|
106 |
+
self.padding_idx = task.dictionaries[0].pad()
|
107 |
+
self.eos_idx = task.dictionaries[0].eos()
|
108 |
+
self.blank_idx = task.dictionaries[0].bos()
|
109 |
+
|
110 |
+
def compute_hubert_loss(self, model, net_output, reduction, preffix='', suffix=''):
|
111 |
+
loss = 0
|
112 |
+
sample_size = []
|
113 |
+
logging_output = {}
|
114 |
+
loss_m_list = []
|
115 |
+
logp_m_list = model.get_logits(net_output, True)
|
116 |
+
targ_m_list = model.get_targets(net_output, True)
|
117 |
+
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
|
118 |
+
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
|
119 |
+
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
|
120 |
+
loss_m_list.append(loss_m)
|
121 |
+
logging_output[f"{preffix}loss_m_{i}"] = loss_m.detach().item()
|
122 |
+
if self.pred_masked_weight > 0:
|
123 |
+
loss += self.pred_masked_weight * sum(loss_m_list)
|
124 |
+
sample_size.append(targ_m_list[0].numel())
|
125 |
+
|
126 |
+
loss_u_list = []
|
127 |
+
logp_u_list = model.get_logits(net_output, False)
|
128 |
+
targ_u_list = model.get_targets(net_output, False)
|
129 |
+
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
|
130 |
+
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
|
131 |
+
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
|
132 |
+
loss_u_list.append(loss_u)
|
133 |
+
logging_output[f"{preffix}loss_u_{i}"] = loss_u.detach().item()
|
134 |
+
if self.pred_nomask_weight > 0:
|
135 |
+
loss += self.pred_nomask_weight * sum(loss_u_list)
|
136 |
+
sample_size.append(targ_u_list[0].numel())
|
137 |
+
|
138 |
+
sample_size = np.mean(sample_size)
|
139 |
+
|
140 |
+
def compute_correct(logits, targets):
|
141 |
+
if logits.numel() == 0:
|
142 |
+
return 0, 0
|
143 |
+
else:
|
144 |
+
assert logits.dim() > 1, logits.shape
|
145 |
+
max = logits.argmax(-1) == targets
|
146 |
+
min = logits.argmin(-1) == targets
|
147 |
+
both = max & min
|
148 |
+
corr = max.long().sum().item() - both.long().sum().item()
|
149 |
+
count = max.numel()
|
150 |
+
return corr, count
|
151 |
+
|
152 |
+
with torch.no_grad():
|
153 |
+
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
|
154 |
+
corr_m, count_m = compute_correct(logp_m, targ_m)
|
155 |
+
logging_output[f"correct_m_{i}{suffix}"] = corr_m
|
156 |
+
logging_output[f"count_m_{i}{suffix}"] = count_m
|
157 |
+
|
158 |
+
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
|
159 |
+
corr_u, count_u = compute_correct(logp_u, targ_u)
|
160 |
+
logging_output[f"correct_u_{i}{suffix}"] = corr_u
|
161 |
+
logging_output[f"count_u_{i}{suffix}"] = count_u
|
162 |
+
|
163 |
+
return loss, sample_size, logging_output
|
164 |
+
|
165 |
+
|
166 |
+
def forward(self, model, sample, reduce=True, log_pred=False):
|
167 |
+
"""Compute the loss for the given sample.
|
168 |
+
Returns a tuple with three elements:
|
169 |
+
1) the loss
|
170 |
+
2) the sample size, which is used as the denominator for the gradient
|
171 |
+
3) logging outputs to display while training
|
172 |
+
"""
|
173 |
+
reduction = "sum" if reduce else "none"
|
174 |
+
|
175 |
+
if "net_input" in sample:
|
176 |
+
unit_sample = text_sample = None
|
177 |
+
else:
|
178 |
+
unit_sample = sample.get("text_mono", None)
|
179 |
+
text_sample = sample.get("text_paired", None)
|
180 |
+
assert unit_sample is not None or text_sample is not None
|
181 |
+
sample = sample.get("speech")
|
182 |
+
|
183 |
+
### 1. S2U: do hubert forward and loss computation
|
184 |
+
sample["modality"] = "speech"
|
185 |
+
net_output = model(target_list=sample["target_list"], **sample["net_input"])
|
186 |
+
loss, sample_size, logging_output = self.compute_hubert_loss(
|
187 |
+
model,
|
188 |
+
net_output,
|
189 |
+
reduction,
|
190 |
+
)
|
191 |
+
if self.loss_weights is not None:
|
192 |
+
assert hasattr(model, "get_extra_losses")
|
193 |
+
extra_losses, names = model.get_extra_losses(net_output)
|
194 |
+
if torch.is_tensor(extra_losses):
|
195 |
+
extra_losses = [extra_losses]
|
196 |
+
names = [names]
|
197 |
+
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
198 |
+
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
199 |
+
assert len(extra_losses) == len(
|
200 |
+
self.loss_weights
|
201 |
+
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
202 |
+
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
203 |
+
if coef != 0 and p is not None:
|
204 |
+
p = coef * p.float() * sample_size
|
205 |
+
loss += p
|
206 |
+
logging_output[f"loss_{n}"] = p.item()
|
207 |
+
for lk in self.log_keys:
|
208 |
+
if lk in net_output:
|
209 |
+
logging_output[lk] = float((net_output[lk]))
|
210 |
+
|
211 |
+
### 2. do text U2T forward and loss computation
|
212 |
+
if text_sample is not None and (self.u2t_ctc_weight + self.u2t_ed_weight) > 0:
|
213 |
+
## 2.1 re-loading "target_list", in default case, target_list = [src_tokens],
|
214 |
+
## while in case of using "unit-phone-char" structure, target_list will be [ref_tokens]
|
215 |
+
text_sample["net_input"]["target_list"] = [
|
216 |
+
text_sample.get("ref_tokens", text_sample["net_input"]["src_tokens"].clone()),
|
217 |
+
]
|
218 |
+
text_net_output = model(**text_sample["net_input"])
|
219 |
+
text_sample_size = text_sample["ntokens"]
|
220 |
+
|
221 |
+
### 2.1 U2T_UCTC
|
222 |
+
if self.u2t_ctc_weight > 0:
|
223 |
+
text_ctc_loss = self.compute_ctc_loss(model, text_net_output, text_sample["target"], reduction=reduction)
|
224 |
+
loss += self.u2t_ctc_weight * text_ctc_loss * sample_size / text_sample_size
|
225 |
+
logging_output["text_ctc_loss"] = utils.item(text_ctc_loss)
|
226 |
+
logging_output["text_sample_size"] = text_sample_size
|
227 |
+
|
228 |
+
### 2.2 U2T_ED
|
229 |
+
if self.u2t_ed_weight > 0:
|
230 |
+
text_dec_loss, text_dec_nll_loss = self.compute_ce_loss(model, text_net_output["decoder_out"], text_sample, reduce=reduce)
|
231 |
+
loss += self.u2t_ed_weight * text_dec_loss * sample_size / text_sample_size
|
232 |
+
logging_output["text_dec_loss"] = utils.item(text_dec_loss)
|
233 |
+
logging_output["text_dec_nll_loss"] = utils.item(text_dec_nll_loss)
|
234 |
+
logging_output["text_sample_size"] = text_sample_size
|
235 |
+
if self.report_accuracy:
|
236 |
+
n_correct, total = self.compute_accuracy(model, text_net_output["decoder_out"], text_sample)
|
237 |
+
logging_output["correct_text_dec"] = utils.item(n_correct.data)
|
238 |
+
logging_output["count_text_dec"] = utils.item(total.data)
|
239 |
+
|
240 |
+
### 3. do unit MUM forward and loss computation
|
241 |
+
if unit_sample is not None and self.text_mum_weight > 0:
|
242 |
+
src_tokens = unit_sample["net_input"]["src_tokens"]
|
243 |
+
target = unit_sample.get("target", None)
|
244 |
+
target = src_tokens.clone() if target is None else target
|
245 |
+
unit_net_output = model.forward_mum(src_tokens, target)
|
246 |
+
loss_num, sample_size_mum, logging_output_mum = self.compute_hubert_loss(
|
247 |
+
model,
|
248 |
+
unit_net_output,
|
249 |
+
reduction,
|
250 |
+
preffix="mum_",
|
251 |
+
suffix="_mum",
|
252 |
+
)
|
253 |
+
loss += self.text_mum_weight * loss_num * sample_size / sample_size_mum
|
254 |
+
logging_output["unit_sample_size"] = sample_size_mum
|
255 |
+
logging_output.update(logging_output_mum)
|
256 |
+
|
257 |
+
logging_output = {
|
258 |
+
"loss": utils.item(loss) if reduce else loss,
|
259 |
+
"ntokens": sample_size,
|
260 |
+
"nsentences": sample["id"].numel() + (text_sample["id"].numel() if text_sample is not None else 0),
|
261 |
+
"sample_size": sample_size,
|
262 |
+
**logging_output,
|
263 |
+
}
|
264 |
+
|
265 |
+
return loss, sample_size, logging_output
|
266 |
+
|
267 |
+
def compute_ctc_loss(self, model, net_output, target, reduction):
|
268 |
+
logits = net_output["encoder_out_ctc"][0] # (T, B, C) from the code-encoder
|
269 |
+
if self.no_ctc_blank:
|
270 |
+
## set prob of <blank> to -inf
|
271 |
+
logits = logits.float()
|
272 |
+
logits[:, :, self.blank_idx] = -1000000.0
|
273 |
+
|
274 |
+
lprobs = F.log_softmax(logits.float(), dim=-1)
|
275 |
+
|
276 |
+
encoder_padding_mask = net_output["encoder_padding_mask"][0]
|
277 |
+
non_padding_mask = ~encoder_padding_mask
|
278 |
+
input_lengths = non_padding_mask.long().sum(-1)
|
279 |
+
pad_mask = (target != self.padding_idx) & (target != self.eos_idx)
|
280 |
+
targets_flat = target.masked_select(pad_mask)
|
281 |
+
target_lengths = pad_mask.sum(-1)
|
282 |
+
|
283 |
+
with torch.backends.cudnn.flags(enabled=False):
|
284 |
+
loss = F.ctc_loss(
|
285 |
+
lprobs,
|
286 |
+
targets_flat,
|
287 |
+
input_lengths,
|
288 |
+
target_lengths,
|
289 |
+
blank=self.blank_idx,
|
290 |
+
reduction=reduction,
|
291 |
+
zero_infinity=True,
|
292 |
+
)
|
293 |
+
return loss
|
294 |
+
|
295 |
+
def compute_ce_loss(self, model, net_output, sample, reduce=True):
|
296 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
297 |
+
loss, nll_loss = label_smoothed_nll_loss(
|
298 |
+
lprobs,
|
299 |
+
target,
|
300 |
+
self.eps,
|
301 |
+
ignore_index=self.padding_idx,
|
302 |
+
reduce=reduce,
|
303 |
+
)
|
304 |
+
return loss, nll_loss
|
305 |
+
|
306 |
+
def compute_accuracy(self, model, net_output, sample):
|
307 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
308 |
+
mask = target.ne(self.padding_idx)
|
309 |
+
n_correct = torch.sum(
|
310 |
+
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
311 |
+
)
|
312 |
+
total = torch.sum(mask)
|
313 |
+
return n_correct, total
|
314 |
+
|
315 |
+
def get_lprobs_and_target(self, model, net_output, sample):
|
316 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
317 |
+
target = sample["target"]
|
318 |
+
|
319 |
+
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
|
320 |
+
|
321 |
+
@staticmethod
|
322 |
+
def reduce_metrics(logging_outputs) -> None:
|
323 |
+
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
|
324 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
325 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
326 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
327 |
+
|
328 |
+
metrics.log_scalar(
|
329 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
330 |
+
)
|
331 |
+
if sample_size != ntokens:
|
332 |
+
metrics.log_scalar(
|
333 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
334 |
+
)
|
335 |
+
metrics.log_derived(
|
336 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
337 |
+
)
|
338 |
+
else:
|
339 |
+
metrics.log_derived(
|
340 |
+
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
341 |
+
)
|
342 |
+
|
343 |
+
counts = {}
|
344 |
+
for lk in logging_outputs[0].keys():
|
345 |
+
if lk.startswith("count_"):
|
346 |
+
val = sum(log.get(lk, 0) for log in logging_outputs)
|
347 |
+
metrics.log_scalar(lk, val)
|
348 |
+
counts[lk] = val
|
349 |
+
|
350 |
+
for lk in logging_outputs[0].keys():
|
351 |
+
if lk.startswith("loss_"):
|
352 |
+
val = sum(log.get(lk, 0) for log in logging_outputs)
|
353 |
+
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
|
354 |
+
elif lk.startswith("correct_"):
|
355 |
+
val = sum(log.get(lk, 0) for log in logging_outputs)
|
356 |
+
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
|
357 |
+
|
358 |
+
if "text_sample_size" in logging_outputs[0]:
|
359 |
+
text_sample_size = sum(log.get("text_sample_size", 0) for log in logging_outputs)
|
360 |
+
for lk in logging_outputs[0].keys():
|
361 |
+
if lk.startswith("text_") and lk.endswith("_loss"):
|
362 |
+
val = sum(log.get(lk, 0) for log in logging_outputs)
|
363 |
+
metrics.log_scalar(lk, val / text_sample_size / math.log(2), round=3)
|
364 |
+
|
365 |
+
if "unit_sample_size" in logging_outputs[0]:
|
366 |
+
unit_sample_size = sum(log.get("unit_sample_size", 0) for log in logging_outputs)
|
367 |
+
for lk in logging_outputs[0].keys():
|
368 |
+
if lk.startswith("mum_loss_"):
|
369 |
+
val = sum(log.get(lk, 0) for log in logging_outputs)
|
370 |
+
metrics.log_scalar(lk, val / unit_sample_size / math.log(2), round=3)
|
371 |
+
|
372 |
+
@staticmethod
|
373 |
+
def aggregate_logging_outputs(logging_outputs):
|
374 |
+
"""Aggregate logging outputs from data parallel training."""
|
375 |
+
raise NotImplementedError()
|
376 |
+
|
377 |
+
@staticmethod
|
378 |
+
def logging_outputs_can_be_summed() -> bool:
|
379 |
+
"""
|
380 |
+
Whether the logging outputs returned by `forward` can be summed
|
381 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
382 |
+
to True will improves distributed training speed.
|
383 |
+
"""
|
384 |
+
return False
|
SpeechT5/Speech2S/speech2s/data/concat_dataset.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 Microsoft
|
3 |
+
# Licensed under The MIT License [see LICENSE for details]
|
4 |
+
# Based on fairseq code bases
|
5 |
+
# https://github.com/facebookresearch/fairseq
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import bisect
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from torch.utils.data.dataloader import default_collate
|
12 |
+
|
13 |
+
from fairseq.data import FairseqDataset
|
14 |
+
|
15 |
+
|
16 |
+
class ConcatDataset(FairseqDataset):
|
17 |
+
@staticmethod
|
18 |
+
def cumsum(sequence, sample_ratios):
|
19 |
+
r, s = [], 0
|
20 |
+
for e, ratio in zip(sequence, sample_ratios):
|
21 |
+
curr_len = int(ratio * len(e))
|
22 |
+
r.append(curr_len + s)
|
23 |
+
s += curr_len
|
24 |
+
return r
|
25 |
+
|
26 |
+
def __init__(self, datasets, sample_ratios=1):
|
27 |
+
super(ConcatDataset, self).__init__()
|
28 |
+
assert len(datasets) > 0, "datasets should not be an empty iterable"
|
29 |
+
self.datasets = list(datasets)
|
30 |
+
if isinstance(sample_ratios, int):
|
31 |
+
sample_ratios = [sample_ratios] * len(self.datasets)
|
32 |
+
self.sample_ratios = sample_ratios
|
33 |
+
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
|
34 |
+
self.real_sizes = [len(d) for d in self.datasets]
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return self.cumulative_sizes[-1]
|
38 |
+
|
39 |
+
def __getitem__(self, idx):
|
40 |
+
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
41 |
+
return self.datasets[dataset_idx][sample_idx]
|
42 |
+
|
43 |
+
def _get_dataset_and_sample_index(self, idx: int):
|
44 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
45 |
+
if dataset_idx == 0:
|
46 |
+
sample_idx = idx
|
47 |
+
else:
|
48 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
49 |
+
sample_idx = sample_idx % self.real_sizes[dataset_idx]
|
50 |
+
return dataset_idx, sample_idx
|
51 |
+
|
52 |
+
def collater(self, samples, **extra_args):
|
53 |
+
# For now only supports datasets with same underlying collater implementations
|
54 |
+
if hasattr(self.datasets[0], "collater"):
|
55 |
+
return self.datasets[0].collater(samples, **extra_args)
|
56 |
+
else:
|
57 |
+
return default_collate(samples, **extra_args)
|
58 |
+
|
59 |
+
def size(self, idx: int):
|
60 |
+
"""
|
61 |
+
Return an example's size as a float or tuple.
|
62 |
+
"""
|
63 |
+
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
64 |
+
return self.datasets[dataset_idx].size(sample_idx)
|
65 |
+
|
66 |
+
def num_tokens(self, index: int):
|
67 |
+
return np.max(self.size(index))
|
68 |
+
|
69 |
+
def attr(self, attr: str, index: int):
|
70 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
|
71 |
+
return getattr(self.datasets[dataset_idx], attr, None)
|
72 |
+
|
73 |
+
@property
|
74 |
+
def sizes(self):
|
75 |
+
_dataset_sizes = []
|
76 |
+
for ds, sr in zip(self.datasets, self.sample_ratios):
|
77 |
+
if isinstance(ds.sizes, np.ndarray):
|
78 |
+
_dataset_sizes.append(np.tile(ds.sizes, sr))
|
79 |
+
else:
|
80 |
+
# Only support underlying dataset with single size array.
|
81 |
+
assert isinstance(ds.sizes, list)
|
82 |
+
_dataset_sizes.append(np.tile(ds.sizes[0], sr))
|
83 |
+
return np.concatenate(_dataset_sizes)
|
84 |
+
|
85 |
+
@property
|
86 |
+
def supports_prefetch(self):
|
87 |
+
return all(d.supports_prefetch for d in self.datasets)
|
88 |
+
|
89 |
+
def ordered_indices(self):
|
90 |
+
"""
|
91 |
+
Returns indices sorted by length. So less padding is needed.
|
92 |
+
"""
|
93 |
+
if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
|
94 |
+
# special handling for concatenating lang_pair_datasets
|
95 |
+
if getattr(self.datasets[0], "shuffle", False):
|
96 |
+
indices = np.random.permutation(len(self)).astype(np.int64)
|
97 |
+
else:
|
98 |
+
indices = np.arange(len(self), dtype=np.int64)
|
99 |
+
sizes = self.sizes
|
100 |
+
tgt_sizes = (
|
101 |
+
sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
|
102 |
+
)
|
103 |
+
src_sizes = (
|
104 |
+
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
|
105 |
+
)
|
106 |
+
# sort by target length, then source length
|
107 |
+
if tgt_sizes is not None:
|
108 |
+
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
|
109 |
+
return indices[np.argsort(src_sizes[indices], kind="mergesort")]
|
110 |
+
else:
|
111 |
+
return np.argsort(self.sizes)
|
112 |
+
|
113 |
+
def prefetch(self, indices):
|
114 |
+
frm = 0
|
115 |
+
for to, ds in zip(self.cumulative_sizes, self.datasets):
|
116 |
+
real_size = len(ds)
|
117 |
+
if getattr(ds, "supports_prefetch", False):
|
118 |
+
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
|
119 |
+
frm = to
|
120 |
+
|
121 |
+
@property
|
122 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
123 |
+
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
|
124 |
+
|
125 |
+
def set_epoch(self, epoch):
|
126 |
+
super().set_epoch(epoch)
|
127 |
+
for ds in self.datasets:
|
128 |
+
if hasattr(ds, "set_epoch"):
|
129 |
+
ds.set_epoch(epoch)
|
SpeechT5/Speech2S/speech2s/data/hubert_dataset.py
ADDED
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 Microsoft
|
3 |
+
# Licensed under The MIT License [see LICENSE for details]
|
4 |
+
# Based on fairseq code bases
|
5 |
+
# https://github.com/facebookresearch/fairseq
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import itertools
|
9 |
+
import logging
|
10 |
+
import io
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
import time
|
14 |
+
from pathlib import Path
|
15 |
+
from typing import Any, List, Optional, Union, Tuple
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from fairseq.data import data_utils, Dictionary
|
22 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
23 |
+
from fairseq.data.audio.audio_utils import (
|
24 |
+
read_from_stored_zip,
|
25 |
+
is_sf_audio_data,
|
26 |
+
)
|
27 |
+
|
28 |
+
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
def parse_path(path: str) -> Tuple[str, List[int]]:
|
33 |
+
"""Parse data path which is either a path to
|
34 |
+
1. a .npy/.wav/.flac/.ogg file
|
35 |
+
2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
|
36 |
+
|
37 |
+
Args:
|
38 |
+
path (str): the data path to parse
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
file_path (str): the file path
|
42 |
+
slice_ptr (list of int): empty in case 1;
|
43 |
+
byte offset and length for the slice in case 2
|
44 |
+
"""
|
45 |
+
|
46 |
+
if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
|
47 |
+
_path, slice_ptr = path, []
|
48 |
+
else:
|
49 |
+
_path, *slice_ptr = path.split(":")
|
50 |
+
if not Path(_path).is_file():
|
51 |
+
raise FileNotFoundError(f"File not found: {_path}")
|
52 |
+
assert len(slice_ptr) in {0, 1, 2}, f"Invalid path: {path}"
|
53 |
+
slice_ptr = [int(i) for i in slice_ptr]
|
54 |
+
return _path, slice_ptr
|
55 |
+
|
56 |
+
def load_audio(manifest_path, max_keep, min_keep, retry_times=5):
|
57 |
+
n_long, n_short = 0, 0
|
58 |
+
names, inds, sizes, chunk_names, chunk_indices = [], [], [], [], []
|
59 |
+
for i in range(retry_times):
|
60 |
+
with open(manifest_path) as f:
|
61 |
+
root = f.readline().strip()
|
62 |
+
for ind, line in enumerate(f):
|
63 |
+
items = line.strip().split("\t")
|
64 |
+
assert len(items) == 2, line
|
65 |
+
sz = int(items[1])
|
66 |
+
if min_keep is not None and sz < min_keep:
|
67 |
+
n_short += 1
|
68 |
+
elif max_keep is not None and sz > max_keep:
|
69 |
+
n_long += 1
|
70 |
+
else:
|
71 |
+
fname = items[0].split(":")
|
72 |
+
if len(fname) > 2:
|
73 |
+
if len(chunk_names) == 0 or fname[0] != chunk_names[-1]:
|
74 |
+
chunk_names.append(fname[0])
|
75 |
+
chunk_indices.append(len(names))
|
76 |
+
names.append(items[0])
|
77 |
+
inds.append(ind)
|
78 |
+
sizes.append(sz)
|
79 |
+
if len(names) == 0:
|
80 |
+
logger.warn(f"Fail to load manifest for the {i} time")
|
81 |
+
time.sleep(1)
|
82 |
+
continue
|
83 |
+
else:
|
84 |
+
break
|
85 |
+
tot = ind + 1
|
86 |
+
logger.info(
|
87 |
+
(
|
88 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
89 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
90 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
91 |
+
)
|
92 |
+
)
|
93 |
+
return root, names, inds, tot, sizes, chunk_names, chunk_indices
|
94 |
+
|
95 |
+
|
96 |
+
def load_label(label_path, inds, tot, retry_times=5):
|
97 |
+
for i in range(retry_times):
|
98 |
+
with open(label_path) as f:
|
99 |
+
labels = [line.rstrip() for line in f]
|
100 |
+
if len(labels) == 0:
|
101 |
+
logger.warn(f"Fail to load label for the {i} time")
|
102 |
+
time.sleep(1)
|
103 |
+
continue
|
104 |
+
else:
|
105 |
+
break
|
106 |
+
assert (
|
107 |
+
len(labels) == tot
|
108 |
+
), f"number of labels does not match ({len(labels)} != {tot})"
|
109 |
+
labels = [labels[i] for i in inds]
|
110 |
+
return labels
|
111 |
+
|
112 |
+
|
113 |
+
def load_label_offset(label_path, inds, tot, retry_times=5):
|
114 |
+
for i in range(retry_times):
|
115 |
+
with open(label_path) as f:
|
116 |
+
code_lengths = [len(line.encode("utf-8")) for line in f]
|
117 |
+
if len(code_lengths) == 0:
|
118 |
+
logger.warn(f"Fail to load label for the {i} time")
|
119 |
+
time.sleep(1)
|
120 |
+
continue
|
121 |
+
else:
|
122 |
+
break
|
123 |
+
assert (
|
124 |
+
len(code_lengths) == tot
|
125 |
+
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
126 |
+
offsets = list(itertools.accumulate([0] + code_lengths))
|
127 |
+
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
128 |
+
return offsets
|
129 |
+
|
130 |
+
|
131 |
+
def verify_label_lengths(
|
132 |
+
audio_sizes,
|
133 |
+
audio_rate,
|
134 |
+
label_path,
|
135 |
+
label_rate,
|
136 |
+
inds,
|
137 |
+
tot,
|
138 |
+
tol=0.1, # tolerance in seconds
|
139 |
+
):
|
140 |
+
if label_rate < 0:
|
141 |
+
logger.info(f"{label_path} is sequence label. skipped")
|
142 |
+
return
|
143 |
+
|
144 |
+
with open(label_path) as f:
|
145 |
+
lengths = [len(line.rstrip().split()) for line in f]
|
146 |
+
assert len(lengths) == tot
|
147 |
+
lengths = [lengths[i] for i in inds]
|
148 |
+
num_invalid = 0
|
149 |
+
for i, ind in enumerate(inds):
|
150 |
+
dur_from_audio = audio_sizes[i] / audio_rate
|
151 |
+
dur_from_label = lengths[i] / label_rate
|
152 |
+
if abs(dur_from_audio - dur_from_label) > tol:
|
153 |
+
logger.warning(
|
154 |
+
(
|
155 |
+
f"audio and label duration differ too much "
|
156 |
+
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
|
157 |
+
f"in line {ind+1} of {label_path}. Check if `label_rate` "
|
158 |
+
f"is correctly set (currently {label_rate}). "
|
159 |
+
f"num. of samples = {audio_sizes[i]}; "
|
160 |
+
f"label length = {lengths[i]}"
|
161 |
+
)
|
162 |
+
)
|
163 |
+
num_invalid += 1
|
164 |
+
if num_invalid > 0:
|
165 |
+
logger.warning(
|
166 |
+
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
class HubertDataset(FairseqDataset):
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
manifest_path: str,
|
174 |
+
sample_rate: float,
|
175 |
+
label_paths: List[str],
|
176 |
+
label_rates: Union[List[float], float], # -1 for sequence labels
|
177 |
+
pad_list: List[str],
|
178 |
+
eos_list: List[str],
|
179 |
+
label_processors: Optional[List[Any]] = None,
|
180 |
+
max_keep_sample_size: Optional[int] = None,
|
181 |
+
min_keep_sample_size: Optional[int] = None,
|
182 |
+
max_sample_size: Optional[int] = None,
|
183 |
+
shuffle: bool = True,
|
184 |
+
pad_audio: bool = False,
|
185 |
+
normalize: bool = False,
|
186 |
+
store_labels: bool = True,
|
187 |
+
random_crop: bool = False,
|
188 |
+
single_target: bool = False,
|
189 |
+
tgt_dict: Optional[Dictionary] = None,
|
190 |
+
add_decoder_target: bool = False,
|
191 |
+
fine_tuning: bool = False,
|
192 |
+
tgt_lang_idx: int = None,
|
193 |
+
tokenizer = None,
|
194 |
+
mbart_style_lang_id: bool = False,
|
195 |
+
retry_times: int = 5,
|
196 |
+
reduce_label_for_dec: bool = True,
|
197 |
+
):
|
198 |
+
self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.chunk_names, self.chunk_indices = load_audio(
|
199 |
+
manifest_path, max_keep_sample_size, min_keep_sample_size, retry_times
|
200 |
+
)
|
201 |
+
self.sample_rate = sample_rate
|
202 |
+
self.shuffle = shuffle
|
203 |
+
self.random_crop = random_crop
|
204 |
+
self.tgt_dict = tgt_dict
|
205 |
+
self.add_decoder_target = add_decoder_target
|
206 |
+
self.fine_tuning = fine_tuning
|
207 |
+
|
208 |
+
self.num_labels = len(label_paths)
|
209 |
+
self.pad_list = pad_list
|
210 |
+
self.eos_list = eos_list
|
211 |
+
self.label_processors = label_processors
|
212 |
+
self.single_target = single_target
|
213 |
+
self.epoch = 0
|
214 |
+
|
215 |
+
self.label_rates = (
|
216 |
+
[label_rates for _ in range(len(label_paths))]
|
217 |
+
if isinstance(label_rates, int)
|
218 |
+
else label_rates
|
219 |
+
)
|
220 |
+
self.store_labels = store_labels
|
221 |
+
if store_labels:
|
222 |
+
self.label_list = [load_label(p, inds, tot, retry_times) for p in label_paths]
|
223 |
+
else:
|
224 |
+
self.label_paths = label_paths
|
225 |
+
self.label_offsets_list = [
|
226 |
+
load_label_offset(p, inds, tot, retry_times) for p in label_paths
|
227 |
+
]
|
228 |
+
assert label_processors is None or len(label_processors) == self.num_labels
|
229 |
+
for label_path, label_rate in zip(label_paths, self.label_rates):
|
230 |
+
verify_label_lengths(
|
231 |
+
self.wav_sizes, sample_rate, label_path, label_rate, inds, tot
|
232 |
+
)
|
233 |
+
|
234 |
+
self.max_sample_size = (
|
235 |
+
max_sample_size if max_sample_size is not None else sys.maxsize
|
236 |
+
)
|
237 |
+
self.pad_audio = pad_audio
|
238 |
+
self.normalize = normalize
|
239 |
+
self.tgt_lang_idx = tgt_lang_idx
|
240 |
+
self.tokenizer = tokenizer
|
241 |
+
self.mbart_style_lang_id = mbart_style_lang_id
|
242 |
+
self.retry_times = retry_times
|
243 |
+
self.reduce_label_for_dec = reduce_label_for_dec
|
244 |
+
logger.info(
|
245 |
+
f"pad_audio={pad_audio}, random_crop={random_crop}, tgt_lang_idx={self.tgt_lang_idx}, reduce_label_for_dec={reduce_label_for_dec}, "
|
246 |
+
f"mbart_style_lang_id={mbart_style_lang_id}, normalize={normalize}, max_sample_size={self.max_sample_size}"
|
247 |
+
)
|
248 |
+
|
249 |
+
def set_epoch(self, epoch):
|
250 |
+
self.epoch = epoch
|
251 |
+
|
252 |
+
def batch_by_size(self, indices, max_tokens=None, max_sentences=None, required_batch_size_multiple=1):
|
253 |
+
self.max_tokens = max_tokens
|
254 |
+
self.max_sentences = max_sentences
|
255 |
+
self.required_batch_size_multiple = required_batch_size_multiple
|
256 |
+
if isinstance(indices[0], np.ndarray):
|
257 |
+
batch_list = []
|
258 |
+
for indice in indices:
|
259 |
+
batch = super(HubertDataset, self).batch_by_size(indice, max_tokens, max_sentences, required_batch_size_multiple)
|
260 |
+
batch_list.append(batch)
|
261 |
+
return batch_list
|
262 |
+
else:
|
263 |
+
return super(HubertDataset, self).batch_by_size(indices, max_tokens, max_sentences, required_batch_size_multiple)
|
264 |
+
def shuffle_batches(self, batches, seed):
|
265 |
+
if isinstance(batches[0], list):
|
266 |
+
new_batches = []
|
267 |
+
with data_utils.numpy_seed(seed):
|
268 |
+
np.random.shuffle(batches)
|
269 |
+
for batch in batches:
|
270 |
+
np.random.shuffle(batch)
|
271 |
+
new_batches.extend(batch)
|
272 |
+
return new_batches
|
273 |
+
else:
|
274 |
+
with data_utils.numpy_seed(seed):
|
275 |
+
np.random.shuffle(batches)
|
276 |
+
return batches
|
277 |
+
|
278 |
+
def get_audio(self, index):
|
279 |
+
import soundfile as sf
|
280 |
+
|
281 |
+
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
282 |
+
_path, slice_ptr = parse_path(wav_path)
|
283 |
+
if len(slice_ptr) == 1:
|
284 |
+
import kaldiio
|
285 |
+
feat = kaldiio.load_mat(wav_path)
|
286 |
+
feat = torch.from_numpy(feat).float()
|
287 |
+
if self.normalize:
|
288 |
+
with torch.no_grad():
|
289 |
+
feat = F.layer_norm(feat, feat.shape[-1])
|
290 |
+
return feat
|
291 |
+
else:
|
292 |
+
if len(slice_ptr) == 2:
|
293 |
+
byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
294 |
+
assert is_sf_audio_data(byte_data)
|
295 |
+
wav_path = io.BytesIO(byte_data)
|
296 |
+
for i in range(self.retry_times):
|
297 |
+
if i < self.retry_times - 1:
|
298 |
+
try:
|
299 |
+
wav, cur_sample_rate = sf.read(wav_path)
|
300 |
+
break
|
301 |
+
except Exception as e:
|
302 |
+
logger.warn(f"Fail to load wav for the {i} time")
|
303 |
+
logger.warn(e)
|
304 |
+
time.sleep(1)
|
305 |
+
continue
|
306 |
+
else:
|
307 |
+
wav, cur_sample_rate = sf.read(wav_path)
|
308 |
+
|
309 |
+
wav = torch.from_numpy(wav).float()
|
310 |
+
wav = self.postprocess(wav, cur_sample_rate)
|
311 |
+
return wav
|
312 |
+
|
313 |
+
def get_label(self, index, label_idx):
|
314 |
+
if self.store_labels:
|
315 |
+
label = self.label_list[label_idx][index]
|
316 |
+
else:
|
317 |
+
with open(self.label_paths[label_idx]) as f:
|
318 |
+
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
319 |
+
f.seek(offset_s)
|
320 |
+
label = f.read(offset_e - offset_s)
|
321 |
+
|
322 |
+
if self.tokenizer is not None and self.fine_tuning:
|
323 |
+
label = self.tokenizer.encode(label)
|
324 |
+
|
325 |
+
if self.label_processors is not None:
|
326 |
+
label = self.label_processors[label_idx](label)
|
327 |
+
return label
|
328 |
+
|
329 |
+
def get_labels(self, index):
|
330 |
+
return [self.get_label(index, i) for i in range(self.num_labels)]
|
331 |
+
|
332 |
+
def __getitem__(self, index):
|
333 |
+
wav = self.get_audio(index)
|
334 |
+
labels = self.get_labels(index)
|
335 |
+
return {"id": index, "source": wav, "label_list": labels}
|
336 |
+
|
337 |
+
def __len__(self):
|
338 |
+
return len(self.wav_sizes)
|
339 |
+
|
340 |
+
def crop_to_max_size(self, wav, target_size):
|
341 |
+
size = len(wav)
|
342 |
+
diff = size - target_size
|
343 |
+
if diff <= 0:
|
344 |
+
return wav, 0
|
345 |
+
|
346 |
+
start, end = 0, target_size
|
347 |
+
if self.random_crop:
|
348 |
+
start = np.random.randint(0, diff + 1)
|
349 |
+
end = size - diff + start
|
350 |
+
return wav[start:end], start
|
351 |
+
|
352 |
+
def collater(self, samples):
|
353 |
+
# target = max(sizes) -> random_crop not used
|
354 |
+
# target = max_sample_size -> random_crop used for long
|
355 |
+
samples = [s for s in samples if s["source"] is not None]
|
356 |
+
if len(samples) == 0:
|
357 |
+
return {}
|
358 |
+
|
359 |
+
audios = [s["source"] for s in samples]
|
360 |
+
audio_sizes = [len(s) for s in audios]
|
361 |
+
if self.pad_audio:
|
362 |
+
audio_size = min(max(audio_sizes), self.max_sample_size)
|
363 |
+
else:
|
364 |
+
audio_size = min(min(audio_sizes), self.max_sample_size)
|
365 |
+
feat_dim = audios[0].size(-1) if audios[0].dim() > 1 else 1
|
366 |
+
collated_audios, padding_mask, audio_starts = self.collater_audio(
|
367 |
+
audios, audio_size, feat_dim,
|
368 |
+
)
|
369 |
+
|
370 |
+
targets_by_label = [
|
371 |
+
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
|
372 |
+
]
|
373 |
+
targets_list, lengths_list, ntokens_list = self.collater_label(
|
374 |
+
targets_by_label, audio_size, audio_starts
|
375 |
+
)
|
376 |
+
|
377 |
+
if self.add_decoder_target:
|
378 |
+
if self.fine_tuning:
|
379 |
+
decoder_label = [
|
380 |
+
torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
|
381 |
+
for i in range(targets_list[0].size(0))
|
382 |
+
]
|
383 |
+
else:
|
384 |
+
if self.tokenizer is not None:
|
385 |
+
decoder_label = [
|
386 |
+
# Set 48 for translate int to char and avoid \n
|
387 |
+
torch.cat(
|
388 |
+
(
|
389 |
+
torch.tensor(
|
390 |
+
self.tokenizer.sp.Encode(
|
391 |
+
"".join(
|
392 |
+
[chr(j + 48) for j in (
|
393 |
+
targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]]
|
394 |
+
).tolist()]
|
395 |
+
), out_type=int
|
396 |
+
)
|
397 |
+
),
|
398 |
+
torch.tensor([self.tgt_dict.eos()])
|
399 |
+
), dim=0
|
400 |
+
).long()
|
401 |
+
for i in range(targets_list[0].size(0))
|
402 |
+
]
|
403 |
+
else:
|
404 |
+
decoder_label = [
|
405 |
+
torch.cat((targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
|
406 |
+
for i in range(targets_list[0].size(0))
|
407 |
+
]
|
408 |
+
|
409 |
+
if self.mbart_style_lang_id:
|
410 |
+
decoder_label = [
|
411 |
+
torch.cat((decoder_label[i], torch.tensor([self.tgt_lang_idx])), 0).long()
|
412 |
+
for i in range(targets_list[0].size(0))
|
413 |
+
]
|
414 |
+
|
415 |
+
dec_ntokens = sum(x.size(0) for x in decoder_label)
|
416 |
+
decoder_target = data_utils.collate_tokens(
|
417 |
+
decoder_label,
|
418 |
+
self.tgt_dict.pad(),
|
419 |
+
self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx,
|
420 |
+
left_pad=False,
|
421 |
+
move_eos_to_beginning=False,
|
422 |
+
)
|
423 |
+
decoder_target_lengths = torch.tensor(
|
424 |
+
[x.size(0) for x in decoder_label], dtype=torch.long
|
425 |
+
)
|
426 |
+
prev_output_tokens = data_utils.collate_tokens(
|
427 |
+
decoder_label,
|
428 |
+
self.tgt_dict.pad(),
|
429 |
+
self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx,
|
430 |
+
left_pad=False,
|
431 |
+
move_eos_to_beginning=True,
|
432 |
+
)
|
433 |
+
|
434 |
+
if self.tgt_lang_idx is not None and not self.mbart_style_lang_id:
|
435 |
+
assert (prev_output_tokens[:, 0] != self.tgt_dict.eos()).sum() == 0
|
436 |
+
prev_output_tokens[:, 0] = self.tgt_lang_idx
|
437 |
+
|
438 |
+
net_input = {
|
439 |
+
"source": collated_audios,
|
440 |
+
"padding_mask": padding_mask,
|
441 |
+
"prev_output_tokens": prev_output_tokens,
|
442 |
+
}
|
443 |
+
batch = {
|
444 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
445 |
+
"net_input": net_input,
|
446 |
+
"decoder_target": decoder_target,
|
447 |
+
"decoder_target_lengths": decoder_target_lengths,
|
448 |
+
"dec_ntokens": dec_ntokens,
|
449 |
+
"lang_idx": self.tgt_lang_idx,
|
450 |
+
}
|
451 |
+
else:
|
452 |
+
net_input = {"source": collated_audios, "padding_mask": padding_mask}
|
453 |
+
batch = {
|
454 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
455 |
+
"net_input": net_input,
|
456 |
+
}
|
457 |
+
|
458 |
+
if self.single_target:
|
459 |
+
batch["target_lengths"] = lengths_list[0]
|
460 |
+
batch["ntokens"] = ntokens_list[0]
|
461 |
+
batch["target"] = targets_list[0]
|
462 |
+
else:
|
463 |
+
batch["target_lengths_list"] = lengths_list
|
464 |
+
batch["ntokens_list"] = ntokens_list
|
465 |
+
batch["target_list"] = targets_list
|
466 |
+
return batch
|
467 |
+
|
468 |
+
def collater_audio(self, audios, audio_size, feat_dim=1):
|
469 |
+
collated_audios = audios[0].new_zeros(len(audios), audio_size, feat_dim)
|
470 |
+
padding_mask = (
|
471 |
+
torch.BoolTensor(collated_audios.shape[0:2]).fill_(False)
|
472 |
+
# if self.pad_audio else None
|
473 |
+
)
|
474 |
+
audio_starts = [0 for _ in audios]
|
475 |
+
for i, audio in enumerate(audios):
|
476 |
+
audio = audio.view(-1, feat_dim)
|
477 |
+
diff = len(audio) - audio_size
|
478 |
+
if diff == 0:
|
479 |
+
collated_audios[i] = audio
|
480 |
+
elif diff < 0:
|
481 |
+
assert self.pad_audio
|
482 |
+
collated_audios[i] = torch.cat([audio, audio.new_full((-diff, feat_dim), 0.0)])
|
483 |
+
padding_mask[i, diff:] = True
|
484 |
+
else:
|
485 |
+
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
486 |
+
audio, audio_size
|
487 |
+
)
|
488 |
+
return collated_audios.squeeze(-1), padding_mask, audio_starts
|
489 |
+
|
490 |
+
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
|
491 |
+
assert label_rate > 0
|
492 |
+
s2f = label_rate / self.sample_rate
|
493 |
+
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
494 |
+
frm_size = int(round(audio_size * s2f))
|
495 |
+
if not self.pad_audio:
|
496 |
+
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
497 |
+
frm_size = min(frm_size, *rem_size)
|
498 |
+
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
|
499 |
+
logger.debug(f"audio_starts={audio_starts}")
|
500 |
+
logger.debug(f"frame_starts={frm_starts}")
|
501 |
+
logger.debug(f"frame_size={frm_size}")
|
502 |
+
|
503 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
504 |
+
ntokens = lengths.sum().item()
|
505 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
506 |
+
return targets, lengths, ntokens
|
507 |
+
|
508 |
+
def collater_seq_label(self, targets, pad):
|
509 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
510 |
+
ntokens = lengths.sum().item()
|
511 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
512 |
+
return targets, lengths, ntokens
|
513 |
+
|
514 |
+
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
515 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
516 |
+
itr = zip(targets_by_label, self.label_rates, self.pad_list)
|
517 |
+
for targets, label_rate, pad in itr:
|
518 |
+
if label_rate == -1:
|
519 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
520 |
+
else:
|
521 |
+
targets, lengths, ntokens = self.collater_frm_label(
|
522 |
+
targets, audio_size, audio_starts, label_rate, pad
|
523 |
+
)
|
524 |
+
targets_list.append(targets)
|
525 |
+
lengths_list.append(lengths)
|
526 |
+
ntokens_list.append(ntokens)
|
527 |
+
return targets_list, lengths_list, ntokens_list
|
528 |
+
|
529 |
+
def num_tokens(self, index):
|
530 |
+
return self.size(index)
|
531 |
+
|
532 |
+
def size(self, index):
|
533 |
+
if self.pad_audio:
|
534 |
+
return self.wav_sizes[index]
|
535 |
+
return min(self.wav_sizes[index], self.max_sample_size)
|
536 |
+
|
537 |
+
@property
|
538 |
+
def sizes(self):
|
539 |
+
return np.array(self.wav_sizes)
|
540 |
+
|
541 |
+
def ordered_indices(self):
|
542 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
543 |
+
on this order."""
|
544 |
+
|
545 |
+
if self.shuffle:
|
546 |
+
if len(self.chunk_names) > 0:
|
547 |
+
logger.info(f"ordered indices for epoch {self.epoch}")
|
548 |
+
with data_utils.numpy_seed(self.epoch):
|
549 |
+
self.chunk_order = np.random.permutation(len(self.chunk_names))
|
550 |
+
chunk_count = 0
|
551 |
+
tmp_sizes = []
|
552 |
+
tmp_indices = []
|
553 |
+
indice = []
|
554 |
+
for i in self.chunk_order:
|
555 |
+
chunk_count += 1
|
556 |
+
start = self.chunk_indices[i]
|
557 |
+
end = self.chunk_indices[i+1] if i < len(self.chunk_names) - 1 else len(self)
|
558 |
+
size = list(self.sizes[start:end])
|
559 |
+
tmp_indices.extend(list(np.arange(start, end)))
|
560 |
+
tmp_sizes.extend(size)
|
561 |
+
if chunk_count % 10 == 0 or i == self.chunk_order[0]:
|
562 |
+
order = [np.random.permutation(len(tmp_indices))]
|
563 |
+
order.append(
|
564 |
+
np.minimum(
|
565 |
+
np.array(tmp_sizes),
|
566 |
+
self.max_sample_size,
|
567 |
+
)
|
568 |
+
)
|
569 |
+
sort_idx = np.lexsort(order)[::-1]
|
570 |
+
indice.append(np.array([tmp_indices[k] for k in sort_idx]))
|
571 |
+
tmp_indices = []
|
572 |
+
tmp_sizes =[]
|
573 |
+
return indice
|
574 |
+
else:
|
575 |
+
order = [np.random.permutation(len(self))]
|
576 |
+
order.append(
|
577 |
+
np.minimum(
|
578 |
+
np.array(self.sizes),
|
579 |
+
self.max_sample_size,
|
580 |
+
)
|
581 |
+
)
|
582 |
+
return np.lexsort(order)[::-1]
|
583 |
+
else:
|
584 |
+
return np.arange(len(self))
|
585 |
+
|
586 |
+
def postprocess(self, wav, cur_sample_rate):
|
587 |
+
if wav.dim() == 2:
|
588 |
+
wav = wav.mean(-1)
|
589 |
+
assert wav.dim() == 1, wav.dim()
|
590 |
+
|
591 |
+
if cur_sample_rate != self.sample_rate:
|
592 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
593 |
+
|
594 |
+
if self.normalize:
|
595 |
+
with torch.no_grad():
|
596 |
+
wav = F.layer_norm(wav, wav.shape)
|
597 |
+
return wav
|
SpeechT5/Speech2S/speech2s/data/language_trible_dataset.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 Microsoft
|
3 |
+
# Licensed under The MIT License [see LICENSE for details]
|
4 |
+
# Based on fairseq code bases
|
5 |
+
# https://github.com/facebookresearch/fairseq
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import os
|
12 |
+
import itertools
|
13 |
+
|
14 |
+
from fairseq.data import FairseqDataset, data_utils
|
15 |
+
from fairseq.data import (
|
16 |
+
AppendTokenDataset,
|
17 |
+
ConcatDataset,
|
18 |
+
PrependTokenDataset,
|
19 |
+
data_utils,
|
20 |
+
indexed_dataset,
|
21 |
+
)
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
def load_langtriple_dataset(
|
26 |
+
data_path,
|
27 |
+
split,
|
28 |
+
src,
|
29 |
+
src_dict,
|
30 |
+
ref,
|
31 |
+
ref_dict,
|
32 |
+
tgt,
|
33 |
+
tgt_dict,
|
34 |
+
combine,
|
35 |
+
dataset_impl,
|
36 |
+
upsample_primary,
|
37 |
+
left_pad_source,
|
38 |
+
left_pad_target,
|
39 |
+
max_source_positions,
|
40 |
+
max_target_positions,
|
41 |
+
prepend_bos=False,
|
42 |
+
load_alignments=False,
|
43 |
+
truncate_source=False,
|
44 |
+
append_source_id=False,
|
45 |
+
num_buckets=0,
|
46 |
+
shuffle=True,
|
47 |
+
pad_to_multiple=1,
|
48 |
+
prepend_bos_src=None,
|
49 |
+
lang_format="[{}]",
|
50 |
+
):
|
51 |
+
assert not truncate_source
|
52 |
+
def split_exists(split, src, ref, tgt, lang, data_path):
|
53 |
+
filename = os.path.join(data_path, "{}.{}-{}-{}.{}".format(split, src, ref, tgt, lang))
|
54 |
+
return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
|
55 |
+
|
56 |
+
src_datasets = []
|
57 |
+
ref_datasets = []
|
58 |
+
tgt_datasets = []
|
59 |
+
|
60 |
+
for k in itertools.count():
|
61 |
+
split_k = split + (str(k) if k > 0 else "")
|
62 |
+
|
63 |
+
# infer langcode
|
64 |
+
if split_exists(split_k, src, ref, tgt, src, data_path):
|
65 |
+
prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, src, ref, tgt))
|
66 |
+
elif split_exists(split_k, tgt, ref, src, src, data_path):
|
67 |
+
prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, tgt, ref, src))
|
68 |
+
else:
|
69 |
+
if k > 0:
|
70 |
+
break
|
71 |
+
else:
|
72 |
+
raise FileNotFoundError(
|
73 |
+
"Dataset not found: {} ({})".format(split, data_path)
|
74 |
+
)
|
75 |
+
|
76 |
+
src_dataset = data_utils.load_indexed_dataset(
|
77 |
+
prefix + src, src_dict, dataset_impl
|
78 |
+
)
|
79 |
+
src_datasets.append(src_dataset)
|
80 |
+
|
81 |
+
ref_dataset = data_utils.load_indexed_dataset(
|
82 |
+
prefix + ref, ref_dict, dataset_impl
|
83 |
+
)
|
84 |
+
ref_datasets.append(ref_dataset)
|
85 |
+
|
86 |
+
tgt_dataset = data_utils.load_indexed_dataset(
|
87 |
+
prefix + tgt, tgt_dict, dataset_impl
|
88 |
+
)
|
89 |
+
if tgt_dataset is not None:
|
90 |
+
tgt_datasets.append(tgt_dataset)
|
91 |
+
|
92 |
+
logger.info(
|
93 |
+
"{} {} {}-{}-{} {} examples".format(
|
94 |
+
data_path, split_k, src, ref, tgt, len(src_datasets[-1])
|
95 |
+
)
|
96 |
+
)
|
97 |
+
|
98 |
+
if not combine:
|
99 |
+
break
|
100 |
+
|
101 |
+
assert len(src_datasets) == len(ref_datasets)
|
102 |
+
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
|
103 |
+
|
104 |
+
if len(src_datasets) == 1:
|
105 |
+
src_dataset = src_datasets[0]
|
106 |
+
ref_dataset = ref_datasets[0]
|
107 |
+
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
|
108 |
+
else:
|
109 |
+
sample_ratios = [1] * len(src_datasets)
|
110 |
+
sample_ratios[0] = upsample_primary
|
111 |
+
src_dataset = ConcatDataset(src_datasets, sample_ratios)
|
112 |
+
ref_dataset = ConcatDataset(ref_datasets, sample_ratios)
|
113 |
+
if len(tgt_datasets) > 0:
|
114 |
+
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
|
115 |
+
else:
|
116 |
+
tgt_dataset = None
|
117 |
+
|
118 |
+
if prepend_bos:
|
119 |
+
assert hasattr(src_dict, "bos_index") and hasattr(ref_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
|
120 |
+
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
|
121 |
+
ref_dataset = PrependTokenDataset(ref_dataset, ref_dict.bos())
|
122 |
+
if tgt_dataset is not None:
|
123 |
+
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
|
124 |
+
elif prepend_bos_src is not None:
|
125 |
+
logger.info(f"prepending src bos: {prepend_bos_src}")
|
126 |
+
src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
|
127 |
+
ref_dataset = PrependTokenDataset(ref_dataset, prepend_bos_src)
|
128 |
+
|
129 |
+
eos = None
|
130 |
+
if append_source_id:
|
131 |
+
src_dataset = AppendTokenDataset(
|
132 |
+
src_dataset, src_dict.index(lang_format.format(src))
|
133 |
+
)
|
134 |
+
ref_dataset = AppendTokenDataset(
|
135 |
+
ref_dataset, ref_dict.index(lang_format.format(ref))
|
136 |
+
)
|
137 |
+
if tgt_dataset is not None:
|
138 |
+
tgt_dataset = AppendTokenDataset(
|
139 |
+
tgt_dataset, tgt_dict.index(lang_format.format(tgt))
|
140 |
+
)
|
141 |
+
eos = tgt_dict.index(lang_format.format(tgt))
|
142 |
+
|
143 |
+
align_dataset = None
|
144 |
+
if load_alignments:
|
145 |
+
align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
|
146 |
+
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
|
147 |
+
align_dataset = data_utils.load_indexed_dataset(
|
148 |
+
align_path, None, dataset_impl
|
149 |
+
)
|
150 |
+
|
151 |
+
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
|
152 |
+
return LanguageTripleDataset(
|
153 |
+
src_dataset,
|
154 |
+
src_dataset.sizes,
|
155 |
+
src_dict,
|
156 |
+
ref_dataset,
|
157 |
+
ref_dataset.sizes,
|
158 |
+
ref_dict,
|
159 |
+
tgt_dataset,
|
160 |
+
tgt_dataset_sizes,
|
161 |
+
tgt_dict,
|
162 |
+
left_pad_source=left_pad_source,
|
163 |
+
left_pad_target=left_pad_target,
|
164 |
+
align_dataset=align_dataset,
|
165 |
+
eos=eos,
|
166 |
+
num_buckets=num_buckets,
|
167 |
+
shuffle=shuffle,
|
168 |
+
pad_to_multiple=pad_to_multiple,
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
def collate(
|
173 |
+
samples,
|
174 |
+
pad_idx,
|
175 |
+
eos_idx,
|
176 |
+
left_pad_source=True,
|
177 |
+
left_pad_target=False,
|
178 |
+
input_feeding=True,
|
179 |
+
pad_to_length=None,
|
180 |
+
pad_to_multiple=1,
|
181 |
+
):
|
182 |
+
if len(samples) == 0:
|
183 |
+
return {}
|
184 |
+
|
185 |
+
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
|
186 |
+
return data_utils.collate_tokens(
|
187 |
+
[s[key] for s in samples],
|
188 |
+
pad_idx,
|
189 |
+
None,
|
190 |
+
left_pad,
|
191 |
+
move_eos_to_beginning,
|
192 |
+
pad_to_length=pad_to_length,
|
193 |
+
pad_to_multiple=pad_to_multiple,
|
194 |
+
)
|
195 |
+
|
196 |
+
def check_alignment(alignment, src_len, tgt_len):
|
197 |
+
if alignment is None or len(alignment) == 0:
|
198 |
+
return False
|
199 |
+
if (
|
200 |
+
alignment[:, 0].max().item() >= src_len - 1
|
201 |
+
or alignment[:, 1].max().item() >= tgt_len - 1
|
202 |
+
):
|
203 |
+
logger.warning("alignment size mismatch found, skipping alignment!")
|
204 |
+
return False
|
205 |
+
return True
|
206 |
+
|
207 |
+
def compute_alignment_weights(alignments):
|
208 |
+
"""
|
209 |
+
Given a tensor of shape [:, 2] containing the source-target indices
|
210 |
+
corresponding to the alignments, a weight vector containing the
|
211 |
+
inverse frequency of each target index is computed.
|
212 |
+
For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
|
213 |
+
a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
|
214 |
+
index 3 is repeated twice)
|
215 |
+
"""
|
216 |
+
align_tgt = alignments[:, 1]
|
217 |
+
_, align_tgt_i, align_tgt_c = torch.unique(
|
218 |
+
align_tgt, return_inverse=True, return_counts=True
|
219 |
+
)
|
220 |
+
align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
|
221 |
+
return 1.0 / align_weights.float()
|
222 |
+
|
223 |
+
id = torch.LongTensor([s["id"] for s in samples])
|
224 |
+
src_tokens = merge(
|
225 |
+
"source",
|
226 |
+
left_pad=left_pad_source,
|
227 |
+
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
228 |
+
)
|
229 |
+
ref_tokens = merge(
|
230 |
+
"reference",
|
231 |
+
left_pad=left_pad_source,
|
232 |
+
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
233 |
+
)
|
234 |
+
# sort by descending source length
|
235 |
+
src_lengths = torch.LongTensor(
|
236 |
+
[s["source"].ne(pad_idx).long().sum() for s in samples]
|
237 |
+
)
|
238 |
+
ref_lengths = torch.LongTensor(
|
239 |
+
[s["reference"].ne(pad_idx).long().sum() for s in samples]
|
240 |
+
)
|
241 |
+
src_lengths, sort_order = src_lengths.sort(descending=True)
|
242 |
+
id = id.index_select(0, sort_order)
|
243 |
+
src_tokens = src_tokens.index_select(0, sort_order)
|
244 |
+
ref_lengths = ref_lengths.index_select(0, sort_order)
|
245 |
+
ref_tokens = ref_tokens.index_select(0, sort_order)
|
246 |
+
|
247 |
+
prev_output_tokens = None
|
248 |
+
target = None
|
249 |
+
if samples[0].get("target", None) is not None:
|
250 |
+
target = merge(
|
251 |
+
"target",
|
252 |
+
left_pad=left_pad_target,
|
253 |
+
pad_to_length=pad_to_length["target"]
|
254 |
+
if pad_to_length is not None
|
255 |
+
else None,
|
256 |
+
)
|
257 |
+
target = target.index_select(0, sort_order)
|
258 |
+
tgt_lengths = torch.LongTensor(
|
259 |
+
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
260 |
+
).index_select(0, sort_order)
|
261 |
+
ntokens = tgt_lengths.sum().item()
|
262 |
+
|
263 |
+
if samples[0].get("prev_output_tokens", None) is not None:
|
264 |
+
prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
|
265 |
+
elif input_feeding:
|
266 |
+
# we create a shifted version of targets for feeding the
|
267 |
+
# previous output token(s) into the next decoder step
|
268 |
+
prev_output_tokens = merge(
|
269 |
+
"target",
|
270 |
+
left_pad=left_pad_target,
|
271 |
+
move_eos_to_beginning=True,
|
272 |
+
pad_to_length=pad_to_length["target"]
|
273 |
+
if pad_to_length is not None
|
274 |
+
else None,
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
ntokens = src_lengths.sum().item()
|
278 |
+
|
279 |
+
batch = {
|
280 |
+
"id": id,
|
281 |
+
"nsentences": len(samples),
|
282 |
+
"ntokens": ntokens,
|
283 |
+
"net_input": {
|
284 |
+
"src_tokens": src_tokens,
|
285 |
+
"src_lengths": src_lengths,
|
286 |
+
},
|
287 |
+
"target": target,
|
288 |
+
"ref_tokens": ref_tokens,
|
289 |
+
"ref_lengths": ref_lengths,
|
290 |
+
}
|
291 |
+
if prev_output_tokens is not None:
|
292 |
+
batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
|
293 |
+
0, sort_order
|
294 |
+
)
|
295 |
+
|
296 |
+
if samples[0].get("alignment", None) is not None:
|
297 |
+
bsz, tgt_sz = batch["target"].shape
|
298 |
+
src_sz = batch["net_input"]["src_tokens"].shape[1]
|
299 |
+
|
300 |
+
offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
|
301 |
+
offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz
|
302 |
+
if left_pad_source:
|
303 |
+
offsets[:, 0] += src_sz - src_lengths
|
304 |
+
if left_pad_target:
|
305 |
+
offsets[:, 1] += tgt_sz - tgt_lengths
|
306 |
+
|
307 |
+
alignments = [
|
308 |
+
alignment + offset
|
309 |
+
for align_idx, offset, src_len, tgt_len in zip(
|
310 |
+
sort_order, offsets, src_lengths, tgt_lengths
|
311 |
+
)
|
312 |
+
for alignment in [samples[align_idx]["alignment"].view(-1, 2)]
|
313 |
+
if check_alignment(alignment, src_len, tgt_len)
|
314 |
+
]
|
315 |
+
|
316 |
+
if len(alignments) > 0:
|
317 |
+
alignments = torch.cat(alignments, dim=0)
|
318 |
+
align_weights = compute_alignment_weights(alignments)
|
319 |
+
|
320 |
+
batch["alignments"] = alignments
|
321 |
+
batch["align_weights"] = align_weights
|
322 |
+
|
323 |
+
if samples[0].get("constraints", None) is not None:
|
324 |
+
# Collate the packed constraints across the samples, padding to
|
325 |
+
# the length of the longest sample.
|
326 |
+
lens = [sample.get("constraints").size(0) for sample in samples]
|
327 |
+
max_len = max(lens)
|
328 |
+
constraints = torch.zeros((len(samples), max(lens))).long()
|
329 |
+
for i, sample in enumerate(samples):
|
330 |
+
constraints[i, 0 : lens[i]] = samples[i].get("constraints")
|
331 |
+
batch["constraints"] = constraints.index_select(0, sort_order)
|
332 |
+
|
333 |
+
return batch
|
334 |
+
|
335 |
+
|
336 |
+
class LanguageTripleDataset(FairseqDataset):
|
337 |
+
"""
|
338 |
+
A pair of torch.utils.data.Datasets.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
src (torch.utils.data.Dataset): source dataset to wrap
|
342 |
+
src_sizes (List[int]): source sentence lengths
|
343 |
+
src_dict (~fairseq.data.Dictionary): source vocabulary
|
344 |
+
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
|
345 |
+
tgt_sizes (List[int], optional): target sentence lengths
|
346 |
+
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
|
347 |
+
left_pad_source (bool, optional): pad source tensors on the left side
|
348 |
+
(default: True).
|
349 |
+
left_pad_target (bool, optional): pad target tensors on the left side
|
350 |
+
(default: False).
|
351 |
+
shuffle (bool, optional): shuffle dataset elements before batching
|
352 |
+
(default: True).
|
353 |
+
input_feeding (bool, optional): create a shifted version of the targets
|
354 |
+
to be passed into the model for teacher forcing (default: True).
|
355 |
+
remove_eos_from_source (bool, optional): if set, removes eos from end
|
356 |
+
of source if it's present (default: False).
|
357 |
+
append_eos_to_target (bool, optional): if set, appends eos to end of
|
358 |
+
target if it's absent (default: False).
|
359 |
+
align_dataset (torch.utils.data.Dataset, optional): dataset
|
360 |
+
containing alignments.
|
361 |
+
constraints (Tensor, optional): 2d tensor with a concatenated, zero-
|
362 |
+
delimited list of constraints for each sentence.
|
363 |
+
append_bos (bool, optional): if set, appends bos to the beginning of
|
364 |
+
source/target sentence.
|
365 |
+
num_buckets (int, optional): if set to a value greater than 0, then
|
366 |
+
batches will be bucketed into the given number of batch shapes.
|
367 |
+
src_lang_id (int, optional): source language ID, if set, the collated batch
|
368 |
+
will contain a field 'src_lang_id' in 'net_input' which indicates the
|
369 |
+
source language of the samples.
|
370 |
+
tgt_lang_id (int, optional): target language ID, if set, the collated batch
|
371 |
+
will contain a field 'tgt_lang_id' which indicates the target language
|
372 |
+
of the samples.
|
373 |
+
"""
|
374 |
+
|
375 |
+
def __init__(
|
376 |
+
self,
|
377 |
+
src,
|
378 |
+
src_sizes,
|
379 |
+
src_dict,
|
380 |
+
ref,
|
381 |
+
ref_sizes,
|
382 |
+
ref_dict,
|
383 |
+
tgt=None,
|
384 |
+
tgt_sizes=None,
|
385 |
+
tgt_dict=None,
|
386 |
+
left_pad_source=True,
|
387 |
+
left_pad_target=False,
|
388 |
+
shuffle=True,
|
389 |
+
input_feeding=True,
|
390 |
+
remove_eos_from_source=False,
|
391 |
+
append_eos_to_target=False,
|
392 |
+
align_dataset=None,
|
393 |
+
constraints=None,
|
394 |
+
append_bos=False,
|
395 |
+
eos=None,
|
396 |
+
num_buckets=0,
|
397 |
+
src_lang_id=None,
|
398 |
+
tgt_lang_id=None,
|
399 |
+
pad_to_multiple=1,
|
400 |
+
):
|
401 |
+
if tgt_dict is not None:
|
402 |
+
assert src_dict.pad() == tgt_dict.pad()
|
403 |
+
assert src_dict.eos() == tgt_dict.eos()
|
404 |
+
assert src_dict.unk() == tgt_dict.unk()
|
405 |
+
if tgt is not None:
|
406 |
+
assert len(src) == len(
|
407 |
+
tgt
|
408 |
+
), "Source and target must contain the same number of examples"
|
409 |
+
assert len(src) == len(
|
410 |
+
ref
|
411 |
+
), "Source and reference must contain the same number of examples"
|
412 |
+
self.src = src
|
413 |
+
self.ref = ref
|
414 |
+
self.tgt = tgt
|
415 |
+
self.src_sizes = np.array(src_sizes)
|
416 |
+
self.ref_sizes = np.array(ref_sizes)
|
417 |
+
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
|
418 |
+
self.sizes = (
|
419 |
+
np.vstack((self.src_sizes, self.tgt_sizes)).T
|
420 |
+
if self.tgt_sizes is not None
|
421 |
+
else self.src_sizes
|
422 |
+
)
|
423 |
+
self.src_dict = src_dict
|
424 |
+
self.ref_dict = ref_dict
|
425 |
+
self.tgt_dict = tgt_dict
|
426 |
+
self.left_pad_source = left_pad_source
|
427 |
+
self.left_pad_target = left_pad_target
|
428 |
+
self.shuffle = shuffle
|
429 |
+
self.input_feeding = input_feeding
|
430 |
+
self.remove_eos_from_source = remove_eos_from_source
|
431 |
+
self.append_eos_to_target = append_eos_to_target
|
432 |
+
self.align_dataset = align_dataset
|
433 |
+
if self.align_dataset is not None:
|
434 |
+
assert (
|
435 |
+
self.tgt_sizes is not None
|
436 |
+
), "Both source and target needed when alignments are provided"
|
437 |
+
self.constraints = constraints
|
438 |
+
self.append_bos = append_bos
|
439 |
+
self.eos = eos if eos is not None else src_dict.eos()
|
440 |
+
self.src_lang_id = src_lang_id
|
441 |
+
self.tgt_lang_id = tgt_lang_id
|
442 |
+
if num_buckets > 0:
|
443 |
+
from fairseq.data import BucketPadLengthDataset
|
444 |
+
|
445 |
+
self.src = BucketPadLengthDataset(
|
446 |
+
self.src,
|
447 |
+
sizes=self.src_sizes,
|
448 |
+
num_buckets=num_buckets,
|
449 |
+
pad_idx=self.src_dict.pad(),
|
450 |
+
left_pad=self.left_pad_source,
|
451 |
+
)
|
452 |
+
self.src_sizes = self.src.sizes
|
453 |
+
logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
|
454 |
+
self.ref = BucketPadLengthDataset(
|
455 |
+
self.ref,
|
456 |
+
sizes=self.ref_sizes,
|
457 |
+
num_buckets=num_buckets,
|
458 |
+
pad_idx=self.ref_dict.pad(),
|
459 |
+
left_pad=self.left_pad_source,
|
460 |
+
)
|
461 |
+
self.ref_sizes = self.ref.sizes
|
462 |
+
logger.info("bucketing reference lengths: {}".format(list(self.src.buckets)))
|
463 |
+
if self.tgt is not None:
|
464 |
+
self.tgt = BucketPadLengthDataset(
|
465 |
+
self.tgt,
|
466 |
+
sizes=self.tgt_sizes,
|
467 |
+
num_buckets=num_buckets,
|
468 |
+
pad_idx=self.tgt_dict.pad(),
|
469 |
+
left_pad=self.left_pad_target,
|
470 |
+
)
|
471 |
+
self.tgt_sizes = self.tgt.sizes
|
472 |
+
logger.info(
|
473 |
+
"bucketing target lengths: {}".format(list(self.tgt.buckets))
|
474 |
+
)
|
475 |
+
|
476 |
+
# determine bucket sizes using self.num_tokens, which will return
|
477 |
+
# the padded lengths (thanks to BucketPadLengthDataset)
|
478 |
+
num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long])
|
479 |
+
self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
|
480 |
+
self.buckets = [
|
481 |
+
(None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
|
482 |
+
]
|
483 |
+
else:
|
484 |
+
self.buckets = None
|
485 |
+
self.pad_to_multiple = pad_to_multiple
|
486 |
+
|
487 |
+
def get_batch_shapes(self):
|
488 |
+
return self.buckets
|
489 |
+
|
490 |
+
def __getitem__(self, index):
|
491 |
+
tgt_item = self.tgt[index] if self.tgt is not None else None
|
492 |
+
src_item = self.src[index]
|
493 |
+
ref_item = self.ref[index]
|
494 |
+
# Append EOS to end of tgt sentence if it does not have an EOS and remove
|
495 |
+
# EOS from end of src sentence if it exists. This is useful when we use
|
496 |
+
# use existing datasets for opposite directions i.e., when we want to
|
497 |
+
# use tgt_dataset as src_dataset and vice versa
|
498 |
+
if self.append_eos_to_target:
|
499 |
+
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
|
500 |
+
if self.tgt and self.tgt[index][-1] != eos:
|
501 |
+
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
|
502 |
+
|
503 |
+
if self.append_bos:
|
504 |
+
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
|
505 |
+
if self.tgt and self.tgt[index][0] != bos:
|
506 |
+
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
|
507 |
+
|
508 |
+
bos = self.src_dict.bos()
|
509 |
+
if self.src[index][0] != bos:
|
510 |
+
src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
|
511 |
+
if self.ref[index][0] != bos:
|
512 |
+
ref_item = torch.cat([torch.LongTensor([bos]), self.ref[index]])
|
513 |
+
|
514 |
+
if self.remove_eos_from_source:
|
515 |
+
eos = self.src_dict.eos()
|
516 |
+
if self.src[index][-1] == eos:
|
517 |
+
src_item = self.src[index][:-1]
|
518 |
+
if self.ref[index][-1] == eos:
|
519 |
+
ref_item = self.ref[index][:-1]
|
520 |
+
|
521 |
+
example = {
|
522 |
+
"id": index,
|
523 |
+
"source": src_item,
|
524 |
+
"reference": ref_item,
|
525 |
+
"target": tgt_item,
|
526 |
+
}
|
527 |
+
if self.align_dataset is not None:
|
528 |
+
example["alignment"] = self.align_dataset[index]
|
529 |
+
if self.constraints is not None:
|
530 |
+
example["constraints"] = self.constraints[index]
|
531 |
+
return example
|
532 |
+
|
533 |
+
def __len__(self):
|
534 |
+
return len(self.src)
|
535 |
+
|
536 |
+
def collater(self, samples, pad_to_length=None):
|
537 |
+
"""Merge a list of samples to form a mini-batch.
|
538 |
+
|
539 |
+
Args:
|
540 |
+
samples (List[dict]): samples to collate
|
541 |
+
pad_to_length (dict, optional): a dictionary of
|
542 |
+
{'source': source_pad_to_length, 'target': target_pad_to_length}
|
543 |
+
to indicate the max length to pad to in source and target respectively.
|
544 |
+
|
545 |
+
Returns:
|
546 |
+
dict: a mini-batch with the following keys:
|
547 |
+
|
548 |
+
- `id` (LongTensor): example IDs in the original input order
|
549 |
+
- `ntokens` (int): total number of tokens in the batch
|
550 |
+
- `net_input` (dict): the input to the Model, containing keys:
|
551 |
+
|
552 |
+
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
|
553 |
+
the source sentence of shape `(bsz, src_len)`. Padding will
|
554 |
+
appear on the left if *left_pad_source* is ``True``.
|
555 |
+
- `src_lengths` (LongTensor): 1D Tensor of the unpadded
|
556 |
+
lengths of each source sentence of shape `(bsz)`
|
557 |
+
- `prev_output_tokens` (LongTensor): a padded 2D Tensor of
|
558 |
+
tokens in the target sentence, shifted right by one
|
559 |
+
position for teacher forcing, of shape `(bsz, tgt_len)`.
|
560 |
+
This key will not be present if *input_feeding* is
|
561 |
+
``False``. Padding will appear on the left if
|
562 |
+
*left_pad_target* is ``True``.
|
563 |
+
- `src_lang_id` (LongTensor): a long Tensor which contains source
|
564 |
+
language IDs of each sample in the batch
|
565 |
+
|
566 |
+
- `target` (LongTensor): a padded 2D Tensor of tokens in the
|
567 |
+
target sentence of shape `(bsz, tgt_len)`. Padding will appear
|
568 |
+
on the left if *left_pad_target* is ``True``.
|
569 |
+
- `tgt_lang_id` (LongTensor): a long Tensor which contains target language
|
570 |
+
IDs of each sample in the batch
|
571 |
+
"""
|
572 |
+
res = collate(
|
573 |
+
samples,
|
574 |
+
pad_idx=self.src_dict.pad(),
|
575 |
+
eos_idx=self.eos,
|
576 |
+
left_pad_source=self.left_pad_source,
|
577 |
+
left_pad_target=self.left_pad_target,
|
578 |
+
input_feeding=self.input_feeding,
|
579 |
+
pad_to_length=pad_to_length,
|
580 |
+
pad_to_multiple=self.pad_to_multiple,
|
581 |
+
)
|
582 |
+
if self.src_lang_id is not None or self.tgt_lang_id is not None:
|
583 |
+
src_tokens = res["net_input"]["src_tokens"]
|
584 |
+
bsz = src_tokens.size(0)
|
585 |
+
if self.src_lang_id is not None:
|
586 |
+
res["net_input"]["src_lang_id"] = (
|
587 |
+
torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
|
588 |
+
)
|
589 |
+
if self.tgt_lang_id is not None:
|
590 |
+
res["tgt_lang_id"] = (
|
591 |
+
torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
|
592 |
+
)
|
593 |
+
return res
|
594 |
+
|
595 |
+
def num_tokens(self, index):
|
596 |
+
"""Return the number of tokens in a sample. This value is used to
|
597 |
+
enforce ``--max-tokens`` during batching."""
|
598 |
+
return max(
|
599 |
+
self.src_sizes[index],
|
600 |
+
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
|
601 |
+
)
|
602 |
+
|
603 |
+
def num_tokens_vec(self, indices):
|
604 |
+
"""Return the number of tokens for a set of positions defined by indices.
|
605 |
+
This value is used to enforce ``--max-tokens`` during batching."""
|
606 |
+
sizes = self.src_sizes[indices]
|
607 |
+
if self.tgt_sizes is not None:
|
608 |
+
sizes = np.maximum(sizes, self.tgt_sizes[indices])
|
609 |
+
return sizes
|
610 |
+
|
611 |
+
def size(self, index):
|
612 |
+
"""Return an example's size as a float or tuple. This value is used when
|
613 |
+
filtering a dataset with ``--max-positions``."""
|
614 |
+
return (
|
615 |
+
self.src_sizes[index],
|
616 |
+
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
|
617 |
+
)
|
618 |
+
|
619 |
+
def ordered_indices(self):
|
620 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
621 |
+
on this order."""
|
622 |
+
if self.shuffle:
|
623 |
+
indices = np.random.permutation(len(self)).astype(np.int64)
|
624 |
+
else:
|
625 |
+
indices = np.arange(len(self), dtype=np.int64)
|
626 |
+
if self.buckets is None:
|
627 |
+
# sort by target length, then source length
|
628 |
+
if self.tgt_sizes is not None:
|
629 |
+
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
|
630 |
+
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
|
631 |
+
else:
|
632 |
+
# sort by bucketed_num_tokens, which is:
|
633 |
+
# max(padded_src_len, padded_tgt_len)
|
634 |
+
return indices[
|
635 |
+
np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
|
636 |
+
]
|
637 |
+
|
638 |
+
@property
|
639 |
+
def supports_prefetch(self):
|
640 |
+
return getattr(self.src, "supports_prefetch", False) and (
|
641 |
+
getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
|
642 |
+
)
|
643 |
+
|
644 |
+
def prefetch(self, indices):
|
645 |
+
self.src.prefetch(indices)
|
646 |
+
if self.tgt is not None:
|
647 |
+
self.tgt.prefetch(indices)
|
648 |
+
if self.align_dataset is not None:
|
649 |
+
self.align_dataset.prefetch(indices)
|
650 |
+
|
651 |
+
def filter_indices_by_size(self, indices, max_sizes):
|
652 |
+
"""Filter a list of sample indices. Remove those that are longer
|
653 |
+
than specified in max_sizes.
|
654 |
+
|
655 |
+
Args:
|
656 |
+
indices (np.array): original array of sample indices
|
657 |
+
max_sizes (int or list[int] or tuple[int]): max sample size,
|
658 |
+
can be defined separately for src and tgt (then list or tuple)
|
659 |
+
|
660 |
+
Returns:
|
661 |
+
np.array: filtered sample array
|
662 |
+
list: list of removed indices
|
663 |
+
"""
|
664 |
+
return data_utils.filter_paired_dataset_indices_by_size(
|
665 |
+
self.src_sizes,
|
666 |
+
self.tgt_sizes,
|
667 |
+
indices,
|
668 |
+
max_sizes,
|
669 |
+
)
|
SpeechT5/Speech2S/speech2s/data/load_langpair_dataset.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 Microsoft
|
3 |
+
# Licensed under The MIT License [see LICENSE for details]
|
4 |
+
# Based on fairseq code bases
|
5 |
+
# https://github.com/facebookresearch/fairseq
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
"""
|
9 |
+
Modified from https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/fairseq/tasks/translation.py
|
10 |
+
1. Add custom lang_format in function load_langpair_dataset
|
11 |
+
2. If truncate_source (default no), use RandomCropDataset instead of TruncateDataset
|
12 |
+
"""
|
13 |
+
|
14 |
+
import itertools
|
15 |
+
import logging
|
16 |
+
import os
|
17 |
+
|
18 |
+
from fairseq.data import (
|
19 |
+
AppendTokenDataset,
|
20 |
+
LanguagePairDataset,
|
21 |
+
PrependTokenDataset,
|
22 |
+
StripTokenDataset,
|
23 |
+
TruncateDataset,
|
24 |
+
RandomCropDataset,
|
25 |
+
data_utils,
|
26 |
+
indexed_dataset,
|
27 |
+
)
|
28 |
+
|
29 |
+
from speechut.data.concat_dataset import ConcatDataset
|
30 |
+
|
31 |
+
|
32 |
+
EVAL_BLEU_ORDER = 4
|
33 |
+
|
34 |
+
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
def load_langpair_dataset(
|
39 |
+
data_path,
|
40 |
+
split,
|
41 |
+
src,
|
42 |
+
src_dict,
|
43 |
+
tgt,
|
44 |
+
tgt_dict,
|
45 |
+
combine,
|
46 |
+
dataset_impl,
|
47 |
+
upsample_primary,
|
48 |
+
left_pad_source,
|
49 |
+
left_pad_target,
|
50 |
+
max_source_positions,
|
51 |
+
max_target_positions,
|
52 |
+
prepend_bos=False,
|
53 |
+
load_alignments=False,
|
54 |
+
truncate_source=False,
|
55 |
+
append_source_id=False,
|
56 |
+
num_buckets=0,
|
57 |
+
shuffle=True,
|
58 |
+
pad_to_multiple=1,
|
59 |
+
prepend_bos_src=None,
|
60 |
+
lang_format="[{}]",
|
61 |
+
input_feeding=True,
|
62 |
+
):
|
63 |
+
def split_exists(split, src, tgt, lang, data_path):
|
64 |
+
filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
|
65 |
+
return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
|
66 |
+
|
67 |
+
src_datasets = []
|
68 |
+
tgt_datasets = []
|
69 |
+
|
70 |
+
for k in itertools.count():
|
71 |
+
split_k = split + (str(k) if k > 0 else "")
|
72 |
+
|
73 |
+
# infer langcode
|
74 |
+
if split_exists(split_k, src, tgt, src, data_path):
|
75 |
+
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
|
76 |
+
elif split_exists(split_k, tgt, src, src, data_path):
|
77 |
+
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
|
78 |
+
else:
|
79 |
+
if k > 0:
|
80 |
+
break
|
81 |
+
else:
|
82 |
+
raise FileNotFoundError(
|
83 |
+
"Dataset not found: {} ({})".format(split, data_path)
|
84 |
+
)
|
85 |
+
|
86 |
+
src_dataset = data_utils.load_indexed_dataset(
|
87 |
+
prefix + src, src_dict, dataset_impl
|
88 |
+
)
|
89 |
+
if truncate_source:
|
90 |
+
src_dataset = AppendTokenDataset(
|
91 |
+
RandomCropDataset(
|
92 |
+
StripTokenDataset(src_dataset, src_dict.eos()),
|
93 |
+
max_source_positions - 1,
|
94 |
+
),
|
95 |
+
src_dict.eos(),
|
96 |
+
)
|
97 |
+
src_datasets.append(src_dataset)
|
98 |
+
|
99 |
+
tgt_dataset = data_utils.load_indexed_dataset(
|
100 |
+
prefix + tgt, tgt_dict, dataset_impl
|
101 |
+
)
|
102 |
+
if tgt_dataset is not None:
|
103 |
+
tgt_datasets.append(tgt_dataset)
|
104 |
+
|
105 |
+
logger.info(
|
106 |
+
"{} {} {}-{} {} examples".format(
|
107 |
+
data_path, split_k, src, tgt, len(src_datasets[-1])
|
108 |
+
)
|
109 |
+
)
|
110 |
+
|
111 |
+
if not combine:
|
112 |
+
break
|
113 |
+
|
114 |
+
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
|
115 |
+
|
116 |
+
if len(src_datasets) == 1:
|
117 |
+
src_dataset = src_datasets[0]
|
118 |
+
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
|
119 |
+
else:
|
120 |
+
sample_ratios = [1] * len(src_datasets)
|
121 |
+
sample_ratios[0] = upsample_primary
|
122 |
+
src_dataset = ConcatDataset(src_datasets, sample_ratios)
|
123 |
+
if len(tgt_datasets) > 0:
|
124 |
+
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
|
125 |
+
else:
|
126 |
+
tgt_dataset = None
|
127 |
+
|
128 |
+
if prepend_bos:
|
129 |
+
assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
|
130 |
+
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
|
131 |
+
if tgt_dataset is not None:
|
132 |
+
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
|
133 |
+
elif prepend_bos_src is not None:
|
134 |
+
logger.info(f"prepending src bos: {prepend_bos_src}")
|
135 |
+
src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
|
136 |
+
|
137 |
+
eos = None
|
138 |
+
if append_source_id:
|
139 |
+
src_dataset = AppendTokenDataset(
|
140 |
+
src_dataset, src_dict.index(lang_format.format(src))
|
141 |
+
)
|
142 |
+
if tgt_dataset is not None:
|
143 |
+
tgt_dataset = AppendTokenDataset(
|
144 |
+
tgt_dataset, tgt_dict.index(lang_format.format(tgt))
|
145 |
+
)
|
146 |
+
eos = tgt_dict.index(lang_format.format(tgt))
|
147 |
+
|
148 |
+
align_dataset = None
|
149 |
+
if load_alignments:
|
150 |
+
align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
|
151 |
+
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
|
152 |
+
align_dataset = data_utils.load_indexed_dataset(
|
153 |
+
align_path, None, dataset_impl
|
154 |
+
)
|
155 |
+
|
156 |
+
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
|
157 |
+
return LanguagePairDataset(
|
158 |
+
src_dataset,
|
159 |
+
src_dataset.sizes,
|
160 |
+
src_dict,
|
161 |
+
tgt_dataset,
|
162 |
+
tgt_dataset_sizes,
|
163 |
+
tgt_dict,
|
164 |
+
left_pad_source=left_pad_source,
|
165 |
+
left_pad_target=left_pad_target,
|
166 |
+
align_dataset=align_dataset,
|
167 |
+
eos=eos,
|
168 |
+
num_buckets=num_buckets,
|
169 |
+
shuffle=shuffle,
|
170 |
+
pad_to_multiple=pad_to_multiple,
|
171 |
+
input_feeding=input_feeding,
|
172 |
+
)
|
SpeechT5/Speech2S/speech2s/data/multimodal_corpus_dataset.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 Microsoft
|
3 |
+
# Licensed under The MIT License [see LICENSE for details]
|
4 |
+
# Based on fairseq code bases
|
5 |
+
# https://github.com/facebookresearch/fairseq
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import logging
|
9 |
+
from os import replace
|
10 |
+
import time
|
11 |
+
from collections import OrderedDict
|
12 |
+
from typing import Any, Dict, List, Optional
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
from fairseq.data import data_utils
|
16 |
+
|
17 |
+
from fairseq.data import FairseqDataset
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class MultiCorpusDataset(FairseqDataset):
|
23 |
+
"""
|
24 |
+
see fairseq/fairseq/data/multi_corpus_dataset.__doc__
|
25 |
+
|
26 |
+
Args:
|
27 |
+
datasets: a OrderedDict of FairseqDataset instances.
|
28 |
+
distribution: a List containing the probability of getting an utterance from
|
29 |
+
corresponding dataset
|
30 |
+
seed: random seed for sampling the datsets
|
31 |
+
sort_indices: if true, will sort the ordered indices by size
|
32 |
+
batch_sample: if true, will ensure each batch is from a single dataset
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
datasets: Dict[str, FairseqDataset],
|
38 |
+
max_positions: Dict,
|
39 |
+
distribution: List[float],
|
40 |
+
max_tokens_ratio: List[float],
|
41 |
+
seed: int = 1234,
|
42 |
+
sort_indices: bool = False,
|
43 |
+
check_length: bool = False,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
assert isinstance(datasets, OrderedDict)
|
47 |
+
assert len(datasets) == len(distribution)
|
48 |
+
# assert sum(distribution) == 1
|
49 |
+
self.datasets = datasets
|
50 |
+
self.distribution = distribution
|
51 |
+
self.max_tokens_ratio = max_tokens_ratio
|
52 |
+
self.seed = seed
|
53 |
+
self.sort_indices = sort_indices
|
54 |
+
self.max_positions = max_positions
|
55 |
+
self.check_length = check_length
|
56 |
+
|
57 |
+
# Avoid repeated conversions to list later
|
58 |
+
self.dataset_list = list(datasets.values())
|
59 |
+
self.total_num_instances = 0
|
60 |
+
|
61 |
+
# first_dataset = self.dataset_list[0]
|
62 |
+
|
63 |
+
self.num_instances_per_dataset = []
|
64 |
+
self.dataset_offsets = []
|
65 |
+
for i, dataset in enumerate(self.dataset_list):
|
66 |
+
assert isinstance(dataset, FairseqDataset)
|
67 |
+
# assert type(dataset) is type(first_dataset)
|
68 |
+
self.num_instances_per_dataset.append(
|
69 |
+
0 if self.distribution[i] == 0 else len(dataset)
|
70 |
+
)
|
71 |
+
self.dataset_offsets.append(self.total_num_instances)
|
72 |
+
self.total_num_instances += self.num_instances_per_dataset[i]
|
73 |
+
|
74 |
+
def ordered_indices(self):
|
75 |
+
start = time.time()
|
76 |
+
with data_utils.numpy_seed(self.seed, self.epoch):
|
77 |
+
logger.info(f"sampling new dataset with seed {self.seed} epoch {self.epoch}")
|
78 |
+
sampled_indices = {}
|
79 |
+
|
80 |
+
# For each dataset i, sample self.distribution[i] * self.total_num_instances
|
81 |
+
for i, key in enumerate(self.datasets):
|
82 |
+
tp = time.time()
|
83 |
+
if self.distribution[i] == 0:
|
84 |
+
# skip dataset if sampling probability is 0
|
85 |
+
continue
|
86 |
+
|
87 |
+
if i < len(self.datasets) - 1:
|
88 |
+
num_instances = int(self.distribution[i] * self.total_num_instances)
|
89 |
+
high = self.dataset_offsets[i + 1]
|
90 |
+
else:
|
91 |
+
num_instances = int(self.distribution[i] * self.total_num_instances)
|
92 |
+
high = self.total_num_instances
|
93 |
+
|
94 |
+
logger.info(f"sampling {num_instances} from {key} dataset")
|
95 |
+
|
96 |
+
# First, add k copies of the dataset where k = num_instances // len(dataset).
|
97 |
+
# This ensures an equal distribution of the data points as much as possible.
|
98 |
+
# For the remaining entries randomly sample them
|
99 |
+
dataset_size = len(self.datasets[key])
|
100 |
+
num_copies = num_instances // dataset_size
|
101 |
+
dataset_indices = np.random.permutation(high - self.dataset_offsets[i])[: num_instances - num_copies * dataset_size]
|
102 |
+
if num_copies > 0:
|
103 |
+
dataset_indices = np.concatenate(
|
104 |
+
(
|
105 |
+
np.repeat(
|
106 |
+
np.arange(high - self.dataset_offsets[i]), num_copies
|
107 |
+
),
|
108 |
+
dataset_indices,
|
109 |
+
)
|
110 |
+
)
|
111 |
+
# filter by size, we should ignore it by setting check_length=False
|
112 |
+
# , as it is very time-consuming on large dadaset
|
113 |
+
if self.max_positions[key] is not None and self.check_length:
|
114 |
+
dataset_indices, ignored = self.datasets[key].filter_indices_by_size(
|
115 |
+
dataset_indices,
|
116 |
+
self.max_positions[key],
|
117 |
+
)
|
118 |
+
if len(ignored) > 0:
|
119 |
+
logger.warning(
|
120 |
+
(
|
121 |
+
"{:,} samples have invalid sizes and will be skipped, "
|
122 |
+
"max_positions={}, first few sample ids={}"
|
123 |
+
).format(len(ignored), self.max_positions[key], ignored[:10])
|
124 |
+
)
|
125 |
+
|
126 |
+
if self.sort_indices:
|
127 |
+
logger.info(" - sampled indices took {}s".format(time.time() - tp))
|
128 |
+
tp = time.time()
|
129 |
+
dataset_indices = np.sort(dataset_indices)
|
130 |
+
ordered_indices = self.datasets[key].ordered_indices()
|
131 |
+
if isinstance(ordered_indices[0], np.ndarray): # chunked audio data
|
132 |
+
dataset_indices = [order_idx + self.dataset_offsets[i] for order_idx in ordered_indices]
|
133 |
+
assert self.dataset_offsets[i] == 0
|
134 |
+
# TODO for chunked audio data, now assume len(dataset_indices) == len(dataset). Don't filter any data.
|
135 |
+
else:
|
136 |
+
dataset_indices = ordered_indices[dataset_indices] + self.dataset_offsets[i]
|
137 |
+
logger.info(" - ordered_indices took {}s".format(time.time() - tp))
|
138 |
+
else:
|
139 |
+
np.random.shuffle(dataset_indices)
|
140 |
+
|
141 |
+
sampled_indices[key] = dataset_indices
|
142 |
+
|
143 |
+
logger.info(
|
144 |
+
"multi_corpus_dataset ordered_indices took {}s".format(
|
145 |
+
time.time() - start
|
146 |
+
)
|
147 |
+
)
|
148 |
+
return sampled_indices
|
149 |
+
|
150 |
+
def _map_index(self, index: int):
|
151 |
+
"""
|
152 |
+
If dataset A has length N and dataset B has length M
|
153 |
+
then index 1 maps to index 1 of dataset A, and index N + 1
|
154 |
+
maps to index 1 of B.
|
155 |
+
"""
|
156 |
+
counter = 0
|
157 |
+
for num_instances, key in zip(self.num_instances_per_dataset, self.datasets):
|
158 |
+
if index < counter + num_instances:
|
159 |
+
return index - counter, key
|
160 |
+
counter += num_instances
|
161 |
+
raise ValueError(
|
162 |
+
"Invalid index: {}, max: {}".format(index, self.total_num_instances)
|
163 |
+
)
|
164 |
+
|
165 |
+
def __len__(self):
|
166 |
+
"""
|
167 |
+
Length of this dataset is the sum of individual datasets
|
168 |
+
"""
|
169 |
+
return self.total_num_instances
|
170 |
+
|
171 |
+
def __getitem__(self, index):
|
172 |
+
new_index, key = self._map_index(index)
|
173 |
+
try:
|
174 |
+
item = self.datasets[key][new_index]
|
175 |
+
item["full_id"] = index
|
176 |
+
return item
|
177 |
+
except Exception as e:
|
178 |
+
e.args = (f"Error from {key} dataset", *e.args)
|
179 |
+
raise
|
180 |
+
|
181 |
+
def collater(self, samples):
|
182 |
+
"""
|
183 |
+
If we are doing batch sampling, then pick the right collater to use.
|
184 |
+
|
185 |
+
Otherwise we assume all collaters are the same.
|
186 |
+
"""
|
187 |
+
if len(samples) == 0:
|
188 |
+
return None
|
189 |
+
|
190 |
+
samples_dict = {key: [] for key in self.datasets}
|
191 |
+
for s in samples:
|
192 |
+
_, key = self._map_index(s["full_id"])
|
193 |
+
samples_dict[key].append(s)
|
194 |
+
|
195 |
+
batch = {}
|
196 |
+
for key in samples_dict:
|
197 |
+
if len(samples_dict[key]) == 0:
|
198 |
+
continue
|
199 |
+
batch[key] = self.datasets[key].collater(samples_dict[key])
|
200 |
+
|
201 |
+
return batch
|
202 |
+
|
203 |
+
|
204 |
+
def num_tokens(self, index: int):
|
205 |
+
index, key = self._map_index(index)
|
206 |
+
return self.datasets[key].num_tokens(index)
|
207 |
+
|
208 |
+
def size(self, index: int):
|
209 |
+
index, key = self._map_index(index)
|
210 |
+
return self.datasets[key].size(index)
|
211 |
+
|
212 |
+
@property
|
213 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
214 |
+
return False
|
215 |
+
|
216 |
+
def set_epoch(self, epoch, **unused):
|
217 |
+
super().set_epoch(epoch)
|
218 |
+
logger.info(f"setting epoch of multi_corpus_dataset to {epoch}")
|
219 |
+
for ds in self.dataset_list:
|
220 |
+
if hasattr(ds, "set_epoch"):
|
221 |
+
ds.set_epoch(epoch)
|
222 |
+
self.epoch = epoch
|
223 |
+
|
224 |
+
@property
|
225 |
+
def supports_prefetch(self):
|
226 |
+
return False
|
227 |
+
|
228 |
+
@property
|
229 |
+
def supports_fetch_outside_dataloader(self):
|
230 |
+
return all(
|
231 |
+
self.datasets[key].supports_fetch_outside_dataloader
|
232 |
+
for key in self.datasets
|
233 |
+
)
|
234 |
+
|
235 |
+
|
236 |
+
def batch_by_size(
|
237 |
+
self,
|
238 |
+
indices,
|
239 |
+
max_tokens=None,
|
240 |
+
max_sentences=None,
|
241 |
+
required_batch_size_multiple=1,
|
242 |
+
):
|
243 |
+
dataset_indices = indices
|
244 |
+
batches_dict = {}
|
245 |
+
for n, key in enumerate(dataset_indices):
|
246 |
+
max_tokens_ratio = self.max_tokens_ratio[n]
|
247 |
+
if isinstance(dataset_indices[key][0], np.ndarray): # chunked audio data
|
248 |
+
cur_batches = self.datasets[key].batch_by_size(
|
249 |
+
dataset_indices[key],
|
250 |
+
round(max_tokens * max_tokens_ratio),
|
251 |
+
max_sentences,
|
252 |
+
required_batch_size_multiple,
|
253 |
+
)
|
254 |
+
logger.info(f"Created {sum([len(b) for b in cur_batches])} [{len(cur_batches)}] batches for dataset {key}")
|
255 |
+
else:
|
256 |
+
cur_batches = super().batch_by_size(
|
257 |
+
np.array(dataset_indices[key], dtype=np.int64),
|
258 |
+
round(max_tokens * max_tokens_ratio),
|
259 |
+
max_sentences,
|
260 |
+
required_batch_size_multiple,
|
261 |
+
)
|
262 |
+
logger.info(f"Created {len(cur_batches)} batches for dataset {key}")
|
263 |
+
batches_dict[key] = cur_batches
|
264 |
+
|
265 |
+
return batches_dict
|
266 |
+
|
267 |
+
|
268 |
+
def get_batch_sampler(
|
269 |
+
self,
|
270 |
+
indices,
|
271 |
+
num_shards,
|
272 |
+
seed,
|
273 |
+
max_tokens=None,
|
274 |
+
max_sentences=None,
|
275 |
+
required_batch_size_multiple=1,
|
276 |
+
split_modality_batch=False,
|
277 |
+
):
|
278 |
+
|
279 |
+
def batch_sampler(dataset, epoch):
|
280 |
+
start = time.time()
|
281 |
+
batches_dict = dataset.batch_by_size(
|
282 |
+
indices,
|
283 |
+
max_tokens=max_tokens,
|
284 |
+
max_sentences=max_sentences,
|
285 |
+
required_batch_size_multiple=required_batch_size_multiple,
|
286 |
+
)
|
287 |
+
logger.info(f"multi_corpus_dataset, batch_by_size took {time.time() - start}s")
|
288 |
+
start = time.time()
|
289 |
+
new_batches = []
|
290 |
+
|
291 |
+
### shuffle inner group size, split into speech/text batches
|
292 |
+
shuffled_batches_list = []
|
293 |
+
speech_batches = []
|
294 |
+
### we should specify the speech_batches because: we need concatenate different speech datasets
|
295 |
+
# (e.g. ltr or km) instead of loading them parellelly.
|
296 |
+
for name, batches in batches_dict.items():
|
297 |
+
if name.startswith("speech"):
|
298 |
+
if isinstance(batches[0], list): # chunked audio data
|
299 |
+
batches = self.datasets[name].shuffle_batches(list(batches), seed + epoch)
|
300 |
+
shuffled_batches_list.append(batches)
|
301 |
+
else:
|
302 |
+
batches = inner_bucket_shuffle(batches, seed+epoch, num_shards*10)
|
303 |
+
batches = batches[: (len(batches) // num_shards) * num_shards]
|
304 |
+
if len(batches) == 0:
|
305 |
+
logger.warning(f"Sample 0 batch for {name}, you should ensure that no {name} data provided.")
|
306 |
+
else:
|
307 |
+
speech_batches += batches
|
308 |
+
else:
|
309 |
+
batches = inner_bucket_shuffle(batches, seed+epoch, num_shards*10)
|
310 |
+
batches = batches[: (len(batches) // num_shards) * num_shards]
|
311 |
+
if len(batches) == 0:
|
312 |
+
logger.warning(f"Sample 0 batch for {name}, you should ensure that no {name} data provided.")
|
313 |
+
else:
|
314 |
+
batches = shuffle_buckets(batches, seed=seed+epoch, inner_shuf=False)
|
315 |
+
shuffled_batches_list.append(batches)
|
316 |
+
if len(speech_batches) > 0:
|
317 |
+
speech_batches = shuffle_buckets(speech_batches, seed=seed+epoch, inner_shuf=False)
|
318 |
+
shuffled_batches_list.append(speech_batches)
|
319 |
+
|
320 |
+
### create the final new_batches
|
321 |
+
num_batch = min(len(batches) for batches in shuffled_batches_list)
|
322 |
+
if split_modality_batch:
|
323 |
+
for i in range(0, num_batch, num_shards):
|
324 |
+
for batches in shuffled_batches_list:
|
325 |
+
new_batches += batches[i: i + num_shards]
|
326 |
+
else:
|
327 |
+
for i in range(num_batch):
|
328 |
+
new_batches.append(np.concatenate([batches[i] for batches in shuffled_batches_list]))
|
329 |
+
|
330 |
+
logger.info(f"multi_corpus_dataset sample {len(new_batches)} batches, took {time.time() - start}s")
|
331 |
+
return new_batches
|
332 |
+
|
333 |
+
def inner_bucket_shuffle(batches, seed, bucket_size=10, thr=0):
|
334 |
+
"""we assert batches is sorted form long to short.
|
335 |
+
shuffle samples in a buctet(e.g. 10 batches).
|
336 |
+
batches: a list of numpy array"""
|
337 |
+
num_batch = len(batches)
|
338 |
+
new_batches = []
|
339 |
+
num_buckets = len(batches) // bucket_size
|
340 |
+
i = 0
|
341 |
+
while i < num_batch:
|
342 |
+
if (i < bucket_size * thr or
|
343 |
+
i >= bucket_size * (num_buckets - thr)
|
344 |
+
):
|
345 |
+
new_batches.append(batches[i])
|
346 |
+
i += 1
|
347 |
+
else:
|
348 |
+
group = np.concatenate(batches[i: i+bucket_size])
|
349 |
+
with data_utils.numpy_seed(seed):
|
350 |
+
np.random.shuffle(group)
|
351 |
+
new_batches += np.array_split(group, bucket_size)
|
352 |
+
i += bucket_size
|
353 |
+
assert all([len(batch) > 0 for batch in new_batches])
|
354 |
+
return new_batches
|
355 |
+
|
356 |
+
def shuffle_buckets(batches, seed, inner_shuf=True):
|
357 |
+
if inner_shuf:
|
358 |
+
batches = inner_bucket_shuffle(batches, seed, num_shards*10)
|
359 |
+
batches = [batches[i: i + num_shards] for i in range(0, len(batches)-num_shards+1, num_shards)]
|
360 |
+
assert len(batches[-1]) == num_shards
|
361 |
+
new_batches = []
|
362 |
+
with data_utils.numpy_seed(seed):
|
363 |
+
np.random.shuffle(batches)
|
364 |
+
for group in batches:
|
365 |
+
new_batches += group
|
366 |
+
return new_batches
|
367 |
+
|
368 |
+
return batch_sampler
|
SpeechT5/Speech2S/speech2s/models/__init__.py
ADDED
File without changes
|
SpeechT5/Speech2S/speech2s/models/speechut.py
ADDED
@@ -0,0 +1,785 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ----------------------------------------------------------------------------
|
2 |
+
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
|
4 |
+
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
|
5 |
+
#
|
6 |
+
# Copyright (c) 2022 Microsoft
|
7 |
+
# Licensed under The MIT License [see LICENSE for details]
|
8 |
+
# ----------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import logging
|
11 |
+
from dataclasses import dataclass, field
|
12 |
+
from typing import Dict, List, Optional, Tuple
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
from fairseq import utils, checkpoint_utils
|
20 |
+
from fairseq.data.data_utils import compute_mask_indices
|
21 |
+
from fairseq.data.dictionary import Dictionary
|
22 |
+
from fairseq.dataclass import ChoiceEnum
|
23 |
+
from fairseq.models import BaseFairseqModel, register_model
|
24 |
+
from fairseq.models.transformer import Embedding
|
25 |
+
from fairseq.file_io import PathManager
|
26 |
+
from torch import Tensor
|
27 |
+
from fairseq.models.wav2vec.wav2vec2 import ConvFeatureExtractionModel
|
28 |
+
from fairseq.modules import GradMultiply, LayerNorm
|
29 |
+
from fairseq.tasks.hubert_pretraining import (
|
30 |
+
HubertPretrainingConfig,
|
31 |
+
HubertPretrainingTask,
|
32 |
+
)
|
33 |
+
from fairseq.models.hubert import HubertConfig
|
34 |
+
from fairseq.models.transformer import TransformerConfig
|
35 |
+
from speechut.modules import TransformerEncoder
|
36 |
+
from speechut.modules import TransformerEncoderBase
|
37 |
+
from speechut.modules import TransformerDecoderBaseScriptable
|
38 |
+
|
39 |
+
logger = logging.getLogger(__name__)
|
40 |
+
|
41 |
+
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
|
42 |
+
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
|
47 |
+
class SpeechutConfig(HubertConfig):
|
48 |
+
use_rel_pos_enc: bool = field(
|
49 |
+
default=False,
|
50 |
+
metadata={"help": "whether to use relative positional encoding"},
|
51 |
+
)
|
52 |
+
scaling_for_att: float = field(
|
53 |
+
default=1.0,
|
54 |
+
metadata={"help": "scaling for attention weights to prevent overflow issue (for large model)"},
|
55 |
+
)
|
56 |
+
|
57 |
+
# unit encoder-decoder
|
58 |
+
text_transformer: TransformerConfig = TransformerConfig()
|
59 |
+
reset_decoder_embedding_config: bool = field(
|
60 |
+
default=False,
|
61 |
+
metadata={"help": "reset the no_scale_embedding/layernorm_embedding to default for the decoder"},
|
62 |
+
)
|
63 |
+
add_unit_encoder: bool = field(
|
64 |
+
default=False,
|
65 |
+
metadata={"help": "add unit encoder"},
|
66 |
+
)
|
67 |
+
add_decoder: bool = field(
|
68 |
+
default=True,
|
69 |
+
metadata={"help": "add decoder"},
|
70 |
+
)
|
71 |
+
add_text_ctc: bool = field(
|
72 |
+
default=False,
|
73 |
+
metadata={"help": "add_text_ctc head"},
|
74 |
+
)
|
75 |
+
text_ctc_conv_kernel: int = field(
|
76 |
+
default=2,
|
77 |
+
metadata={"help": "text_ctc_conv kernel size"},
|
78 |
+
)
|
79 |
+
mask_u2t: bool = field(
|
80 |
+
default=True,
|
81 |
+
metadata={"help": "mask the unit input in unit-to-text task"},
|
82 |
+
)
|
83 |
+
|
84 |
+
# embedding mixing
|
85 |
+
mix_with_unit: bool = field(
|
86 |
+
default=True,
|
87 |
+
metadata={"help": "mix with the unit embeddings"},
|
88 |
+
)
|
89 |
+
use_pred_unit: bool = field(
|
90 |
+
default=False,
|
91 |
+
metadata={"help": "use the embeddings of predicted units"},
|
92 |
+
)
|
93 |
+
l2_embedding: bool = field(
|
94 |
+
default=False,
|
95 |
+
metadata={"help": "compute l2 loss between unit embedding and unit hidden state"},
|
96 |
+
)
|
97 |
+
|
98 |
+
# Finetune related
|
99 |
+
encoder_dict_size: int = field(
|
100 |
+
default=-1,
|
101 |
+
metadata={"help": "text encoder dictionary dimension"},
|
102 |
+
)
|
103 |
+
|
104 |
+
decoder_dict_size: int = field(
|
105 |
+
default=-1,
|
106 |
+
metadata={"help": "decoder dictionary dimension"},
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
@register_model("speechut", dataclass=SpeechutConfig)
|
111 |
+
class SpeechutModel(BaseFairseqModel):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
cfg: SpeechutConfig,
|
115 |
+
task_cfg: HubertPretrainingConfig,
|
116 |
+
dictionaries: List[Dictionary],
|
117 |
+
unit_dictionary: Dictionary = None,
|
118 |
+
text_tgt_dictionary: Dictionary = None,
|
119 |
+
) -> None:
|
120 |
+
super().__init__()
|
121 |
+
logger.info(f"SpeechutModel Config: {cfg}")
|
122 |
+
|
123 |
+
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
|
124 |
+
self.embed = feature_enc_layers[-1][0]
|
125 |
+
|
126 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
127 |
+
conv_layers=feature_enc_layers,
|
128 |
+
dropout=0.0,
|
129 |
+
mode=cfg.extractor_mode,
|
130 |
+
conv_bias=cfg.conv_bias,
|
131 |
+
)
|
132 |
+
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
133 |
+
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
|
134 |
+
|
135 |
+
self.post_extract_proj = (
|
136 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
137 |
+
if self.embed != cfg.encoder_embed_dim
|
138 |
+
else None
|
139 |
+
)
|
140 |
+
|
141 |
+
self.mask_prob = cfg.mask_prob
|
142 |
+
self.mask_selection = cfg.mask_selection
|
143 |
+
self.mask_other = cfg.mask_other
|
144 |
+
self.mask_length = cfg.mask_length
|
145 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
146 |
+
self.mask_min_space = cfg.mask_min_space
|
147 |
+
|
148 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
149 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
150 |
+
self.mask_channel_other = cfg.mask_channel_other
|
151 |
+
self.mask_channel_length = cfg.mask_channel_length
|
152 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
153 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
154 |
+
|
155 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
156 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
157 |
+
|
158 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
159 |
+
self.logit_temp = cfg.logit_temp
|
160 |
+
self.skip_masked = cfg.skip_masked
|
161 |
+
self.skip_nomask = cfg.skip_nomask
|
162 |
+
|
163 |
+
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
164 |
+
|
165 |
+
self.mask_emb = nn.Parameter(
|
166 |
+
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
167 |
+
)
|
168 |
+
|
169 |
+
self.encoder = TransformerEncoder(cfg)
|
170 |
+
self.layer_norm = LayerNorm(self.embed)
|
171 |
+
|
172 |
+
self.target_glu = None
|
173 |
+
if cfg.target_glu:
|
174 |
+
self.target_glu = nn.Sequential(
|
175 |
+
nn.Linear(final_dim, final_dim * 2), nn.GLU()
|
176 |
+
)
|
177 |
+
|
178 |
+
self.final_dim = final_dim
|
179 |
+
assert len(dictionaries) <= 2, f"Only support <=2 kinds of targets, get {len(dictionaries)} dictionaries"
|
180 |
+
if len(dictionaries) == 1:
|
181 |
+
dictionaries = [dictionaries[0], dictionaries[0]]
|
182 |
+
self.num_classes = [len(d) for d in dictionaries]
|
183 |
+
|
184 |
+
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
|
185 |
+
self.code_encoder_proj = nn.Linear(cfg.text_transformer.encoder.embed_dim, self.num_classes[-1])
|
186 |
+
self.final_proj_list = [self.final_proj, self.code_encoder_proj]
|
187 |
+
|
188 |
+
self.label_embs_concat = nn.Parameter(torch.FloatTensor(self.num_classes[0], final_dim))
|
189 |
+
self.label_embs_list = [self.label_embs_concat]
|
190 |
+
for p in self.label_embs_list:
|
191 |
+
nn.init.uniform_(p)
|
192 |
+
|
193 |
+
### build unit encoder:
|
194 |
+
self.mask_u2t = cfg.mask_u2t
|
195 |
+
self.add_text_ctc = cfg.add_text_ctc
|
196 |
+
self.text_ctc_conv_kernel = cfg.text_ctc_conv_kernel
|
197 |
+
self.padding_idx = unit_dictionary.pad()
|
198 |
+
self.unit_mask_idx = unit_dictionary.index("<mask>")
|
199 |
+
|
200 |
+
self.add_unit_encoder = cfg.add_unit_encoder
|
201 |
+
self.mix_with_unit = cfg.mix_with_unit
|
202 |
+
self.use_pred_unit = cfg.use_pred_unit
|
203 |
+
self.l2_embedding = cfg.l2_embedding
|
204 |
+
if self.add_unit_encoder:
|
205 |
+
assert len(unit_dictionary) == self.num_classes[0], f"unit_dictionary: {len(unit_dictionary)}, self.num_classes[0]: {self.num_classes[0]}"
|
206 |
+
### build unit pre-net, and shared with hubert label_embs if needed (default: False)
|
207 |
+
self.unit_embed_tokens = self.build_embedding(
|
208 |
+
unit_dictionary,
|
209 |
+
cfg.text_transformer.encoder.embed_dim,
|
210 |
+
)
|
211 |
+
if self.final_dim == cfg.text_transformer.encoder.embed_dim:
|
212 |
+
logger.info("Share label_embs[0] with unit_embed_tokens ...")
|
213 |
+
nn.init.uniform_(self.unit_embed_tokens.weight)
|
214 |
+
self.label_embs_list[0] = self.unit_embed_tokens.weight
|
215 |
+
|
216 |
+
### build unit encoder
|
217 |
+
self.unit_encoder = TransformerEncoderBase(
|
218 |
+
cfg.text_transformer,
|
219 |
+
unit_dictionary,
|
220 |
+
self.unit_embed_tokens,
|
221 |
+
use_rel_pos_enc=cfg.use_rel_pos_enc,
|
222 |
+
scaling_for_att=cfg.scaling_for_att,
|
223 |
+
)
|
224 |
+
|
225 |
+
### build text ctc head
|
226 |
+
if self.add_text_ctc:
|
227 |
+
conv = nn.Conv1d(
|
228 |
+
cfg.text_transformer.encoder.embed_dim, cfg.text_transformer.encoder.embed_dim,
|
229 |
+
self.text_ctc_conv_kernel,
|
230 |
+
stride=self.text_ctc_conv_kernel // 2,
|
231 |
+
bias=False,
|
232 |
+
padding=self.text_ctc_conv_kernel // 2,
|
233 |
+
)
|
234 |
+
nn.init.kaiming_normal_(conv.weight)
|
235 |
+
self.unit_encoder_ctc_head = nn.Sequential(
|
236 |
+
Rotate3D(),
|
237 |
+
conv,
|
238 |
+
nn.Dropout(p=0.1),
|
239 |
+
nn.Sequential(
|
240 |
+
Rotate3D(),
|
241 |
+
Rotate3D(),
|
242 |
+
LayerNorm(cfg.text_transformer.encoder.embed_dim),
|
243 |
+
),
|
244 |
+
nn.GELU(),
|
245 |
+
nn.Linear(cfg.text_transformer.encoder.embed_dim, len(text_tgt_dictionary)),
|
246 |
+
)
|
247 |
+
|
248 |
+
### build unit2text decoder, not available for now
|
249 |
+
self.add_decoder = cfg.add_decoder
|
250 |
+
self.text_transformer_cfg = cfg.text_transformer
|
251 |
+
if self.add_decoder:
|
252 |
+
# To make sure that the decoder dict size is the same as the fine-tuning tgt_dict size or bpe code dict size
|
253 |
+
dec_dictionary = self.cutting_dictionary(text_tgt_dictionary, cfg.decoder_dict_size)
|
254 |
+
decoder_embed_tokens = self.build_embedding(
|
255 |
+
dec_dictionary, cfg.text_transformer.decoder.embed_dim
|
256 |
+
)
|
257 |
+
if cfg.reset_decoder_embedding_config:
|
258 |
+
cfg.text_transformer.no_scale_embedding = False
|
259 |
+
cfg.text_transformer.layernorm_embedding = False
|
260 |
+
cfg.text_transformer.no_token_positional_embeddings = False
|
261 |
+
self.decoder = TransformerDecoderBaseScriptable(cfg.text_transformer, dec_dictionary, decoder_embed_tokens, use_rel_pos_enc=cfg.use_rel_pos_enc)
|
262 |
+
|
263 |
+
|
264 |
+
def cutting_dictionary(self, dictionary, dict_size):
|
265 |
+
if dictionary is None or dict_size <= 0:
|
266 |
+
return dictionary
|
267 |
+
else:
|
268 |
+
import copy
|
269 |
+
cut_dictionary = copy.deepcopy(dictionary)
|
270 |
+
if dict_size > len(cut_dictionary):
|
271 |
+
for i in range(dict_size - len(cut_dictionary)):
|
272 |
+
cut_dictionary.symbols.append(f'_{i}_')
|
273 |
+
else:
|
274 |
+
cut_dictionary.symbols = cut_dictionary.symbols[:dict_size]
|
275 |
+
return cut_dictionary
|
276 |
+
|
277 |
+
def build_embedding(self, dictionary, embed_dim):
|
278 |
+
num_embeddings = len(dictionary)
|
279 |
+
padding_idx = dictionary.pad()
|
280 |
+
return Embedding(num_embeddings, embed_dim, padding_idx)
|
281 |
+
|
282 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
283 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
284 |
+
|
285 |
+
super().upgrade_state_dict_named(state_dict, name)
|
286 |
+
return state_dict
|
287 |
+
|
288 |
+
@classmethod
|
289 |
+
def build_model(cls, cfg: SpeechutConfig, task: HubertPretrainingTask):
|
290 |
+
"""Build a new model instance."""
|
291 |
+
unit_dictionary = getattr(task, "text_src_dictionary", None)
|
292 |
+
text_tgt_dictionary = getattr(task, "text_dictionary", None)
|
293 |
+
model = SpeechutModel(cfg, task.cfg, task.dictionaries, unit_dictionary, text_tgt_dictionary)
|
294 |
+
return model
|
295 |
+
|
296 |
+
def apply_mask(self, x, padding_mask, target_list):
|
297 |
+
B, T, C = x.shape
|
298 |
+
if self.mask_prob > 0:
|
299 |
+
mask_indices = compute_mask_indices(
|
300 |
+
(B, T),
|
301 |
+
padding_mask,
|
302 |
+
self.mask_prob,
|
303 |
+
self.mask_length,
|
304 |
+
self.mask_selection,
|
305 |
+
self.mask_other,
|
306 |
+
min_masks=2,
|
307 |
+
no_overlap=self.no_mask_overlap,
|
308 |
+
min_space=self.mask_min_space,
|
309 |
+
)
|
310 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
311 |
+
x[mask_indices] = self.mask_emb
|
312 |
+
else:
|
313 |
+
mask_indices = None
|
314 |
+
|
315 |
+
if self.mask_channel_prob > 0:
|
316 |
+
mask_channel_indices = compute_mask_indices(
|
317 |
+
(B, C),
|
318 |
+
None,
|
319 |
+
self.mask_channel_prob,
|
320 |
+
self.mask_channel_length,
|
321 |
+
self.mask_channel_selection,
|
322 |
+
self.mask_channel_other,
|
323 |
+
no_overlap=self.no_mask_channel_overlap,
|
324 |
+
min_space=self.mask_channel_min_space,
|
325 |
+
)
|
326 |
+
mask_channel_indices = (
|
327 |
+
torch.from_numpy(mask_channel_indices)
|
328 |
+
.to(x.device)
|
329 |
+
.unsqueeze(1)
|
330 |
+
.expand(-1, T, -1)
|
331 |
+
)
|
332 |
+
x[mask_channel_indices] = 0
|
333 |
+
|
334 |
+
return x, mask_indices
|
335 |
+
|
336 |
+
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
|
337 |
+
if self.feature_grad_mult > 0:
|
338 |
+
features = self.feature_extractor(source)
|
339 |
+
if self.feature_grad_mult != 1.0:
|
340 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
341 |
+
else:
|
342 |
+
with torch.no_grad():
|
343 |
+
features = self.feature_extractor(source)
|
344 |
+
return features
|
345 |
+
|
346 |
+
def forward_targets(
|
347 |
+
self,
|
348 |
+
features: torch.Tensor,
|
349 |
+
target_list: List[torch.Tensor],
|
350 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
351 |
+
# Trim features to ensure labels exist and then get aligned labels
|
352 |
+
feat_tsz = features.size(2)
|
353 |
+
targ_tsz = min([t.size(1) for t in target_list])
|
354 |
+
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
355 |
+
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
356 |
+
features = features[..., :feat_tsz]
|
357 |
+
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
358 |
+
target_inds += np.random.choice(int(self.feat2tar_ratio))
|
359 |
+
target_list = [t[:, target_inds.long()] for t in target_list]
|
360 |
+
return features, target_list
|
361 |
+
|
362 |
+
def forward_padding_mask(
|
363 |
+
self,
|
364 |
+
features: torch.Tensor,
|
365 |
+
padding_mask: torch.Tensor,
|
366 |
+
) -> torch.Tensor:
|
367 |
+
extra = padding_mask.size(1) % features.size(1)
|
368 |
+
if extra > 0:
|
369 |
+
padding_mask = padding_mask[:, :-extra]
|
370 |
+
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
371 |
+
padding_mask = padding_mask.all(-1)
|
372 |
+
return padding_mask
|
373 |
+
|
374 |
+
def get_normalized_probs(
|
375 |
+
self,
|
376 |
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
377 |
+
log_probs: bool,
|
378 |
+
sample: Optional[Dict[str, Tensor]] = None,
|
379 |
+
):
|
380 |
+
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
381 |
+
lprobs.batch_first = True
|
382 |
+
return lprobs
|
383 |
+
|
384 |
+
def downsample_ctc_padding_mask(self, padding_mask):
|
385 |
+
"""
|
386 |
+
padding_mask: (B, T)
|
387 |
+
"""
|
388 |
+
stride = self.text_ctc_conv_kernel // 2
|
389 |
+
return padding_mask[:, ::stride]
|
390 |
+
|
391 |
+
def compute_pred(self, proj_x, label_embs):
|
392 |
+
if self.target_glu:
|
393 |
+
label_embs = self.target_glu(label_embs)
|
394 |
+
x = F.normalize(proj_x.float(), dim=-1) # (S, D)
|
395 |
+
label_embs = F.normalize(label_embs.float(), dim=-1) # (C, D)
|
396 |
+
logits = torch.matmul(x, label_embs.T).type_as(proj_x) # (S, C)
|
397 |
+
logits /= self.logit_temp
|
398 |
+
return logits
|
399 |
+
|
400 |
+
def compute_hubert_logits(self, x, target, proj, label_embs, padding_mask, mask_indices):
|
401 |
+
if not self.skip_masked:
|
402 |
+
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
403 |
+
proj_x_m = proj(x[masked_indices])
|
404 |
+
logit_m_list = [(self.compute_pred(proj_x_m, label_embs), target[masked_indices])]
|
405 |
+
else:
|
406 |
+
logit_m_list = [None]
|
407 |
+
|
408 |
+
if not self.skip_nomask:
|
409 |
+
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
410 |
+
proj_x_u = proj(x[nomask_indices])
|
411 |
+
logit_u_list = [(self.compute_pred(proj_x_u, label_embs), target[nomask_indices])]
|
412 |
+
else:
|
413 |
+
logit_u_list = [None]
|
414 |
+
|
415 |
+
return logit_m_list, logit_u_list
|
416 |
+
|
417 |
+
def compute_ce_logits(self, x, target, proj, padding_mask, mask_indices):
|
418 |
+
if not self.skip_masked:
|
419 |
+
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
420 |
+
logit_m_list = [(proj(x[masked_indices]), target[masked_indices])]
|
421 |
+
else:
|
422 |
+
logit_m_list = [None]
|
423 |
+
|
424 |
+
if not self.skip_nomask:
|
425 |
+
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
426 |
+
logit_u_list = [(proj(x[nomask_indices]), target[nomask_indices])]
|
427 |
+
else:
|
428 |
+
logit_u_list = [None]
|
429 |
+
|
430 |
+
return logit_m_list, logit_u_list
|
431 |
+
|
432 |
+
def convert_embeddings(self,
|
433 |
+
x,
|
434 |
+
padding_mask,
|
435 |
+
target=None,
|
436 |
+
mask_indices=None,
|
437 |
+
mix_with_unit=False,
|
438 |
+
use_pred_unit=False,
|
439 |
+
l2_embedding=False,
|
440 |
+
remask=False
|
441 |
+
):
|
442 |
+
"""
|
443 |
+
1. Mix with units if needed (default: True)
|
444 |
+
2. Prepare for unit_encoder inputs
|
445 |
+
Inputs:
|
446 |
+
x, (B, T, D)
|
447 |
+
Return:
|
448 |
+
src_tokens, (B, T)
|
449 |
+
soft_embeddings, (B, T, D)
|
450 |
+
l2_loss, a loss
|
451 |
+
"""
|
452 |
+
soft_embeddings = self.final_proj_list[0](x) if x.size(-1) == self.final_dim else x
|
453 |
+
if padding_mask is None:
|
454 |
+
padding_mask = soft_embeddings.new_zeros(soft_embeddings.size(0), soft_embeddings.size(1), dtype=torch.long)
|
455 |
+
if use_pred_unit:
|
456 |
+
src_tokens = self.compute_pred(self.final_proj_list[0](x), self.label_embs_list[0]).argmax(dim=-1)
|
457 |
+
src_tokens[padding_mask] = self.padding_idx
|
458 |
+
elif target is not None:
|
459 |
+
src_tokens = target
|
460 |
+
else:
|
461 |
+
src_tokens = padding_mask.long()
|
462 |
+
|
463 |
+
if l2_embedding | mix_with_unit:
|
464 |
+
unit_embeddings = self.unit_embed_tokens(src_tokens) # (B, T, D)
|
465 |
+
|
466 |
+
l2_loss = 0
|
467 |
+
if l2_embedding:
|
468 |
+
if mask_indices is not None:
|
469 |
+
l2_loss = (soft_embeddings - unit_embeddings)[mask_indices].float().pow(2).mean(dim=-1)
|
470 |
+
scale = unit_embeddings[mask_indices].float().pow(2).sum(dim=-1)
|
471 |
+
else:
|
472 |
+
l2_loss = (soft_embeddings - unit_embeddings).float().pow(2).mean(dim=-1)
|
473 |
+
scale = unit_embeddings.float().pow(2).sum(dim=-1)
|
474 |
+
l2_loss = (l2_loss / scale).mean()
|
475 |
+
|
476 |
+
if mix_with_unit:
|
477 |
+
B, T, D = x.shape
|
478 |
+
selected_indices = compute_mask_indices(
|
479 |
+
(B, T),
|
480 |
+
padding_mask,
|
481 |
+
self.mask_prob / 2,
|
482 |
+
self.mask_length // 2,
|
483 |
+
self.mask_selection,
|
484 |
+
self.mask_other,
|
485 |
+
min_masks=2,
|
486 |
+
no_overlap=self.no_mask_overlap,
|
487 |
+
min_space=self.mask_min_space,
|
488 |
+
)
|
489 |
+
selected_indices = torch.from_numpy(selected_indices).to(x.device)
|
490 |
+
if mask_indices is not None:
|
491 |
+
if remask:
|
492 |
+
remask_indices = torch.logical_and(selected_indices, mask_indices)
|
493 |
+
soft_embeddings[remask_indices] = self.mask_emb
|
494 |
+
swap_indices = torch.logical_and(selected_indices, ~mask_indices)
|
495 |
+
else:
|
496 |
+
swap_indices = selected_indices
|
497 |
+
soft_embeddings[swap_indices] = unit_embeddings[swap_indices]
|
498 |
+
|
499 |
+
soft_embeddings = soft_embeddings * (1 - padding_mask.unsqueeze(-1).type_as(x))
|
500 |
+
return src_tokens, soft_embeddings, l2_loss
|
501 |
+
|
502 |
+
def forward(
|
503 |
+
self,
|
504 |
+
source: torch.Tensor = None,
|
505 |
+
src_tokens: torch.Tensor = None,
|
506 |
+
src_lengths: torch.Tensor = None,
|
507 |
+
prev_output_tokens: torch.Tensor = None,
|
508 |
+
target_list: Optional[List[torch.Tensor]] = None,
|
509 |
+
padding_mask: Optional[torch.Tensor] = None,
|
510 |
+
mask: bool = True,
|
511 |
+
features_only: bool = False,
|
512 |
+
output_layer: Optional[int] = None,
|
513 |
+
) -> Dict[str, torch.Tensor]:
|
514 |
+
assert source is not None or src_tokens is not None
|
515 |
+
if source is not None:
|
516 |
+
return self.forward_speech(
|
517 |
+
source=source,
|
518 |
+
target_list=target_list,
|
519 |
+
padding_mask=padding_mask,
|
520 |
+
mask=mask,
|
521 |
+
features_only=features_only,
|
522 |
+
output_layer=output_layer,
|
523 |
+
)
|
524 |
+
else:
|
525 |
+
return self.forward_text(
|
526 |
+
src_tokens=src_tokens,
|
527 |
+
src_lengths=src_lengths,
|
528 |
+
prev_output_tokens=prev_output_tokens,
|
529 |
+
mask=self.mask_u2t,
|
530 |
+
features_only=features_only,
|
531 |
+
output_layer=output_layer,
|
532 |
+
)
|
533 |
+
|
534 |
+
def forward_speech(
|
535 |
+
self,
|
536 |
+
source: torch.Tensor = None,
|
537 |
+
target_list: Optional[List[torch.Tensor]] = None,
|
538 |
+
padding_mask: Optional[torch.Tensor] = None,
|
539 |
+
mask: bool = True,
|
540 |
+
features_only: bool = False,
|
541 |
+
output_layer: Optional[int] = None,
|
542 |
+
) -> Dict[str, torch.Tensor]:
|
543 |
+
"""output layer is 1-based"""
|
544 |
+
features = self.forward_features(source)
|
545 |
+
if target_list is not None:
|
546 |
+
features, target_list = self.forward_targets(features, target_list)
|
547 |
+
|
548 |
+
features_pen = features.float().pow(2).mean()
|
549 |
+
|
550 |
+
features = features.transpose(1, 2)
|
551 |
+
features = self.layer_norm(features)
|
552 |
+
unmasked_features = features.clone()
|
553 |
+
|
554 |
+
if padding_mask is not None:
|
555 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
556 |
+
|
557 |
+
if self.post_extract_proj is not None:
|
558 |
+
features = self.post_extract_proj(features)
|
559 |
+
|
560 |
+
features = self.dropout_input(features)
|
561 |
+
unmasked_features = self.dropout_features(unmasked_features)
|
562 |
+
|
563 |
+
if mask:
|
564 |
+
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
|
565 |
+
else:
|
566 |
+
x = features
|
567 |
+
mask_indices = None
|
568 |
+
|
569 |
+
# feature: (B, T, D), float
|
570 |
+
# target: (B, T), long
|
571 |
+
# x: (B, T, D), float
|
572 |
+
# padding_mask: (B, T), bool
|
573 |
+
# mask_indices: (B, T), bool
|
574 |
+
x, _ = self.encoder(
|
575 |
+
x,
|
576 |
+
padding_mask=padding_mask,
|
577 |
+
layer=None if output_layer is None else output_layer - 1,
|
578 |
+
)
|
579 |
+
|
580 |
+
if features_only:
|
581 |
+
return {"x": x, "padding_mask": padding_mask, "features": features}
|
582 |
+
|
583 |
+
logit_m_list, logit_u_list = self.compute_hubert_logits(
|
584 |
+
x,
|
585 |
+
target_list[0],
|
586 |
+
self.final_proj_list[0],
|
587 |
+
self.label_embs_list[0],
|
588 |
+
padding_mask,
|
589 |
+
mask_indices,
|
590 |
+
)
|
591 |
+
|
592 |
+
result = {
|
593 |
+
"logit_m_list": logit_m_list,
|
594 |
+
"logit_u_list": logit_u_list,
|
595 |
+
"padding_mask": padding_mask,
|
596 |
+
"features_pen": features_pen,
|
597 |
+
}
|
598 |
+
|
599 |
+
if self.add_unit_encoder:
|
600 |
+
src_tokens, x_emb, l2_loss = self.convert_embeddings(
|
601 |
+
x,
|
602 |
+
padding_mask, target_list[0],
|
603 |
+
mask_indices=mask_indices,
|
604 |
+
mix_with_unit=self.mix_with_unit,
|
605 |
+
use_pred_unit=self.use_pred_unit,
|
606 |
+
l2_embedding=self.l2_embedding,
|
607 |
+
)
|
608 |
+
encoder_out = self.unit_encoder(src_tokens, token_embeddings=x_emb)
|
609 |
+
|
610 |
+
result['encoder_out'] = encoder_out['encoder_out'] # [(T, B, D)]
|
611 |
+
result['encoder_padding_mask'] = encoder_out['encoder_padding_mask'] # [(B, T)]
|
612 |
+
if self.l2_embedding:
|
613 |
+
result['embedding_l2_loss'] = l2_loss
|
614 |
+
|
615 |
+
code_logit_m_list, code_logit_u_list = self.compute_ce_logits(
|
616 |
+
encoder_out['encoder_out'][0].transpose(0, 1), # -> (B, T, C)
|
617 |
+
target_list[-1],
|
618 |
+
self.final_proj_list[1],
|
619 |
+
padding_mask,
|
620 |
+
mask_indices,
|
621 |
+
)
|
622 |
+
result['logit_m_list'] += code_logit_m_list
|
623 |
+
result['logit_u_list'] += code_logit_u_list
|
624 |
+
return result
|
625 |
+
|
626 |
+
def forward_text(
|
627 |
+
self,
|
628 |
+
src_tokens: torch.Tensor = None,
|
629 |
+
src_lengths: torch.Tensor = None,
|
630 |
+
prev_output_tokens: torch.Tensor = None,
|
631 |
+
target_list: Optional[List[torch.Tensor]] = None,
|
632 |
+
mask: bool = True,
|
633 |
+
features_only: bool = False,
|
634 |
+
output_layer: Optional[int] = None,
|
635 |
+
) -> Dict[str, torch.Tensor]:
|
636 |
+
assert self.add_unit_encoder, f"Can not forward unit-text branch without unit_encoder!"
|
637 |
+
|
638 |
+
padding_mask = src_tokens == self.padding_idx
|
639 |
+
unit_embeddings = self.unit_embed_tokens(src_tokens)
|
640 |
+
if mask:
|
641 |
+
unit_embeddings, mask_indices = self.apply_mask(unit_embeddings, padding_mask, [src_tokens])
|
642 |
+
|
643 |
+
encoder_out = self.unit_encoder(
|
644 |
+
src_tokens,
|
645 |
+
token_embeddings=unit_embeddings,
|
646 |
+
return_all_hiddens=output_layer is not None,
|
647 |
+
)
|
648 |
+
|
649 |
+
result = {}
|
650 |
+
result["encoder_out"] = encoder_out["encoder_out"]
|
651 |
+
result["encoder_states"] = encoder_out["encoder_states"]
|
652 |
+
result["padding_mask"] = padding_mask
|
653 |
+
|
654 |
+
if self.add_text_ctc:
|
655 |
+
result["encoder_out_ctc"] = [self.unit_encoder_ctc_head(x) for x in encoder_out['encoder_out']]
|
656 |
+
result["encoder_padding_mask"] = [
|
657 |
+
self.downsample_ctc_padding_mask(padding_mask) for padding_mask in encoder_out['encoder_padding_mask']
|
658 |
+
]
|
659 |
+
|
660 |
+
if features_only:
|
661 |
+
return result
|
662 |
+
if self.add_decoder:
|
663 |
+
assert prev_output_tokens is not None
|
664 |
+
decoder_out = self.decoder(
|
665 |
+
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out,
|
666 |
+
)
|
667 |
+
result['decoder_out'] = decoder_out
|
668 |
+
return result
|
669 |
+
|
670 |
+
def forward_mum(self, src_tokens, target, mask=True):
|
671 |
+
target_list = [target]
|
672 |
+
padding_mask = src_tokens.eq(self.unit_encoder.padding_idx)
|
673 |
+
unit_embeddings = self.unit_embed_tokens(src_tokens)
|
674 |
+
if mask:
|
675 |
+
unit_embeddings, mask_indices = self.apply_mask(unit_embeddings, padding_mask, target_list)
|
676 |
+
else:
|
677 |
+
### If already applied mask on src_tokens, then the target_list should contains many padding_idx
|
678 |
+
mask_indices = target_list[-1] != self.padding_idx
|
679 |
+
unit_embeddings[mask_indices] = self.mask_emb
|
680 |
+
|
681 |
+
encoder_out = self.unit_encoder(
|
682 |
+
src_tokens,
|
683 |
+
token_embeddings=unit_embeddings,
|
684 |
+
)
|
685 |
+
code_logit_m_list, code_logit_u_list = self.compute_ce_logits(
|
686 |
+
encoder_out["encoder_out"][0].transpose(0, 1),
|
687 |
+
target_list[-1],
|
688 |
+
self.final_proj_list[1],
|
689 |
+
padding_mask,
|
690 |
+
mask_indices,
|
691 |
+
)
|
692 |
+
result = {}
|
693 |
+
result["logit_m_list"] = code_logit_m_list
|
694 |
+
result["logit_u_list"] = code_logit_u_list
|
695 |
+
result["padding_mask"] = padding_mask
|
696 |
+
return result
|
697 |
+
|
698 |
+
def extract_features(
|
699 |
+
self,
|
700 |
+
source: torch.Tensor,
|
701 |
+
padding_mask: Optional[torch.Tensor] = None,
|
702 |
+
mask: bool = False,
|
703 |
+
ret_conv: bool = False,
|
704 |
+
output_layer: Optional[int] = None,
|
705 |
+
**kwargs,
|
706 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
707 |
+
"""Extract encoder features for only speech input"""
|
708 |
+
res = self.forward(
|
709 |
+
source,
|
710 |
+
padding_mask=padding_mask,
|
711 |
+
mask=mask,
|
712 |
+
features_only=True,
|
713 |
+
output_layer=output_layer,
|
714 |
+
)
|
715 |
+
x = res["x"] # B x T x D
|
716 |
+
padding_mask = res["padding_mask"]
|
717 |
+
|
718 |
+
if self.add_unit_encoder:
|
719 |
+
src_tokens, x, _ = self.convert_embeddings(
|
720 |
+
x,
|
721 |
+
padding_mask,
|
722 |
+
mix_with_unit=False,
|
723 |
+
use_pred_unit=False,
|
724 |
+
)
|
725 |
+
encoder_out = self.unit_encoder(
|
726 |
+
src_tokens,
|
727 |
+
token_embeddings=x,
|
728 |
+
return_all_hiddens=output_layer is not None
|
729 |
+
)
|
730 |
+
res["x"] = encoder_out['encoder_out'][0].transpose(0, 1) # (B, T, D)
|
731 |
+
|
732 |
+
feature = res["features"] if ret_conv else res["x"]
|
733 |
+
if output_layer is not None:
|
734 |
+
feature = encoder_out['encoder_states']
|
735 |
+
|
736 |
+
return feature, padding_mask
|
737 |
+
|
738 |
+
def get_logits(self, net_output, is_masked=True):
|
739 |
+
if is_masked:
|
740 |
+
logits_list = net_output["logit_m_list"]
|
741 |
+
else:
|
742 |
+
logits_list = net_output["logit_u_list"]
|
743 |
+
logits_list = [x[0].float() for x in logits_list if x is not None]
|
744 |
+
return logits_list
|
745 |
+
|
746 |
+
def get_targets(self, net_output, is_masked=True):
|
747 |
+
if is_masked:
|
748 |
+
logits_list = net_output["logit_m_list"]
|
749 |
+
else:
|
750 |
+
logits_list = net_output["logit_u_list"]
|
751 |
+
targets_list = [x[1].long() for x in logits_list if x is not None]
|
752 |
+
return targets_list
|
753 |
+
|
754 |
+
def get_extra_losses(self, net_output):
|
755 |
+
extra_losses = []
|
756 |
+
names = []
|
757 |
+
|
758 |
+
if "features_pen" in net_output:
|
759 |
+
extra_losses.append(net_output["features_pen"])
|
760 |
+
names.append("features_pen")
|
761 |
+
|
762 |
+
if "embedding_l2_loss" in net_output:
|
763 |
+
extra_losses.append(net_output["embedding_l2_loss"])
|
764 |
+
names.append("embedding_l2_loss")
|
765 |
+
|
766 |
+
return extra_losses, names
|
767 |
+
|
768 |
+
def remove_pretraining_modules(self, step2=False):
|
769 |
+
self.target_glu = None
|
770 |
+
|
771 |
+
def load_checkpoint(self, checkpoint: str):
|
772 |
+
if not PathManager.exists(checkpoint):
|
773 |
+
raise IOError("Model file not found: {}".format(checkpoint))
|
774 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint)
|
775 |
+
return state
|
776 |
+
|
777 |
+
class Rotate3D(nn.Module):
|
778 |
+
"""
|
779 |
+
(T, B, D) --> (B, D, T) --> (D, T, B) --> (T, B, D)
|
780 |
+
"""
|
781 |
+
def __init__(self):
|
782 |
+
super().__init__()
|
783 |
+
|
784 |
+
def forward(self, x):
|
785 |
+
return x.permute(1, 2, 0)
|
SpeechT5/Speech2S/speech2s/models/speechut_asr.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ----------------------------------------------------------------------------
|
2 |
+
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
|
4 |
+
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
|
5 |
+
#
|
6 |
+
# Copyright (c) 2022 Microsoft
|
7 |
+
# Licensed under The MIT License [see LICENSE for details]
|
8 |
+
# ----------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import contextlib
|
11 |
+
import torch
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from fairseq import utils
|
14 |
+
from fairseq.models import BaseFairseqModel, register_model
|
15 |
+
from fairseq.models.fairseq_encoder import FairseqEncoder
|
16 |
+
from fairseq.models.hubert import HubertAsrConfig, HubertEncoder
|
17 |
+
from fairseq.tasks import FairseqTask
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class SpeechUTASRConfig(HubertAsrConfig):
|
21 |
+
add_decoder: bool = field(
|
22 |
+
default=True,
|
23 |
+
metadata={"help": "add decoder for fine-tune"},
|
24 |
+
)
|
25 |
+
|
26 |
+
@register_model("speechut_asr", dataclass=SpeechUTASRConfig)
|
27 |
+
class SpeechUTASR(BaseFairseqModel):
|
28 |
+
"""
|
29 |
+
A encoder-ctc-decoder model if cfg.add_decoder is True, or a encoder-ctc model
|
30 |
+
"""
|
31 |
+
def __init__(self, cfg: SpeechUTASRConfig, encoder: FairseqEncoder):
|
32 |
+
super().__init__()
|
33 |
+
self.cfg = cfg
|
34 |
+
self.encoder = encoder
|
35 |
+
if not cfg.add_decoder:
|
36 |
+
self.encoder.w2v_model.decoder = None
|
37 |
+
|
38 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
39 |
+
super().upgrade_state_dict_named(state_dict, name)
|
40 |
+
return state_dict
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def build_model(cls, cfg: SpeechUTASRConfig, task: FairseqTask):
|
44 |
+
"""Build a new model instance."""
|
45 |
+
encoder = SpeechUTEncoder(cfg, task)
|
46 |
+
return cls(cfg, encoder)
|
47 |
+
|
48 |
+
def forward(self, source, padding_mask, prev_output_tokens, **kwargs):
|
49 |
+
encoder_out = self.encoder(source, padding_mask, **kwargs)
|
50 |
+
|
51 |
+
x = self.encoder.final_dropout(encoder_out['encoder_out'][0]) # (T, B, C)
|
52 |
+
if self.encoder.proj:
|
53 |
+
x = self.encoder.proj(x)
|
54 |
+
if self.encoder.conv_ctc_proj:
|
55 |
+
padding_mask = self.encoder.w2v_model.downsample_ctc_padding_mask(encoder_out["encoder_padding_mask"][0])
|
56 |
+
else:
|
57 |
+
padding_mask = encoder_out["encoder_padding_mask"]
|
58 |
+
|
59 |
+
decoder_out = self.decoder(
|
60 |
+
prev_output_tokens, encoder_out=encoder_out, **kwargs
|
61 |
+
) if self.cfg.add_decoder else None
|
62 |
+
|
63 |
+
return {
|
64 |
+
"encoder_out_ctc": x, # (T, B, C), for CTC loss
|
65 |
+
"padding_mask": padding_mask, # (B, T), for CTC loss
|
66 |
+
"decoder_out": decoder_out, # for ED loss
|
67 |
+
}
|
68 |
+
|
69 |
+
def forward_decoder(self, prev_output_tokens, **kwargs):
|
70 |
+
return self.decoder(prev_output_tokens, **kwargs)
|
71 |
+
|
72 |
+
def get_logits(self, net_output):
|
73 |
+
"""For CTC decoding"""
|
74 |
+
logits = net_output["encoder_out"]
|
75 |
+
padding = net_output["encoder_padding_mask"]
|
76 |
+
if padding is not None and padding.any():
|
77 |
+
padding = padding.T
|
78 |
+
logits[padding][..., 0] = 0
|
79 |
+
logits[padding][..., 1:] = float("-inf")
|
80 |
+
|
81 |
+
return logits
|
82 |
+
|
83 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
84 |
+
"""For 1) computing CTC loss, 2) decoder decoding."""
|
85 |
+
|
86 |
+
if "encoder_out_ctc" in net_output:
|
87 |
+
logits = net_output["encoder_out_ctc"]
|
88 |
+
else:
|
89 |
+
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
|
90 |
+
|
91 |
+
if isinstance(logits, list):
|
92 |
+
logits = logits[0]
|
93 |
+
|
94 |
+
if log_probs:
|
95 |
+
return utils.log_softmax(logits.float(), dim=-1)
|
96 |
+
else:
|
97 |
+
return utils.softmax(logits.float(), dim=-1)
|
98 |
+
|
99 |
+
@property
|
100 |
+
def decoder(self):
|
101 |
+
return self.encoder.w2v_model.decoder
|
102 |
+
|
103 |
+
|
104 |
+
class SpeechUTEncoder(HubertEncoder):
|
105 |
+
"""
|
106 |
+
Modified from fairseq.models.hubert.hubert_asr.HubertEncoder
|
107 |
+
1. make it compatible with encoder-decoder model
|
108 |
+
"""
|
109 |
+
def __init__(self, cfg: HubertAsrConfig, task):
|
110 |
+
super().__init__(cfg, task)
|
111 |
+
|
112 |
+
if (task.target_dictionary is not None) and (
|
113 |
+
hasattr(self.w2v_model, "unit_encoder_ctc_head")
|
114 |
+
):
|
115 |
+
self.proj = self.w2v_model.unit_encoder_ctc_head
|
116 |
+
self.conv_ctc_proj = True
|
117 |
+
else:
|
118 |
+
self.conv_ctc_proj = False
|
119 |
+
|
120 |
+
def forward(self, source, padding_mask, tbc=True, **kwargs):
|
121 |
+
w2v_args = {
|
122 |
+
"source": source,
|
123 |
+
"padding_mask": padding_mask,
|
124 |
+
"mask": self.apply_mask and self.training,
|
125 |
+
}
|
126 |
+
ft = self.freeze_finetune_updates <= self.num_updates
|
127 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
128 |
+
x, padding_mask = self.w2v_model.extract_features(**w2v_args)
|
129 |
+
if tbc:
|
130 |
+
# B x T x C -> T x B x C
|
131 |
+
x = x.transpose(0, 1)
|
132 |
+
return {
|
133 |
+
"encoder_out": [x], # T x B x C
|
134 |
+
"encoder_padding_mask": [padding_mask], # B x T
|
135 |
+
}
|
136 |
+
|
137 |
+
def forward_torchscript(self, net_input):
|
138 |
+
"""A TorchScript-compatible version of forward.
|
139 |
+
|
140 |
+
Forward the encoder out.
|
141 |
+
"""
|
142 |
+
x, padding_mask = self.w2v_model.extract_features(**net_input, mask=False)
|
143 |
+
# B x T x C -> T x B x C
|
144 |
+
x = x.transpose(0, 1)
|
145 |
+
|
146 |
+
encoder_out = {
|
147 |
+
"encoder_out" : [x],
|
148 |
+
"encoder_padding_mask" : [padding_mask],
|
149 |
+
}
|
150 |
+
if self.proj:
|
151 |
+
x = self.proj(x)
|
152 |
+
encoder_out["encoder_out_ctc"] = x
|
153 |
+
|
154 |
+
return encoder_out
|
155 |
+
|
156 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
157 |
+
if encoder_out["encoder_out"] is not None:
|
158 |
+
encoder_out["encoder_out"] = [
|
159 |
+
x.index_select(1, new_order) for x in encoder_out["encoder_out"]
|
160 |
+
]
|
161 |
+
if encoder_out["encoder_padding_mask"] is not None:
|
162 |
+
encoder_out["encoder_padding_mask"] = [
|
163 |
+
x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]
|
164 |
+
]
|
165 |
+
return encoder_out
|
SpeechT5/Speech2S/speech2s/models/speechut_st.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ----------------------------------------------------------------------------
|
2 |
+
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
|
4 |
+
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
|
5 |
+
#
|
6 |
+
# Copyright (c) 2022 Microsoft
|
7 |
+
# Licensed under The MIT License [see LICENSE for details]
|
8 |
+
# ----------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import contextlib
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from argparse import Namespace
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Any
|
17 |
+
from fairseq import checkpoint_utils, tasks
|
18 |
+
from fairseq.models import BaseFairseqModel, register_model
|
19 |
+
from fairseq.models.fairseq_encoder import FairseqEncoder
|
20 |
+
from fairseq.tasks import FairseqTask
|
21 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
22 |
+
from fairseq.data.data_utils import lengths_to_padding_mask
|
23 |
+
|
24 |
+
from fairseq.models.hubert import HubertAsrConfig
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class SpeechUTS2TConfig(HubertAsrConfig):
|
30 |
+
### the following config is only for the compatibility to fairseq speech_to_text task
|
31 |
+
input_feat_per_channel: Any = None
|
32 |
+
input_channels: Any = None
|
33 |
+
speaker_to_id: Any = None
|
34 |
+
|
35 |
+
@register_model("speechut_st_legacy", dataclass=SpeechUTS2TConfig)
|
36 |
+
class SpeechUTS2T(BaseFairseqModel):
|
37 |
+
"""An encoder-decoder model."""
|
38 |
+
def __init__(self, cfg: SpeechUTS2TConfig, encoder: FairseqEncoder):
|
39 |
+
super().__init__()
|
40 |
+
self.cfg = cfg
|
41 |
+
self.encoder = encoder
|
42 |
+
|
43 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
44 |
+
super().upgrade_state_dict_named(state_dict, name)
|
45 |
+
return state_dict
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
def build_model(cls, cfg: SpeechUTS2TConfig, task: FairseqTask):
|
49 |
+
"""Build a new model instance."""
|
50 |
+
encoder = SpeechUTEncoder(cfg, task)
|
51 |
+
return cls(cfg, encoder)
|
52 |
+
|
53 |
+
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
|
54 |
+
encoder_out = self.encoder(src_tokens, src_lengths, **kwargs)
|
55 |
+
decoder_out = self.encoder.w2v_model.decoder(
|
56 |
+
prev_output_tokens, encoder_out=encoder_out, **kwargs
|
57 |
+
)
|
58 |
+
return decoder_out
|
59 |
+
|
60 |
+
def forward_decoder(self, prev_output_tokens, **kwargs):
|
61 |
+
return self.encoder.w2v_model.decoder(prev_output_tokens, **kwargs)
|
62 |
+
|
63 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
64 |
+
"""For decoder decoding."""
|
65 |
+
return self.encoder.w2v_model.decoder.get_normalized_probs(net_output, log_probs, sample)
|
66 |
+
|
67 |
+
@property
|
68 |
+
def decoder(self):
|
69 |
+
return self.encoder.w2v_model.decoder
|
70 |
+
|
71 |
+
|
72 |
+
class SpeechUTEncoder(FairseqEncoder):
|
73 |
+
"""
|
74 |
+
Modified from fairseq.models.hubert.hubert_asr.HubertEncoder
|
75 |
+
1. make it compatible with fairseq speech_to_text task
|
76 |
+
2. make it compatible with encoder-decoder model
|
77 |
+
"""
|
78 |
+
def __init__(self, cfg: SpeechUTS2TConfig, task):
|
79 |
+
self.apply_mask = cfg.apply_mask
|
80 |
+
|
81 |
+
arg_overrides = {
|
82 |
+
"dropout": cfg.dropout,
|
83 |
+
"activation_dropout": cfg.activation_dropout,
|
84 |
+
"dropout_input": cfg.dropout_input,
|
85 |
+
"attention_dropout": cfg.attention_dropout,
|
86 |
+
"mask_length": cfg.mask_length,
|
87 |
+
"mask_prob": cfg.mask_prob,
|
88 |
+
"mask_selection": cfg.mask_selection,
|
89 |
+
"mask_other": cfg.mask_other,
|
90 |
+
"no_mask_overlap": cfg.no_mask_overlap,
|
91 |
+
"mask_channel_length": cfg.mask_channel_length,
|
92 |
+
"mask_channel_prob": cfg.mask_channel_prob,
|
93 |
+
"mask_channel_selection": cfg.mask_channel_selection,
|
94 |
+
"mask_channel_other": cfg.mask_channel_other,
|
95 |
+
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
|
96 |
+
"encoder_layerdrop": cfg.layerdrop,
|
97 |
+
"feature_grad_mult": cfg.feature_grad_mult,
|
98 |
+
}
|
99 |
+
|
100 |
+
if cfg.w2v_args is None:
|
101 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
|
102 |
+
w2v_args = state.get("cfg", None)
|
103 |
+
if w2v_args is None:
|
104 |
+
w2v_args = convert_namespace_to_omegaconf(state["args"])
|
105 |
+
cfg.w2v_args = w2v_args
|
106 |
+
else:
|
107 |
+
state = None
|
108 |
+
w2v_args = cfg.w2v_args
|
109 |
+
if isinstance(w2v_args, Namespace):
|
110 |
+
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
|
111 |
+
|
112 |
+
assert task.data_cfg.standardize_audio() == w2v_args.task.normalize, (
|
113 |
+
"Fine-tuning works best when data normalization is the same. "
|
114 |
+
"Please check that --normalize is set or unset for "
|
115 |
+
"both pre-training and here"
|
116 |
+
)
|
117 |
+
|
118 |
+
pretrain_task = tasks.setup_task(w2v_args.task, load_local_states=False)
|
119 |
+
assert state is not None and "task_state" in state, f"the stored dictionaries not found in checkpoint!"
|
120 |
+
# This will load the stored "dictionaries" object
|
121 |
+
pretrain_task.load_state_dict(state["task_state"])
|
122 |
+
|
123 |
+
model = pretrain_task.build_model(w2v_args.model, from_checkpoint=True)
|
124 |
+
if state is not None and not cfg.no_pretrained_weights:
|
125 |
+
try:
|
126 |
+
model.load_state_dict(state["model"], strict=True)
|
127 |
+
except Exception as e:
|
128 |
+
logger.warn(e)
|
129 |
+
model.load_state_dict(state["model"], strict=False)
|
130 |
+
|
131 |
+
model.remove_pretraining_modules()
|
132 |
+
|
133 |
+
super().__init__(pretrain_task.source_dictionary)
|
134 |
+
|
135 |
+
d = w2v_args.model.encoder_embed_dim
|
136 |
+
|
137 |
+
self.w2v_model = model
|
138 |
+
|
139 |
+
self.final_dropout = nn.Dropout(cfg.final_dropout)
|
140 |
+
self.freeze_finetune_updates = cfg.freeze_finetune_updates
|
141 |
+
self.num_updates = 0
|
142 |
+
|
143 |
+
def set_num_updates(self, num_updates):
|
144 |
+
"""Set the number of parameters updates."""
|
145 |
+
super().set_num_updates(num_updates)
|
146 |
+
self.num_updates = num_updates
|
147 |
+
|
148 |
+
def forward(self, src_tokens=None, src_lengths=None, **kwargs):
|
149 |
+
|
150 |
+
w2v_args = {
|
151 |
+
"source": src_tokens,
|
152 |
+
"padding_mask": lengths_to_padding_mask(src_lengths),
|
153 |
+
"mask": self.apply_mask and self.training,
|
154 |
+
}
|
155 |
+
|
156 |
+
ft = self.freeze_finetune_updates <= self.num_updates
|
157 |
+
|
158 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
159 |
+
x, padding_mask = self.w2v_model.extract_features(**w2v_args)
|
160 |
+
# B x T x C -> T x B x C
|
161 |
+
x = x.transpose(0, 1)
|
162 |
+
|
163 |
+
return {
|
164 |
+
"encoder_out": [x], # T x B x C
|
165 |
+
"encoder_padding_mask": [padding_mask], # B x T
|
166 |
+
"padding_mask": [padding_mask],
|
167 |
+
}
|
168 |
+
|
169 |
+
def forward_torchscript(self, net_input):
|
170 |
+
"""A TorchScript-compatible version of forward.
|
171 |
+
|
172 |
+
Forward the encoder out.
|
173 |
+
"""
|
174 |
+
_net_input = {
|
175 |
+
"source": net_input["src_tokens"],
|
176 |
+
"padding_mask": lengths_to_padding_mask(net_input["src_lengths"]),
|
177 |
+
"mask": False,
|
178 |
+
}
|
179 |
+
|
180 |
+
x, padding_mask = self.w2v_model.extract_features(**_net_input)
|
181 |
+
# B x T x C -> T x B x C
|
182 |
+
x = x.transpose(0, 1)
|
183 |
+
|
184 |
+
encoder_out = {
|
185 |
+
"encoder_out" : [x],
|
186 |
+
"encoder_padding_mask" : [padding_mask],
|
187 |
+
}
|
188 |
+
return encoder_out
|
189 |
+
|
190 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
191 |
+
if encoder_out["encoder_out"] is not None:
|
192 |
+
encoder_out["encoder_out"] = [
|
193 |
+
x.index_select(1, new_order) for x in encoder_out["encoder_out"]
|
194 |
+
]
|
195 |
+
if encoder_out["encoder_padding_mask"] is not None:
|
196 |
+
encoder_out["encoder_padding_mask"] = [
|
197 |
+
x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]
|
198 |
+
]
|
199 |
+
return encoder_out
|
200 |
+
|
201 |
+
def max_positions(self):
|
202 |
+
"""Maximum input length supported by the encoder."""
|
203 |
+
return None
|
204 |
+
|
205 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
206 |
+
return state_dict
|
207 |
+
|
208 |
+
|
209 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
210 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
211 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
|
212 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
213 |
+
return m
|
214 |
+
|
215 |
+
|
216 |
+
def Linear(in_features, out_features, bias=True):
|
217 |
+
m = nn.Linear(in_features, out_features, bias)
|
218 |
+
nn.init.xavier_uniform_(m.weight)
|
219 |
+
if bias:
|
220 |
+
nn.init.constant_(m.bias, 0.0)
|
221 |
+
return m
|
SpeechT5/Speech2S/speech2s/models/t5_transformer_lm.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
|
3 |
+
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
from fairseq.models import (
|
11 |
+
register_model_architecture,
|
12 |
+
)
|
13 |
+
from fairseq.models.transformer_lm import base_lm_architecture
|
14 |
+
|
15 |
+
|
16 |
+
@register_model_architecture(model_name="transformer_lm", arch_name="transformer_lm_t5")
|
17 |
+
def transformer_lm_t5(args):
|
18 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280)
|
19 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6144)
|
20 |
+
args.decoder_layers = getattr(args, "decoder_layers", 20)
|
21 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
22 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
23 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
24 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
25 |
+
base_lm_architecture(args)
|
SpeechT5/Speech2S/speech2s/modules/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 Microsoft
|
3 |
+
# Licensed under The MIT License [see LICENSE for details]
|
4 |
+
# Based on fairseq code bases
|
5 |
+
# https://github.com/facebookresearch/fairseq
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
from .learned_positional_embedding import LearnedPositionalEmbedding
|
9 |
+
from .multihead_attention import MultiheadAttention
|
10 |
+
from .relative_pos_enc import RelativePositionalEncoding
|
11 |
+
from .transformer_layer import TransformerEncoderLayerBase, TransformerDecoderLayerBase
|
12 |
+
from .w2v_encoder import TransformerEncoder, TransformerSentenceEncoderLayer
|
13 |
+
from .transformer_encoder import TransformerEncoderBase
|
14 |
+
from .transformer_decoder import TransformerDecoderScriptable, TransformerDecoderBaseScriptable
|
15 |
+
|
16 |
+
__all__ = [
|
17 |
+
"MultiheadAttention",
|
18 |
+
"RelativePositionalEncoding",
|
19 |
+
"LearnedPositionalEmbedding",
|
20 |
+
"TransformerEncoderLayerBase",
|
21 |
+
"TransformerDecoderLayerBase",
|
22 |
+
"TransformerEncoder",
|
23 |
+
"TransformerSentenceEncoderLayer",
|
24 |
+
"TransformerEncoderBase",
|
25 |
+
"TransformerDecoderScriptable",
|
26 |
+
"TransformerDecoderBaseScriptable",
|
27 |
+
]
|
SpeechT5/Speech2S/speech2s/modules/ctc_prefix_score.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import six
|
8 |
+
|
9 |
+
|
10 |
+
class CTCPrefixScore(object):
|
11 |
+
"""Compute CTC label sequence scores
|
12 |
+
which is based on Algorithm 2 in WATANABE et al.
|
13 |
+
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
14 |
+
but extended to efficiently compute the probablities of multiple labels
|
15 |
+
simultaneously
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, x, blank, eos, xp):
|
19 |
+
self.xp = xp
|
20 |
+
self.logzero = -10000000000.0
|
21 |
+
self.blank = blank
|
22 |
+
self.eos = eos
|
23 |
+
self.input_length = len(x)
|
24 |
+
self.x = x
|
25 |
+
|
26 |
+
def initial_state(self):
|
27 |
+
"""Obtain an initial CTC state
|
28 |
+
:return: CTC state
|
29 |
+
"""
|
30 |
+
# initial CTC state is made of a frame x 2 tensor that corresponds to
|
31 |
+
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
|
32 |
+
# superscripts n and b (non-blank and blank), respectively.
|
33 |
+
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
|
34 |
+
r[0, 1] = self.x[0, self.blank]
|
35 |
+
for i in six.moves.range(1, self.input_length):
|
36 |
+
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
|
37 |
+
return r
|
38 |
+
|
39 |
+
def __call__(self, y, cs, r_prev):
|
40 |
+
"""Compute CTC prefix scores for next labels
|
41 |
+
:param y : prefix label sequence
|
42 |
+
:param cs : array of next labels
|
43 |
+
:param r_prev: previous CTC state
|
44 |
+
:return ctc_scores, ctc_states
|
45 |
+
"""
|
46 |
+
# initialize CTC states
|
47 |
+
output_length = len(y) - 1 # ignore sos
|
48 |
+
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
|
49 |
+
# that corresponds to r_t^n(h) and r_t^b(h).
|
50 |
+
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
|
51 |
+
xs = self.x[:, cs]
|
52 |
+
if output_length == 0:
|
53 |
+
r[0, 0] = xs[0]
|
54 |
+
r[0, 1] = self.logzero
|
55 |
+
else:
|
56 |
+
r[output_length - 1] = self.logzero
|
57 |
+
|
58 |
+
# prepare forward probabilities for the last label
|
59 |
+
r_sum = self.xp.logaddexp(
|
60 |
+
r_prev[:, 0], r_prev[:, 1]
|
61 |
+
) # log(r_t^n(g) + r_t^b(g))
|
62 |
+
last = y[-1]
|
63 |
+
if output_length > 0 and last in cs:
|
64 |
+
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
|
65 |
+
for i in six.moves.range(len(cs)):
|
66 |
+
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
|
67 |
+
else:
|
68 |
+
log_phi = r_sum
|
69 |
+
|
70 |
+
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
|
71 |
+
# and log prefix probabilities log(psi)
|
72 |
+
start = max(output_length, 1)
|
73 |
+
log_psi = r[start - 1, 0]
|
74 |
+
for t in six.moves.range(start, self.input_length):
|
75 |
+
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
|
76 |
+
r[t, 1] = (
|
77 |
+
self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
|
78 |
+
)
|
79 |
+
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
|
80 |
+
|
81 |
+
# get P(...eos|X) that ends with the prefix itself
|
82 |
+
eos_pos = self.xp.where(cs == self.eos)[0]
|
83 |
+
if len(eos_pos) > 0:
|
84 |
+
log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
|
85 |
+
|
86 |
+
# exclude blank probs
|
87 |
+
blank_pos = self.xp.where(cs == self.blank)[0]
|
88 |
+
if len(blank_pos) > 0:
|
89 |
+
log_psi[blank_pos] = self.logzero
|
90 |
+
|
91 |
+
# return the log prefix probability and CTC states, where the label axis
|
92 |
+
# of the CTC states is moved to the first axis to slice it easily
|
93 |
+
return log_psi, self.xp.rollaxis(r, 2)
|
SpeechT5/Speech2S/speech2s/modules/learned_positional_embedding.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 Microsoft
|
3 |
+
# Licensed under The MIT License [see LICENSE for details]
|
4 |
+
# Based on fairseq code bases
|
5 |
+
# https://github.com/facebookresearch/fairseq
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
"""
|
9 |
+
Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/learned_positional_embedding.py
|
10 |
+
1. Add clamping if the input length exceeds the max-source-tokens
|
11 |
+
"""
|
12 |
+
|
13 |
+
from typing import Dict, Optional
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from fairseq import utils
|
19 |
+
from torch import Tensor
|
20 |
+
|
21 |
+
|
22 |
+
class LearnedPositionalEmbedding(nn.Embedding):
|
23 |
+
"""
|
24 |
+
This module learns positional embeddings up to a fixed maximum size.
|
25 |
+
Padding ids are ignored by either offsetting based on padding_idx
|
26 |
+
or by setting padding_idx to None and ensuring that the appropriate
|
27 |
+
position ids are passed to the forward function.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
|
31 |
+
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
32 |
+
self.onnx_trace = False
|
33 |
+
if self.padding_idx is not None:
|
34 |
+
self.max_positions = self.num_embeddings - self.padding_idx - 1
|
35 |
+
else:
|
36 |
+
self.max_positions = self.num_embeddings
|
37 |
+
|
38 |
+
def forward(
|
39 |
+
self,
|
40 |
+
input: Tensor,
|
41 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
42 |
+
positions: Optional[Tensor] = None,
|
43 |
+
):
|
44 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
45 |
+
assert (positions is None) or (
|
46 |
+
self.padding_idx is None
|
47 |
+
), "If positions is pre-computed then padding_idx should not be set."
|
48 |
+
|
49 |
+
if positions is None:
|
50 |
+
if incremental_state is not None:
|
51 |
+
# positions is the same for every token when decoding a single step
|
52 |
+
# Without the int() cast, it doesn't work in some cases when exporting to ONNX
|
53 |
+
positions = torch.zeros(
|
54 |
+
(1, 1), device=input.device, dtype=input.dtype
|
55 |
+
).fill_(int(self.padding_idx + input.size(1)))
|
56 |
+
else:
|
57 |
+
positions = utils.make_positions(
|
58 |
+
input, self.padding_idx, onnx_trace=self.onnx_trace
|
59 |
+
)
|
60 |
+
positions = torch.clamp(positions, max=self.padding_idx + self.max_positions)
|
61 |
+
return F.embedding(
|
62 |
+
positions,
|
63 |
+
self.weight,
|
64 |
+
self.padding_idx,
|
65 |
+
self.max_norm,
|
66 |
+
self.norm_type,
|
67 |
+
self.scale_grad_by_freq,
|
68 |
+
self.sparse,
|
69 |
+
)
|
SpeechT5/Speech2S/speech2s/modules/multihead_attention.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 Microsoft
|
3 |
+
# Licensed under The MIT License [see LICENSE for details]
|
4 |
+
# Based on fairseq code bases
|
5 |
+
# https://github.com/facebookresearch/fairseq
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
from typing import Dict, Optional, Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from fairseq import utils
|
13 |
+
from torch import Tensor
|
14 |
+
|
15 |
+
from fairseq.modules import MultiheadAttention as FairseqMultiheadAttention
|
16 |
+
|
17 |
+
|
18 |
+
class MultiheadAttention(FairseqMultiheadAttention):
|
19 |
+
"""Multi-headed attention.
|
20 |
+
|
21 |
+
See "Attention Is All You Need" for more details.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
embed_dim,
|
27 |
+
num_heads,
|
28 |
+
kdim=None,
|
29 |
+
vdim=None,
|
30 |
+
dropout=0.0,
|
31 |
+
bias=True,
|
32 |
+
add_bias_kv=False,
|
33 |
+
add_zero_attn=False,
|
34 |
+
self_attention=False,
|
35 |
+
encoder_decoder_attention=False,
|
36 |
+
q_noise=0.0,
|
37 |
+
qn_block_size=8,
|
38 |
+
scaling_for_att=1.0
|
39 |
+
):
|
40 |
+
super().__init__(
|
41 |
+
embed_dim,
|
42 |
+
num_heads,
|
43 |
+
kdim,
|
44 |
+
vdim,
|
45 |
+
dropout,
|
46 |
+
bias,
|
47 |
+
add_bias_kv,
|
48 |
+
add_zero_attn,
|
49 |
+
self_attention,
|
50 |
+
encoder_decoder_attention,
|
51 |
+
q_noise,
|
52 |
+
qn_block_size,
|
53 |
+
)
|
54 |
+
self.scaling_for_att = scaling_for_att
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
query,
|
59 |
+
key: Optional[Tensor],
|
60 |
+
value: Optional[Tensor],
|
61 |
+
key_padding_mask: Optional[Tensor] = None,
|
62 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
63 |
+
need_weights: bool = True,
|
64 |
+
static_kv: bool = False,
|
65 |
+
attn_mask: Optional[Tensor] = None,
|
66 |
+
before_softmax: bool = False,
|
67 |
+
need_head_weights: bool = False,
|
68 |
+
position_bias: Optional[Tensor] = None,
|
69 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
70 |
+
"""Input shape: Time x Batch x Channel
|
71 |
+
|
72 |
+
Args:
|
73 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
74 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
75 |
+
padding elements are indicated by 1s.
|
76 |
+
need_weights (bool, optional): return the attention weights,
|
77 |
+
averaged over heads (default: False).
|
78 |
+
attn_mask (ByteTensor, optional): typically used to
|
79 |
+
implement causal attention, where the mask prevents the
|
80 |
+
attention from looking forward in time (default: None).
|
81 |
+
before_softmax (bool, optional): return the raw attention
|
82 |
+
weights and values before the attention softmax.
|
83 |
+
need_head_weights (bool, optional): return the attention
|
84 |
+
weights for each head. Implies *need_weights*. Default:
|
85 |
+
return the average attention weights over all heads.
|
86 |
+
"""
|
87 |
+
if need_head_weights:
|
88 |
+
need_weights = True
|
89 |
+
|
90 |
+
is_tpu = query.device.type == "xla"
|
91 |
+
|
92 |
+
tgt_len, bsz, embed_dim = query.size()
|
93 |
+
src_len = tgt_len
|
94 |
+
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
|
95 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
96 |
+
if key is not None:
|
97 |
+
src_len, key_bsz, _ = key.size()
|
98 |
+
if not torch.jit.is_scripting():
|
99 |
+
assert key_bsz == bsz
|
100 |
+
assert value is not None
|
101 |
+
assert src_len, bsz == value.shape[:2]
|
102 |
+
|
103 |
+
if (
|
104 |
+
not self.onnx_trace
|
105 |
+
and not is_tpu # don't use PyTorch version on TPUs
|
106 |
+
and incremental_state is None
|
107 |
+
and not static_kv
|
108 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
109 |
+
# treats bias in linear module as method.
|
110 |
+
and not torch.jit.is_scripting()
|
111 |
+
and position_bias is None
|
112 |
+
):
|
113 |
+
assert key is not None and value is not None
|
114 |
+
return F.multi_head_attention_forward(
|
115 |
+
query,
|
116 |
+
key,
|
117 |
+
value,
|
118 |
+
self.embed_dim,
|
119 |
+
self.num_heads,
|
120 |
+
torch.empty([0]),
|
121 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
122 |
+
self.bias_k,
|
123 |
+
self.bias_v,
|
124 |
+
self.add_zero_attn,
|
125 |
+
self.dropout_module.p,
|
126 |
+
self.out_proj.weight,
|
127 |
+
self.out_proj.bias,
|
128 |
+
self.training or self.dropout_module.apply_during_inference,
|
129 |
+
key_padding_mask,
|
130 |
+
need_weights,
|
131 |
+
attn_mask,
|
132 |
+
use_separate_proj_weight=True,
|
133 |
+
q_proj_weight=self.q_proj.weight,
|
134 |
+
k_proj_weight=self.k_proj.weight,
|
135 |
+
v_proj_weight=self.v_proj.weight,
|
136 |
+
)
|
137 |
+
|
138 |
+
if incremental_state is not None:
|
139 |
+
saved_state = self._get_input_buffer(incremental_state)
|
140 |
+
if saved_state is not None and "prev_key" in saved_state:
|
141 |
+
# previous time steps are cached - no need to recompute
|
142 |
+
# key and value if they are static
|
143 |
+
if static_kv:
|
144 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
145 |
+
key = value = None
|
146 |
+
else:
|
147 |
+
saved_state = None
|
148 |
+
|
149 |
+
if self.self_attention:
|
150 |
+
q = self.q_proj(query)
|
151 |
+
k = self.k_proj(query)
|
152 |
+
v = self.v_proj(query)
|
153 |
+
elif self.encoder_decoder_attention:
|
154 |
+
# encoder-decoder attention
|
155 |
+
q = self.q_proj(query)
|
156 |
+
if key is None:
|
157 |
+
assert value is None
|
158 |
+
k = v = None
|
159 |
+
else:
|
160 |
+
k = self.k_proj(key)
|
161 |
+
v = self.v_proj(key)
|
162 |
+
|
163 |
+
else:
|
164 |
+
assert key is not None and value is not None
|
165 |
+
q = self.q_proj(query)
|
166 |
+
k = self.k_proj(key)
|
167 |
+
v = self.v_proj(value)
|
168 |
+
q *= self.scaling
|
169 |
+
q *= (1 / self.scaling_for_att)
|
170 |
+
|
171 |
+
if self.bias_k is not None:
|
172 |
+
assert self.bias_v is not None
|
173 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
174 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
175 |
+
if attn_mask is not None:
|
176 |
+
attn_mask = torch.cat(
|
177 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
178 |
+
)
|
179 |
+
if key_padding_mask is not None:
|
180 |
+
key_padding_mask = torch.cat(
|
181 |
+
[
|
182 |
+
key_padding_mask,
|
183 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
184 |
+
],
|
185 |
+
dim=1,
|
186 |
+
)
|
187 |
+
|
188 |
+
q = (
|
189 |
+
q.contiguous()
|
190 |
+
.view(tgt_len, bsz * self.num_heads, self.head_dim)
|
191 |
+
.transpose(0, 1)
|
192 |
+
)
|
193 |
+
if k is not None:
|
194 |
+
k = (
|
195 |
+
k.contiguous()
|
196 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
197 |
+
.transpose(0, 1)
|
198 |
+
)
|
199 |
+
if v is not None:
|
200 |
+
v = (
|
201 |
+
v.contiguous()
|
202 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
203 |
+
.transpose(0, 1)
|
204 |
+
)
|
205 |
+
|
206 |
+
if saved_state is not None:
|
207 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
208 |
+
if "prev_key" in saved_state:
|
209 |
+
_prev_key = saved_state["prev_key"]
|
210 |
+
assert _prev_key is not None
|
211 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
212 |
+
if static_kv:
|
213 |
+
k = prev_key
|
214 |
+
else:
|
215 |
+
assert k is not None
|
216 |
+
k = torch.cat([prev_key, k], dim=1)
|
217 |
+
src_len = k.size(1)
|
218 |
+
if "prev_value" in saved_state:
|
219 |
+
_prev_value = saved_state["prev_value"]
|
220 |
+
assert _prev_value is not None
|
221 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
222 |
+
if static_kv:
|
223 |
+
v = prev_value
|
224 |
+
else:
|
225 |
+
assert v is not None
|
226 |
+
v = torch.cat([prev_value, v], dim=1)
|
227 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
228 |
+
if "prev_key_padding_mask" in saved_state:
|
229 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
230 |
+
assert k is not None and v is not None
|
231 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
232 |
+
key_padding_mask=key_padding_mask,
|
233 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
234 |
+
batch_size=bsz,
|
235 |
+
src_len=k.size(1),
|
236 |
+
static_kv=static_kv,
|
237 |
+
)
|
238 |
+
|
239 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
240 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
241 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
242 |
+
# In this branch incremental_state is never None
|
243 |
+
assert incremental_state is not None
|
244 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
245 |
+
assert k is not None
|
246 |
+
assert k.size(1) == src_len
|
247 |
+
|
248 |
+
# This is part of a workaround to get around fork/join parallelism
|
249 |
+
# not supporting Optional types.
|
250 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
251 |
+
key_padding_mask = None
|
252 |
+
|
253 |
+
if key_padding_mask is not None:
|
254 |
+
assert key_padding_mask.size(0) == bsz
|
255 |
+
assert key_padding_mask.size(1) == src_len
|
256 |
+
|
257 |
+
if self.add_zero_attn:
|
258 |
+
assert v is not None
|
259 |
+
src_len += 1
|
260 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
261 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
262 |
+
if attn_mask is not None:
|
263 |
+
attn_mask = torch.cat(
|
264 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
265 |
+
)
|
266 |
+
if key_padding_mask is not None:
|
267 |
+
key_padding_mask = torch.cat(
|
268 |
+
[
|
269 |
+
key_padding_mask,
|
270 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
271 |
+
key_padding_mask
|
272 |
+
),
|
273 |
+
],
|
274 |
+
dim=1,
|
275 |
+
)
|
276 |
+
|
277 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
278 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
279 |
+
|
280 |
+
if position_bias is not None: ## first order
|
281 |
+
## position_bias: [241, 241, 64]
|
282 |
+
#print ("attn_weights: ", attn_weights.size()) # [492, 241, 241]
|
283 |
+
reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64]
|
284 |
+
#print ("reshape_q: ", reshape_q.size())
|
285 |
+
B = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
|
286 |
+
#print ("B: ", B.size()) ## [241, 492, 241]
|
287 |
+
#B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1))
|
288 |
+
B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1))
|
289 |
+
#print ("B 2: ", B.size())
|
290 |
+
attn_weights += B
|
291 |
+
|
292 |
+
attn_weights *= self.scaling_for_att
|
293 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
294 |
+
|
295 |
+
if attn_mask is not None:
|
296 |
+
attn_mask = attn_mask.unsqueeze(0)
|
297 |
+
if self.onnx_trace:
|
298 |
+
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
299 |
+
attn_weights += attn_mask
|
300 |
+
|
301 |
+
if key_padding_mask is not None:
|
302 |
+
# don't attend to padding symbols
|
303 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
304 |
+
if not is_tpu:
|
305 |
+
attn_weights = attn_weights.masked_fill(
|
306 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
307 |
+
float("-inf"),
|
308 |
+
)
|
309 |
+
else:
|
310 |
+
attn_weights = attn_weights.transpose(0, 2)
|
311 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
312 |
+
attn_weights = attn_weights.transpose(0, 2)
|
313 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
314 |
+
|
315 |
+
if self.scaling_for_att > 1.0:
|
316 |
+
attn_weights = attn_weights - attn_weights.detach().max(dim=-1, keepdim=True)[0]
|
317 |
+
|
318 |
+
if before_softmax:
|
319 |
+
return attn_weights, v
|
320 |
+
|
321 |
+
attn_weights_float = utils.softmax(
|
322 |
+
attn_weights, dim=-1, onnx_trace=self.onnx_trace
|
323 |
+
)
|
324 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
325 |
+
attn_probs = self.dropout_module(attn_weights)
|
326 |
+
|
327 |
+
assert v is not None
|
328 |
+
attn = torch.bmm(attn_probs, v)
|
329 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
330 |
+
if self.onnx_trace and attn.size(1) == 1:
|
331 |
+
# when ONNX tracing a single decoder step (sequence length == 1)
|
332 |
+
# the transpose is a no-op copy before view, thus unnecessary
|
333 |
+
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
|
334 |
+
else:
|
335 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
336 |
+
attn = self.out_proj(attn)
|
337 |
+
attn_weights: Optional[Tensor] = None
|
338 |
+
if need_weights:
|
339 |
+
attn_weights = attn_weights_float.view(
|
340 |
+
bsz, self.num_heads, tgt_len, src_len
|
341 |
+
).transpose(1, 0)
|
342 |
+
if not need_head_weights:
|
343 |
+
# average attention weights over heads
|
344 |
+
attn_weights = attn_weights.mean(dim=0)
|
345 |
+
|
346 |
+
return attn, attn_weights
|
SpeechT5/Speech2S/speech2s/modules/relative_pos_enc.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 Microsoft
|
3 |
+
# Licensed under The MIT License [see LICENSE for details]
|
4 |
+
# Based on fairseq code bases
|
5 |
+
# https://github.com/facebookresearch/fairseq
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
class RelativePositionalEncoding(torch.nn.Module):
|
11 |
+
def __init__(self, d_model, maxlen=1000, embed_v=False):
|
12 |
+
super(RelativePositionalEncoding, self).__init__()
|
13 |
+
|
14 |
+
self.d_model = d_model
|
15 |
+
self.maxlen = maxlen
|
16 |
+
self.pe_k = torch.nn.Embedding(2*maxlen, d_model)
|
17 |
+
if embed_v:
|
18 |
+
self.pe_v = torch.nn.Embedding(2*maxlen, d_model)
|
19 |
+
self.embed_v = embed_v
|
20 |
+
|
21 |
+
|
22 |
+
def forward(self, pos_seq, incremental_state=None):
|
23 |
+
pos_seq[pos_seq < -self.maxlen] = -self.maxlen
|
24 |
+
pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1
|
25 |
+
pos_seq = pos_seq + self.maxlen
|
26 |
+
|
27 |
+
if incremental_state is not None:
|
28 |
+
pos_seq = pos_seq[-1:]
|
29 |
+
|
30 |
+
if self.embed_v:
|
31 |
+
return self.pe_k(pos_seq), self.pe_v(pos_seq)
|
32 |
+
else:
|
33 |
+
return self.pe_k(pos_seq), None
|