amupd commited on
Commit
62e9ca6
·
1 Parent(s): 2347c4b

SpeechT5 upload

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SpeechT5 +0 -1
  2. SpeechT5/CODE_OF_CONDUCT.md +9 -0
  3. SpeechT5/LICENSE +21 -0
  4. SpeechT5/README.md +267 -0
  5. SpeechT5/SECURITY.md +41 -0
  6. SpeechT5/Speech2C/README.md +145 -0
  7. SpeechT5/Speech2C/speech2c/__init__.py +1 -0
  8. SpeechT5/Speech2C/speech2c/config/base_100h.yaml +93 -0
  9. SpeechT5/Speech2C/speech2c/config/base_10h.yaml +104 -0
  10. SpeechT5/Speech2C/speech2c/config/speech2c_base_librispeech.yaml +100 -0
  11. SpeechT5/Speech2C/speech2c/criterions/__init__.py +10 -0
  12. SpeechT5/Speech2C/speech2c/criterions/ctc_ce.py +404 -0
  13. SpeechT5/Speech2C/speech2c/criterions/speech2c_criterion.py +261 -0
  14. SpeechT5/Speech2C/speech2c/data/speech2c_dataset.py +145 -0
  15. SpeechT5/Speech2C/speech2c/models/modules/ctc_prefix_score.py +93 -0
  16. SpeechT5/Speech2C/speech2c/models/modules/multihead_attention.py +341 -0
  17. SpeechT5/Speech2C/speech2c/models/modules/relative_pos_enc.py +35 -0
  18. SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder.py +485 -0
  19. SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder_layer.py +215 -0
  20. SpeechT5/Speech2C/speech2c/models/modules/transformer_encoder.py +278 -0
  21. SpeechT5/Speech2C/speech2c/models/speech2c.py +321 -0
  22. SpeechT5/Speech2C/speech2c/models/speech2c_asr.py +276 -0
  23. SpeechT5/Speech2C/speech2c/models/t5_transformer_lm.py +25 -0
  24. SpeechT5/Speech2C/speech2c/squence_generator.py +1028 -0
  25. SpeechT5/Speech2C/speech2c/tasks/speech2c_pretraining.py +91 -0
  26. SpeechT5/Speech2S/README.md +64 -0
  27. SpeechT5/Speech2S/speech2s/__init__.py +1 -0
  28. SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_base_100h.yaml +101 -0
  29. SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_100h.yaml +102 -0
  30. SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_960h.yaml +100 -0
  31. SpeechT5/Speech2S/speech2s/config/pretrain/speechut_base_librispeech.yaml +153 -0
  32. SpeechT5/Speech2S/speech2s/config/pretrain/speechut_large_librilight.yaml +159 -0
  33. SpeechT5/Speech2S/speech2s/criterions/__init__.py +9 -0
  34. SpeechT5/Speech2S/speech2s/criterions/ctc_ce.py +414 -0
  35. SpeechT5/Speech2S/speech2s/criterions/speechut_criterion.py +384 -0
  36. SpeechT5/Speech2S/speech2s/data/concat_dataset.py +129 -0
  37. SpeechT5/Speech2S/speech2s/data/hubert_dataset.py +597 -0
  38. SpeechT5/Speech2S/speech2s/data/language_trible_dataset.py +669 -0
  39. SpeechT5/Speech2S/speech2s/data/load_langpair_dataset.py +172 -0
  40. SpeechT5/Speech2S/speech2s/data/multimodal_corpus_dataset.py +368 -0
  41. SpeechT5/Speech2S/speech2s/models/__init__.py +0 -0
  42. SpeechT5/Speech2S/speech2s/models/speechut.py +785 -0
  43. SpeechT5/Speech2S/speech2s/models/speechut_asr.py +165 -0
  44. SpeechT5/Speech2S/speech2s/models/speechut_st.py +221 -0
  45. SpeechT5/Speech2S/speech2s/models/t5_transformer_lm.py +25 -0
  46. SpeechT5/Speech2S/speech2s/modules/__init__.py +27 -0
  47. SpeechT5/Speech2S/speech2s/modules/ctc_prefix_score.py +93 -0
  48. SpeechT5/Speech2S/speech2s/modules/learned_positional_embedding.py +69 -0
  49. SpeechT5/Speech2S/speech2s/modules/multihead_attention.py +346 -0
  50. SpeechT5/Speech2S/speech2s/modules/relative_pos_enc.py +33 -0
SpeechT5 DELETED
@@ -1 +0,0 @@
1
- Subproject commit 8b5ade783571e63450aaa5507444150dcb08fa94
 
 
SpeechT5/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Microsoft Open Source Code of Conduct
2
+
3
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4
+
5
+ Resources:
6
+
7
+ - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8
+ - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9
+ - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
SpeechT5/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
SpeechT5/README.md ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SpeechT5
2
+
3
+ Unified-modal speech-text pre-training for spoken language processing:
4
+
5
+ > [**SpeechT5**](https://arxiv.org/abs/2110.07205) (```ACL 2022```): **SpeechT5: Unified-Modal Encoder-Decoder Pre-training for Spoken Language Processing**
6
+
7
+ > [**Speech2C**](https://arxiv.org/abs/2203.17113) (```INTERSPEECH 2022```): **Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data**
8
+
9
+ > [**YiTrans**](https://arxiv.org/abs/2206.05777) (```IWSLT 2022```): **The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task**
10
+
11
+ > [**SpeechUT**](https://arxiv.org/abs/2210.03730) (```EMNLP 2022```): **SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training**
12
+
13
+ > [**SpeechLM**](https://arxiv.org/abs/2209.15329) (```Arxiv 2022```): **SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data**
14
+
15
+ > [**Speech2S**](https://arxiv.org/abs/2210.17027) (```ICASSP 2023```): **Joint Pre-Training with Speech and Bilingual Text for Direct Speech to Speech Translation**
16
+
17
+ > [**Prosody-SpeechT5**](https://ieeexplore.ieee.org/document/10096530/) (```ICASSP 2023```): **Prosody-aware SpeechT5 for Expressive Neural TTS**
18
+
19
+ > [**VATLM**](https://arxiv.org/abs/2211.11275) (```IEEE Transactions on Multimedia```): **VATLM: Visual-Audio-Text Pre-Training with Unified Masked Prediction for Speech Representation Learning**
20
+
21
+ > [**VALL-E X**](https://arxiv.org/abs/2303.03926) (```Arxiv 2023```): **Speak Foreign Languages with Your Own Voice: Cross-Lingual Neural Codec Language Modeling**
22
+
23
+ > [**VioLA**](https://arxiv.org/abs/2305.16107) (```Arxiv 2023```): **VioLA: Unified Codec Language Models for Speech Recognition, Synthesis, and Translation**
24
+
25
+ <!-- Model introductions, evaluation results, and model inference instructions are located in the corresponding folders. The source code is [https://github.com/microsoft/SpeechT5/tree/main/ModelName]. -->
26
+
27
+
28
+ ## Update
29
+
30
+ - May, 2023: VioLA [**Arxiv**](https://arxiv.org/abs/2305.16107).
31
+ - May, 2023: [**VATLM**](https://arxiv.org/abs/2211.11275) was accepted by IEEE Transactions on Multimedia.
32
+ - March, 2023: VALL-E X [**Arxiv**](https://arxiv.org/abs/2303.03926) and [**Demo**](https://aka.ms/vallex).
33
+ - February, 2023: [**Speech2S**](https://arxiv.org/abs/2210.17027) and [**Prosody-SpeechT5**](https://arxiv.org/abs/2211.11275) were accepted by ICASSP 2023.
34
+ - [HuggingFace Integration] February, 2023: [**SpeechT5**](https://aclanthology.org/2022.acl-long.393/) models are on [**HuggingFace**](https://huggingface.co/blog/speecht5).
35
+ - [Model Release] November, 2022: [**VATLM**](https://github.com/microsoft/SpeechT5/tree/main/VATLM) models are released.
36
+ - November, 2022: VATLM [**Arxiv**](https://arxiv.org/abs/2211.11275).
37
+ - November, 2022: Speech2S [**Arxiv**](https://arxiv.org/abs/2210.17027).
38
+ - [Model Release] October, 2022: [**SpeechUT**](https://github.com/microsoft/SpeechT5/tree/main/SpeechUT) models are released.
39
+ - October, 2022: [**SpeechUT**](https://arxiv.org/abs/2210.03730) was accepted by EMNLP 2022.
40
+ - [Model Release] October, 2022: [**SpeechLM**](https://github.com/microsoft/SpeechT5/tree/main/SpeechLM) models are released.
41
+ - September, 2022: SpeechLM [**Arxiv**](https://arxiv.org/abs/2209.15329).
42
+ - [Evaluation] June, 2022: The end-to-end ST system [**YiTrans**](https://arxiv.org/abs/2206.05777) achieved top results on [**IWSLT 2022**](https://iwslt.org/2022/offline) shared tasks.
43
+ - June, 2022: [**Speech2C**](https://www.isca-speech.org/archive/interspeech_2022/ao22_interspeech.html) was accepted by InterSpeech 2022.
44
+ - [Model Release] May, 2022: [**Speech2C**](https://github.com/microsoft/SpeechT5/tree/main/Speech2C) models are released.
45
+ - [Model Release] April, 2022: [**SpeechT5**](https://github.com/microsoft/SpeechT5/tree/main/SpeechT5) models are released.
46
+ - March, 2022: Speech2C [**Arxiv**](https://arxiv.org/abs/2203.17113).
47
+ - February, 2022: [**SpeechT5**](https://aclanthology.org/2022.acl-long.393/) was accepted by ACL 2022.
48
+ - October, 2021: SpeechT5 [**Arxiv**](https://arxiv.org/abs/2110.07205).
49
+
50
+
51
+ ## Pre-Trained Models
52
+
53
+
54
+ | Model | Pre-training Dataset | Fine-tuning Dataset | Model |
55
+ | :------: | :----------------------------------------------: | :-----------------: | :-----: |
56
+ | SpeechT5 Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [LibriSpeech LM Dataset](https://www.openslr.org/11/) | - | [HuggingFace](https://huggingface.co/ajyy/SpeechT5/resolve/main/speecht5_base.pt)<br /> [Google Drive](https://drive.google.com/file/d/1Sq00uZ1pw6Z4OUaqhOWzQEJxIVWgAO5U/view?usp=sharing) |
57
+ | SpeechT5 Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [LibriSpeech LM Dataset](https://www.openslr.org/11/) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [HuggingFace](https://huggingface.co/ajyy/SpeechT5/resolve/main/speecht5_base_asr.pt)<br /> [Google Drive](https://drive.google.com/file/d/1qLKJ81JPWOGf1MHfjSmgtZyqqTqgI6kT/view?usp=sharing) |
58
+ | SpeechT5 Large | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [LibriSpeech LM Dataset](https://www.openslr.org/11/) | - | [Google Drive](https://drive.google.com/file/d/1M79b1jetSPOVxWVMIX-y0URvDjNskZKp/view?usp=sharing) |
59
+ | Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | - | [Google Drive](https://drive.google.com/file/d/1nGZ0LWEwlLq2pz7o805YALsMr9irV0Za/view?usp=sharing) |
60
+ | Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [10 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1nWSAc-33LmcDQHzH8IjXVJsuk0JZTWgN/view?usp=sharing) |
61
+ | Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1LwbQ5Y3tKZoK3s1ayLQgsfLTFnmkKNZs/view?usp=sharing) |
62
+ | SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Google drive](https://drive.google.com/file/d/1iJvhSGghNrMT-wAY1nwVu2YaYuTy1pxx/view?usp=sharing) |
63
+ | SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google drive](https://drive.google.com/file/d/1mH3N7iKMWYk3rSBJErQPYf3x5ugqDq5x/view?usp=sharing) |
64
+ | SpeechLM-H Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Google drive](https://drive.google.com/file/d/1eblW8U8f9t-NTuCNRrNHwr-8BeLAUAmQ/view?usp=sharing) |
65
+ | SpeechLM-H Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google drive](https://drive.google.com/file/d/1vXyO5DolbiWiTYZ6pkkKQsu2wJetaPlv/view?usp=sharing) |
66
+ | SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-De CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/finetune_covost/checkpoint_ende.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
67
+ | SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Ca CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/finetune_covost/checkpoint_enca.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
68
+ | SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Ar CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/finetune_covost/checkpoint_enar.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
69
+ | SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Tr CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage](https://valle.blob.core.windows.net/share/speechlm/finetune_covost/checkpoint_entr.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
70
+ | SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | - | [Google drive](https://drive.google.com/file/d/1QjLIgTJKIylVIp5hUkfSjGPtz8Xo7Lky/view?usp=sharing) |
71
+ | SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [960 hrs LibriSpeech](http://www.openslr.org/12) | [Google drive](https://drive.google.com/file/d/1YZQDVv096o8Opt0RBnkRiZXYPRDqKZnP/view?usp=sharing) |
72
+ | SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-De CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1qYygNWSc11TQbBI1OzC4ChlR-dNh8t9S/view?usp=sharing) |
73
+ | SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Ca CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/162U88mwso2aVfzzPkEM2nP_vwTpcb57T/view?usp=sharing) |
74
+ | SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Ar CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1lbTSRXewEeb2t45URunD6EiJcbniyjWW/view?usp=sharing) |
75
+ | SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Tr CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1Er4I_jHS175pQQph223yKtiiLQ378VvH/view?usp=sharing) |
76
+ | SpeechUT Base (ASR) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4asr_32gpu_1accum/checkpoint_298_400000.pt?sv=2020-04-08&st=2023-03-08T01%3A39%3A48Z&se=2024-03-09T01%3A39%3A00Z&sr=b&sp=r&sig=l3gJS1D%2BJfLfNfS3z33WjmSMGrOBJ63CvqGGewC6WeU%3D)|
77
+ | SpeechUT Base (ASR) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/speechut_base_asr100h_checkpoint_best.pt?sv=2020-04-08&st=2023-03-08T01%3A41%3A22Z&se=2024-03-09T01%3A41%3A00Z&sr=b&sp=r&sig=%2B9lpGrqtZXa%2F6n1uZT%2Biey54ky31bYKSJytgfnBbbN4%3D)|
78
+ | SpeechUT Large (ASR) | [60k hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/large_speechut4asr_32gpu_4accum/checkpoint_22_400k.pt?sv=2020-04-08&st=2023-03-08T01%3A42%3A10Z&se=2024-03-09T01%3A42%3A00Z&sr=b&sp=r&sig=TZNcsHQAqapyj%2BAvpHtl749kZy9flTkWm8P5L4W26qs%3D)|
79
+ | SpeechUT Large (ASR) | [60k hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [960 hrs LibriSpeech](http://www.openslr.org/12) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/speechut_large_asr960h_checkpoint_best.pt?sv=2020-04-08&st=2023-03-08T01%3A43%3A02Z&se=2024-03-09T01%3A43%3A00Z&sr=b&sp=r&sig=PmO%2BgSAMXRgMC7GfpS4c%2BrDPsfJGekqUzD5AJm7RrYU%3D)|
80
+ | SpeechUT Base (En-De) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [408 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [4.6M Text](https://www.statmt.org/wmt16/) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4ende_32gpu_1accum/checkpoint_217_400000.pt?sv=2020-04-08&st=2023-03-08T01%3A43%3A47Z&se=2024-03-09T01%3A43%3A00Z&sr=b&sp=r&sig=XDEesMdGQ027j7YtpSql1kZtwgfv39gruOuWYlKlJ7w%3D)|
81
+ | SpeechUT Base (En-De) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [408 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [4.6M Text](https://www.statmt.org/wmt16/) | [En-De MuST-C v1](https://ict.fbk.eu/must-c/) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4ende_32gpu_1accum/fineutne_ende_checkpoint_avg.pt?sv=2020-04-08&st=2023-03-08T01%3A44%3A15Z&se=2024-03-09T01%3A44%3A00Z&sr=b&sp=r&sig=8dcenahRg46EJdwiHUalVBJgKra6JoSN7tUxdLAwzOM%3D)|
82
+ | SpeechUT Base (En-Es) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [504 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [15M Text](https://www.statmt.org/wmt13/) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4enes_32gpu_1accum/checkpoint_204_400000.pt?sv=2020-04-08&st=2023-03-08T01%3A48%3A16Z&se=2024-03-09T01%3A48%3A00Z&sr=b&sp=r&sig=hWoCM0y0SGZTD4CznC%2F5CejFczkqDYTOaISmlhCAYAU%3D)|
83
+ | SpeechUT Base (En-Es) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [504 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [15M Text](https://www.statmt.org/wmt13/) | [En-Es MuST-C v1](https://ict.fbk.eu/must-c/) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4enes_32gpu_1accum/fineutne_enes_checkpoint_avg.pt?sv=2020-04-08&st=2023-03-08T01%3A48%3A41Z&se=2024-03-09T01%3A48%3A00Z&sr=b&sp=r&sig=KGfzgKfKkDVQI0JxxnS%2BsYdBQzhUqFLQAVYG0OSGBtk%3D)|
84
+ | SpeechUT Base (En-Fr) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [492 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [40M Text](https://www.statmt.org/wmt14/) | - | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4enfr_32gpu_1accum/checkpoint_297_600000.pt?sv=2020-04-08&st=2023-03-08T01%3A49%3A09Z&se=2024-03-09T01%3A49%3A00Z&sr=b&sp=r&sig=1eqpXMLCjWpfyd7AiOHGzfk%2B8ZYqWwVWdHk1GqXgoeg%3D)|
85
+ | SpeechUT Base (En-Fr) | [960 hrs LibriSpeech](http://www.openslr.org/12) + [492 hrs MuST-C v1](https://ict.fbk.eu/must-c/) + [40M Text](https://www.statmt.org/wmt14/) | [En-Fr MuST-C v1](https://ict.fbk.eu/must-c/) | [Azure Storage](https://valle.blob.core.windows.net/share/speechut/base_speechut4enfr_32gpu_1accum/fineutne_enfr_checkpoint.pt?sv=2020-04-08&st=2023-03-08T01%3A49%3A34Z&se=2024-03-09T01%3A49%3A00Z&sr=b&sp=r&sig=i3jMAqvL1Vp7DRjACAbrdoQKhlv2Cmi40%2F14SJ6%2BoiU%3D)|
86
+
87
+
88
+
89
+ ## SpeechT5 Introduction
90
+
91
+ Motivated by the success of T5 (Text-To-Text Transfer Transformer) in pre-trained natural language processing models, we propose a unified-modal SpeechT5 framework that explores the encoder-decoder pre-training for self-supervised speech/text representation learning.
92
+ The SpeechT5 framework consists of a shared encoder-decoder network and six modal-specific (speech/text) pre/post-nets.
93
+ After preprocessing the input speech/text through the pre-nets, the shared encoder-decoder network models the sequence-to-sequence transformation, and then the post-nets generate the output in the speech/text modality based on the output of the decoder.
94
+
95
+ <img src="SpeechT5/speecht5_framework.png" alt="se" width="1000" />
96
+
97
+ Leveraging large-scale unlabeled speech and text data, we pre-train SpeechT5 to learn a unified-modal representation, hoping to improve the modeling capability for both speech and text.
98
+ To align the textual and speech information into this unified semantic space, we propose a cross-modal vector quantization approach that randomly mixes up speech/text states with latent units as the interface between encoder and decoder.
99
+ Extensive evaluations show the superiority of the proposed SpeechT5 framework on a wide variety of spoken language processing tasks, including automatic speech recognition, speech synthesis, speech translation, voice conversion, speech enhancement, and speaker identification.
100
+
101
+ <!--
102
+ Model introductions, evaluation results, and model inference instructions are located in the corresponding folders. The source code is here [https://github.com/microsoft/SpeechT5/tree/main/SpeechT5].
103
+ -->
104
+
105
+ ## SpeechT5 Downstream Task Performance
106
+
107
+ We evaluate our models on typical spoken language processing tasks, including automatic speech recognition, text to speech, speech to text translation, voice conversion, speech enhancement, and speaker identification.
108
+
109
+ ### Automatic Speech Recognition
110
+
111
+ Evaluation on the [LibriSpeech](http://www.openslr.org/12)
112
+
113
+ | Model |LM | dev-clean | dev-other | test-clean | test-other |
114
+ | ------------- |------------- | ------| ----- | ----| ----|
115
+ | wav2vec2.0 Base | - | 6.1 | 13.5 | 6.1 | 13.3 |
116
+ | HuBERT Base | - | 5.5 | 13.1 | 5.8 | 13.3 |
117
+ | Baseline (w/o CTC) | - | 5.8 | 12.3 | 6.2 | 12.3 |
118
+ | Baseline | - | 4.9 | 11.7 | 5.0 | 11.9 |
119
+ | SpeechT5 (w/o CTC) | - | 5.4 | 10.7 | 5.8 | 10.7 |
120
+ | **SpeechT5** | - | **4.3** | **10.3** | **4.4** | **10.4** |
121
+ | DiscreteBERT | 4-gram | 4.0 |10.9 |4.5 |12.1 |
122
+ | wav2vec 2.0 Base | 4-gram | 2.7 |7.9 |3.4 |8.0 |
123
+ | HuBERT Base | 4-gram | 2.7 |7.8 |3.4 |8.1 |
124
+ | wav2vec 2.0 Base | Transf. | 2.2 |6.3 |2.6 |6.3 |
125
+ | Baseline | Transf. | 2.3 |6.3 |2.5 |6.3 |
126
+ | **SpeechT5** | Transf. | **2.1** |**5.5** |**2.4** |**5.8** |
127
+
128
+ ### Text-to-Speech
129
+
130
+ Evaluation on the [LibriTTS](http://www.openslr.org/60/)
131
+
132
+
133
+ | Model | Naturalness | MOS | CMOS |
134
+ | ------------- |------------ | ------ | ----- |
135
+ | Ground Truth | - | 3.87 | - |
136
+ | Baseline | 2.76 | 3.56 | 0 |
137
+ | **SpeechT5** | 2.91 | **3.65** | **+0.290** |
138
+
139
+ ### Speech Translation
140
+
141
+ Evaluation on the [MUST-C v1](https://ict.fbk.eu/must-c/)
142
+
143
+ | Model | EN-DE | EN-FR |
144
+ | ------------- |------------ | ------ |
145
+ | Fairseq ST | 22.70 | 32.90 |
146
+ | ESPnet ST | 22.91 | 32.69 |
147
+ | Adapter Tuning| 24.63 | 34.98 |
148
+ | Baseline | 23.43 | 33.76 |
149
+ | SpeechT5 (w/o initializing decoder) | 24.44 | 34.5 |
150
+ | **SpeechT5** | **25.18** | **35.30** |
151
+
152
+
153
+ ### Voice Conversion
154
+
155
+ Evaluation on the [CMU Arctic](http://www.festvox.org/cmu_arctic/)
156
+
157
+
158
+ | Model | WER | WER | MCD | MCD |
159
+ | ------------- | ------ | ----- | ---- | ----|
160
+ | | bdl to slt | clb to slt | bdl to slt | clb to slt |
161
+ | VTN w/ ASR | 11.1 | 10.9 | 6.5 | 6.11 |
162
+ | VTN w/ TTS | 7.6 | 9.1 | 6.33 | 13.3 |
163
+ | Many-to-many VTN | - | - | 6.13 | 5.97 |
164
+ | Baseline | 21.5 | 10.8 | 6.26 | 6.16 |
165
+ | **SpeechT5** | **7.8** | **6.4** | **5.93**| **5.87** |
166
+
167
+
168
+
169
+ ### Speech Enhancement
170
+
171
+ Evaluation on the [WSJ0 Hipster AmbientMixtures (WHAM!)](http://wham.whisper.ai/)
172
+
173
+
174
+ | Model | WER |
175
+ | ------------- |------------ |
176
+ | Ground Truth Speech | 3.2 |
177
+ | Noisy Speech | 76.1 |
178
+ | Baseline | 10.9 |
179
+ | **SpeechT5** | **8.9** |
180
+
181
+
182
+ ### Speaker Identification
183
+
184
+ Evaluation on the [VoxCeleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html)
185
+
186
+ | Model | Acc |
187
+ | ------------- |------------ |
188
+ | SUPERB, wav2vec 2.0 Base | 75.18% |
189
+ | SUPERB, HuBERT Base | 81.42% |
190
+ | SUPERB, HuBERT Large | 90.33% |
191
+ | SpeechNet, single task | 86.00% |
192
+ | SpeechNet, multi-task with TTS | 87.90% |
193
+ | Thin ResNet-34 | 89.00% |
194
+ | Baseline | 91.92% |
195
+ | **SpeechT5** | **96.49%** |
196
+
197
+ ## License
198
+
199
+ This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
200
+ Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) and [ESPnet](https://github.com/espnet/espnet) projects.
201
+
202
+ [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
203
+
204
+ ### Reference
205
+
206
+ If you find our work is useful in your research, please cite the following paper:
207
+
208
+ ```bibtex
209
+ @article{Ao2021SpeechT5,
210
+ title = {SpeechT5: Unified-Modal Encoder-Decoder Pre-training for Spoken Language Processing},
211
+ author = {Junyi Ao and Rui Wang and Long Zhou and Chengyi Wang and Shuo Ren and Yu Wu and Shujie Liu and Tom Ko and Qing Li and Yu Zhang and Zhihua Wei and Yao Qian and Jinyu Li and Furu Wei},
212
+ eprint={2110.07205},
213
+ archivePrefix={arXiv},
214
+ primaryClass={eess.AS},
215
+ year={2021}
216
+ }
217
+ ```
218
+
219
+ ```bibtex
220
+ @article{Ao2022Speech2C,
221
+ title = {Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data},
222
+ author = {Junyi Ao and Ziqiang Zhang and Long Zhou and Shujie Liu and Haizhou Li and Tom Ko and Lirong Dai and Jinyu Li and Yao Qian and Furu Wei},
223
+ eprint={2203.17113},
224
+ archivePrefix={arXiv},
225
+ primaryClass={cs.SD},
226
+ year={2022}
227
+ }
228
+ ```
229
+
230
+ ```bibtex
231
+ @article{Zhang2022Yitrans,
232
+ title = {The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task},
233
+ author = {Zhang, Ziqiang and Ao, Junyi and Zhou, Long and Liu, Shujie and Wei, Furu and Li, Jinyu},
234
+ eprint={2206.05777},
235
+ archivePrefix={arXiv},
236
+ primaryClass={cs.CL},
237
+ year={2022}
238
+ }
239
+ ```
240
+
241
+ ```bibtex
242
+ @article{zhang2022speechut,
243
+ title = {SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training},
244
+ author = {Zhang, Ziqiang and Zhou, Long and Ao, Junyi and Liu, Shujie and Dai, Lirong and Li, Jinyu and Wei, Furu},
245
+ eprint={2210.03730},
246
+ archivePrefix={arXiv},
247
+ primaryClass={cs.CL},
248
+ year={2022}
249
+ }
250
+ ```
251
+
252
+ ```bibtex
253
+ @article{zhang2022speechlm,
254
+ title = {SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data},
255
+ author = {Zhang, Ziqiang and Chen, Sanyuan and Zhou, Long and Wu, Yu and Ren, Shuo and Liu, Shujie and Yao, Zhuoyuan and Gong, Xun and Dai, Lirong and Li, Jinyu and Wei, Furu},
256
+ eprint={2209.15329},
257
+ archivePrefix={arXiv},
258
+ primaryClass={cs.CL},
259
+ year={2022}
260
+ }
261
+ ```
262
+
263
+ ### Contact Information
264
+
265
+ For help or issues using SpeechT5 models, please submit a GitHub issue.
266
+
267
+ For other communications related to SpeechT5, please contact Long Zhou (`lozhou@microsoft.com`).
SpeechT5/SECURITY.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- BEGIN MICROSOFT SECURITY.MD V0.0.7 BLOCK -->
2
+
3
+ ## Security
4
+
5
+ Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6
+
7
+ If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8
+
9
+ ## Reporting Security Issues
10
+
11
+ **Please do not report security vulnerabilities through public GitHub issues.**
12
+
13
+ Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14
+
15
+ If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16
+
17
+ You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18
+
19
+ Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20
+
21
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22
+ * Full paths of source file(s) related to the manifestation of the issue
23
+ * The location of the affected source code (tag/branch/commit or direct URL)
24
+ * Any special configuration required to reproduce the issue
25
+ * Step-by-step instructions to reproduce the issue
26
+ * Proof-of-concept or exploit code (if possible)
27
+ * Impact of the issue, including how an attacker might exploit the issue
28
+
29
+ This information will help us triage your report more quickly.
30
+
31
+ If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32
+
33
+ ## Preferred Languages
34
+
35
+ We prefer all communications to be in English.
36
+
37
+ ## Policy
38
+
39
+ Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40
+
41
+ <!-- END MICROSOFT SECURITY.MD BLOCK -->
SpeechT5/Speech2C/README.md ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Speech2C
2
+
3
+ > [**Speech2C**](https://arxiv.org/abs/2203.17113) (```INTERSPEECH 2022```): **Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data**
4
+
5
+ ## Pre-Trained and Fine-tuned Models
6
+
7
+ | Model | Pre-training Dataset | Fine-tuning Dataset | Model |
8
+ | :------: | :----------------------------------------------: | :-----------------: | :-----: |
9
+ | Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | - | [Google Drive](https://drive.google.com/file/d/1nGZ0LWEwlLq2pz7o805YALsMr9irV0Za/view?usp=sharing) |
10
+ | Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [10 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1nWSAc-33LmcDQHzH8IjXVJsuk0JZTWgN/view?usp=sharing) |
11
+ | Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1LwbQ5Y3tKZoK3s1ayLQgsfLTFnmkKNZs/view?usp=sharing) |
12
+
13
+
14
+ ## Language Model and Vocabulary
15
+ | Model | Dataset | Model | Vocabulary |
16
+ | :------: | :------: | :---: | :--------: |
17
+ | LM | [LibriSpeech LM Dataset](https://www.openslr.org/11/) | [Model](https://drive.google.com/file/d/1UDCcNJT1DlquSRw0iRAXH6GHlf6zK6-8/view?usp=sharing) | [Vocabulary](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt) |
18
+
19
+ ## Setup
20
+ ```
21
+ git submodule update --init Speech2C/fairseq
22
+ cd Speech2C/
23
+ pip install --editable fairseq/
24
+ ```
25
+
26
+ ## Data Preparation
27
+ Please follow the steps of data preparation for HuBERT in [here](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#data-preparation).
28
+
29
+ ## Pre-Training
30
+ ```
31
+ DATA_DIR=
32
+ LABEL_DIR=
33
+ FAIRSEQ_PATH=
34
+
35
+ python ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \
36
+ --config-dir speech2c/config \
37
+ --config-name speech2c_base_librispeech \
38
+ task.data=${DATA_DIR} task.label_dir=${LABEL_DIR} task.labels='["km"]' \
39
+ model.label_rate=50 common.user_dir=SpeechT5/Speech2C/speech2c \
40
+ ```
41
+
42
+ ## Finetune
43
+
44
+ ```
45
+ DATA_DIR=
46
+ LABEL_DIR=
47
+ FAIRSEQ_PATH=
48
+ W2V_PATH=
49
+ CONFIG_NAME=
50
+
51
+ python ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \
52
+ --config-dir speech2c/config \
53
+ --config-name ${CONFIG_NAME} \
54
+ task.data=${DATA_DIR} task.label_dir=${LABEL_DIR} \
55
+ model.w2v_path=${W2V_PATH} common.user_dir=SpeechT5/Speech2C/speech2c \
56
+ ```
57
+
58
+ ## Inference
59
+ Note that joint CTC and decoder inference is only supported when the batch size is 1.
60
+
61
+ ```
62
+ FAIRSEQ_PATH=
63
+ DATA_DIR=
64
+ LABEL_DIR=
65
+ BEAM_SIZE=
66
+ CTC_WEIGHT=
67
+ TEST_SET=
68
+ CHECKPOINT_PATH=
69
+ W2V_PATH=
70
+
71
+
72
+ python ${FAIRSEQ_PATH}/fairseq_cli/generate.py ${DATA_DIR} \
73
+ --label-dir ${LABEL_DIR} \
74
+ --path ${CHECKPOINT_PATH} \
75
+ --user-dir SpeechT5/Speech2C/speech2c \
76
+ --model-overrides "{'w2v_path': '${W2V_PATH}'}" \
77
+ --gen-subset ${TEST_SET} \
78
+ --task speech2c_pretraining \
79
+ --post-process letter \
80
+ --add-decoder \
81
+ --labels '["ltr"]' \
82
+ --fine-tuning \
83
+ --scoring wer \
84
+ --max-len-a 0 \
85
+ --max-len-b 620 \
86
+ --pad-audio \
87
+ --random-crop \
88
+ --ctc-weight ${CTC_WEIGHT} \
89
+ --max-tokens 8000000 \
90
+ --beam ${BEAM_SIZE} \
91
+ --single-target \
92
+ ```
93
+
94
+ ## Results on Librispeech
95
+
96
+ ### Evaluation on the [LibriSpeech](http://www.openslr.org/12) 10hr subset
97
+
98
+ | Model |LM | test-clean | test-other |
99
+ | ------------- |------------- | ----| ----|
100
+ | wav2vec2.0 Base | - | 11.1 | 17.6 |
101
+ | HuBERT Base | - | 10.1 | 16.8 |
102
+ | **Speech2C** | - | **7.8** | **13.1** |
103
+ | wav2vec 2.0 Base | 4-gram | 4.3 |9.5 |
104
+ | wav2vec 2.0 Base | Transf. |3.2 |7.8 |
105
+ | HuBERT Base | 4-gram |4.3 |9.4 |
106
+ | **Speech2C** | **Transf.** | **3.1** | **7.0** |
107
+
108
+ ### Evaluation on the [LibriSpeech](http://www.openslr.org/12) 100hr subset
109
+
110
+ | Model |LM | test-clean | test-other |
111
+ | ------------- |------------- | ----| ----|
112
+ | wav2vec2.0 Base | - | 6.1 | 13.3 |
113
+ | wav2vec2.0 Large | - | 4.7 | 9.0 |
114
+ | HuBERT Base | - | 6.3 | 13.2 |
115
+ | SpeechT5 | - | 4.4 | 10.4 |
116
+ | Baseline | - | 5.0 | 11.9 |
117
+ | **Speech2C** | - | **4.3** |**9.0** |
118
+ | wav2vec 2.0 Base | 4-gram | 3.4 |8.0 |
119
+ | wav2vec 2.0 Base | Transf. | 2.6 | 6.3 |
120
+ | HuBERT Base | 4-gram | 3.4 |8.1 |
121
+ | SpeechT5 | Transf. | 2.4 |5.8 |
122
+ | Baseline | Transf. | 2.5 |6.3 |
123
+ | **Speech2C** | **Transf.** | **2.4** |**5.2** |
124
+
125
+ ## License
126
+
127
+ This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
128
+ Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq).
129
+
130
+ [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
131
+
132
+ ## Reference
133
+
134
+ If you find our work is useful in your research, please cite the following paper:
135
+
136
+ ```bibtex
137
+ @article{Ao2022Speech2C,
138
+ title = {Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data},
139
+ author = {Junyi Ao and Ziqiang Zhang and Long Zhou and Shujie Liu and Haizhou Li and Tom Ko and Lirong Dai and Jinyu Li and Yao Qian and Furu Wei},
140
+ eprint={2203.17113},
141
+ archivePrefix={arXiv},
142
+ primaryClass={cs.SD},
143
+ year={2022}
144
+ }
145
+ ```
SpeechT5/Speech2C/speech2c/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import data, tasks, criterions, models # noqa
SpeechT5/Speech2C/speech2c/config/base_100h.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ tensorboard_logdir: tblog
8
+ seed: 1337
9
+
10
+ checkpoint:
11
+ no_epoch_checkpoints: true
12
+ best_checkpoint_metric: dec_accuracy
13
+ maximize_best_checkpoint_metric: true
14
+
15
+ distributed_training:
16
+ ddp_backend: c10d
17
+ find_unused_parameters: true
18
+ distributed_world_size: 1
19
+ distributed_port: 29671
20
+ nprocs_per_node: 8
21
+
22
+ task:
23
+ _name: speech2c_pretraining
24
+ data: ???
25
+ fine_tuning: true
26
+ label_dir: ???
27
+ normalize: false # must be consistent with pre-training
28
+ labels: ["ltr"]
29
+ single_target: true
30
+ add_decoder: true
31
+ pad_audio: true
32
+ random_crop: false
33
+
34
+ dataset:
35
+ num_workers: 6
36
+ max_tokens: 3200000
37
+ skip_invalid_size_inputs_valid_test: true
38
+ train_subset: train_100h
39
+ valid_subset: dev_other
40
+
41
+ criterion:
42
+ _name: ctc_ce
43
+ zero_infinity: true
44
+
45
+ optimization:
46
+ max_update: 80000
47
+ lr: [0.00004]
48
+ sentence_avg: true
49
+ update_freq: [1]
50
+
51
+ optimizer:
52
+ _name: adam
53
+ adam_betas: (0.9,0.98)
54
+ adam_eps: 1e-08
55
+
56
+ lr_scheduler:
57
+ _name: tri_stage
58
+ phase_ratio: [0.1, 0.4, 0.5]
59
+ final_lr_scale: 0.05
60
+
61
+ model:
62
+ _name: speech2c_ctc
63
+ w2v_path: ???
64
+ apply_mask: true
65
+ mask_prob: 0.65
66
+ mask_channel_prob: 0.5
67
+ mask_channel_length: 64
68
+ layerdrop: 0.1
69
+ decoder_layerdrop: 0.1
70
+ activation_dropout: 0.1
71
+ feature_grad_mult: 0.0
72
+ freeze_finetune_updates: 25000
73
+
74
+ hydra:
75
+ job:
76
+ config:
77
+ override_dirname:
78
+ kv_sep: '-'
79
+ item_sep: '__'
80
+ exclude_keys:
81
+ - run
82
+ - task.data
83
+ - task.label_dir
84
+ - model.w2v_path
85
+ - dataset.train_subset
86
+ - dataset.valid_subset
87
+ - criterion.wer_kenlm_model
88
+ - criterion.wer_lexicon
89
+ run:
90
+ dir: ???
91
+ sweep:
92
+ dir: ???
93
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
SpeechT5/Speech2C/speech2c/config/base_10h.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ tensorboard_logdir: tblog
8
+ seed: 1337
9
+
10
+ checkpoint:
11
+ save_interval: 5
12
+ keep_interval_updates: 1
13
+ no_epoch_checkpoints: true
14
+ best_checkpoint_metric: dec_accuracy
15
+ maximize_best_checkpoint_metric: true
16
+
17
+ distributed_training:
18
+ ddp_backend: c10d
19
+ find_unused_parameters: true
20
+ distributed_world_size: 1
21
+ distributed_port: 29671
22
+ nprocs_per_node: 8
23
+
24
+ task:
25
+ _name: speech2c_pretraining
26
+ data: ???
27
+ fine_tuning: true
28
+ label_dir: ???
29
+ normalize: false # must be consistent with pre-training
30
+ labels: ["ltr"]
31
+ single_target: true
32
+ add_decoder: true
33
+ pad_audio: true
34
+ random_crop: false
35
+
36
+ dataset:
37
+ num_workers: 6
38
+ max_tokens: 3200000
39
+ skip_invalid_size_inputs_valid_test: true
40
+ validate_after_updates: ${model.freeze_finetune_updates}
41
+ validate_interval: 5
42
+ train_subset: train_10h
43
+ valid_subset: dev_other
44
+
45
+ criterion:
46
+ _name: ctc_ce
47
+ zero_infinity: true
48
+
49
+ optimization:
50
+ max_update: 25000
51
+ lr: [2e-5]
52
+ sentence_avg: true
53
+ update_freq: [1]
54
+
55
+ optimizer:
56
+ _name: adam
57
+ adam_betas: (0.9,0.98)
58
+ adam_eps: 1e-08
59
+
60
+ lr_scheduler:
61
+ _name: tri_stage
62
+ phase_ratio: [0.1, 0.4, 0.5]
63
+ final_lr_scale: 0.05
64
+
65
+ model:
66
+ _name: speech2c_ctc
67
+ w2v_path: ???
68
+ apply_mask: true
69
+ mask_selection: static
70
+ mask_length: 10
71
+ mask_other: 0
72
+ mask_prob: 0.75
73
+ mask_channel_selection: static
74
+ mask_channel_length: 64
75
+ mask_channel_other: 0
76
+ mask_channel_prob: 0.5
77
+ layerdrop: 0.1
78
+ decoder_layerdrop: 0.1
79
+ dropout: 0.0
80
+ activation_dropout: 0.1
81
+ attention_dropout: 0.0
82
+ feature_grad_mult: 0.0
83
+ freeze_finetune_updates: 10000
84
+
85
+ hydra:
86
+ job:
87
+ config:
88
+ override_dirname:
89
+ kv_sep: '-'
90
+ item_sep: '__'
91
+ exclude_keys:
92
+ - run
93
+ - task.data
94
+ - task.label_dir
95
+ - model.w2v_path
96
+ - dataset.train_subset
97
+ - dataset.valid_subset
98
+ - criterion.wer_kenlm_model
99
+ - criterion.wer_lexicon
100
+ run:
101
+ dir: ???
102
+ sweep:
103
+ dir: ???
104
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
SpeechT5/Speech2C/speech2c/config/speech2c_base_librispeech.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ seed: 1337
8
+ tensorboard_logdir: tblog
9
+
10
+ checkpoint:
11
+ save_interval_updates: 25000
12
+ keep_interval_updates: 1
13
+ no_epoch_checkpoints: true
14
+
15
+
16
+ distributed_training:
17
+ ddp_backend: no_c10d
18
+ distributed_backend: 'nccl'
19
+ distributed_world_size: 32
20
+ distributed_port: 29671
21
+ nprocs_per_node: 8
22
+ find_unused_parameters: true
23
+
24
+ task:
25
+ _name: speech2c_pretraining
26
+ data: ???
27
+ label_dir: ???
28
+ labels: ???
29
+ label_rate: ${model.label_rate}
30
+ sample_rate: 16000
31
+ max_sample_size: 250000
32
+ min_sample_size: 32000
33
+ pad_audio: false
34
+ random_crop: true
35
+ normalize: false # must be consistent with extractor
36
+ add_decoder: true
37
+
38
+ dataset:
39
+ num_workers: 6
40
+ max_tokens: 1400000
41
+ skip_invalid_size_inputs_valid_test: true
42
+ validate_interval: 5
43
+ validate_interval_updates: 10000
44
+
45
+ criterion:
46
+ _name: speech2c
47
+ pred_masked_weight: 1.0
48
+ pred_nomask_weight: 0.0
49
+ loss_weights: [10,]
50
+
51
+ optimization:
52
+ max_update: 400000
53
+ lr: [0.0005]
54
+ clip_norm: 10.0
55
+
56
+ optimizer:
57
+ _name: adam
58
+ adam_betas: (0.9,0.98)
59
+ adam_eps: 1e-06
60
+ weight_decay: 0.01
61
+
62
+ lr_scheduler:
63
+ _name: polynomial_decay
64
+ warmup_updates: 32000
65
+
66
+ model:
67
+ _name: speech2c
68
+ label_rate: ???
69
+ skip_masked: false
70
+ skip_nomask: false
71
+ mask_prob: 0.80
72
+ extractor_mode: default
73
+ conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
74
+ final_dim: 256
75
+ encoder_layerdrop: 0.05
76
+ dropout_input: 0.1
77
+ dropout_features: 0.1
78
+ dropout: 0.1
79
+ attention_dropout: 0.1
80
+ feature_grad_mult: 0.1
81
+ untie_final_proj: true
82
+ activation_dropout: 0.0
83
+ use_rel_pos_enc: true
84
+ decoder_dict_size: -1
85
+
86
+ hydra:
87
+ job:
88
+ config:
89
+ override_dirname:
90
+ kv_sep: '-'
91
+ item_sep: '__'
92
+ exclude_keys:
93
+ - run
94
+ - task.data
95
+ - task.label_dir
96
+ run:
97
+ dir: ???
98
+ sweep:
99
+ dir: ???
100
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
SpeechT5/Speech2C/speech2c/criterions/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+
4
+
5
+ for file in os.listdir(os.path.dirname(__file__)):
6
+ if file.endswith(".py") and not file.startswith("_"):
7
+ criterion_name = file[: file.find(".py")]
8
+ importlib.import_module(
9
+ "speech2c.criterions." + criterion_name
10
+ )
SpeechT5/Speech2C/speech2c/criterions/ctc_ce.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ from argparse import Namespace
12
+ from dataclasses import dataclass, field
13
+ from omegaconf import II
14
+ from typing import Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from fairseq import metrics, utils
19
+ from fairseq.criterions import FairseqCriterion, register_criterion
20
+ from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
21
+ from fairseq.dataclass import FairseqDataclass
22
+ from fairseq.data.data_utils import post_process
23
+ from fairseq.tasks import FairseqTask
24
+ from fairseq.logging.meters import safe_round
25
+
26
+
27
+ @dataclass
28
+ class CtcCeCriterionConfig(FairseqDataclass):
29
+ zero_infinity: bool = field(
30
+ default=False,
31
+ metadata={"help": "zero inf loss when source length <= target length"},
32
+ )
33
+ sentence_avg: bool = II("optimization.sentence_avg")
34
+ post_process: str = field(
35
+ default="letter",
36
+ metadata={
37
+ "help": "how to post process predictions into words. can be letter, "
38
+ "wordpiece, BPE symbols, etc. "
39
+ "See fairseq.data.data_utils.post_process() for full list of options"
40
+ },
41
+ )
42
+ wer_kenlm_model: Optional[str] = field(
43
+ default=None,
44
+ metadata={
45
+ "help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
46
+ },
47
+ )
48
+ wer_lexicon: Optional[str] = field(
49
+ default=None,
50
+ metadata={"help": "lexicon to use with wer_kenlm_model"},
51
+ )
52
+ wer_lm_weight: float = field(
53
+ default=2.0,
54
+ metadata={"help": "lm weight to use with wer_kenlm_model"},
55
+ )
56
+ wer_word_score: float = field(
57
+ default=-1.0,
58
+ metadata={"help": "lm word score to use with wer_kenlm_model"},
59
+ )
60
+
61
+ wer_args: Optional[str] = field(
62
+ default=None,
63
+ metadata={
64
+ "help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
65
+ },
66
+ )
67
+
68
+ dec_weight: float = field(
69
+ default=0.5,
70
+ metadata={"help": "weights for decoder CE Loss, loss will be ((1 - dec_weight) * hubert_loss + dec_weight * CE_Loss)"},
71
+ )
72
+ report_accuracy: bool = field(
73
+ default=True,
74
+ metadata={"help": "report decoder accuracy metric"},
75
+ )
76
+ ignore_prefix_size: int = field(
77
+ default=0,
78
+ metadata={"help": "Ignore first N tokens"},
79
+ )
80
+ label_smoothing: float = field(
81
+ default=0.1,
82
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
83
+ )
84
+
85
+
86
+ @register_criterion("ctc_ce", dataclass=CtcCeCriterionConfig)
87
+ class CtcCeCriterion(FairseqCriterion):
88
+ def __init__(self, cfg: CtcCeCriterionConfig, task: FairseqTask):
89
+ super().__init__(task)
90
+ self.blank_idx = (
91
+ task.target_dictionary.index(task.blank_symbol)
92
+ if hasattr(task, "blank_symbol")
93
+ else 0
94
+ )
95
+ self.pad_idx = task.target_dictionary.pad()
96
+ self.eos_idx = task.target_dictionary.eos()
97
+ self.post_process = cfg.post_process
98
+
99
+ if cfg.wer_args is not None:
100
+ (
101
+ cfg.wer_kenlm_model,
102
+ cfg.wer_lexicon,
103
+ cfg.wer_lm_weight,
104
+ cfg.wer_word_score,
105
+ ) = eval(cfg.wer_args)
106
+
107
+ if cfg.wer_kenlm_model is not None:
108
+ from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
109
+
110
+ dec_args = Namespace()
111
+ dec_args.nbest = 1
112
+ dec_args.criterion = "ctc"
113
+ dec_args.kenlm_model = cfg.wer_kenlm_model
114
+ dec_args.lexicon = cfg.wer_lexicon
115
+ dec_args.beam = 50
116
+ dec_args.beam_size_token = min(50, len(task.target_dictionary))
117
+ dec_args.beam_threshold = min(50, len(task.target_dictionary))
118
+ dec_args.lm_weight = cfg.wer_lm_weight
119
+ dec_args.word_score = cfg.wer_word_score
120
+ dec_args.unk_weight = -math.inf
121
+ dec_args.sil_weight = 0
122
+
123
+ self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
124
+ else:
125
+ self.w2l_decoder = None
126
+
127
+ self.zero_infinity = cfg.zero_infinity
128
+ self.sentence_avg = cfg.sentence_avg
129
+
130
+ self.dec_weight = cfg.dec_weight
131
+ self.report_accuracy = cfg.report_accuracy
132
+ self.ignore_prefix_size = cfg.ignore_prefix_size
133
+ self.eps = cfg.label_smoothing
134
+
135
+ def forward(self, model, sample, reduce=True):
136
+ net_output = model(**sample["net_input"])
137
+ lprobs = model.get_normalized_probs(
138
+ net_output, log_probs=True
139
+ ).contiguous() # (T, B, C) from the encoder
140
+
141
+ if "src_lengths" in sample["net_input"]:
142
+ input_lengths = sample["net_input"]["src_lengths"]
143
+ else:
144
+ if net_output["padding_mask"] is not None:
145
+ non_padding_mask = ~net_output["padding_mask"]
146
+ input_lengths = non_padding_mask.long().sum(-1)
147
+ else:
148
+ input_lengths = lprobs.new_full(
149
+ (lprobs.size(1),), lprobs.size(0), dtype=torch.long
150
+ )
151
+
152
+ pad_mask = (sample["target"] != self.pad_idx) & (
153
+ sample["target"] != self.eos_idx
154
+ )
155
+ targets_flat = sample["target"].masked_select(pad_mask)
156
+ if "target_lengths" in sample:
157
+ target_lengths = sample["target_lengths"]
158
+ else:
159
+ target_lengths = pad_mask.sum(-1)
160
+
161
+ with torch.backends.cudnn.flags(enabled=False):
162
+ loss = F.ctc_loss(
163
+ lprobs,
164
+ targets_flat,
165
+ input_lengths,
166
+ target_lengths,
167
+ blank=self.blank_idx,
168
+ reduction="sum",
169
+ zero_infinity=self.zero_infinity,
170
+ )
171
+
172
+ ntokens = (
173
+ sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
174
+ )
175
+
176
+ sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
177
+
178
+ logging_output = {}
179
+ if "decoder_target" in sample:
180
+ dec_sample_size = sample["target"].size(0) if self.sentence_avg else sample["dec_ntokens"]
181
+ dec_loss, dec_nll_loss = self.compute_ce_loss(model, net_output["decoder_out"], sample, reduce=reduce)
182
+ logging_output["ctc_loss"] = loss.item()
183
+ loss = (1 - self.dec_weight) * loss + (self.dec_weight * dec_loss * sample_size / dec_sample_size)
184
+ logging_output["dec_loss"] = dec_loss.item()
185
+ logging_output["dec_nll_loss"] = dec_nll_loss.item()
186
+ logging_output["dec_sample_size"] = dec_sample_size
187
+
188
+ if self.report_accuracy:
189
+ n_correct, total = self.compute_accuracy(model, net_output["decoder_out"], sample)
190
+ logging_output["dec_n_correct"] = utils.item(n_correct.data)
191
+ logging_output["total"] = utils.item(total.data)
192
+
193
+ logging_output = {
194
+ "loss": utils.item(loss.data), # * sample['ntokens'],
195
+ "ntokens": ntokens,
196
+ "nsentences": sample["id"].numel(),
197
+ "sample_size": sample_size,
198
+ **logging_output,
199
+ }
200
+
201
+ if not model.training:
202
+ import editdistance
203
+
204
+ with torch.no_grad():
205
+ lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
206
+
207
+ c_err = 0
208
+ c_len = 0
209
+ w_errs = 0
210
+ w_len = 0
211
+ wv_errs = 0
212
+ for lp, t, inp_l in zip(
213
+ lprobs_t,
214
+ sample["target_label"]
215
+ if "target_label" in sample
216
+ else sample["target"],
217
+ input_lengths,
218
+ ):
219
+ lp = lp[:inp_l].unsqueeze(0)
220
+
221
+ decoded = None
222
+ if self.w2l_decoder is not None:
223
+ decoded = self.w2l_decoder.decode(lp)
224
+ if len(decoded) < 1:
225
+ decoded = None
226
+ else:
227
+ decoded = decoded[0]
228
+ if len(decoded) < 1:
229
+ decoded = None
230
+ else:
231
+ decoded = decoded[0]
232
+
233
+ p = (t != self.task.target_dictionary.pad()) & (
234
+ t != self.task.target_dictionary.eos()
235
+ )
236
+ targ = t[p]
237
+ targ_units = self.task.target_dictionary.string(targ)
238
+ targ_units_arr = targ.tolist()
239
+
240
+ toks = lp.argmax(dim=-1).unique_consecutive()
241
+ pred_units_arr = toks[toks != self.blank_idx].tolist()
242
+
243
+ c_err += editdistance.eval(pred_units_arr, targ_units_arr)
244
+ c_len += len(targ_units_arr)
245
+
246
+ targ_words = post_process(targ_units, self.post_process).split()
247
+
248
+ pred_units = self.task.target_dictionary.string(pred_units_arr)
249
+ pred_words_raw = post_process(pred_units, self.post_process).split()
250
+
251
+ if decoded is not None and "words" in decoded:
252
+ pred_words = decoded["words"]
253
+ w_errs += editdistance.eval(pred_words, targ_words)
254
+ wv_errs += editdistance.eval(pred_words_raw, targ_words)
255
+ else:
256
+ dist = editdistance.eval(pred_words_raw, targ_words)
257
+ w_errs += dist
258
+ wv_errs += dist
259
+
260
+ w_len += len(targ_words)
261
+
262
+ logging_output["wv_errors"] = wv_errs
263
+ logging_output["w_errors"] = w_errs
264
+ logging_output["w_total"] = w_len
265
+ logging_output["c_errors"] = c_err
266
+ logging_output["c_total"] = c_len
267
+
268
+ return loss, sample_size, logging_output
269
+
270
+ def compute_ce_loss(self, model, net_output, sample, reduce=True):
271
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
272
+ loss, nll_loss = label_smoothed_nll_loss(
273
+ lprobs,
274
+ target,
275
+ self.eps,
276
+ ignore_index=self.pad_idx,
277
+ reduce=reduce,
278
+ )
279
+ return loss, nll_loss
280
+
281
+ def compute_accuracy(self, model, net_output, sample):
282
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
283
+ mask = target.ne(self.pad_idx)
284
+ n_correct = torch.sum(
285
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
286
+ )
287
+ total = torch.sum(mask)
288
+ return n_correct, total
289
+
290
+ def get_lprobs_and_target(self, model, net_output, sample):
291
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
292
+ target = sample["decoder_target"]
293
+ if self.ignore_prefix_size > 0:
294
+ if getattr(lprobs, "batch_first", False):
295
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
296
+ target = target[:, self.ignore_prefix_size :].contiguous()
297
+ else:
298
+ lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
299
+ target = target[self.ignore_prefix_size :, :].contiguous()
300
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
301
+
302
+
303
+ @staticmethod
304
+ def reduce_metrics(logging_outputs) -> None:
305
+ """Aggregate logging outputs from data parallel training."""
306
+
307
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
308
+ ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
309
+ nsentences = utils.item(
310
+ sum(log.get("nsentences", 0) for log in logging_outputs)
311
+ )
312
+ sample_size = utils.item(
313
+ sum(log.get("sample_size", 0) for log in logging_outputs)
314
+ )
315
+
316
+ metrics.log_scalar(
317
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
318
+ )
319
+ metrics.log_scalar("ntokens", ntokens)
320
+ metrics.log_scalar("nsentences", nsentences)
321
+ if sample_size != ntokens:
322
+ metrics.log_scalar(
323
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
324
+ )
325
+
326
+ c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
327
+ metrics.log_scalar("_c_errors", c_errors)
328
+ c_total = sum(log.get("c_total", 0) for log in logging_outputs)
329
+ metrics.log_scalar("_c_total", c_total)
330
+ w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
331
+ metrics.log_scalar("_w_errors", w_errors)
332
+ wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
333
+ metrics.log_scalar("_wv_errors", wv_errors)
334
+ w_total = sum(log.get("w_total", 0) for log in logging_outputs)
335
+ metrics.log_scalar("_w_total", w_total)
336
+
337
+ if c_total > 0:
338
+ metrics.log_derived(
339
+ "uer",
340
+ lambda meters: safe_round(
341
+ meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
342
+ )
343
+ if meters["_c_total"].sum > 0
344
+ else float("nan"),
345
+ )
346
+ if w_total > 0:
347
+ metrics.log_derived(
348
+ "wer",
349
+ lambda meters: safe_round(
350
+ meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
351
+ )
352
+ if meters["_w_total"].sum > 0
353
+ else float("nan"),
354
+ )
355
+ metrics.log_derived(
356
+ "raw_wer",
357
+ lambda meters: safe_round(
358
+ meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
359
+ )
360
+ if meters["_w_total"].sum > 0
361
+ else float("nan"),
362
+ )
363
+
364
+ if "dec_loss" in logging_outputs[0]:
365
+ ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
366
+ dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
367
+ dec_nll_loss_sum = sum(log.get("dec_nll_loss", 0) for log in logging_outputs)
368
+ dec_sample_size = sum(log.get("dec_sample_size", 0) for log in logging_outputs)
369
+ metrics.log_scalar(
370
+ "dec_loss", dec_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
371
+ )
372
+ metrics.log_scalar(
373
+ "ctc_loss", ctc_loss_sum / sample_size / math.log(2), sample_size, round=3
374
+ )
375
+ metrics.log_scalar(
376
+ "dec_nll_loss", dec_nll_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
377
+ )
378
+ metrics.log_derived(
379
+ "dec_ppl", lambda meters: utils.get_perplexity(meters["dec_nll_loss"].avg)
380
+ )
381
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
382
+ if total > 0:
383
+ metrics.log_scalar("total", total)
384
+ n_correct = utils.item(
385
+ sum(log.get("dec_n_correct", 0) for log in logging_outputs)
386
+ )
387
+ metrics.log_scalar("dec_n_correct", n_correct)
388
+ metrics.log_derived(
389
+ "dec_accuracy",
390
+ lambda meters: round(
391
+ meters["dec_n_correct"].sum * 100.0 / meters["total"].sum, 3
392
+ )
393
+ if meters["total"].sum > 0
394
+ else float("nan"),
395
+ )
396
+
397
+ @staticmethod
398
+ def logging_outputs_can_be_summed() -> bool:
399
+ """
400
+ Whether the logging outputs returned by `forward` can be summed
401
+ across workers prior to calling `reduce_metrics`. Setting this
402
+ to True will improves distributed training speed.
403
+ """
404
+ return True
SpeechT5/Speech2C/speech2c/criterions/speech2c_criterion.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import re
12
+ from dataclasses import dataclass, field
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from fairseq import metrics, utils
17
+ from fairseq.criterions import FairseqCriterion, register_criterion
18
+ from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
19
+ from fairseq.criterions.hubert_criterion import HubertCriterionConfig
20
+
21
+ @dataclass
22
+ class Speech2cCriterionConfig(HubertCriterionConfig):
23
+ dec_weight: float = field(
24
+ default=1.0,
25
+ metadata={"help": "weights for decoder CE Loss, loss will be (hubert_loss + dec_weight * CE_Loss)"},
26
+ )
27
+ report_accuracy: bool = field(
28
+ default=True,
29
+ metadata={"help": "report decoder accuracy metric"},
30
+ )
31
+ ignore_prefix_size: int = field(
32
+ default=0,
33
+ metadata={"help": "Ignore first N tokens"},
34
+ )
35
+ label_smoothing: float = field(
36
+ default=0.0,
37
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
38
+ )
39
+
40
+
41
+ @register_criterion("speech2c", dataclass=Speech2cCriterionConfig)
42
+ class Speech2cCriterion(FairseqCriterion):
43
+ def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None, dec_weight=1.0, report_accuracy=False, ignore_prefix_size=0, label_smoothing=0.0):
44
+ super().__init__(task)
45
+ self.pred_masked_weight = pred_masked_weight
46
+ self.pred_nomask_weight = pred_nomask_weight
47
+ self.loss_weights = loss_weights
48
+ self.log_keys = [] if log_keys is None else log_keys
49
+ self.dec_weight = dec_weight
50
+ self.report_accuracy = report_accuracy
51
+ self.ignore_prefix_size = ignore_prefix_size
52
+ self.eps = label_smoothing
53
+ self.padding_idx = task.dictionaries[0].pad()
54
+
55
+ def forward(self, model, sample, reduce=True, log_pred=False):
56
+ """Compute the loss for the given sample.
57
+ Returns a tuple with three elements:
58
+ 1) the loss
59
+ 2) the sample size, which is used as the denominator for the gradient
60
+ 3) logging outputs to display while training
61
+ """
62
+ net_output = model(target_list=sample["target_list"], **sample["net_input"])
63
+ loss = 0.0
64
+ sample_size = 0
65
+ logging_output = {}
66
+ reduction = "sum" if reduce else "none"
67
+
68
+ loss_m_list = []
69
+ logp_m_list = model.get_logits(net_output, True)
70
+ targ_m_list = model.get_targets(net_output, True)
71
+ assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
72
+ for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
73
+ loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
74
+ loss_m_list.append(loss_m)
75
+ logging_output[f"loss_m_{i}"] = loss_m.detach().item()
76
+ if self.pred_masked_weight > 0:
77
+ loss += self.pred_masked_weight * sum(loss_m_list)
78
+ sample_size += targ_m_list[0].numel()
79
+
80
+ loss_u_list = []
81
+ logp_u_list = model.get_logits(net_output, False)
82
+ targ_u_list = model.get_targets(net_output, False)
83
+ assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
84
+ for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
85
+ loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
86
+ loss_u_list.append(loss_u)
87
+ logging_output[f"loss_u_{i}"] = loss_u.detach().item()
88
+ if self.pred_nomask_weight > 0:
89
+ loss += self.pred_nomask_weight * sum(loss_u_list)
90
+ sample_size += targ_u_list[0].numel()
91
+
92
+ if self.loss_weights is not None:
93
+ assert hasattr(model, "get_extra_losses")
94
+ extra_losses, names = model.get_extra_losses(net_output)
95
+ if torch.is_tensor(extra_losses):
96
+ extra_losses = [extra_losses]
97
+ names = [names]
98
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
99
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
100
+ assert len(extra_losses) == len(
101
+ self.loss_weights
102
+ ), f"{len(extra_losses)}, {len(self.loss_weights)}"
103
+ for p, n, coef in zip(extra_losses, names, self.loss_weights):
104
+ if coef != 0 and p is not None:
105
+ p = coef * p.float() * sample_size
106
+ loss += p
107
+ logging_output[f"loss_{n}"] = p.item()
108
+
109
+ if "decoder_target" in sample:
110
+ dec_sample_size = sample["dec_ntokens"]
111
+ dec_loss, dec_nll_loss = self.compute_ce_loss(model, net_output["decoder_out"], sample, reduce=reduce)
112
+ loss = loss + (self.dec_weight * dec_loss * sample_size / dec_sample_size)
113
+ logging_output["dec_loss"] = dec_loss.item()
114
+ logging_output["dec_nll_loss"] = dec_nll_loss.item()
115
+ logging_output["dec_sample_size"] = dec_sample_size
116
+
117
+ if self.report_accuracy:
118
+ n_correct, total = self.compute_accuracy(model, net_output["decoder_out"], sample)
119
+ logging_output["dec_n_correct"] = utils.item(n_correct.data)
120
+ logging_output["total"] = utils.item(total.data)
121
+
122
+ logging_output = {
123
+ "loss": loss.item() if reduce else loss,
124
+ "ntokens": sample_size,
125
+ "nsentences": sample["id"].numel(),
126
+ "sample_size": sample_size,
127
+ **logging_output,
128
+ }
129
+
130
+ for lk in self.log_keys:
131
+ if lk in net_output:
132
+ logging_output[lk] = float((net_output[lk]))
133
+
134
+ def compute_correct(logits):
135
+ if logits.numel() == 0:
136
+ return 0, 0
137
+ else:
138
+ assert logits.dim() > 1, logits.shape
139
+ max = logits.argmax(-1) == 0
140
+ min = logits.argmin(-1) == 0
141
+ both = max & min
142
+ corr = max.long().sum().item() - both.long().sum().item()
143
+ count = max.numel()
144
+ return corr, count
145
+
146
+ with torch.no_grad():
147
+ for i, logp_m in enumerate(logp_m_list):
148
+ corr_m, count_m = compute_correct(logp_m)
149
+ logging_output[f"correct_m_{i}"] = corr_m
150
+ logging_output[f"count_m_{i}"] = count_m
151
+
152
+ for i, logp_u in enumerate(logp_u_list):
153
+ corr_u, count_u = compute_correct(logp_u)
154
+ logging_output[f"correct_u_{i}"] = corr_u
155
+ logging_output[f"count_u_{i}"] = count_u
156
+
157
+ return loss, sample_size, logging_output
158
+
159
+ def compute_ce_loss(self, model, net_output, sample, reduce=True):
160
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
161
+ loss, nll_loss = label_smoothed_nll_loss(
162
+ lprobs,
163
+ target,
164
+ self.eps,
165
+ ignore_index=self.padding_idx,
166
+ reduce=reduce,
167
+ )
168
+ return loss, nll_loss
169
+
170
+ def compute_accuracy(self, model, net_output, sample):
171
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
172
+ mask = target.ne(self.padding_idx)
173
+ n_correct = torch.sum(
174
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
175
+ )
176
+ total = torch.sum(mask)
177
+ return n_correct, total
178
+
179
+ def get_lprobs_and_target(self, model, net_output, sample):
180
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
181
+ target = sample["decoder_target"]
182
+ if self.ignore_prefix_size > 0:
183
+ if getattr(lprobs, "batch_first", False):
184
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
185
+ target = target[:, self.ignore_prefix_size :].contiguous()
186
+ else:
187
+ lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
188
+ target = target[self.ignore_prefix_size :, :].contiguous()
189
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
190
+
191
+ @staticmethod
192
+ def reduce_metrics(logging_outputs) -> None:
193
+ """Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
194
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
195
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
196
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
197
+
198
+ metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
199
+ if sample_size != ntokens:
200
+ metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
201
+ metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
202
+ else:
203
+ metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
204
+
205
+ counts = {}
206
+ for lk in logging_outputs[0].keys():
207
+ if lk.startswith("count_"):
208
+ val = sum(log[lk] for log in logging_outputs)
209
+ metrics.log_scalar(lk, val)
210
+ counts[lk] = val
211
+
212
+ for lk in logging_outputs[0].keys():
213
+ if lk.startswith("loss_"):
214
+ val = sum(log[lk] for log in logging_outputs)
215
+ metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
216
+ elif lk.startswith("correct_"):
217
+ val = sum(log[lk] for log in logging_outputs)
218
+ metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
219
+
220
+ if "dec_loss" in logging_outputs[0]:
221
+ dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
222
+ dec_nll_loss_sum = sum(log.get("dec_nll_loss", 0) for log in logging_outputs)
223
+ dec_sample_size = sum(log.get("dec_sample_size", 0) for log in logging_outputs)
224
+ metrics.log_scalar(
225
+ "dec_loss", dec_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
226
+ )
227
+ metrics.log_scalar(
228
+ "dec_nll_loss", dec_nll_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
229
+ )
230
+ metrics.log_derived(
231
+ "dec_ppl", lambda meters: utils.get_perplexity(meters["dec_nll_loss"].avg)
232
+ )
233
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
234
+ if total > 0:
235
+ metrics.log_scalar("total", total)
236
+ n_correct = utils.item(
237
+ sum(log.get("dec_n_correct", 0) for log in logging_outputs)
238
+ )
239
+ metrics.log_scalar("dec_n_correct", n_correct)
240
+ metrics.log_derived(
241
+ "dec_accuracy",
242
+ lambda meters: round(
243
+ meters["dec_n_correct"].sum * 100.0 / meters["total"].sum, 3
244
+ )
245
+ if meters["total"].sum > 0
246
+ else float("nan"),
247
+ )
248
+
249
+ @staticmethod
250
+ def aggregate_logging_outputs(logging_outputs):
251
+ """Aggregate logging outputs from data parallel training."""
252
+ raise NotImplementedError()
253
+
254
+ @staticmethod
255
+ def logging_outputs_can_be_summed() -> bool:
256
+ """
257
+ Whether the logging outputs returned by `forward` can be summed
258
+ across workers prior to calling `reduce_metrics`. Setting this
259
+ to True will improves distributed training speed.
260
+ """
261
+ return False
SpeechT5/Speech2C/speech2c/data/speech2c_dataset.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import logging
11
+ from typing import Any, List, Optional, Union
12
+
13
+ import torch
14
+ from fairseq.data import data_utils, Dictionary
15
+ from fairseq.data.audio.hubert_dataset import HubertDataset
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class Speech2cDataset(HubertDataset):
20
+ def __init__(
21
+ self,
22
+ manifest_path: str,
23
+ sample_rate: float,
24
+ label_paths: List[str],
25
+ label_rates: Union[List[float], float], # -1 for sequence labels
26
+ pad_list: List[str],
27
+ eos_list: List[str],
28
+ label_processors: Optional[List[Any]] = None,
29
+ max_keep_sample_size: Optional[int] = None,
30
+ min_keep_sample_size: Optional[int] = None,
31
+ max_sample_size: Optional[int] = None,
32
+ shuffle: bool = True,
33
+ pad_audio: bool = False,
34
+ normalize: bool = False,
35
+ store_labels: bool = True,
36
+ random_crop: bool = False,
37
+ single_target: bool = False,
38
+ tgt_dict: Optional[Dictionary] = None,
39
+ add_decoder: bool = False,
40
+ fine_tuning: bool = False,
41
+ ):
42
+ super().__init__(
43
+ manifest_path,
44
+ sample_rate,
45
+ label_paths,
46
+ label_rates,
47
+ pad_list,
48
+ eos_list,
49
+ label_processors,
50
+ max_keep_sample_size,
51
+ min_keep_sample_size,
52
+ max_sample_size,
53
+ shuffle,
54
+ pad_audio,
55
+ normalize,
56
+ store_labels,
57
+ random_crop,
58
+ single_target
59
+ )
60
+
61
+ self.tgt_dict = tgt_dict
62
+ self.add_decoder = add_decoder
63
+ self.fine_tuning = fine_tuning
64
+
65
+ def collater(self, samples):
66
+ # target = max(sizes) -> random_crop not used
67
+ # target = max_sample_size -> random_crop used for long
68
+ samples = [s for s in samples if s["source"] is not None]
69
+ if len(samples) == 0:
70
+ return {}
71
+
72
+ audios = [s["source"] for s in samples]
73
+ audio_sizes = [len(s) for s in audios]
74
+ if self.pad_audio:
75
+ audio_size = min(max(audio_sizes), self.max_sample_size)
76
+ else:
77
+ audio_size = min(min(audio_sizes), self.max_sample_size)
78
+ collated_audios, padding_mask, audio_starts = self.collater_audio(
79
+ audios, audio_size
80
+ )
81
+
82
+ targets_by_label = [
83
+ [s["label_list"][i] for s in samples] for i in range(self.num_labels)
84
+ ]
85
+ targets_list, lengths_list, ntokens_list = self.collater_label(
86
+ targets_by_label, audio_size, audio_starts
87
+ )
88
+
89
+ if self.add_decoder:
90
+ if self.fine_tuning:
91
+ decoder_label = [
92
+ torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
93
+ for i in range(targets_list[0].size(0))
94
+ ]
95
+ else:
96
+ decoder_label = [
97
+ torch.cat((targets_list[0][i, :lengths_list[0][i]].unique_consecutive(), torch.tensor([self.tgt_dict.eos()])), 0).long()
98
+ for i in range(targets_list[0].size(0))
99
+ ]
100
+ dec_ntokens = sum(x.size(0) for x in decoder_label)
101
+ decoder_target = data_utils.collate_tokens(
102
+ decoder_label,
103
+ self.tgt_dict.pad(),
104
+ self.tgt_dict.eos(),
105
+ left_pad=False,
106
+ move_eos_to_beginning=False,
107
+ )
108
+ decoder_target_lengths = torch.tensor(
109
+ [x.size(0) for x in decoder_label], dtype=torch.long
110
+ )
111
+ prev_output_tokens = data_utils.collate_tokens(
112
+ decoder_label,
113
+ self.tgt_dict.pad(),
114
+ self.tgt_dict.eos(),
115
+ left_pad=False,
116
+ move_eos_to_beginning=True,
117
+ )
118
+ net_input = {
119
+ "source": collated_audios,
120
+ "padding_mask": padding_mask,
121
+ "prev_output_tokens": prev_output_tokens,
122
+ }
123
+ batch = {
124
+ "id": torch.LongTensor([s["id"] for s in samples]),
125
+ "net_input": net_input,
126
+ "decoder_target": decoder_target,
127
+ "decoder_target_lengths": decoder_target_lengths,
128
+ "dec_ntokens": dec_ntokens,
129
+ }
130
+ else:
131
+ net_input = {"source": collated_audios, "padding_mask": padding_mask}
132
+ batch = {
133
+ "id": torch.LongTensor([s["id"] for s in samples]),
134
+ "net_input": net_input,
135
+ }
136
+
137
+ if self.single_target:
138
+ batch["target_lengths"] = lengths_list[0]
139
+ batch["ntokens"] = ntokens_list[0]
140
+ batch["target"] = targets_list[0]
141
+ else:
142
+ batch["target_lengths_list"] = lengths_list
143
+ batch["ntokens_list"] = ntokens_list
144
+ batch["target_list"] = targets_list
145
+ return batch
SpeechT5/Speech2C/speech2c/models/modules/ctc_prefix_score.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+ import numpy as np
7
+ import six
8
+
9
+
10
+ class CTCPrefixScore(object):
11
+ """Compute CTC label sequence scores
12
+ which is based on Algorithm 2 in WATANABE et al.
13
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
14
+ but extended to efficiently compute the probablities of multiple labels
15
+ simultaneously
16
+ """
17
+
18
+ def __init__(self, x, blank, eos, xp):
19
+ self.xp = xp
20
+ self.logzero = -10000000000.0
21
+ self.blank = blank
22
+ self.eos = eos
23
+ self.input_length = len(x)
24
+ self.x = x
25
+
26
+ def initial_state(self):
27
+ """Obtain an initial CTC state
28
+ :return: CTC state
29
+ """
30
+ # initial CTC state is made of a frame x 2 tensor that corresponds to
31
+ # r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
32
+ # superscripts n and b (non-blank and blank), respectively.
33
+ r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
34
+ r[0, 1] = self.x[0, self.blank]
35
+ for i in six.moves.range(1, self.input_length):
36
+ r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
37
+ return r
38
+
39
+ def __call__(self, y, cs, r_prev):
40
+ """Compute CTC prefix scores for next labels
41
+ :param y : prefix label sequence
42
+ :param cs : array of next labels
43
+ :param r_prev: previous CTC state
44
+ :return ctc_scores, ctc_states
45
+ """
46
+ # initialize CTC states
47
+ output_length = len(y) - 1 # ignore sos
48
+ # new CTC states are prepared as a frame x (n or b) x n_labels tensor
49
+ # that corresponds to r_t^n(h) and r_t^b(h).
50
+ r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
51
+ xs = self.x[:, cs]
52
+ if output_length == 0:
53
+ r[0, 0] = xs[0]
54
+ r[0, 1] = self.logzero
55
+ else:
56
+ r[output_length - 1] = self.logzero
57
+
58
+ # prepare forward probabilities for the last label
59
+ r_sum = self.xp.logaddexp(
60
+ r_prev[:, 0], r_prev[:, 1]
61
+ ) # log(r_t^n(g) + r_t^b(g))
62
+ last = y[-1]
63
+ if output_length > 0 and last in cs:
64
+ log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
65
+ for i in six.moves.range(len(cs)):
66
+ log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
67
+ else:
68
+ log_phi = r_sum
69
+
70
+ # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
71
+ # and log prefix probabilities log(psi)
72
+ start = max(output_length, 1)
73
+ log_psi = r[start - 1, 0]
74
+ for t in six.moves.range(start, self.input_length):
75
+ r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
76
+ r[t, 1] = (
77
+ self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
78
+ )
79
+ log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
80
+
81
+ # get P(...eos|X) that ends with the prefix itself
82
+ eos_pos = self.xp.where(cs == self.eos)[0]
83
+ if len(eos_pos) > 0:
84
+ log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
85
+
86
+ # exclude blank probs
87
+ blank_pos = self.xp.where(cs == self.blank)[0]
88
+ if len(blank_pos) > 0:
89
+ log_psi[blank_pos] = self.logzero
90
+
91
+ # return the log prefix probability and CTC states, where the label axis
92
+ # of the CTC states is moved to the first axis to slice it easily
93
+ return log_psi, self.xp.rollaxis(r, 2)
SpeechT5/Speech2C/speech2c/models/modules/multihead_attention.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ from typing import Dict, Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from fairseq import utils
15
+ from torch import Tensor
16
+
17
+ from fairseq.modules import MultiheadAttention as FairseqMultiheadAttention
18
+
19
+
20
+ class MultiheadAttention(FairseqMultiheadAttention):
21
+ """Multi-headed attention.
22
+
23
+ See "Attention Is All You Need" for more details.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ embed_dim,
29
+ num_heads,
30
+ kdim=None,
31
+ vdim=None,
32
+ dropout=0.0,
33
+ bias=True,
34
+ add_bias_kv=False,
35
+ add_zero_attn=False,
36
+ self_attention=False,
37
+ encoder_decoder_attention=False,
38
+ q_noise=0.0,
39
+ qn_block_size=8,
40
+ ):
41
+ super().__init__(
42
+ embed_dim,
43
+ num_heads,
44
+ kdim,
45
+ vdim,
46
+ dropout,
47
+ bias,
48
+ add_bias_kv,
49
+ add_zero_attn,
50
+ self_attention,
51
+ encoder_decoder_attention,
52
+ q_noise,
53
+ qn_block_size,
54
+ )
55
+
56
+ def forward(
57
+ self,
58
+ query,
59
+ key: Optional[Tensor],
60
+ value: Optional[Tensor],
61
+ key_padding_mask: Optional[Tensor] = None,
62
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
63
+ need_weights: bool = True,
64
+ static_kv: bool = False,
65
+ attn_mask: Optional[Tensor] = None,
66
+ before_softmax: bool = False,
67
+ need_head_weights: bool = False,
68
+ position_bias: Optional[Tensor] = None,
69
+ ) -> Tuple[Tensor, Optional[Tensor]]:
70
+ """Input shape: Time x Batch x Channel
71
+
72
+ Args:
73
+ key_padding_mask (ByteTensor, optional): mask to exclude
74
+ keys that are pads, of shape `(batch, src_len)`, where
75
+ padding elements are indicated by 1s.
76
+ need_weights (bool, optional): return the attention weights,
77
+ averaged over heads (default: False).
78
+ attn_mask (ByteTensor, optional): typically used to
79
+ implement causal attention, where the mask prevents the
80
+ attention from looking forward in time (default: None).
81
+ before_softmax (bool, optional): return the raw attention
82
+ weights and values before the attention softmax.
83
+ need_head_weights (bool, optional): return the attention
84
+ weights for each head. Implies *need_weights*. Default:
85
+ return the average attention weights over all heads.
86
+ """
87
+ if need_head_weights:
88
+ need_weights = True
89
+
90
+ is_tpu = query.device.type == "xla"
91
+
92
+ tgt_len, bsz, embed_dim = query.size()
93
+ src_len = tgt_len
94
+ assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
95
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
96
+ if key is not None:
97
+ src_len, key_bsz, _ = key.size()
98
+ if not torch.jit.is_scripting():
99
+ assert key_bsz == bsz
100
+ assert value is not None
101
+ assert src_len, bsz == value.shape[:2]
102
+
103
+ if (
104
+ not self.onnx_trace
105
+ and not is_tpu # don't use PyTorch version on TPUs
106
+ and incremental_state is None
107
+ and not static_kv
108
+ # A workaround for quantization to work. Otherwise JIT compilation
109
+ # treats bias in linear module as method.
110
+ and not torch.jit.is_scripting()
111
+ and position_bias is None
112
+ ):
113
+ assert key is not None and value is not None
114
+ return F.multi_head_attention_forward(
115
+ query,
116
+ key,
117
+ value,
118
+ self.embed_dim,
119
+ self.num_heads,
120
+ torch.empty([0]),
121
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
122
+ self.bias_k,
123
+ self.bias_v,
124
+ self.add_zero_attn,
125
+ self.dropout_module.p,
126
+ self.out_proj.weight,
127
+ self.out_proj.bias,
128
+ self.training or self.dropout_module.apply_during_inference,
129
+ key_padding_mask,
130
+ need_weights,
131
+ attn_mask,
132
+ use_separate_proj_weight=True,
133
+ q_proj_weight=self.q_proj.weight,
134
+ k_proj_weight=self.k_proj.weight,
135
+ v_proj_weight=self.v_proj.weight,
136
+ )
137
+
138
+ if incremental_state is not None:
139
+ saved_state = self._get_input_buffer(incremental_state)
140
+ if saved_state is not None and "prev_key" in saved_state:
141
+ # previous time steps are cached - no need to recompute
142
+ # key and value if they are static
143
+ if static_kv:
144
+ assert self.encoder_decoder_attention and not self.self_attention
145
+ key = value = None
146
+ else:
147
+ saved_state = None
148
+
149
+ if self.self_attention:
150
+ q = self.q_proj(query)
151
+ k = self.k_proj(query)
152
+ v = self.v_proj(query)
153
+ elif self.encoder_decoder_attention:
154
+ # encoder-decoder attention
155
+ q = self.q_proj(query)
156
+ if key is None:
157
+ assert value is None
158
+ k = v = None
159
+ else:
160
+ k = self.k_proj(key)
161
+ v = self.v_proj(key)
162
+
163
+ else:
164
+ assert key is not None and value is not None
165
+ q = self.q_proj(query)
166
+ k = self.k_proj(key)
167
+ v = self.v_proj(value)
168
+ q *= self.scaling
169
+
170
+ if self.bias_k is not None:
171
+ assert self.bias_v is not None
172
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
173
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
174
+ if attn_mask is not None:
175
+ attn_mask = torch.cat(
176
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
177
+ )
178
+ if key_padding_mask is not None:
179
+ key_padding_mask = torch.cat(
180
+ [
181
+ key_padding_mask,
182
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
183
+ ],
184
+ dim=1,
185
+ )
186
+
187
+ q = (
188
+ q.contiguous()
189
+ .view(tgt_len, bsz * self.num_heads, self.head_dim)
190
+ .transpose(0, 1)
191
+ )
192
+ if k is not None:
193
+ k = (
194
+ k.contiguous()
195
+ .view(-1, bsz * self.num_heads, self.head_dim)
196
+ .transpose(0, 1)
197
+ )
198
+ if v is not None:
199
+ v = (
200
+ v.contiguous()
201
+ .view(-1, bsz * self.num_heads, self.head_dim)
202
+ .transpose(0, 1)
203
+ )
204
+
205
+ if saved_state is not None:
206
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
207
+ if "prev_key" in saved_state:
208
+ _prev_key = saved_state["prev_key"]
209
+ assert _prev_key is not None
210
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
211
+ if static_kv:
212
+ k = prev_key
213
+ else:
214
+ assert k is not None
215
+ k = torch.cat([prev_key, k], dim=1)
216
+ src_len = k.size(1)
217
+ if "prev_value" in saved_state:
218
+ _prev_value = saved_state["prev_value"]
219
+ assert _prev_value is not None
220
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
221
+ if static_kv:
222
+ v = prev_value
223
+ else:
224
+ assert v is not None
225
+ v = torch.cat([prev_value, v], dim=1)
226
+ prev_key_padding_mask: Optional[Tensor] = None
227
+ if "prev_key_padding_mask" in saved_state:
228
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
229
+ assert k is not None and v is not None
230
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
231
+ key_padding_mask=key_padding_mask,
232
+ prev_key_padding_mask=prev_key_padding_mask,
233
+ batch_size=bsz,
234
+ src_len=k.size(1),
235
+ static_kv=static_kv,
236
+ )
237
+
238
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
239
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
240
+ saved_state["prev_key_padding_mask"] = key_padding_mask
241
+ # In this branch incremental_state is never None
242
+ assert incremental_state is not None
243
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
244
+ assert k is not None
245
+ assert k.size(1) == src_len
246
+
247
+ # This is part of a workaround to get around fork/join parallelism
248
+ # not supporting Optional types.
249
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
250
+ key_padding_mask = None
251
+
252
+ if key_padding_mask is not None:
253
+ assert key_padding_mask.size(0) == bsz
254
+ assert key_padding_mask.size(1) == src_len
255
+
256
+ if self.add_zero_attn:
257
+ assert v is not None
258
+ src_len += 1
259
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
260
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
261
+ if attn_mask is not None:
262
+ attn_mask = torch.cat(
263
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
264
+ )
265
+ if key_padding_mask is not None:
266
+ key_padding_mask = torch.cat(
267
+ [
268
+ key_padding_mask,
269
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
270
+ key_padding_mask
271
+ ),
272
+ ],
273
+ dim=1,
274
+ )
275
+
276
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
277
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
278
+
279
+ if position_bias is not None: ## first order
280
+ ## position_bias: [241, 241, 64]
281
+ #print ("attn_weights: ", attn_weights.size()) # [492, 241, 241]
282
+ reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64]
283
+ #print ("reshape_q: ", reshape_q.size())
284
+ B = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
285
+ #print ("B: ", B.size()) ## [241, 492, 241]
286
+ #B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1))
287
+ B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1))
288
+ #print ("B 2: ", B.size())
289
+ attn_weights += B
290
+
291
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
292
+
293
+ if attn_mask is not None:
294
+ attn_mask = attn_mask.unsqueeze(0)
295
+ if self.onnx_trace:
296
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
297
+ attn_weights += attn_mask
298
+
299
+ if key_padding_mask is not None:
300
+ # don't attend to padding symbols
301
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
302
+ if not is_tpu:
303
+ attn_weights = attn_weights.masked_fill(
304
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
305
+ float("-inf"),
306
+ )
307
+ else:
308
+ attn_weights = attn_weights.transpose(0, 2)
309
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
310
+ attn_weights = attn_weights.transpose(0, 2)
311
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
312
+
313
+ if before_softmax:
314
+ return attn_weights, v
315
+
316
+ attn_weights_float = utils.softmax(
317
+ attn_weights, dim=-1, onnx_trace=self.onnx_trace
318
+ )
319
+ attn_weights = attn_weights_float.type_as(attn_weights)
320
+ attn_probs = self.dropout_module(attn_weights)
321
+
322
+ assert v is not None
323
+ attn = torch.bmm(attn_probs, v)
324
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
325
+ if self.onnx_trace and attn.size(1) == 1:
326
+ # when ONNX tracing a single decoder step (sequence length == 1)
327
+ # the transpose is a no-op copy before view, thus unnecessary
328
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
329
+ else:
330
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
331
+ attn = self.out_proj(attn)
332
+ attn_weights: Optional[Tensor] = None
333
+ if need_weights:
334
+ attn_weights = attn_weights_float.view(
335
+ bsz, self.num_heads, tgt_len, src_len
336
+ ).transpose(1, 0)
337
+ if not need_head_weights:
338
+ # average attention weights over heads
339
+ attn_weights = attn_weights.mean(dim=0)
340
+
341
+ return attn, attn_weights
SpeechT5/Speech2C/speech2c/models/modules/relative_pos_enc.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import torch
11
+
12
+ class RelativePositionalEncoding(torch.nn.Module):
13
+ def __init__(self, d_model, maxlen=1000, embed_v=False):
14
+ super(RelativePositionalEncoding, self).__init__()
15
+
16
+ self.d_model = d_model
17
+ self.maxlen = maxlen
18
+ self.pe_k = torch.nn.Embedding(2*maxlen, d_model)
19
+ if embed_v:
20
+ self.pe_v = torch.nn.Embedding(2*maxlen, d_model)
21
+ self.embed_v = embed_v
22
+
23
+
24
+ def forward(self, pos_seq, incremental_state=None):
25
+ pos_seq[pos_seq < -self.maxlen] = -self.maxlen
26
+ pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1
27
+ pos_seq = pos_seq + self.maxlen
28
+
29
+ if incremental_state is not None:
30
+ pos_seq = pos_seq[-1:]
31
+
32
+ if self.embed_v:
33
+ return self.pe_k(pos_seq), self.pe_v(pos_seq)
34
+ else:
35
+ return self.pe_k(pos_seq), None
SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from fairseq import utils
16
+ from fairseq.distributed import fsdp_wrap
17
+ from fairseq.models import FairseqIncrementalDecoder
18
+ from fairseq.models.transformer import TransformerConfig
19
+ from fairseq.models.transformer.transformer_decoder import module_name_fordropout, Linear
20
+ from fairseq.modules import (
21
+ AdaptiveSoftmax,
22
+ BaseLayer,
23
+ FairseqDropout,
24
+ LayerDropModuleList,
25
+ LayerNorm,
26
+ PositionalEmbedding,
27
+ SinusoidalPositionalEmbedding,
28
+ )
29
+ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
30
+ from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
31
+ from torch import Tensor
32
+
33
+
34
+ from speech2c.models.modules.transformer_decoder_layer import TransformerDecoderLayerBase
35
+ from speech2c.models.modules.relative_pos_enc import RelativePositionalEncoding
36
+
37
+
38
+ class TransformerDecoderBase(FairseqIncrementalDecoder):
39
+ """
40
+ Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer
41
+ is a :class:`TransformerDecoderLayer`.
42
+
43
+ Args:
44
+ args (argparse.Namespace): parsed command-line arguments
45
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
46
+ embed_tokens (torch.nn.Embedding): output embedding
47
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
48
+ (default: False).
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ cfg,
54
+ dictionary,
55
+ embed_tokens,
56
+ no_encoder_attn=False,
57
+ output_projection=None,
58
+ use_rel_pos_enc=False,
59
+ ):
60
+ self.cfg = cfg
61
+ super().__init__(dictionary)
62
+ self.register_buffer("version", torch.Tensor([3]))
63
+ self._future_mask = torch.empty(0)
64
+
65
+ self.dropout_module = FairseqDropout(
66
+ cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
67
+ )
68
+ self.decoder_layerdrop = cfg.decoder.layerdrop
69
+ self.share_input_output_embed = cfg.share_decoder_input_output_embed
70
+
71
+ input_embed_dim = embed_tokens.embedding_dim
72
+ embed_dim = cfg.decoder.embed_dim
73
+ self.embed_dim = embed_dim
74
+ self.output_embed_dim = cfg.decoder.output_dim
75
+
76
+ self.padding_idx = embed_tokens.padding_idx
77
+ self.max_target_positions = cfg.max_target_positions
78
+
79
+ self.embed_tokens = embed_tokens
80
+
81
+ self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
82
+
83
+ if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
84
+ self.quant_noise = apply_quant_noise_(
85
+ nn.Linear(embed_dim, embed_dim, bias=False),
86
+ cfg.quant_noise.pq,
87
+ cfg.quant_noise.pq_block_size,
88
+ )
89
+ else:
90
+ self.quant_noise = None
91
+
92
+ self.project_in_dim = (
93
+ Linear(input_embed_dim, embed_dim, bias=False)
94
+ if embed_dim != input_embed_dim
95
+ else None
96
+ )
97
+ self.embed_positions = (
98
+ PositionalEmbedding(
99
+ self.max_target_positions,
100
+ embed_dim,
101
+ self.padding_idx,
102
+ learned=cfg.decoder.learned_pos,
103
+ )
104
+ if not cfg.no_token_positional_embeddings
105
+ else None
106
+ )
107
+ if cfg.layernorm_embedding:
108
+ self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
109
+ else:
110
+ self.layernorm_embedding = None
111
+
112
+ self.cross_self_attention = cfg.cross_self_attention
113
+
114
+ self.use_rel_pos_enc = use_rel_pos_enc
115
+ if self.decoder_layerdrop > 0.0:
116
+ self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
117
+ else:
118
+ self.layers = nn.ModuleList([])
119
+ self.layers.extend(
120
+ [
121
+ self.build_decoder_layer(cfg, no_encoder_attn)
122
+ for _ in range(cfg.decoder.layers)
123
+ ]
124
+ )
125
+ self.num_layers = len(self.layers)
126
+
127
+ if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm:
128
+ self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
129
+ else:
130
+ self.layer_norm = None
131
+
132
+ self.project_out_dim = (
133
+ Linear(embed_dim, self.output_embed_dim, bias=False)
134
+ if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights
135
+ else None
136
+ )
137
+
138
+ self.adaptive_softmax = None
139
+ self.output_projection = output_projection
140
+ if self.output_projection is None:
141
+ self.build_output_projection(cfg, dictionary, embed_tokens)
142
+
143
+ if self.use_rel_pos_enc:
144
+ self.pos_emb = RelativePositionalEncoding(self.embed_dim // cfg.decoder.attention_heads, 24)
145
+
146
+ def build_output_projection(self, cfg, dictionary, embed_tokens):
147
+ if cfg.adaptive_softmax_cutoff is not None:
148
+ self.adaptive_softmax = AdaptiveSoftmax(
149
+ len(dictionary),
150
+ self.output_embed_dim,
151
+ utils.eval_str_list(cfg.adaptive_softmax_cutoff, type=int),
152
+ dropout=cfg.adaptive_softmax_dropout,
153
+ adaptive_inputs=embed_tokens if cfg.tie_adaptive_weights else None,
154
+ factor=cfg.adaptive_softmax_factor,
155
+ tie_proj=cfg.tie_adaptive_proj,
156
+ )
157
+ elif self.share_input_output_embed:
158
+ self.output_projection = nn.Linear(
159
+ self.embed_tokens.weight.shape[1],
160
+ self.embed_tokens.weight.shape[0],
161
+ bias=False,
162
+ )
163
+ self.output_projection.weight = self.embed_tokens.weight
164
+ else:
165
+ self.output_projection = nn.Linear(
166
+ self.output_embed_dim, len(dictionary), bias=False
167
+ )
168
+ nn.init.normal_(
169
+ self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
170
+ )
171
+ num_base_layers = cfg.base_layers
172
+ for i in range(num_base_layers):
173
+ self.layers.insert(
174
+ ((i + 1) * cfg.decoder.layers) // (num_base_layers + 1),
175
+ BaseLayer(cfg),
176
+ )
177
+
178
+ def build_decoder_layer(self, cfg, no_encoder_attn=False):
179
+ layer = TransformerDecoderLayerBase(cfg, no_encoder_attn, has_relative_attention_bias=self.use_rel_pos_enc)
180
+ checkpoint = cfg.checkpoint_activations
181
+ if checkpoint:
182
+ offload_to_cpu = cfg.offload_activations
183
+ layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
184
+ # if we are checkpointing, enforce that FSDP always wraps the
185
+ # checkpointed layer, regardless of layer size
186
+ min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
187
+ layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
188
+ return layer
189
+
190
+ def forward(
191
+ self,
192
+ prev_output_tokens,
193
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
194
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
195
+ features_only: bool = False,
196
+ full_context_alignment: bool = False,
197
+ alignment_layer: Optional[int] = None,
198
+ alignment_heads: Optional[int] = None,
199
+ src_lengths: Optional[Any] = None,
200
+ return_all_hiddens: bool = False,
201
+ ):
202
+ """
203
+ Args:
204
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
205
+ `(batch, tgt_len)`, for teacher forcing
206
+ encoder_out (optional): output from the encoder, used for
207
+ encoder-side attention, should be of size T x B x C
208
+ incremental_state (dict): dictionary used for storing state during
209
+ :ref:`Incremental decoding`
210
+ features_only (bool, optional): only return features without
211
+ applying output layer (default: False).
212
+ full_context_alignment (bool, optional): don't apply
213
+ auto-regressive mask to self-attention (default: False).
214
+
215
+ Returns:
216
+ tuple:
217
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
218
+ - a dictionary with any model-specific outputs
219
+ """
220
+
221
+ x, extra = self.extract_features(
222
+ prev_output_tokens,
223
+ encoder_out=encoder_out,
224
+ incremental_state=incremental_state,
225
+ full_context_alignment=full_context_alignment,
226
+ alignment_layer=alignment_layer,
227
+ alignment_heads=alignment_heads,
228
+ )
229
+
230
+ if not features_only:
231
+ x = self.output_layer(x)
232
+ return x, extra
233
+
234
+ def extract_features_scriptable(
235
+ self,
236
+ prev_output_tokens,
237
+ encoder_out: Optional[Dict[str, List[Tensor]]],
238
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
239
+ full_context_alignment: bool = False,
240
+ alignment_layer: Optional[int] = None,
241
+ alignment_heads: Optional[int] = None,
242
+ ):
243
+ """
244
+ Similar to *forward* but only return features.
245
+
246
+ Includes several features from "Jointly Learning to Align and
247
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
248
+
249
+ Args:
250
+ full_context_alignment (bool, optional): don't apply
251
+ auto-regressive mask to self-attention (default: False).
252
+ alignment_layer (int, optional): return mean alignment over
253
+ heads at this layer (default: last layer).
254
+ alignment_heads (int, optional): only average alignment over
255
+ this many heads (default: all heads).
256
+
257
+ Returns:
258
+ tuple:
259
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
260
+ - a dictionary with any model-specific outputs
261
+ """
262
+ bs, slen = prev_output_tokens.size()
263
+ if alignment_layer is None:
264
+ alignment_layer = self.num_layers - 1
265
+
266
+ enc: Optional[Tensor] = None
267
+ padding_mask: Optional[Tensor] = None
268
+ if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
269
+ enc = encoder_out["encoder_out"][0]
270
+ assert (
271
+ enc.size()[1] == bs
272
+ ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
273
+ if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
274
+ padding_mask = encoder_out["encoder_padding_mask"][0]
275
+
276
+ # embed positions
277
+ positions = None
278
+ if self.embed_positions is not None:
279
+ positions = self.embed_positions(
280
+ prev_output_tokens, incremental_state=incremental_state
281
+ )
282
+
283
+ if incremental_state is not None:
284
+ prev_output_tokens = prev_output_tokens[:, -1:]
285
+ if positions is not None:
286
+ positions = positions[:, -1:]
287
+
288
+ # embed tokens and positions
289
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
290
+
291
+ if self.quant_noise is not None:
292
+ x = self.quant_noise(x)
293
+
294
+ if self.project_in_dim is not None:
295
+ x = self.project_in_dim(x)
296
+
297
+ if positions is not None:
298
+ x += positions
299
+
300
+ if self.layernorm_embedding is not None:
301
+ x = self.layernorm_embedding(x)
302
+
303
+ x = self.dropout_module(x)
304
+
305
+ # B x T x C -> T x B x C
306
+ x = x.transpose(0, 1)
307
+ if self.use_rel_pos_enc:
308
+ pos_seq = torch.arange(0, slen).long().to(x.device)
309
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
310
+ pos_k, _ = self.pos_emb(pos_seq, incremental_state)
311
+ else:
312
+ pos_k = None
313
+
314
+ self_attn_padding_mask: Optional[Tensor] = None
315
+ if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
316
+ self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
317
+
318
+ # decoder layers
319
+ attn: Optional[Tensor] = None
320
+ inner_states: List[Optional[Tensor]] = [x]
321
+ for idx, layer in enumerate(self.layers):
322
+ if incremental_state is None and not full_context_alignment:
323
+ self_attn_mask = self.buffered_future_mask(x)
324
+ else:
325
+ self_attn_mask = None
326
+
327
+ x, layer_attn, _ = layer(
328
+ x,
329
+ enc,
330
+ padding_mask,
331
+ incremental_state,
332
+ self_attn_mask=self_attn_mask,
333
+ self_attn_padding_mask=self_attn_padding_mask,
334
+ need_attn=bool((idx == alignment_layer)),
335
+ need_head_weights=bool((idx == alignment_layer)),
336
+ pos_bias=pos_k,
337
+ )
338
+ inner_states.append(x)
339
+ if layer_attn is not None and idx == alignment_layer:
340
+ attn = layer_attn.float().to(x)
341
+
342
+ if attn is not None:
343
+ if alignment_heads is not None:
344
+ attn = attn[:alignment_heads]
345
+
346
+ # average probabilities over heads
347
+ attn = attn.mean(dim=0)
348
+
349
+ if self.layer_norm is not None:
350
+ x = self.layer_norm(x)
351
+
352
+ # T x B x C -> B x T x C
353
+ x = x.transpose(0, 1)
354
+
355
+ if self.project_out_dim is not None:
356
+ x = self.project_out_dim(x)
357
+
358
+ return x, {"attn": [attn], "inner_states": inner_states}
359
+
360
+ def output_layer(self, features):
361
+ """Project features to the vocabulary size."""
362
+ if self.adaptive_softmax is None:
363
+ # project back to size of vocabulary
364
+ return self.output_projection(features)
365
+ else:
366
+ return features
367
+
368
+ def max_positions(self):
369
+ """Maximum output length supported by the decoder."""
370
+ if self.embed_positions is None:
371
+ return self.max_target_positions
372
+ return min(self.max_target_positions, self.embed_positions.max_positions)
373
+
374
+ def buffered_future_mask(self, tensor):
375
+ dim = tensor.size(0)
376
+ # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
377
+ if (
378
+ self._future_mask.size(0) == 0
379
+ or (not self._future_mask.device == tensor.device)
380
+ or self._future_mask.size(0) < dim
381
+ ):
382
+ self._future_mask = torch.triu(
383
+ utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
384
+ )
385
+ self._future_mask = self._future_mask.to(tensor)
386
+ return self._future_mask[:dim, :dim]
387
+
388
+ def upgrade_state_dict_named(self, state_dict, name):
389
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
390
+ if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
391
+ weights_key = "{}.embed_positions.weights".format(name)
392
+ if weights_key in state_dict:
393
+ del state_dict[weights_key]
394
+ state_dict[
395
+ "{}.embed_positions._float_tensor".format(name)
396
+ ] = torch.FloatTensor(1)
397
+
398
+ if f"{name}.output_projection.weight" not in state_dict:
399
+ if self.share_input_output_embed:
400
+ embed_out_key = f"{name}.embed_tokens.weight"
401
+ else:
402
+ embed_out_key = f"{name}.embed_out"
403
+ if embed_out_key in state_dict:
404
+ state_dict[f"{name}.output_projection.weight"] = state_dict[
405
+ embed_out_key
406
+ ]
407
+ if not self.share_input_output_embed:
408
+ del state_dict[embed_out_key]
409
+
410
+ for i in range(self.num_layers):
411
+ # update layer norms
412
+ layer_norm_map = {
413
+ "0": "self_attn_layer_norm",
414
+ "1": "encoder_attn_layer_norm",
415
+ "2": "final_layer_norm",
416
+ }
417
+ for old, new in layer_norm_map.items():
418
+ for m in ("weight", "bias"):
419
+ k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
420
+ if k in state_dict:
421
+ state_dict[
422
+ "{}.layers.{}.{}.{}".format(name, i, new, m)
423
+ ] = state_dict[k]
424
+ del state_dict[k]
425
+
426
+ version_key = "{}.version".format(name)
427
+ if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
428
+ # earlier checkpoints did not normalize after the stack of layers
429
+ self.layer_norm = None
430
+ self.normalize = False
431
+ state_dict[version_key] = torch.Tensor([1])
432
+
433
+ return state_dict
434
+
435
+
436
+ class TransformerDecoder(TransformerDecoderBase):
437
+ def __init__(
438
+ self,
439
+ args,
440
+ dictionary,
441
+ embed_tokens,
442
+ no_encoder_attn=False,
443
+ output_projection=None,
444
+ ):
445
+ self.args = args
446
+ super().__init__(
447
+ TransformerConfig.from_namespace(args),
448
+ dictionary,
449
+ embed_tokens,
450
+ no_encoder_attn=no_encoder_attn,
451
+ output_projection=output_projection,
452
+ use_rel_pos_enc=args.use_rel_pos_enc,
453
+ )
454
+
455
+ def build_output_projection(self, args, dictionary, embed_tokens):
456
+ super().build_output_projection(
457
+ TransformerConfig.from_namespace(args), dictionary, embed_tokens
458
+ )
459
+
460
+ def build_decoder_layer(self, args, no_encoder_attn=False):
461
+ return super().build_decoder_layer(
462
+ TransformerConfig.from_namespace(args), no_encoder_attn=no_encoder_attn
463
+ )
464
+
465
+ class TransformerDecoderScriptable(TransformerDecoder):
466
+ def extract_features(
467
+ self,
468
+ prev_output_tokens,
469
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
470
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
471
+ full_context_alignment: bool = False,
472
+ alignment_layer: Optional[int] = None,
473
+ alignment_heads: Optional[int] = None,
474
+ ):
475
+ # call scriptable method from parent class
476
+ x, _ = self.extract_features_scriptable(
477
+ prev_output_tokens,
478
+ encoder_out,
479
+ incremental_state,
480
+ full_context_alignment,
481
+ alignment_layer,
482
+ alignment_heads,
483
+ )
484
+ return x, None
485
+
SpeechT5/Speech2C/speech2c/models/modules/transformer_decoder_layer.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ from typing import Dict, List, Optional
11
+
12
+ import torch
13
+ from torch import Tensor
14
+ from fairseq.modules.transformer_layer import TransformerDecoderLayerBase as FairseqTransformerDecoderLayerBase
15
+ from fairseq.modules import LayerNorm
16
+
17
+ from speech2c.models.modules.multihead_attention import MultiheadAttention
18
+
19
+
20
+ class TransformerDecoderLayerBase(FairseqTransformerDecoderLayerBase):
21
+ """Decoder layer block.
22
+
23
+ In the original paper each operation (multi-head attention, encoder
24
+ attention or FFN) is postprocessed with: `dropout -> add residual ->
25
+ layernorm`. In the tensor2tensor code they suggest that learning is more
26
+ robust when preprocessing each layer with layernorm and postprocessing with:
27
+ `dropout -> add residual`. We default to the approach in the paper, but the
28
+ tensor2tensor approach can be enabled by setting
29
+ *cfg.decoder.normalize_before* to ``True``.
30
+
31
+ Args:
32
+ args (argparse.Namespace): parsed command-line arguments
33
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
34
+ (default: False).
35
+ """
36
+
37
+ def __init__(
38
+ self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, has_relative_attention_bias=False
39
+ ):
40
+ super().__init__(
41
+ cfg,
42
+ no_encoder_attn,
43
+ add_bias_kv,
44
+ add_zero_attn,
45
+ )
46
+
47
+ if has_relative_attention_bias:
48
+ self.norm_k = LayerNorm(self.embed_dim // cfg.decoder.attention_heads)
49
+
50
+ def build_self_attention(
51
+ self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False
52
+ ):
53
+ return MultiheadAttention(
54
+ embed_dim,
55
+ cfg.decoder.attention_heads,
56
+ dropout=cfg.attention_dropout,
57
+ add_bias_kv=add_bias_kv,
58
+ add_zero_attn=add_zero_attn,
59
+ self_attention=not cfg.cross_self_attention,
60
+ q_noise=self.quant_noise,
61
+ qn_block_size=self.quant_noise_block_size,
62
+ )
63
+
64
+ def forward(
65
+ self,
66
+ x,
67
+ encoder_out: Optional[torch.Tensor] = None,
68
+ encoder_padding_mask: Optional[torch.Tensor] = None,
69
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
70
+ prev_self_attn_state: Optional[List[torch.Tensor]] = None,
71
+ prev_attn_state: Optional[List[torch.Tensor]] = None,
72
+ self_attn_mask: Optional[torch.Tensor] = None,
73
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
74
+ need_attn: bool = False,
75
+ need_head_weights: bool = False,
76
+ pos_bias=None,
77
+ ):
78
+ """
79
+ Args:
80
+ x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
81
+ encoder_padding_mask (ByteTensor, optional): binary
82
+ ByteTensor of shape `(batch, src_len)` where padding
83
+ elements are indicated by ``1``.
84
+ need_attn (bool, optional): return attention weights
85
+ need_head_weights (bool, optional): return attention weights
86
+ for each head (default: return average over heads).
87
+ Returns:
88
+ encoded output of shape `(seq_len, batch, embed_dim)`
89
+ """
90
+ if need_head_weights:
91
+ need_attn = True
92
+
93
+ residual = x
94
+ if self.normalize_before:
95
+ x = self.self_attn_layer_norm(x)
96
+ if pos_bias is not None:
97
+ pos_bias = self.norm_k(pos_bias)
98
+ if prev_self_attn_state is not None:
99
+ prev_key, prev_value = prev_self_attn_state[:2]
100
+ saved_state: Dict[str, Optional[Tensor]] = {
101
+ "prev_key": prev_key,
102
+ "prev_value": prev_value,
103
+ }
104
+ if len(prev_self_attn_state) >= 3:
105
+ saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
106
+ assert incremental_state is not None
107
+ self.self_attn._set_input_buffer(incremental_state, saved_state)
108
+ _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
109
+ if self.cross_self_attention and not (
110
+ incremental_state is not None
111
+ and _self_attn_input_buffer is not None
112
+ and "prev_key" in _self_attn_input_buffer
113
+ ):
114
+ if self_attn_mask is not None:
115
+ assert encoder_out is not None
116
+ self_attn_mask = torch.cat(
117
+ (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
118
+ )
119
+ if self_attn_padding_mask is not None:
120
+ if encoder_padding_mask is None:
121
+ assert encoder_out is not None
122
+ encoder_padding_mask = self_attn_padding_mask.new_zeros(
123
+ encoder_out.size(1), encoder_out.size(0)
124
+ )
125
+ self_attn_padding_mask = torch.cat(
126
+ (encoder_padding_mask, self_attn_padding_mask), dim=1
127
+ )
128
+ assert encoder_out is not None
129
+ y = torch.cat((encoder_out, x), dim=0)
130
+ else:
131
+ y = x
132
+
133
+ x, attn = self.self_attn(
134
+ query=x,
135
+ key=y,
136
+ value=y,
137
+ key_padding_mask=self_attn_padding_mask,
138
+ incremental_state=incremental_state,
139
+ need_weights=False,
140
+ attn_mask=self_attn_mask,
141
+ position_bias=pos_bias,
142
+ )
143
+ if self.c_attn is not None:
144
+ tgt_len, bsz = x.size(0), x.size(1)
145
+ x = x.view(tgt_len, bsz, self.nh, self.head_dim)
146
+ x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
147
+ x = x.reshape(tgt_len, bsz, self.embed_dim)
148
+ if self.attn_ln is not None:
149
+ x = self.attn_ln(x)
150
+ x = self.dropout_module(x)
151
+ x = self.residual_connection(x, residual)
152
+ if not self.normalize_before:
153
+ x = self.self_attn_layer_norm(x)
154
+
155
+ if self.encoder_attn is not None and encoder_out is not None:
156
+ residual = x
157
+ if self.normalize_before:
158
+ x = self.encoder_attn_layer_norm(x)
159
+ if prev_attn_state is not None:
160
+ prev_key, prev_value = prev_attn_state[:2]
161
+ saved_state: Dict[str, Optional[Tensor]] = {
162
+ "prev_key": prev_key,
163
+ "prev_value": prev_value,
164
+ }
165
+ if len(prev_attn_state) >= 3:
166
+ saved_state["prev_key_padding_mask"] = prev_attn_state[2]
167
+ assert incremental_state is not None
168
+ self.encoder_attn._set_input_buffer(incremental_state, saved_state)
169
+
170
+ x, attn = self.encoder_attn(
171
+ query=x,
172
+ key=encoder_out,
173
+ value=encoder_out,
174
+ key_padding_mask=encoder_padding_mask,
175
+ incremental_state=incremental_state,
176
+ static_kv=True,
177
+ need_weights=need_attn or (not self.training and self.need_attn),
178
+ need_head_weights=need_head_weights,
179
+ )
180
+ x = self.dropout_module(x)
181
+ x = self.residual_connection(x, residual)
182
+ if not self.normalize_before:
183
+ x = self.encoder_attn_layer_norm(x)
184
+
185
+ residual = x
186
+ if self.normalize_before:
187
+ x = self.final_layer_norm(x)
188
+
189
+ x = self.activation_fn(self.fc1(x))
190
+ x = self.activation_dropout_module(x)
191
+ if self.ffn_layernorm is not None:
192
+ x = self.ffn_layernorm(x)
193
+ x = self.fc2(x)
194
+ x = self.dropout_module(x)
195
+ if self.w_resid is not None:
196
+ residual = torch.mul(self.w_resid, residual)
197
+ x = self.residual_connection(x, residual)
198
+ if not self.normalize_before:
199
+ x = self.final_layer_norm(x)
200
+ if self.onnx_trace and incremental_state is not None:
201
+ saved_state = self.self_attn._get_input_buffer(incremental_state)
202
+ assert saved_state is not None
203
+ if self_attn_padding_mask is not None:
204
+ self_attn_state = [
205
+ saved_state["prev_key"],
206
+ saved_state["prev_value"],
207
+ saved_state["prev_key_padding_mask"],
208
+ ]
209
+ else:
210
+ self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
211
+ return x, attn, self_attn_state
212
+ return x, attn, None
213
+
214
+ def make_generation_fast_(self, need_attn: bool = False, **kwargs):
215
+ self.need_attn = need_attn
SpeechT5/Speech2C/speech2c/models/modules/transformer_encoder.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from fairseq import utils
17
+ from fairseq.dataclass import ChoiceEnum
18
+ from fairseq.modules import (
19
+ LayerNorm,
20
+ MultiheadAttention,
21
+ SamePad,
22
+ )
23
+ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
24
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
25
+ from fairseq.utils import index_put
26
+ from fairseq.distributed import fsdp_wrap
27
+ from fairseq.models.wav2vec.utils import pad_to_multiple
28
+ from fairseq.models.wav2vec.wav2vec2 import TransformerEncoder as W2vTransformerEncoder
29
+
30
+ from speech2c.models.modules.relative_pos_enc import RelativePositionalEncoding
31
+ from speech2c.models.modules.multihead_attention import MultiheadAttention
32
+
33
+ EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
34
+ MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
35
+
36
+
37
+ class TransformerEncoder(W2vTransformerEncoder):
38
+ def __init__(self, args):
39
+ super().__init__(args)
40
+
41
+ self.dropout = args.dropout
42
+ self.embedding_dim = args.encoder_embed_dim
43
+ self.required_seq_len_multiple = args.required_seq_len_multiple
44
+ self.use_rel_pos_enc = getattr(args, "use_rel_pos_enc", False)
45
+
46
+ self.pos_conv = nn.Conv1d(
47
+ self.embedding_dim,
48
+ self.embedding_dim,
49
+ kernel_size=args.conv_pos,
50
+ padding=args.conv_pos // 2,
51
+ groups=args.conv_pos_groups,
52
+ )
53
+ dropout = 0
54
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
55
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
56
+ nn.init.constant_(self.pos_conv.bias, 0)
57
+
58
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
59
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
60
+
61
+ layers = []
62
+ for _ in range(args.encoder_layers):
63
+ layer = TransformerSentenceEncoderLayer(
64
+ embedding_dim=self.embedding_dim,
65
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
66
+ num_attention_heads=args.encoder_attention_heads,
67
+ dropout=self.dropout,
68
+ attention_dropout=args.attention_dropout,
69
+ activation_dropout=args.activation_dropout,
70
+ activation_fn=args.activation_fn,
71
+ layer_norm_first=args.layer_norm_first,
72
+ has_relative_attention_bias=self.use_rel_pos_enc,
73
+ )
74
+ if args.checkpoint_activations:
75
+ layer = fsdp_wrap(layer)
76
+ layer = checkpoint_wrapper(layer)
77
+ layers.append(layer)
78
+ self.layers = nn.ModuleList(layers)
79
+
80
+ self.layer_norm_first = args.layer_norm_first
81
+ self.layer_norm = LayerNorm(self.embedding_dim)
82
+ self.layerdrop = args.encoder_layerdrop
83
+ if self.use_rel_pos_enc:
84
+ self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim // args.encoder_attention_heads, 160)
85
+
86
+
87
+ self.apply(init_bert_params)
88
+
89
+ def forward(self, x, padding_mask=None, layer=None):
90
+ x, layer_results = self.extract_features(x, padding_mask, layer)
91
+
92
+ if self.layer_norm_first and layer is None:
93
+ x = self.layer_norm(x)
94
+
95
+ return x, layer_results
96
+
97
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
98
+
99
+ if padding_mask is not None:
100
+ x = index_put(x, padding_mask, 0)
101
+
102
+ x_conv = self.pos_conv(x.transpose(1, 2))
103
+ x_conv = x_conv.transpose(1, 2)
104
+ x = x + x_conv
105
+
106
+ if not self.layer_norm_first:
107
+ x = self.layer_norm(x)
108
+
109
+ # pad to the sequence length dimension
110
+ x, pad_length = pad_to_multiple(
111
+ x, self.required_seq_len_multiple, dim=-2, value=0
112
+ )
113
+ if pad_length > 0 and padding_mask is None:
114
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
115
+ padding_mask[:, -pad_length:] = True
116
+ else:
117
+ padding_mask, _ = pad_to_multiple(
118
+ padding_mask, self.required_seq_len_multiple, dim=-1, value=True
119
+ )
120
+ x = F.dropout(x, p=self.dropout, training=self.training)
121
+
122
+ # B x T x C -> T x B x C
123
+ x = x.transpose(0, 1)
124
+
125
+ if self.use_rel_pos_enc:
126
+ x_len = x.shape[0]
127
+ pos_seq = torch.arange(0, x_len).long().to(x.device)
128
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
129
+ pos_k, pos_v = self.pos_emb(pos_seq)
130
+ else:
131
+ pos_k = None
132
+
133
+ layer_results = []
134
+ r = None
135
+ for i, layer in enumerate(self.layers):
136
+ dropout_probability = np.random.random()
137
+ if not self.training or (dropout_probability > self.layerdrop):
138
+ x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_k)
139
+ if tgt_layer is not None:
140
+ # unpad if needed
141
+ if pad_length > 0:
142
+ layer_results.append(
143
+ (
144
+ x[:-pad_length],
145
+ z[:, :-pad_length, :-pad_length]
146
+ if z is not None
147
+ else z,
148
+ )
149
+ )
150
+ else:
151
+ layer_results.append((x, z))
152
+ if i == tgt_layer:
153
+ r = x
154
+ break
155
+
156
+ if r is not None:
157
+ x = r
158
+
159
+ # T x B x C -> B x T x C
160
+ x = x.transpose(0, 1)
161
+ # undo paddding
162
+ if pad_length > 0:
163
+ x = x[:, :-pad_length]
164
+
165
+ return x, layer_results
166
+
167
+
168
+ class TransformerSentenceEncoderLayer(nn.Module):
169
+ """
170
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
171
+ models.
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ embedding_dim: float = 768,
177
+ ffn_embedding_dim: float = 3072,
178
+ num_attention_heads: float = 8,
179
+ dropout: float = 0.1,
180
+ attention_dropout: float = 0.1,
181
+ activation_dropout: float = 0.1,
182
+ activation_fn: str = "relu",
183
+ layer_norm_first: bool = False,
184
+ has_relative_attention_bias: bool = False,
185
+ ) -> None:
186
+
187
+ super().__init__()
188
+ # Initialize parameters
189
+ self.embedding_dim = embedding_dim
190
+ self.dropout = dropout
191
+ self.activation_dropout = activation_dropout
192
+
193
+ # Initialize blocks
194
+ self.activation_fn = utils.get_activation_fn(activation_fn)
195
+ self.self_attn = MultiheadAttention(
196
+ self.embedding_dim,
197
+ num_attention_heads,
198
+ dropout=attention_dropout,
199
+ self_attention=True,
200
+ )
201
+
202
+ self.dropout1 = nn.Dropout(dropout)
203
+ self.dropout2 = nn.Dropout(self.activation_dropout)
204
+ self.dropout3 = nn.Dropout(dropout)
205
+
206
+ self.layer_norm_first = layer_norm_first
207
+
208
+ # layer norm associated with the self attention layer
209
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
210
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
211
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
212
+
213
+ # layer norm associated with the position wise feed-forward NN
214
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
215
+
216
+ if has_relative_attention_bias:
217
+ self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads)
218
+
219
+ def forward(
220
+ self,
221
+ x: torch.Tensor,
222
+ self_attn_mask: torch.Tensor = None,
223
+ self_attn_padding_mask: torch.Tensor = None,
224
+ need_weights: bool = False,
225
+ att_args=None,
226
+ pos_bias=None,
227
+ ):
228
+ """
229
+ LayerNorm is applied either before or after the self-attention/ffn
230
+ modules similar to the original Transformer imlementation.
231
+ """
232
+ residual = x
233
+
234
+ if self.layer_norm_first:
235
+ x = self.self_attn_layer_norm(x)
236
+ if pos_bias is not None:
237
+ pos_bias = self.norm_k(pos_bias)
238
+ x, attn = self.self_attn(
239
+ query=x,
240
+ key=x,
241
+ value=x,
242
+ key_padding_mask=self_attn_padding_mask,
243
+ attn_mask=self_attn_mask,
244
+ position_bias=pos_bias,
245
+ )
246
+ x = self.dropout1(x)
247
+ x = residual + x
248
+
249
+ residual = x
250
+ x = self.final_layer_norm(x)
251
+ x = self.activation_fn(self.fc1(x))
252
+ x = self.dropout2(x)
253
+ x = self.fc2(x)
254
+ x = self.dropout3(x)
255
+ x = residual + x
256
+ else:
257
+ x, attn = self.self_attn(
258
+ query=x,
259
+ key=x,
260
+ value=x,
261
+ key_padding_mask=self_attn_padding_mask,
262
+ position_bias=pos_bias,
263
+ )
264
+
265
+ x = self.dropout1(x)
266
+ x = residual + x
267
+
268
+ x = self.self_attn_layer_norm(x)
269
+
270
+ residual = x
271
+ x = self.activation_fn(self.fc1(x))
272
+ x = self.dropout2(x)
273
+ x = self.fc2(x)
274
+ x = self.dropout3(x)
275
+ x = residual + x
276
+ x = self.final_layer_norm(x)
277
+
278
+ return x, attn
SpeechT5/Speech2C/speech2c/models/speech2c.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import logging
11
+ import copy
12
+ import contextlib
13
+ from typing import Dict, List, Optional, Tuple
14
+
15
+ import torch
16
+ from dataclasses import dataclass, field
17
+ from fairseq.data.dictionary import Dictionary
18
+ from fairseq.models import register_model
19
+ from fairseq.models.hubert import HubertConfig, HubertModel
20
+ from fairseq.models.transformer import Embedding
21
+ from torch import Tensor
22
+ from speech2c.tasks.speech2c_pretraining import (
23
+ Speech2cPretrainingConfig,
24
+ Speech2cPretrainingTask,
25
+ )
26
+
27
+ from speech2c.models.modules.transformer_decoder import TransformerDecoderScriptable
28
+ from speech2c.models.modules.transformer_encoder import TransformerEncoder
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ @dataclass
34
+ class Speech2cConfig(HubertConfig):
35
+ use_rel_pos_enc: bool = field(
36
+ default=False,
37
+ metadata={"help": "whether to use relative positional encoding"},
38
+ )
39
+
40
+ # decoder
41
+ decoder_layers: int = field(
42
+ default=6, metadata={"help": "num decoder layers in the transformer"}
43
+ )
44
+ decoder_embed_dim: int = field(
45
+ default=768, metadata={"help": "decoder embedding dimension"}
46
+ )
47
+ decoder_ffn_embed_dim: int = field(
48
+ default=3072, metadata={"help": "decoder embedding dimension for FFN"}
49
+ )
50
+ decoder_attention_heads: int = field(
51
+ default=12, metadata={"help": "num decoder attention heads"}
52
+ )
53
+ decoder_normalize_before: bool = field(
54
+ default=False,
55
+ metadata={"help": "apply layernorm before each decoder block"},
56
+ )
57
+ decoder_layerdrop: float = field(
58
+ default=0.0,
59
+ metadata={"help": "probability of dropping a tarnsformer layer"},
60
+ )
61
+ share_decoder_input_output_embed: bool = field(
62
+ default=False,
63
+ metadata={"help": "share decoder input and output embeddings"},
64
+ )
65
+ decoder_output_dim: int = field(
66
+ default=768, metadata={"help": "decoder output dimension"}
67
+ )
68
+ max_target_positions: int = field(
69
+ default=3000, metadata={"help": "max target position"}
70
+ )
71
+ no_scale_embedding: bool = field(
72
+ default=False,
73
+ metadata={"help": "not scale embedding"},
74
+ )
75
+ adaptive_input: bool = field(
76
+ default=False,
77
+ metadata={"help": "adaptive input"},
78
+ )
79
+ quant_noise_pq: int = field(
80
+ default=0, metadata={"help": "quant noise pq"}
81
+ )
82
+ decoder_learned_pos: bool = field(
83
+ default=False,
84
+ metadata={"help": "decoder learnable positional embedding"},
85
+ )
86
+ no_token_positional_embeddings: bool = field(
87
+ default=False,
88
+ metadata={"help": "no token positional embeddings"},
89
+ )
90
+ decoder_dict_size: int = field(
91
+ default=-1,
92
+ metadata={"help": "decoder dictionary dimension, only used for fine-tuning"},
93
+ )
94
+
95
+ # FP16 optimization
96
+ required_seq_len_multiple: int = field(
97
+ default=1,
98
+ metadata={
99
+ "help": "pad the input to encoder such that the sequence length is divisible by multiple"
100
+ },
101
+ )
102
+ crop_seq_to_multiple: int = field(
103
+ default=1,
104
+ metadata={
105
+ "help": "crop convolutional feature extractor output such that the sequence length is divisible by multiple"
106
+ },
107
+ )
108
+
109
+
110
+ @register_model("speech2c", dataclass=Speech2cConfig)
111
+ class Speech2cModel(HubertModel):
112
+ def __init__(
113
+ self,
114
+ cfg: Speech2cConfig,
115
+ task_cfg: Speech2cPretrainingConfig,
116
+ dictionaries: List[Dictionary],
117
+ ) -> None:
118
+ super().__init__(cfg, task_cfg, dictionaries)
119
+ logger.info(f"Speech2cModel Config: {cfg}")
120
+
121
+ self.encoder = TransformerEncoder(cfg)
122
+
123
+ self.add_decoder = task_cfg.add_decoder
124
+ if task_cfg.add_decoder:
125
+ def build_embedding(dictionary, embed_dim):
126
+ num_embeddings = len(dictionary)
127
+ padding_idx = dictionary.pad()
128
+ return Embedding(num_embeddings, embed_dim, padding_idx)
129
+
130
+ # To make sure that the decoder dict size is the same as the fine-tuning tgt_dict size
131
+ cut_dictionary = copy.deepcopy(dictionaries[0])
132
+ if cfg.decoder_dict_size != -1:
133
+ cut_dictionary.symbols = cut_dictionary.symbols[:cfg.decoder_dict_size]
134
+
135
+ decoder_embed_tokens = build_embedding(
136
+ cut_dictionary, cfg.decoder_embed_dim
137
+ )
138
+
139
+ self.decoder = TransformerDecoderScriptable(cfg, cut_dictionary, decoder_embed_tokens)
140
+
141
+
142
+ @classmethod
143
+ def build_model(cls, cfg: Speech2cConfig, task: Speech2cPretrainingTask):
144
+ """Build a new model instance."""
145
+
146
+ model = Speech2cModel(cfg, task.cfg, task.dictionaries)
147
+ return model
148
+
149
+ def get_normalized_probs(
150
+ self,
151
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
152
+ log_probs: bool,
153
+ sample: Optional[Dict[str, Tensor]] = None,
154
+ ):
155
+ # net_output['encoder_out'] is a (B, T, D) tensor
156
+ lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
157
+ lprobs.batch_first = True
158
+ return lprobs
159
+
160
+ def forward(
161
+ self,
162
+ source: torch.Tensor,
163
+ target_list: Optional[List[torch.Tensor]] = None,
164
+ padding_mask: Optional[torch.Tensor] = None,
165
+ mask: bool = True,
166
+ features_only: bool = False,
167
+ output_layer: Optional[int] = None,
168
+ prev_output_tokens: Optional[torch.Tensor] = None,
169
+ ) -> Dict[str, torch.Tensor]:
170
+ """output layer is 1-based"""
171
+ features = self.forward_features(source)
172
+ if target_list is not None:
173
+ features, target_list = self.forward_targets(features, target_list)
174
+
175
+ features_pen = features.float().pow(2).mean()
176
+
177
+ features = features.transpose(1, 2)
178
+ features = self.layer_norm(features)
179
+ unmasked_features = features.clone()
180
+
181
+ if padding_mask is not None:
182
+ padding_mask = self.forward_padding_mask(features, padding_mask)
183
+
184
+ if self.post_extract_proj is not None:
185
+ features = self.post_extract_proj(features)
186
+
187
+ features = self.dropout_input(features)
188
+ unmasked_features = self.dropout_features(unmasked_features)
189
+
190
+ if mask:
191
+ x, mask_indices = self.apply_mask(features, padding_mask, target_list)
192
+ else:
193
+ x = features
194
+ mask_indices = None
195
+
196
+ # feature: (B, T, D), float
197
+ # target: (B, T), long
198
+ # x: (B, T, D), float
199
+ # padding_mask: (B, T), bool
200
+ # mask_indices: (B, T), bool
201
+ x, _ = self.encoder(
202
+ x,
203
+ padding_mask=padding_mask,
204
+ layer=None if output_layer is None else output_layer - 1,
205
+ )
206
+
207
+ if features_only:
208
+ return {"x": x, "padding_mask": padding_mask, "features": features}
209
+
210
+ def compute_pred(proj_x, target, label_embs):
211
+ # compute logits for the i-th label set
212
+ y = torch.index_select(label_embs, 0, target.long())
213
+ negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
214
+ if self.target_glu:
215
+ y = self.target_glu(y)
216
+ negs = self.target_glu(negs)
217
+ # proj_x: (S, D)
218
+ # y: (S, D)
219
+ # negs: (Neg, S, D)
220
+ return self.compute_nce(proj_x, y, negs)
221
+
222
+ label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
223
+
224
+ if not self.skip_masked:
225
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
226
+ proj_x_m = self.final_proj(x[masked_indices])
227
+ if self.untie_final_proj:
228
+ proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
229
+ else:
230
+ proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
231
+ logit_m_list = [
232
+ compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
233
+ for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
234
+ ]
235
+ else:
236
+ logit_m_list = [None for _ in target_list]
237
+
238
+ if not self.skip_nomask:
239
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
240
+ proj_x_u = self.final_proj(x[nomask_indices])
241
+ if self.untie_final_proj:
242
+ proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
243
+ else:
244
+ proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
245
+
246
+ logit_u_list = [
247
+ compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
248
+ for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))
249
+ ]
250
+ else:
251
+ logit_u_list = [None for _ in target_list]
252
+
253
+ result = {
254
+ "logit_m_list": logit_m_list,
255
+ "logit_u_list": logit_u_list,
256
+ "padding_mask": padding_mask,
257
+ "features_pen": features_pen,
258
+ }
259
+ if self.add_decoder:
260
+ encoder_out = {
261
+ "encoder_out": [x.transpose(0, 1)], # T x B x C
262
+ "encoder_padding_mask": [padding_mask], # B x T
263
+ }
264
+ assert prev_output_tokens is not None
265
+ decoder_out = self.decoder(
266
+ prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
267
+ )
268
+ result['decoder_out'] = decoder_out
269
+ return result
270
+
271
+ def forward_torchscript(self, net_input: Dict[str, Tensor]):
272
+ """A TorchScript-compatible version of forward.
273
+ Encoders which use additional arguments may want to override
274
+ this method for TorchScript compatibility.
275
+ """
276
+ res = self.forward(
277
+ net_input["source"],
278
+ padding_mask=net_input["padding_mask"],
279
+ mask=False,
280
+ features_only=True
281
+ )
282
+
283
+ encoder_out = {
284
+ "encoder_out": [res["x"].transpose(0, 1)], # T x B x C
285
+ "encoder_padding_mask": [res["padding_mask"]], # B x T
286
+ }
287
+ return encoder_out
288
+
289
+ def extract_features(
290
+ self,
291
+ source: torch.Tensor,
292
+ padding_mask: Optional[torch.Tensor] = None,
293
+ mask: bool = False,
294
+ ret_conv: bool = False,
295
+ output_layer: Optional[int] = None,
296
+ prev_output_tokens: Optional[torch.Tensor] = None,
297
+ ft: bool = True,
298
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
299
+ with torch.no_grad() if not ft else contextlib.ExitStack():
300
+ res = self.forward(
301
+ source,
302
+ padding_mask=padding_mask,
303
+ mask=mask,
304
+ features_only=True,
305
+ output_layer=output_layer,
306
+ )
307
+
308
+ feature = res["features"] if ret_conv else res["x"]
309
+ if self.add_decoder:
310
+ encoder_out = {
311
+ "encoder_out": [feature.transpose(0, 1)], # T x B x C
312
+ "encoder_padding_mask": [res["padding_mask"]], # B x T
313
+ }
314
+ assert prev_output_tokens is not None
315
+ decoder_out = self.decoder(
316
+ prev_output_tokens=prev_output_tokens,
317
+ encoder_out=encoder_out,
318
+ )
319
+ else:
320
+ decoder_out = None
321
+ return feature, res["padding_mask"], decoder_out
SpeechT5/Speech2C/speech2c/models/speech2c_asr.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ from argparse import Namespace
11
+ from omegaconf import II
12
+
13
+ import torch.nn as nn
14
+ from dataclasses import dataclass, field
15
+ from fairseq import checkpoint_utils, tasks, utils
16
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
17
+ from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
18
+ from fairseq.models.hubert.hubert_asr import HubertAsrConfig, Linear
19
+ from fairseq.tasks import FairseqTask
20
+
21
+
22
+ @dataclass
23
+ class Speech2cAsrConfig(HubertAsrConfig):
24
+ # for decoder
25
+ decoder_layerdrop: float = field(
26
+ default=0.0,
27
+ metadata={"help": "probability of dropping a decoder layer in hubert"},
28
+ )
29
+
30
+ add_decoder: bool = II("task.add_decoder")
31
+
32
+ @dataclass
33
+ class Speech2cCtcConfig(Speech2cAsrConfig):
34
+ pass
35
+
36
+
37
+ @register_model("speech2c_ctc", dataclass=Speech2cCtcConfig)
38
+ class Speech2cCtc(BaseFairseqModel):
39
+ def __init__(self, cfg: Speech2cCtcConfig, w2v_encoder: BaseFairseqModel):
40
+ super().__init__()
41
+ self.cfg = cfg
42
+ self.w2v_encoder = w2v_encoder
43
+
44
+ def upgrade_state_dict_named(self, state_dict, name):
45
+ super().upgrade_state_dict_named(state_dict, name)
46
+ return state_dict
47
+
48
+ @classmethod
49
+ def build_model(cls, cfg: Speech2cCtcConfig, task: FairseqTask):
50
+ """Build a new model instance."""
51
+ w2v_encoder = Speech2cEncoder(cfg, task.target_dictionary)
52
+ return cls(cfg, w2v_encoder)
53
+
54
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
55
+ """Get normalized probabilities (or log probs) from a net's output."""
56
+ if "encoder_out" not in net_output:
57
+ return self.w2v_encoder.get_normalized_probs_decoder(net_output, log_probs, sample)
58
+
59
+ if "encoder_out_for_ctc" in net_output:
60
+ logits = net_output["encoder_out_for_ctc"]
61
+ else:
62
+ logits = net_output["encoder_out"]
63
+
64
+ if isinstance(logits, list):
65
+ logits = logits[0]
66
+
67
+ if log_probs:
68
+ return utils.log_softmax(logits.float(), dim=-1)
69
+ else:
70
+ return utils.softmax(logits.float(), dim=-1)
71
+
72
+ def get_logits(self, net_output):
73
+ logits = net_output["encoder_out"]
74
+ padding = net_output["encoder_padding_mask"]
75
+ if padding is not None and padding.any():
76
+ padding = padding.T
77
+ logits[padding][..., 0] = 0
78
+ logits[padding][..., 1:] = float("-inf")
79
+
80
+ return logits
81
+
82
+ def forward(self, **kwargs):
83
+ x = self.w2v_encoder(**kwargs)
84
+ return x
85
+
86
+ @property
87
+ def encoder(self):
88
+ return self.w2v_encoder
89
+
90
+ def reorder_encoder_out(self, encoder_out, new_order):
91
+ return self.encoder.reorder_encoder_out(encoder_out, new_order)
92
+
93
+ @property
94
+ def decoder(self):
95
+ return self.w2v_encoder.w2v_model.decoder
96
+
97
+
98
+ class Speech2cEncoder(FairseqEncoder):
99
+ def __init__(self, cfg: Speech2cAsrConfig, tgt_dict=None):
100
+ self.apply_mask = cfg.apply_mask
101
+
102
+ arg_overrides = {
103
+ "dropout": cfg.dropout,
104
+ "activation_dropout": cfg.activation_dropout,
105
+ "dropout_input": cfg.dropout_input,
106
+ "attention_dropout": cfg.attention_dropout,
107
+ "mask_length": cfg.mask_length,
108
+ "mask_prob": cfg.mask_prob,
109
+ "mask_selection": cfg.mask_selection,
110
+ "mask_other": cfg.mask_other,
111
+ "no_mask_overlap": cfg.no_mask_overlap,
112
+ "mask_channel_length": cfg.mask_channel_length,
113
+ "mask_channel_prob": cfg.mask_channel_prob,
114
+ "mask_channel_selection": cfg.mask_channel_selection,
115
+ "mask_channel_other": cfg.mask_channel_other,
116
+ "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
117
+ "encoder_layerdrop": cfg.layerdrop,
118
+ "decoder_layerdrop": cfg.decoder_layerdrop,
119
+ "feature_grad_mult": cfg.feature_grad_mult,
120
+ "decoder_dict_size": len(tgt_dict) if cfg.add_decoder else -1,
121
+ }
122
+
123
+ if cfg.w2v_args is None:
124
+ state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
125
+ w2v_args = state.get("cfg", None)
126
+ if w2v_args is None:
127
+ w2v_args = convert_namespace_to_omegaconf(state["args"])
128
+ cfg.w2v_args = w2v_args
129
+ else:
130
+ state = None
131
+ w2v_args = cfg.w2v_args
132
+ if isinstance(w2v_args, Namespace):
133
+ cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
134
+
135
+ assert cfg.normalize == w2v_args.task.normalize, (
136
+ "Fine-tuning works best when data normalization is the same. "
137
+ "Please check that --normalize is set or unset for "
138
+ "both pre-training and here"
139
+ )
140
+
141
+ w2v_args.task.data = cfg.data
142
+ w2v_args.task.add_decoder = cfg.add_decoder
143
+ task = tasks.setup_task(w2v_args.task)
144
+ if state is not None and "task_state" in state:
145
+ # This will load the stored "dictionaries" object
146
+ task.load_state_dict(state["task_state"])
147
+ model = task.build_model(w2v_args.model)
148
+
149
+ if state is not None and not cfg.no_pretrained_weights:
150
+ if "decoder.embed_tokens.weight" in state["model"]:
151
+ del state["model"]["decoder.embed_tokens.weight"]
152
+ if "decoder.output_projection.weight" in state["model"]:
153
+ del state["model"]["decoder.output_projection.weight"]
154
+ # set strict=False because we omit some modules
155
+ model.load_state_dict(state["model"], strict=False)
156
+
157
+ model.remove_pretraining_modules()
158
+
159
+ super().__init__(task.source_dictionary)
160
+
161
+ d = model.mask_emb.size(0)
162
+
163
+ self.w2v_model = model
164
+
165
+ self.final_dropout = nn.Dropout(cfg.final_dropout)
166
+ self.freeze_finetune_updates = cfg.freeze_finetune_updates
167
+ self.num_updates = 0
168
+
169
+ if tgt_dict is not None:
170
+ self.proj = Linear(d, len(tgt_dict))
171
+ elif getattr(cfg, "decoder_embed_dim", d) != d:
172
+ self.proj = Linear(d, cfg.decoder_embed_dim)
173
+ else:
174
+ self.proj = None
175
+
176
+ def set_num_updates(self, num_updates):
177
+ """Set the number of parameters updates."""
178
+ super().set_num_updates(num_updates)
179
+ self.num_updates = num_updates
180
+
181
+ def forward(self, source, padding_mask, prev_output_tokens=None, tbc=True, **kwargs):
182
+
183
+ ft = self.freeze_finetune_updates <= self.num_updates
184
+ w2v_args = {
185
+ "source": source,
186
+ "padding_mask": padding_mask,
187
+ "mask": self.apply_mask and self.training,
188
+ "prev_output_tokens": prev_output_tokens,
189
+ "ft": ft,
190
+ }
191
+
192
+ x, padding_mask, decoder_out = self.w2v_model.extract_features(**w2v_args)
193
+
194
+ if tbc:
195
+ # B x T x C -> T x B x C
196
+ x = x.transpose(0, 1)
197
+
198
+ x = self.final_dropout(x)
199
+
200
+ if self.proj:
201
+ x = self.proj(x)
202
+
203
+ return {
204
+ "encoder_out": x, # T x B x C
205
+ "encoder_padding_mask": padding_mask, # B x T
206
+ "padding_mask": padding_mask,
207
+ "decoder_out": decoder_out,
208
+ }
209
+
210
+ def get_normalized_probs_decoder(self, net_output, log_probs, sample=None):
211
+ # net_output['encoder_out'] is a (B, T, D) tensor
212
+ return self.w2v_model.get_normalized_probs(net_output, log_probs, sample)
213
+
214
+ def reorder_encoder_out(self, encoder_out, new_order):
215
+ if encoder_out["encoder_out"] is not None:
216
+ if isinstance(encoder_out["encoder_out"], list):
217
+ encoder_out["encoder_out"] = (
218
+ [] if len(encoder_out["encoder_out"]) == 0
219
+ else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
220
+ )
221
+ else:
222
+ encoder_out["encoder_out"] = encoder_out[
223
+ "encoder_out"
224
+ ].index_select(1, new_order)
225
+ if encoder_out["encoder_padding_mask"] is not None:
226
+ if isinstance(encoder_out["encoder_padding_mask"], list):
227
+ encoder_out["encoder_padding_mask"] = (
228
+ [] if len(encoder_out["encoder_padding_mask"]) == 0
229
+ else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
230
+ )
231
+ else:
232
+ encoder_out["encoder_padding_mask"] = encoder_out[
233
+ "encoder_padding_mask"
234
+ ].index_select(0, new_order)
235
+ if "decoder_out" in encoder_out and encoder_out["decoder_out"] is not None:
236
+ if isinstance(encoder_out["decoder_out"], list):
237
+ encoder_out["decoder_out"] = (
238
+ [] if len(encoder_out["decoder_out"]) == 0
239
+ else [x.index_select(0, new_order) for x in encoder_out["decoder_out"]]
240
+ )
241
+ else:
242
+ encoder_out["decoder_out"] = encoder_out[
243
+ "decoder_out"
244
+ ].index_select(0, new_order)
245
+ if "encoder_out_for_ctc" in encoder_out and encoder_out["encoder_out_for_ctc"] is not None:
246
+ if isinstance(encoder_out["encoder_out_for_ctc"], list):
247
+ encoder_out["encoder_out_for_ctc"] = (
248
+ [] if len(encoder_out["encoder_out_for_ctc"]) == 0
249
+ else [x.index_select(1, new_order) for x in encoder_out["encoder_out_for_ctc"]]
250
+ )
251
+ else:
252
+ encoder_out["encoder_out_for_ctc"] = encoder_out[
253
+ "encoder_out_for_ctc"
254
+ ].index_select(1, new_order)
255
+
256
+ return encoder_out
257
+
258
+ def forward_torchscript(self, net_input):
259
+ """A TorchScript-compatible version of forward.
260
+ Encoders which use additional arguments may want to override
261
+ this method for TorchScript compatibility.
262
+ """
263
+ encoder_out = self.w2v_model.forward_torchscript(net_input)
264
+
265
+ assert self.proj is not None
266
+ encoder_out['encoder_out_for_ctc'] = [self.proj(encoder_out['encoder_out'][0])]
267
+
268
+ return encoder_out
269
+
270
+ def max_positions(self):
271
+ """Maximum input length supported by the encoder."""
272
+ return None
273
+
274
+ def upgrade_state_dict_named(self, state_dict, name):
275
+ return state_dict
276
+
SpeechT5/Speech2C/speech2c/models/t5_transformer_lm.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ from fairseq.models import (
11
+ register_model_architecture,
12
+ )
13
+ from fairseq.models.transformer_lm import base_lm_architecture
14
+
15
+
16
+ @register_model_architecture(model_name="transformer_lm", arch_name="transformer_lm_t5")
17
+ def transformer_lm_t5(args):
18
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280)
19
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6144)
20
+ args.decoder_layers = getattr(args, "decoder_layers", 20)
21
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
22
+ args.dropout = getattr(args, "dropout", 0.1)
23
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
24
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
25
+ base_lm_architecture(args)
SpeechT5/Speech2C/speech2c/squence_generator.py ADDED
@@ -0,0 +1,1028 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ from typing import Dict, List, Optional
12
+ import sys
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from fairseq import search, utils
17
+ from fairseq.data import data_utils
18
+ from fairseq.models import FairseqIncrementalDecoder
19
+ from torch import Tensor
20
+ from fairseq.ngram_repeat_block import NGramRepeatBlock
21
+ from speech2c.models.modules.ctc_prefix_score import CTCPrefixScore
22
+ import numpy
23
+
24
+
25
+ CTC_SCORING_RATIO = 7.0
26
+
27
+ class SequenceGenerator(nn.Module):
28
+ def __init__(
29
+ self,
30
+ models,
31
+ tgt_dict,
32
+ beam_size=1,
33
+ max_len_a=0,
34
+ max_len_b=200,
35
+ max_len=0,
36
+ min_len=1,
37
+ normalize_scores=True,
38
+ len_penalty=1.0,
39
+ unk_penalty=0.0,
40
+ temperature=1.0,
41
+ match_source_len=False,
42
+ no_repeat_ngram_size=0,
43
+ search_strategy=None,
44
+ eos=None,
45
+ symbols_to_strip_from_output=None,
46
+ lm_model=None,
47
+ lm_weight=1.0,
48
+ ctc_weight=0.0,
49
+ ):
50
+ """Generates translations of a given source sentence.
51
+ Args:
52
+ models (List[~fairseq.models.FairseqModel]): ensemble of models,
53
+ currently support fairseq.models.TransformerModel for scripting
54
+ beam_size (int, optional): beam width (default: 1)
55
+ max_len_a/b (int, optional): generate sequences of maximum length
56
+ ax + b, where x is the source length
57
+ max_len (int, optional): the maximum length of the generated output
58
+ (not including end-of-sentence)
59
+ min_len (int, optional): the minimum length of the generated output
60
+ (not including end-of-sentence)
61
+ normalize_scores (bool, optional): normalize scores by the length
62
+ of the output (default: True)
63
+ len_penalty (float, optional): length penalty, where <1.0 favors
64
+ shorter, >1.0 favors longer sentences (default: 1.0)
65
+ unk_penalty (float, optional): unknown word penalty, where <0
66
+ produces more unks, >0 produces fewer (default: 0.0)
67
+ temperature (float, optional): temperature, where values
68
+ >1.0 produce more uniform samples and values <1.0 produce
69
+ sharper samples (default: 1.0)
70
+ match_source_len (bool, optional): outputs should match the source
71
+ length (default: False)
72
+ """
73
+ super().__init__()
74
+ if isinstance(models, EnsembleModel):
75
+ self.model = models
76
+ else:
77
+ self.model = EnsembleModel(models)
78
+ self.tgt_dict = tgt_dict
79
+ self.pad = tgt_dict.pad()
80
+ self.unk = tgt_dict.unk()
81
+ self.eos = tgt_dict.eos() if eos is None else eos
82
+ self.blank = self.tgt_dict.index("<s>")
83
+ self.symbols_to_strip_from_output = (
84
+ symbols_to_strip_from_output.union({self.eos})
85
+ if symbols_to_strip_from_output is not None
86
+ else {self.eos}
87
+ )
88
+ self.vocab_size = len(tgt_dict)
89
+ self.beam_size = beam_size
90
+ # the max beam size is the dictionary size - 1, since we never select pad
91
+ self.beam_size = min(beam_size, self.vocab_size - 1)
92
+ self.max_len_a = max_len_a
93
+ self.max_len_b = max_len_b
94
+ self.min_len = min_len
95
+ self.max_len = max_len or self.model.max_decoder_positions()
96
+
97
+ self.normalize_scores = normalize_scores
98
+ self.len_penalty = len_penalty
99
+ self.unk_penalty = unk_penalty
100
+ self.temperature = temperature
101
+ self.match_source_len = match_source_len
102
+
103
+ if no_repeat_ngram_size > 0:
104
+ self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
105
+ else:
106
+ self.repeat_ngram_blocker = None
107
+
108
+ assert temperature > 0, "--temperature must be greater than 0"
109
+
110
+ self.search = (
111
+ search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
112
+ )
113
+ # We only need to set src_lengths in LengthConstrainedBeamSearch.
114
+ # As a module attribute, setting it would break in multithread
115
+ # settings when the model is shared.
116
+ self.should_set_src_lengths = (
117
+ hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
118
+ )
119
+
120
+ self.model.eval()
121
+
122
+ self.lm_model = lm_model
123
+ self.lm_weight = lm_weight
124
+ self.ctc_weight = ctc_weight
125
+ if self.lm_model is not None:
126
+ self.lm_model.eval()
127
+
128
+ def cuda(self):
129
+ self.model.cuda()
130
+ return self
131
+
132
+ @torch.no_grad()
133
+ def forward(
134
+ self,
135
+ sample: Dict[str, Dict[str, Tensor]],
136
+ prefix_tokens: Optional[Tensor] = None,
137
+ bos_token: Optional[int] = None,
138
+ ):
139
+ """Generate a batch of translations.
140
+ Args:
141
+ sample (dict): batch
142
+ prefix_tokens (torch.LongTensor, optional): force decoder to begin
143
+ with these tokens
144
+ bos_token (int, optional): beginning of sentence token
145
+ (default: self.eos)
146
+ """
147
+ return self._generate(sample, prefix_tokens, bos_token=bos_token)
148
+
149
+ # TODO(myleott): unused, deprecate after pytorch-translate migration
150
+ def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
151
+ """Iterate over a batched dataset and yield individual translations.
152
+ Args:
153
+ cuda (bool, optional): use GPU for generation
154
+ timer (StopwatchMeter, optional): time generations
155
+ """
156
+ for sample in data_itr:
157
+ s = utils.move_to_cuda(sample) if cuda else sample
158
+ if "net_input" not in s:
159
+ continue
160
+ input = s["net_input"]
161
+ # model.forward normally channels prev_output_tokens into the decoder
162
+ # separately, but SequenceGenerator directly calls model.encoder
163
+ encoder_input = {
164
+ k: v for k, v in input.items() if k != "prev_output_tokens"
165
+ }
166
+ if timer is not None:
167
+ timer.start()
168
+ with torch.no_grad():
169
+ hypos = self.generate(encoder_input)
170
+ if timer is not None:
171
+ timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
172
+ for i, id in enumerate(s["id"].data):
173
+ # remove padding
174
+ src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
175
+ ref = (
176
+ utils.strip_pad(s["target"].data[i, :], self.pad)
177
+ if s["target"] is not None
178
+ else None
179
+ )
180
+ yield id, src, ref, hypos[i]
181
+
182
+ @torch.no_grad()
183
+ def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
184
+ """Generate translations. Match the api of other fairseq generators.
185
+ Args:
186
+ models (List[~fairseq.models.FairseqModel]): ensemble of models
187
+ sample (dict): batch
188
+ prefix_tokens (torch.LongTensor, optional): force decoder to begin
189
+ with these tokens
190
+ constraints (torch.LongTensor, optional): force decoder to include
191
+ the list of constraints
192
+ bos_token (int, optional): beginning of sentence token
193
+ (default: self.eos)
194
+ """
195
+ return self._generate(sample, **kwargs)
196
+
197
+ def _generate(
198
+ self,
199
+ sample: Dict[str, Dict[str, Tensor]],
200
+ prefix_tokens: Optional[Tensor] = None,
201
+ constraints: Optional[Tensor] = None,
202
+ bos_token: Optional[int] = None,
203
+ ):
204
+ incremental_states = torch.jit.annotate(
205
+ List[Dict[str, Dict[str, Optional[Tensor]]]],
206
+ [
207
+ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
208
+ for i in range(self.model.models_size)
209
+ ],
210
+ )
211
+ net_input = sample["net_input"]
212
+
213
+ if "src_tokens" in net_input:
214
+ src_tokens = net_input["src_tokens"]
215
+ # length of the source text being the character length except EndOfSentence and pad
216
+ src_lengths = (
217
+ (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
218
+ )
219
+ elif "source" in net_input:
220
+ src_tokens = net_input["source"]
221
+ src_lengths = (
222
+ net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
223
+ if net_input["padding_mask"] is not None
224
+ else torch.tensor(src_tokens.size(-1)).to(src_tokens)
225
+ )
226
+ elif "features" in net_input:
227
+ src_tokens = net_input["features"]
228
+ src_lengths = (
229
+ net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
230
+ if net_input["padding_mask"] is not None
231
+ else torch.tensor(src_tokens.size(-1)).to(src_tokens)
232
+ )
233
+ else:
234
+ raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys()))
235
+
236
+ # bsz: total number of sentences in beam
237
+ # Note that src_tokens may have more than 2 dimensions (i.e. audio features)
238
+ bsz, src_len = src_tokens.size()[:2]
239
+ beam_size = self.beam_size
240
+
241
+ if constraints is not None and not self.search.supports_constraints:
242
+ raise NotImplementedError(
243
+ "Target-side constraints were provided, but search method doesn't support them"
244
+ )
245
+
246
+ # Initialize constraints, when active
247
+ self.search.init_constraints(constraints, beam_size)
248
+
249
+ max_len: int = -1
250
+ if self.match_source_len:
251
+ max_len = src_lengths.max().item()
252
+ else:
253
+ max_len = min(
254
+ int(self.max_len_a * src_len + self.max_len_b),
255
+ self.max_len - 1,
256
+ )
257
+ assert (
258
+ self.min_len <= max_len
259
+ ), "min_len cannot be larger than max_len, please adjust these!"
260
+ # compute the encoder output for each beam
261
+ with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
262
+ encoder_outs = self.model.forward_encoder(net_input)
263
+
264
+ # Get CTC lprobs and prep ctc_scorer
265
+ if self.ctc_weight > 0:
266
+ ctc_lprobs = self.model.models[0].get_normalized_probs(
267
+ encoder_outs[0], log_probs=True
268
+ ).contiguous().transpose(0, 1) # (B, T, C) from the encoder
269
+
270
+ hyp = {}
271
+ ctc_prefix_score = CTCPrefixScore(ctc_lprobs[0].detach().cpu().numpy(), self.blank, self.eos, numpy)
272
+ hyp["ctc_state_prev"] = ctc_prefix_score.initial_state()
273
+ hyp["ctc_score_prev"] = 0.0
274
+ ctc_beam = min(ctc_lprobs.shape[-1], int(beam_size * CTC_SCORING_RATIO))
275
+ ctc_hyps = {str(self.eos): hyp}
276
+
277
+ # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
278
+ new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
279
+ new_order = new_order.to(src_tokens.device).long()
280
+ encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
281
+ # ensure encoder_outs is a List.
282
+ assert encoder_outs is not None
283
+
284
+ # initialize buffers
285
+ scores = (
286
+ torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
287
+ ) # +1 for eos; pad is never chosen for scoring
288
+ tokens = (
289
+ torch.zeros(bsz * beam_size, max_len + 2)
290
+ .to(src_tokens)
291
+ .long()
292
+ .fill_(self.pad)
293
+ ) # +2 for eos and pad
294
+ tokens[:, 0] = self.eos if bos_token is None else bos_token
295
+ attn: Optional[Tensor] = None
296
+
297
+ # A list that indicates candidates that should be ignored.
298
+ # For example, suppose we're sampling and have already finalized 2/5
299
+ # samples. Then cands_to_ignore would mark 2 positions as being ignored,
300
+ # so that we only finalize the remaining 3 samples.
301
+ cands_to_ignore = (
302
+ torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
303
+ ) # forward and backward-compatible False mask
304
+
305
+ # list of completed sentences
306
+ finalized = torch.jit.annotate(
307
+ List[List[Dict[str, Tensor]]],
308
+ [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
309
+ ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
310
+
311
+ # a boolean array indicating if the sentence at the index is finished or not
312
+ finished = [False for i in range(bsz)]
313
+ num_remaining_sent = bsz # number of sentences remaining
314
+
315
+ # number of candidate hypos per step
316
+ cand_size = 2 * beam_size # 2 x beam size in case half are EOS
317
+
318
+ # offset arrays for converting between different indexing schemes
319
+ bbsz_offsets = (
320
+ (torch.arange(0, bsz) * beam_size)
321
+ .unsqueeze(1)
322
+ .type_as(tokens)
323
+ .to(src_tokens.device)
324
+ )
325
+ cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device)
326
+
327
+ reorder_state: Optional[Tensor] = None
328
+ batch_idxs: Optional[Tensor] = None
329
+
330
+ original_batch_idxs: Optional[Tensor] = None
331
+ if "id" in sample and isinstance(sample["id"], Tensor):
332
+ original_batch_idxs = sample["id"]
333
+ else:
334
+ original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
335
+
336
+ for step in range(max_len + 1): # one extra step for EOS marker
337
+ # reorder decoder internal states based on the prev choice of beams
338
+ if reorder_state is not None:
339
+ if batch_idxs is not None:
340
+ # update beam indices to take into account removed sentences
341
+ corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
342
+ batch_idxs
343
+ )
344
+ reorder_state.view(-1, beam_size).add_(
345
+ corr.unsqueeze(-1) * beam_size
346
+ )
347
+ original_batch_idxs = original_batch_idxs[batch_idxs]
348
+ self.model.reorder_incremental_state(incremental_states, reorder_state)
349
+ encoder_outs = self.model.reorder_encoder_out(
350
+ encoder_outs, reorder_state
351
+ )
352
+ with torch.autograd.profiler.record_function("EnsembleModel: forward_decoder"):
353
+ lprobs, avg_attn_scores = self.model.forward_decoder(
354
+ tokens[:, : step + 1],
355
+ encoder_outs,
356
+ incremental_states,
357
+ self.temperature,
358
+ )
359
+
360
+ if self.ctc_weight > 0 and step != 0:
361
+ # lprobs[:, self.blank] = -math.inf # never select blank
362
+ ctc_lprobs = lprobs.clone()
363
+ ctc_lprobs[:, self.blank] = -math.inf # never select blank
364
+ _, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1)
365
+ for b in range(tokens.size(0)):
366
+ hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist())
367
+ ctc_scores, ctc_states = ctc_prefix_score(
368
+ tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"]
369
+ )
370
+ lprobs[b] = lprobs[b]
371
+ lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy(
372
+ ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"]
373
+ ).to(device="cuda")
374
+ for j in range(len(local_best_ids[b])):
375
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {}
376
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j]
377
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j]
378
+
379
+ elif self.ctc_weight > 0 and step == 0:
380
+ ctc_lprobs = lprobs.clone()
381
+ ctc_lprobs[:, self.blank] = -math.inf # never select blank
382
+ _, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1)
383
+ for b in range(tokens.size(0)):
384
+ hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist())
385
+ ctc_scores, ctc_states = ctc_prefix_score(
386
+ tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"]
387
+ )
388
+ lprobs[b] = lprobs[b]
389
+ lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy(
390
+ ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"]
391
+ ).to(device="cuda")
392
+ for j in range(len(local_best_ids[b])):
393
+ if b == 0:
394
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {}
395
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j]
396
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j]
397
+
398
+ if self.lm_model is not None:
399
+ lm_out = self.lm_model(tokens[:, : step + 1])
400
+ probs = self.lm_model.get_normalized_probs(
401
+ lm_out, log_probs=True, sample=None
402
+ )
403
+ probs = probs[:, -1, :] * self.lm_weight
404
+ lprobs += probs
405
+ # handle prefix tokens (possibly with different lengths)
406
+ if (
407
+ prefix_tokens is not None
408
+ and step < prefix_tokens.size(1)
409
+ and step < max_len
410
+ ):
411
+ lprobs, tokens, scores = self._prefix_tokens(
412
+ step, lprobs, scores, tokens, prefix_tokens, beam_size
413
+ )
414
+ elif step < self.min_len:
415
+ # minimum length constraint (does not apply if using prefix_tokens)
416
+ lprobs[:, self.eos] = -math.inf
417
+
418
+ lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
419
+
420
+ lprobs[:, self.pad] = -math.inf # never select pad
421
+ lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
422
+ lprobs[:, self.blank] = -math.inf # never select blank
423
+
424
+ # handle max length constraint
425
+ if step >= max_len:
426
+ lprobs[:, : self.eos] = -math.inf
427
+ lprobs[:, self.eos + 1 :] = -math.inf
428
+
429
+ # Record attention scores, only support avg_attn_scores is a Tensor
430
+ if avg_attn_scores is not None:
431
+ if attn is None:
432
+ attn = torch.empty(
433
+ bsz * beam_size, avg_attn_scores.size(1), max_len + 2
434
+ ).to(scores)
435
+ attn[:, :, step + 1].copy_(avg_attn_scores)
436
+
437
+ scores = scores.type_as(lprobs)
438
+ eos_bbsz_idx = torch.empty(0).to(
439
+ tokens
440
+ ) # indices of hypothesis ending with eos (finished sentences)
441
+ eos_scores = torch.empty(0).to(
442
+ scores
443
+ ) # scores of hypothesis ending with eos (finished sentences)
444
+
445
+ if self.should_set_src_lengths:
446
+ self.search.set_src_lengths(src_lengths)
447
+
448
+ if self.repeat_ngram_blocker is not None:
449
+ lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
450
+
451
+ # Shape: (batch, cand_size)
452
+ cand_scores, cand_indices, cand_beams = self.search.step(
453
+ step,
454
+ lprobs.view(bsz, -1, self.vocab_size),
455
+ scores.view(bsz, beam_size, -1)[:, :, :step],
456
+ tokens[:, : step + 1],
457
+ original_batch_idxs,
458
+ )
459
+
460
+ # cand_bbsz_idx contains beam indices for the top candidate
461
+ # hypotheses, with a range of values: [0, bsz*beam_size),
462
+ # and dimensions: [bsz, cand_size]
463
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
464
+
465
+ # finalize hypotheses that end in eos
466
+ # Shape of eos_mask: (batch size, beam size)
467
+ eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
468
+ eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
469
+
470
+ # only consider eos when it's among the top beam_size indices
471
+ # Now we know what beam item(s) to finish
472
+ # Shape: 1d list of absolute-numbered
473
+ eos_bbsz_idx = torch.masked_select(
474
+ cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
475
+ )
476
+
477
+ finalized_sents: List[int] = []
478
+ if eos_bbsz_idx.numel() > 0:
479
+ eos_scores = torch.masked_select(
480
+ cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
481
+ )
482
+
483
+ finalized_sents = self.finalize_hypos(
484
+ step,
485
+ eos_bbsz_idx,
486
+ eos_scores,
487
+ tokens,
488
+ scores,
489
+ finalized,
490
+ finished,
491
+ beam_size,
492
+ attn,
493
+ src_lengths,
494
+ max_len,
495
+ )
496
+ num_remaining_sent -= len(finalized_sents)
497
+
498
+ assert num_remaining_sent >= 0
499
+ if num_remaining_sent == 0:
500
+ break
501
+ if self.search.stop_on_max_len and step >= max_len:
502
+ break
503
+ assert step < max_len, f"{step} < {max_len}"
504
+
505
+ # Remove finalized sentences (ones for which {beam_size}
506
+ # finished hypotheses have been generated) from the batch.
507
+ if len(finalized_sents) > 0:
508
+ new_bsz = bsz - len(finalized_sents)
509
+
510
+ # construct batch_idxs which holds indices of batches to keep for the next pass
511
+ batch_mask = torch.ones(
512
+ bsz, dtype=torch.bool, device=cand_indices.device
513
+ )
514
+ batch_mask[finalized_sents] = False
515
+ # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
516
+ batch_idxs = torch.arange(
517
+ bsz, device=cand_indices.device
518
+ ).masked_select(batch_mask)
519
+
520
+ # Choose the subset of the hypothesized constraints that will continue
521
+ self.search.prune_sentences(batch_idxs)
522
+
523
+ eos_mask = eos_mask[batch_idxs]
524
+ cand_beams = cand_beams[batch_idxs]
525
+ bbsz_offsets.resize_(new_bsz, 1)
526
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
527
+ cand_scores = cand_scores[batch_idxs]
528
+ cand_indices = cand_indices[batch_idxs]
529
+
530
+ if prefix_tokens is not None:
531
+ prefix_tokens = prefix_tokens[batch_idxs]
532
+ src_lengths = src_lengths[batch_idxs]
533
+ cands_to_ignore = cands_to_ignore[batch_idxs]
534
+
535
+ scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
536
+ tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
537
+ if attn is not None:
538
+ attn = attn.view(bsz, -1)[batch_idxs].view(
539
+ new_bsz * beam_size, attn.size(1), -1
540
+ )
541
+ bsz = new_bsz
542
+ else:
543
+ batch_idxs = None
544
+
545
+ # Set active_mask so that values > cand_size indicate eos hypos
546
+ # and values < cand_size indicate candidate active hypos.
547
+ # After, the min values per row are the top candidate active hypos
548
+
549
+ # Rewrite the operator since the element wise or is not supported in torchscript.
550
+
551
+ eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
552
+ active_mask = torch.add(
553
+ eos_mask.type_as(cand_offsets) * cand_size,
554
+ cand_offsets[: eos_mask.size(1)],
555
+ )
556
+
557
+ # get the top beam_size active hypotheses, which are just
558
+ # the hypos with the smallest values in active_mask.
559
+ # {active_hypos} indicates which {beam_size} hypotheses
560
+ # from the list of {2 * beam_size} candidates were
561
+ # selected. Shapes: (batch size, beam size)
562
+ new_cands_to_ignore, active_hypos = torch.topk(
563
+ active_mask, k=beam_size, dim=1, largest=False
564
+ )
565
+
566
+ # update cands_to_ignore to ignore any finalized hypos.
567
+ cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
568
+ # Make sure there is at least one active item for each sentence in the batch.
569
+ assert (~cands_to_ignore).any(dim=1).all()
570
+
571
+ # update cands_to_ignore to ignore any finalized hypos
572
+
573
+ # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
574
+ # can be selected more than once).
575
+ active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
576
+ active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
577
+
578
+ active_bbsz_idx = active_bbsz_idx.view(-1)
579
+ active_scores = active_scores.view(-1)
580
+
581
+ # copy tokens and scores for active hypotheses
582
+
583
+ # Set the tokens for each beam (can select the same row more than once)
584
+ tokens[:, : step + 1] = torch.index_select(
585
+ tokens[:, : step + 1], dim=0, index=active_bbsz_idx
586
+ )
587
+ # Select the next token for each of them
588
+ tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
589
+ cand_indices, dim=1, index=active_hypos
590
+ )
591
+ if step > 0:
592
+ scores[:, :step] = torch.index_select(
593
+ scores[:, :step], dim=0, index=active_bbsz_idx
594
+ )
595
+ scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
596
+ cand_scores, dim=1, index=active_hypos
597
+ )
598
+
599
+ # Update constraints based on which candidates were selected for the next beam
600
+ self.search.update_constraints(active_hypos)
601
+
602
+ # copy attention for active hypotheses
603
+ if attn is not None:
604
+ attn[:, :, : step + 2] = torch.index_select(
605
+ attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
606
+ )
607
+
608
+ # reorder incremental state in decoder
609
+ reorder_state = active_bbsz_idx
610
+
611
+ # sort by score descending
612
+ for sent in range(len(finalized)):
613
+ scores = torch.tensor(
614
+ [float(elem["score"].item()) for elem in finalized[sent]]
615
+ )
616
+ _, sorted_scores_indices = torch.sort(scores, descending=True)
617
+ finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
618
+ finalized[sent] = torch.jit.annotate(
619
+ List[Dict[str, Tensor]], finalized[sent]
620
+ )
621
+ return finalized
622
+
623
+ def _prefix_tokens(
624
+ self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
625
+ ):
626
+ """Handle prefix tokens"""
627
+ prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
628
+ prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
629
+ prefix_mask = prefix_toks.ne(self.pad)
630
+ lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1
631
+ lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
632
+ -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
633
+ )
634
+ # if prefix includes eos, then we should make sure tokens and
635
+ # scores are the same across all beams
636
+ eos_mask = prefix_toks.eq(self.eos)
637
+ if eos_mask.any():
638
+ # validate that the first beam matches the prefix
639
+ first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
640
+ :, 0, 1 : step + 1
641
+ ]
642
+ eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
643
+ target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
644
+ assert (first_beam == target_prefix).all()
645
+
646
+ # copy tokens, scores and lprobs from the first beam to all beams
647
+ tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
648
+ scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
649
+ lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
650
+ return lprobs, tokens, scores
651
+
652
+ def replicate_first_beam(self, tensor, mask, beam_size: int):
653
+ tensor = tensor.view(-1, beam_size, tensor.size(-1))
654
+ tensor[mask] = tensor[mask][:, :1, :]
655
+ return tensor.view(-1, tensor.size(-1))
656
+
657
+ def finalize_hypos(
658
+ self,
659
+ step: int,
660
+ bbsz_idx,
661
+ eos_scores,
662
+ tokens,
663
+ scores,
664
+ finalized: List[List[Dict[str, Tensor]]],
665
+ finished: List[bool],
666
+ beam_size: int,
667
+ attn: Optional[Tensor],
668
+ src_lengths,
669
+ max_len: int,
670
+ ):
671
+ """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
672
+ A sentence is finalized when {beam_size} finished items have been collected for it.
673
+ Returns number of sentences (not beam items) being finalized.
674
+ These will be removed from the batch and not processed further.
675
+ Args:
676
+ bbsz_idx (Tensor):
677
+ """
678
+ assert bbsz_idx.numel() == eos_scores.numel()
679
+
680
+ # clone relevant token and attention tensors.
681
+ # tokens is (batch * beam, max_len). So the index_select
682
+ # gets the newly EOS rows, then selects cols 1..{step + 2}
683
+ tokens_clone = tokens.index_select(0, bbsz_idx)[
684
+ :, 1 : step + 2
685
+ ] # skip the first index, which is EOS
686
+
687
+ tokens_clone[:, step] = self.eos
688
+ attn_clone = (
689
+ attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
690
+ if attn is not None
691
+ else None
692
+ )
693
+
694
+ # compute scores per token position
695
+ pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
696
+ pos_scores[:, step] = eos_scores
697
+ # convert from cumulative to per-position scores
698
+ pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
699
+
700
+ # normalize sentence-level scores
701
+ if self.normalize_scores:
702
+ eos_scores /= (step + 1) ** self.len_penalty
703
+
704
+ # cum_unfin records which sentences in the batch are finished.
705
+ # It helps match indexing between (a) the original sentences
706
+ # in the batch and (b) the current, possibly-reduced set of
707
+ # sentences.
708
+ cum_unfin: List[int] = []
709
+ prev = 0
710
+ for f in finished:
711
+ if f:
712
+ prev += 1
713
+ else:
714
+ cum_unfin.append(prev)
715
+ cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx)
716
+
717
+ unfin_idx = bbsz_idx // beam_size
718
+ sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx)
719
+
720
+ # Create a set of "{sent}{unfin_idx}", where
721
+ # "unfin_idx" is the index in the current (possibly reduced)
722
+ # list of sentences, and "sent" is the index in the original,
723
+ # unreduced batch
724
+ # For every finished beam item
725
+ # sentence index in the current (possibly reduced) batch
726
+ seen = (sent << 32) + unfin_idx
727
+ unique_seen: List[int] = torch.unique(seen).tolist()
728
+
729
+ if self.match_source_len:
730
+ condition = step > torch.index_select(src_lengths, 0, unfin_idx)
731
+ eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores)
732
+ sent_list: List[int] = sent.tolist()
733
+ for i in range(bbsz_idx.size()[0]):
734
+ # An input sentence (among those in a batch) is finished when
735
+ # beam_size hypotheses have been collected for it
736
+ if len(finalized[sent_list[i]]) < beam_size:
737
+ if attn_clone is not None:
738
+ # remove padding tokens from attn scores
739
+ hypo_attn = attn_clone[i]
740
+ else:
741
+ hypo_attn = torch.empty(0)
742
+
743
+ finalized[sent_list[i]].append(
744
+ {
745
+ "tokens": tokens_clone[i],
746
+ "score": eos_scores[i],
747
+ "attention": hypo_attn, # src_len x tgt_len
748
+ "alignment": torch.empty(0),
749
+ "positional_scores": pos_scores[i],
750
+ }
751
+ )
752
+
753
+ newly_finished: List[int] = []
754
+ for unique_s in unique_seen:
755
+ # check termination conditions for this sentence
756
+ unique_sent: int = unique_s >> 32
757
+ unique_unfin_idx: int = unique_s - (unique_sent << 32)
758
+
759
+ if not finished[unique_sent] and self.is_finished(
760
+ step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size
761
+ ):
762
+ finished[unique_sent] = True
763
+ newly_finished.append(unique_unfin_idx)
764
+
765
+ return newly_finished
766
+
767
+ def is_finished(
768
+ self,
769
+ step: int,
770
+ unfin_idx: int,
771
+ max_len: int,
772
+ finalized_sent_len: int,
773
+ beam_size: int,
774
+ ):
775
+ """
776
+ Check whether decoding for a sentence is finished, which
777
+ occurs when the list of finalized sentences has reached the
778
+ beam size, or when we reach the maximum length.
779
+ """
780
+ assert finalized_sent_len <= beam_size
781
+ if finalized_sent_len == beam_size or step == max_len:
782
+ return True
783
+ return False
784
+
785
+
786
+ class EnsembleModel(nn.Module):
787
+ """A wrapper around an ensemble of models."""
788
+
789
+ def __init__(self, models):
790
+ super().__init__()
791
+ self.models_size = len(models)
792
+ # method '__len__' is not supported in ModuleList for torch script
793
+ self.single_model = models[0]
794
+ self.models = nn.ModuleList(models)
795
+
796
+ self.has_incremental: bool = False
797
+ if all(
798
+ hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder)
799
+ for m in models
800
+ ):
801
+ self.has_incremental = True
802
+
803
+ def forward(self):
804
+ pass
805
+
806
+ def has_encoder(self):
807
+ return hasattr(self.single_model, "encoder")
808
+
809
+ def has_incremental_states(self):
810
+ return self.has_incremental
811
+
812
+ def max_decoder_positions(self):
813
+ return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize])
814
+
815
+ @torch.jit.export
816
+ def forward_encoder(self, net_input: Dict[str, Tensor]):
817
+ if not self.has_encoder():
818
+ return None
819
+ return [model.encoder.forward_torchscript(net_input) for model in self.models]
820
+
821
+ @torch.jit.export
822
+ def forward_decoder(
823
+ self,
824
+ tokens,
825
+ encoder_outs: List[Dict[str, List[Tensor]]],
826
+ incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
827
+ temperature: float = 1.0,
828
+ ):
829
+ log_probs = []
830
+ avg_attn: Optional[Tensor] = None
831
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None
832
+ for i, model in enumerate(self.models):
833
+ if self.has_encoder():
834
+ encoder_out = encoder_outs[i]
835
+ # decode each model
836
+ if self.has_incremental_states():
837
+ decoder_out = model.decoder.forward(
838
+ tokens,
839
+ encoder_out=encoder_out,
840
+ incremental_state=incremental_states[i],
841
+ )
842
+ else:
843
+ if hasattr(model, "decoder"):
844
+ decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out)
845
+ else:
846
+ decoder_out = model.forward(tokens)
847
+
848
+ attn: Optional[Tensor] = None
849
+ decoder_len = len(decoder_out)
850
+ if decoder_len > 1 and decoder_out[1] is not None:
851
+ if isinstance(decoder_out[1], Tensor):
852
+ attn = decoder_out[1]
853
+ else:
854
+ attn_holder = decoder_out[1]["attn"]
855
+ if isinstance(attn_holder, Tensor):
856
+ attn = attn_holder
857
+ elif attn_holder is not None:
858
+ attn = attn_holder[0]
859
+ if attn is not None:
860
+ attn = attn[:, -1, :]
861
+
862
+ decoder_out_tuple = (
863
+ decoder_out[0][:, -1:, :].div_(temperature),
864
+ None if decoder_len <= 1 else decoder_out[1],
865
+ )
866
+ probs = model.get_normalized_probs(
867
+ decoder_out_tuple, log_probs=True, sample=None
868
+ )
869
+ probs = probs[:, -1, :]
870
+ if self.models_size == 1:
871
+ return probs, attn
872
+
873
+ log_probs.append(probs)
874
+ if attn is not None:
875
+ if avg_attn is None:
876
+ avg_attn = attn
877
+ else:
878
+ avg_attn.add_(attn)
879
+
880
+ avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
881
+ self.models_size
882
+ )
883
+
884
+ if avg_attn is not None:
885
+ avg_attn.div_(self.models_size)
886
+ return avg_probs, avg_attn
887
+
888
+ @torch.jit.export
889
+ def reorder_encoder_out(
890
+ self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order
891
+ ):
892
+ """
893
+ Reorder encoder output according to *new_order*.
894
+ Args:
895
+ encoder_out: output from the ``forward()`` method
896
+ new_order (LongTensor): desired order
897
+ Returns:
898
+ *encoder_out* rearranged according to *new_order*
899
+ """
900
+ new_outs: List[Dict[str, List[Tensor]]] = []
901
+ if not self.has_encoder():
902
+ return new_outs
903
+ for i, model in enumerate(self.models):
904
+ assert encoder_outs is not None
905
+ new_outs.append(
906
+ model.encoder.reorder_encoder_out(encoder_outs[i], new_order)
907
+ )
908
+ return new_outs
909
+
910
+ @torch.jit.export
911
+ def reorder_incremental_state(
912
+ self,
913
+ incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
914
+ new_order,
915
+ ):
916
+ if not self.has_incremental_states():
917
+ return
918
+ for i, model in enumerate(self.models):
919
+ model.decoder.reorder_incremental_state_scripting(
920
+ incremental_states[i], new_order
921
+ )
922
+
923
+
924
+ class SequenceGeneratorWithAlignment(SequenceGenerator):
925
+ def __init__(
926
+ self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs
927
+ ):
928
+ """Generates translations of a given source sentence.
929
+ Produces alignments following "Jointly Learning to Align and
930
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
931
+ Args:
932
+ left_pad_target (bool, optional): Whether or not the
933
+ hypothesis should be left padded or not when they are
934
+ teacher forced for generating alignments.
935
+ """
936
+ super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs)
937
+ self.left_pad_target = left_pad_target
938
+
939
+ if print_alignment == "hard":
940
+ self.extract_alignment = utils.extract_hard_alignment
941
+ elif print_alignment == "soft":
942
+ self.extract_alignment = utils.extract_soft_alignment
943
+
944
+ @torch.no_grad()
945
+ def generate(self, models, sample, **kwargs):
946
+ finalized = super()._generate(sample, **kwargs)
947
+
948
+ src_tokens = sample["net_input"]["src_tokens"]
949
+ bsz = src_tokens.shape[0]
950
+ beam_size = self.beam_size
951
+ (
952
+ src_tokens,
953
+ src_lengths,
954
+ prev_output_tokens,
955
+ tgt_tokens,
956
+ ) = self._prepare_batch_for_alignment(sample, finalized)
957
+ if any(getattr(m, "full_context_alignment", False) for m in self.model.models):
958
+ attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens)
959
+ else:
960
+ attn = [
961
+ finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0)
962
+ for i in range(bsz * beam_size)
963
+ ]
964
+
965
+ if src_tokens.device != "cpu":
966
+ src_tokens = src_tokens.to("cpu")
967
+ tgt_tokens = tgt_tokens.to("cpu")
968
+ attn = [i.to("cpu") for i in attn]
969
+
970
+ # Process the attn matrix to extract hard alignments.
971
+ for i in range(bsz * beam_size):
972
+ alignment = self.extract_alignment(
973
+ attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos
974
+ )
975
+ finalized[i // beam_size][i % beam_size]["alignment"] = alignment
976
+ return finalized
977
+
978
+ def _prepare_batch_for_alignment(self, sample, hypothesis):
979
+ src_tokens = sample["net_input"]["src_tokens"]
980
+ bsz = src_tokens.shape[0]
981
+ src_tokens = (
982
+ src_tokens[:, None, :]
983
+ .expand(-1, self.beam_size, -1)
984
+ .contiguous()
985
+ .view(bsz * self.beam_size, -1)
986
+ )
987
+ src_lengths = sample["net_input"]["src_lengths"]
988
+ src_lengths = (
989
+ src_lengths[:, None]
990
+ .expand(-1, self.beam_size)
991
+ .contiguous()
992
+ .view(bsz * self.beam_size)
993
+ )
994
+ prev_output_tokens = data_utils.collate_tokens(
995
+ [beam["tokens"] for example in hypothesis for beam in example],
996
+ self.pad,
997
+ self.eos,
998
+ self.left_pad_target,
999
+ move_eos_to_beginning=True,
1000
+ )
1001
+ tgt_tokens = data_utils.collate_tokens(
1002
+ [beam["tokens"] for example in hypothesis for beam in example],
1003
+ self.pad,
1004
+ self.eos,
1005
+ self.left_pad_target,
1006
+ move_eos_to_beginning=False,
1007
+ )
1008
+ return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
1009
+
1010
+
1011
+ class EnsembleModelWithAlignment(EnsembleModel):
1012
+ """A wrapper around an ensemble of models."""
1013
+
1014
+ def __init__(self, models):
1015
+ super().__init__(models)
1016
+
1017
+ def forward_align(self, src_tokens, src_lengths, prev_output_tokens):
1018
+ avg_attn = None
1019
+ for model in self.models:
1020
+ decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
1021
+ attn = decoder_out[1]["attn"][0]
1022
+ if avg_attn is None:
1023
+ avg_attn = attn
1024
+ else:
1025
+ avg_attn.add_(attn)
1026
+ if len(self.models) > 1:
1027
+ avg_attn.div_(len(self.models))
1028
+ return avg_attn
SpeechT5/Speech2C/speech2c/tasks/speech2c_pretraining.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import logging
11
+
12
+ from dataclasses import dataclass, field
13
+ from fairseq.data import Dictionary
14
+ from fairseq.tasks import register_task
15
+ from fairseq.tasks.hubert_pretraining import HubertPretrainingConfig, HubertPretrainingTask, LabelEncoder
16
+ from speech2c.data.speech2c_dataset import Speech2cDataset
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class Speech2cPretrainingConfig(HubertPretrainingConfig):
23
+ add_decoder: bool = field(
24
+ default=False,
25
+ metadata={"help": "whether to add decoder for CE Loss on code"},
26
+ )
27
+
28
+ # For inference
29
+ ctc_weight: float = field(
30
+ default=0.0,
31
+ metadata={"help": "ctc weight during inference"},
32
+ )
33
+
34
+
35
+ @register_task("speech2c_pretraining", dataclass=Speech2cPretrainingConfig)
36
+ class Speech2cPretrainingTask(HubertPretrainingTask):
37
+
38
+ cfg: Speech2cPretrainingConfig
39
+
40
+ def load_dictionaries(self):
41
+ label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
42
+ dictionaries = [Dictionary.load(f"{label_dir}/dict.{label}.txt") for label in self.cfg.labels]
43
+ return dictionaries[0] if self.cfg.fine_tuning else dictionaries
44
+
45
+ def load_dataset(self, split: str, **kwargs) -> None:
46
+ manifest = f"{self.cfg.data}/{split}.tsv"
47
+ dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries
48
+ pad_list = [dict.pad() for dict in dicts]
49
+ eos_list = [dict.eos() for dict in dicts]
50
+ procs = [LabelEncoder(dict) for dict in dicts]
51
+ paths = [
52
+ f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels
53
+ ]
54
+
55
+ # hubert v1: pad_audio=True, random_crop=False;
56
+ self.datasets[split] = Speech2cDataset(
57
+ manifest,
58
+ sample_rate=self.cfg.sample_rate,
59
+ label_paths=paths,
60
+ label_rates=self.cfg.label_rate,
61
+ pad_list=pad_list,
62
+ eos_list=eos_list,
63
+ label_processors=procs,
64
+ max_keep_sample_size=self.cfg.max_keep_size,
65
+ min_keep_sample_size=self.cfg.min_sample_size,
66
+ max_sample_size=self.cfg.max_sample_size,
67
+ pad_audio=self.cfg.pad_audio,
68
+ normalize=self.cfg.normalize,
69
+ store_labels=False,
70
+ random_crop=self.cfg.random_crop,
71
+ single_target=self.cfg.single_target,
72
+ tgt_dict=dicts[0],
73
+ add_decoder=self.cfg.add_decoder,
74
+ fine_tuning=self.cfg.fine_tuning,
75
+ )
76
+
77
+ def build_generator(
78
+ self,
79
+ models,
80
+ args,
81
+ seq_gen_cls=None,
82
+ extra_gen_cls_kwargs=None,
83
+ ):
84
+ from speech2c.squence_generator import SequenceGenerator
85
+ extra_gen_cls_kwargs = {
86
+ "ctc_weight": self.cfg.ctc_weight,
87
+ **extra_gen_cls_kwargs
88
+ }
89
+ return super().build_generator(
90
+ models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs
91
+ )
SpeechT5/Speech2S/README.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Speech2S
2
+ <!--**Pre-trained models for speech related tasks**-->
3
+
4
+ [**Joint Pre-Training with Speech and Bilingual Text for Direct Speech to Speech Translation**](https://arxiv.org/abs/2210.17027)
5
+
6
+
7
+ - (Updating) Nov. 2022: release the code and models
8
+ - Nov. 2022: release preprint in [arXiv](https://arxiv.org/abs/2210.17027)
9
+
10
+ ## Pre-Trained and Fine-tuned Models
11
+
12
+ | Model | Pre-training Dataset | Fine-tuning Dataset | Model |
13
+ | :------: | :----------------------------------------------: | :-----------------: | :-----: |
14
+ | Speech2S_enes | Voxpopuli_en_v2 | - | [Google Drive](https://drive.google.com/file/d/1TYypFiEKoCixUro8FTTG23bRZYwAxhkX/view?usp=share_link) |
15
+ | Speech2S_enes | Voxpopuli_en_v2 | Voxpopuli_s2s | [Google Drive](https://drive.google.com/file/d/11RxeKznSrHcoP_KK9A1VgwRt3fNh_U_C/view?usp=share_link) |
16
+ | Speech2S_esen | Voxpopuli_es_v2 | - | [Google Drive](https://drive.google.com/file/d/1NoC7W-UtQZ-ugIptF1ex0ZlGJncsT1S4/view?usp=share_link) |
17
+ | Speech2S_esen | Voxpopuli_es_v2 | Voxpopuli_s2s | [Google Drive](https://drive.google.com/file/d/1eNcKw4ZWGmcABWXJxlf6MKocmiPrKSkH/view?usp=share_link) |
18
+
19
+
20
+ ## Setup
21
+ ```
22
+ cd Speech2S/speech2s
23
+ pip install --editable fairseq/
24
+ ```
25
+
26
+ ## Data Preparation
27
+ Please follow the steps of data preparation for S2ST in [here](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md).
28
+
29
+ ## Pre-Training
30
+ ```
31
+ cd speech2s/stpretrain_scripts
32
+ base_sc2c_enes.sh
33
+ ```
34
+ ## Finetune
35
+ ```
36
+ cd speech2s/stpretrain_scripts
37
+ finetune_enes.sh
38
+ ```
39
+ ## Inference
40
+ ```
41
+ cd speech2s/stpretrain_scripts
42
+ inference_ed.sh
43
+ ```
44
+ ## Results on Voxpopuli and Covst
45
+
46
+
47
+ ## License
48
+
49
+ This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
50
+ Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq).
51
+
52
+ [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
53
+
54
+ ## Reference
55
+
56
+ If you find our work is useful in your research, please cite the following paper:
57
+ ```bibtex
58
+ @article{wei2022joint,
59
+ title={Joint Pre-Training with Speech and Bilingual Text for Direct Speech to Speech Translation},
60
+ author={Wei, Kun and Zhou, Long and Zhang, Ziqiang and Chen, Liping and Liu, Shujie and He, Lei and Li, Jinyu and Wei, Furu},
61
+ journal={arXiv preprint arXiv:2210.17027},
62
+ year={2022}
63
+ }
64
+ ```
SpeechT5/Speech2S/speech2s/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import data, tasks, criterions, models
SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_base_100h.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 100
7
+ tensorboard_logdir: tblog
8
+ seed: 1337
9
+
10
+ checkpoint:
11
+ save_interval: 1
12
+ keep_last_epochs: 1
13
+ keep_best_checkpoints: 5
14
+ best_checkpoint_metric: dec_accuracy
15
+ maximize_best_checkpoint_metric: true
16
+ restore_file: checkpoint_last.pt
17
+
18
+ distributed_training:
19
+ ddp_backend: legacy_ddp
20
+ find_unused_parameters: true
21
+ distributed_world_size: 1
22
+ distributed_port: -1
23
+ nprocs_per_node: 8
24
+
25
+ task:
26
+ _name: joint_sc2t_pretraining
27
+ data: ???
28
+ fine_tuning: true
29
+ label_dir: ???
30
+ normalize: false # must be consistent with pre-training
31
+ labels: ["ltr"]
32
+ store_labels: true
33
+ single_target: true
34
+ add_decoder_target: true
35
+ pad_audio: false
36
+ random_crop: true
37
+ hubert_tokenizer: "none"
38
+ sp_path: None
39
+
40
+ dataset:
41
+ num_workers: 0
42
+ max_tokens: 1300000
43
+ skip_invalid_size_inputs_valid_test: true
44
+ train_subset: train_100
45
+ valid_subset: dev_other
46
+ required_batch_size_multiple: 1
47
+
48
+ criterion:
49
+ _name: ctc_ce
50
+ zero_infinity: true
51
+
52
+ optimization:
53
+ max_update: 40000
54
+ lr: [0.00001]
55
+ sentence_avg: true
56
+ update_freq: [2]
57
+
58
+ optimizer:
59
+ _name: adam
60
+ adam_betas: (0.9,0.98)
61
+ adam_eps: 1e-08
62
+ weight_decay: 0.0
63
+
64
+ lr_scheduler:
65
+ _name: tri_stage
66
+ phase_ratio: [0.1, 0.4, 0.5]
67
+ final_lr_scale: 0.05
68
+
69
+ model:
70
+ _name: speechut_asr
71
+ w2v_path: ???
72
+ apply_mask: true
73
+ mask_prob: 0.65
74
+ mask_channel_prob: 0.5
75
+ mask_channel_length: 64
76
+ layerdrop: 0.1
77
+ activation_dropout: 0.1
78
+ feature_grad_mult: 0.0
79
+ freeze_finetune_updates: 0
80
+ add_decoder: true
81
+
82
+ hydra:
83
+ job:
84
+ config:
85
+ override_dirname:
86
+ kv_sep: '-'
87
+ item_sep: '__'
88
+ exclude_keys:
89
+ - run
90
+ - task.data
91
+ - task.label_dir
92
+ - model.w2v_path
93
+ - dataset.train_subset
94
+ - dataset.valid_subset
95
+ - criterion.wer_kenlm_model
96
+ - criterion.wer_lexicon
97
+ run:
98
+ dir: ???
99
+ sweep:
100
+ dir: ???
101
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_100h.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 100
7
+ tensorboard_logdir: tblog
8
+ seed: 1337
9
+
10
+ checkpoint:
11
+ save_interval: 1
12
+ keep_last_epochs: 5
13
+ keep_best_checkpoints: 5
14
+ best_checkpoint_metric: dec_accuracy
15
+ maximize_best_checkpoint_metric: true
16
+ restore_file: checkpoint_last.pt
17
+
18
+ distributed_training:
19
+ ddp_backend: legacy_ddp
20
+ find_unused_parameters: true
21
+ distributed_world_size: 16
22
+ distributed_port: -1
23
+ nprocs_per_node: 8
24
+
25
+ task:
26
+ _name: joint_sc2t_pretraining
27
+ data: ???
28
+ fine_tuning: true
29
+ label_dir: ???
30
+ normalize: true # must be consistent with pre-training
31
+ labels: ["ltr"]
32
+ store_labels: true
33
+ single_target: true
34
+ add_decoder_target: true
35
+ pad_audio: false
36
+ random_crop: true
37
+ hubert_tokenizer: "none"
38
+ sp_path: None
39
+
40
+ dataset:
41
+ num_workers: 0
42
+ max_tokens: 1300000
43
+ skip_invalid_size_inputs_valid_test: true
44
+ train_subset: train_100
45
+ valid_subset: dev_other
46
+ required_batch_size_multiple: 1
47
+
48
+ criterion:
49
+ _name: ctc_ce
50
+ zero_infinity: true
51
+
52
+ optimization:
53
+ max_update: 40000
54
+ lr: [0.00001]
55
+ sentence_avg: true
56
+ update_freq: [2]
57
+
58
+ optimizer:
59
+ _name: adam
60
+ adam_betas: (0.9,0.98)
61
+ adam_eps: 1e-08
62
+ weight_decay: 0.0
63
+
64
+ lr_scheduler:
65
+ _name: tri_stage
66
+ phase_ratio: [0.1, 0.4, 0.5]
67
+ final_lr_scale: 0.05
68
+
69
+ model:
70
+ _name: speechut_asr
71
+ w2v_path: ???
72
+ apply_mask: true
73
+ mask_prob: 0.5
74
+ mask_channel_prob: 0.5
75
+ mask_channel_length: 64
76
+ layerdrop: 0.0
77
+ activation_dropout: 0.1
78
+ attention_dropout: 0.1
79
+ feature_grad_mult: 0.0
80
+ freeze_finetune_updates: 0
81
+ add_decoder: true
82
+
83
+ hydra:
84
+ job:
85
+ config:
86
+ override_dirname:
87
+ kv_sep: '-'
88
+ item_sep: '__'
89
+ exclude_keys:
90
+ - run
91
+ - task.data
92
+ - task.label_dir
93
+ - model.w2v_path
94
+ - dataset.train_subset
95
+ - dataset.valid_subset
96
+ - criterion.wer_kenlm_model
97
+ - criterion.wer_lexicon
98
+ run:
99
+ dir: ???
100
+ sweep:
101
+ dir: ???
102
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
SpeechT5/Speech2S/speech2s/config/finetune_asr/speechut_large_960h.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 100
7
+ tensorboard_logdir: tblog
8
+
9
+ checkpoint:
10
+ save_interval: 1
11
+ keep_last_epochs: 5
12
+ keep_best_checkpoints: 5
13
+ best_checkpoint_metric: dec_accuracy
14
+ maximize_best_checkpoint_metric: true
15
+ restore_file: checkpoint_last.pt
16
+
17
+ distributed_training:
18
+ ddp_backend: legacy_ddp
19
+ find_unused_parameters: true
20
+ distributed_world_size: 24
21
+ distributed_port: -1
22
+ nprocs_per_node: 8
23
+
24
+ task:
25
+ _name: joint_sc2t_pretraining
26
+ data: ???
27
+ fine_tuning: true
28
+ label_dir: ???
29
+ normalize: true # must be consistent with pre-training
30
+ labels: ["ltr"]
31
+ store_labels: true
32
+ single_target: true
33
+ add_decoder_target: true
34
+ pad_audio: false
35
+ random_crop: true
36
+ hubert_tokenizer: "none"
37
+ sp_path: None
38
+
39
+ dataset:
40
+ num_workers: 0
41
+ max_tokens: 1300000
42
+ skip_invalid_size_inputs_valid_test: true
43
+ train_subset: train_960
44
+ valid_subset: dev_other
45
+ required_batch_size_multiple: 1
46
+
47
+ criterion:
48
+ _name: ctc_ce
49
+ zero_infinity: true
50
+
51
+ optimization:
52
+ max_update: 40000
53
+ lr: [0.00001]
54
+ sentence_avg: true
55
+ update_freq: [2]
56
+
57
+ optimizer:
58
+ _name: adam
59
+ adam_betas: (0.9,0.98)
60
+ adam_eps: 1e-08
61
+ weight_decay: 0.0
62
+
63
+ lr_scheduler:
64
+ _name: tri_stage
65
+ phase_ratio: [0.1, 0.4, 0.5]
66
+ final_lr_scale: 0.05
67
+
68
+ model:
69
+ _name: speechut_asr
70
+ w2v_path: ???
71
+ apply_mask: true
72
+ mask_prob: 0.5
73
+ mask_channel_prob: 0.25
74
+ mask_channel_length: 64
75
+ layerdrop: 0.0
76
+ activation_dropout: 0.1
77
+ feature_grad_mult: 0.0
78
+ freeze_finetune_updates: 0
79
+ add_decoder: true
80
+
81
+ hydra:
82
+ job:
83
+ config:
84
+ override_dirname:
85
+ kv_sep: '-'
86
+ item_sep: '__'
87
+ exclude_keys:
88
+ - run
89
+ - task.data
90
+ - task.label_dir
91
+ - model.w2v_path
92
+ - dataset.train_subset
93
+ - dataset.valid_subset
94
+ - criterion.wer_kenlm_model
95
+ - criterion.wer_lexicon
96
+ run:
97
+ dir: ???
98
+ sweep:
99
+ dir: ???
100
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
SpeechT5/Speech2S/speech2s/config/pretrain/speechut_base_librispeech.yaml ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ seed: 1337
8
+ tensorboard_logdir: tblog
9
+
10
+ checkpoint:
11
+ save_dir: ???
12
+ save_interval: 4
13
+ keep_last_epochs: 4
14
+ save_interval_updates: 50000
15
+ keep_interval_updates: -1
16
+ keep_interval_updates_pattern: 50000
17
+ # no_epoch_checkpoints: true
18
+
19
+ distributed_training:
20
+ ddp_backend: no_c10d
21
+ distributed_backend: 'nccl'
22
+ distributed_port: -1
23
+ distributed_world_size: 32
24
+ nprocs_per_node: 8
25
+ find_unused_parameters: true
26
+
27
+ task:
28
+ _name: joint_sc2t_pretraining
29
+ data: ???
30
+ label_dir: ???
31
+ labels: ???
32
+ label_rate: ${model.label_rate}
33
+ store_labels: true
34
+ sample_rate: 16000
35
+ max_sample_size: 250000
36
+ min_sample_size: 32000
37
+ pad_audio: false
38
+ random_crop: true
39
+ normalize: false # must be consistent with extractor
40
+ add_decoder_target: true
41
+ text_cfg:
42
+ seed: ${common.seed}
43
+ text_data: ???
44
+ data_config: config.yaml
45
+ sample_break_mode: eos
46
+ tokens_per_sample: 1024
47
+ shorten_method: "random_crop"
48
+ text_maxtokens_ratio: 1.5
49
+
50
+ dataset:
51
+ num_workers: 6
52
+ max_tokens: 1400000
53
+ skip_invalid_size_inputs_valid_test: true
54
+ validate_interval: ${checkpoint.save_interval}
55
+ validate_interval_updates: ${checkpoint.save_interval_updates}
56
+ required_batch_size_multiple: 1
57
+
58
+ criterion:
59
+ _name: speechut_criterion
60
+ pred_masked_weight: 1.0
61
+ pred_nomask_weight: 0.0
62
+ loss_weights: [10,]
63
+ label_smoothing: 0.1
64
+ u2t_ed_weight: 0.1
65
+ u2t_ctc_weight: 0.1
66
+ text_mum_weight: 0.5
67
+
68
+ optimization:
69
+ max_update: 400000
70
+ lr: [0.0005]
71
+ clip_norm: 10.0
72
+
73
+ optimizer:
74
+ _name: adam
75
+ adam_betas: (0.9,0.98)
76
+ adam_eps: 1e-06
77
+ weight_decay: 0.01
78
+
79
+ lr_scheduler:
80
+ _name: polynomial_decay
81
+ warmup_updates: 32000
82
+
83
+ model:
84
+ _name: speechut
85
+ label_rate: ???
86
+ skip_masked: false
87
+ skip_nomask: false
88
+ mask_prob: 0.80
89
+ extractor_mode: default
90
+ conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
91
+ final_dim: 256
92
+ activation_fn: "gelu"
93
+ encoder_layers: 6
94
+ encoder_attention_heads: 8
95
+ encoder_layerdrop: 0.0
96
+ dropout_input: 0.1
97
+ dropout_features: 0.1
98
+ dropout: 0.1
99
+ attention_dropout: 0.1
100
+ feature_grad_mult: 0.1
101
+ untie_final_proj: true
102
+ activation_dropout: 0.0
103
+ use_rel_pos_enc: true
104
+ add_unit_encoder: true
105
+ add_text_ctc: true
106
+ mask_u2t: false
107
+ mix_with_unit: true
108
+ add_decoder: true
109
+ reset_decoder_embedding_config: true
110
+ text_transformer:
111
+ activation_fn: ${model.activation_fn}
112
+ dropout: ${model.dropout}
113
+ attention_dropout: ${model.attention_dropout}
114
+ activation_dropout: ${model.activation_dropout}
115
+ max_source_positions: 3000
116
+ max_target_positions: 3000
117
+ no_scale_embedding: true
118
+ layernorm_embedding: true
119
+ no_token_positional_embeddings: false
120
+ share_decoder_input_output_embed: false
121
+ encoder:
122
+ embed_dim: 768
123
+ ffn_embed_dim: 3072
124
+ layers: 6
125
+ attention_heads: 8
126
+ normalize_before: false
127
+ learned_pos: true
128
+ layerdrop: ${model.encoder_layerdrop}
129
+ decoder:
130
+ layerdrop: 0.1
131
+ embed_dim: 768
132
+ ffn_embed_dim: 3072
133
+ layers: 6
134
+ attention_heads: 12
135
+ normalize_before: false
136
+ learned_pos: false
137
+ output_dim: 768
138
+
139
+ hydra:
140
+ job:
141
+ config:
142
+ override_dirname:
143
+ kv_sep: '-'
144
+ item_sep: '__'
145
+ exclude_keys:
146
+ - run
147
+ - task.data
148
+ - task.label_dir
149
+ run:
150
+ dir: ???
151
+ sweep:
152
+ dir: ???
153
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
SpeechT5/Speech2S/speech2s/config/pretrain/speechut_large_librilight.yaml ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ fp16_scale_tolerance: 0.1 # alleviate fp16 overflow issue
6
+ log_format: json
7
+ log_interval: 200
8
+ seed: 1234
9
+ tensorboard_logdir: tblog
10
+
11
+ checkpoint:
12
+ save_dir: ???
13
+ save_interval: 1
14
+ keep_last_epochs: 4
15
+ save_interval_updates: 10000
16
+ keep_interval_updates: -1
17
+ keep_interval_updates_pattern: 10000
18
+ # no_epoch_checkpoints: true
19
+
20
+ distributed_training:
21
+ ddp_backend: no_c10d
22
+ distributed_backend: 'nccl'
23
+ distributed_port: -1
24
+ distributed_world_size: 128
25
+ nprocs_per_node: 8
26
+ find_unused_parameters: true
27
+
28
+ task:
29
+ _name: joint_sc2t_pretraining
30
+ data: ???
31
+ label_dir: ???
32
+ labels: ???
33
+ label_rate: ${model.label_rate}
34
+ store_labels: true
35
+ sample_rate: 16000
36
+ max_sample_size: 250000
37
+ min_sample_size: 32000
38
+ pad_audio: false
39
+ random_crop: true
40
+ normalize: true # must be consistent with extractor
41
+ add_decoder_target: true
42
+ text_cfg:
43
+ seed: ${common.seed}
44
+ text_data: ???
45
+ data_config: config.yaml
46
+ sample_break_mode: eos
47
+ tokens_per_sample: 1024
48
+ shorten_method: "random_crop"
49
+ text_maxtokens_ratio: 1.4
50
+
51
+ dataset:
52
+ num_workers: 6
53
+ max_tokens: 900000
54
+ skip_invalid_size_inputs_valid_test: true
55
+ validate_interval: ${checkpoint.save_interval}
56
+ validate_interval_updates: ${checkpoint.save_interval_updates}
57
+ required_batch_size_multiple: 2
58
+
59
+ criterion:
60
+ _name: speechut_criterion
61
+ pred_masked_weight: 1.0
62
+ pred_nomask_weight: 0.0
63
+ loss_weights: [10,]
64
+ label_smoothing: 0.1
65
+ u2t_ed_weight: 0.1
66
+ u2t_ctc_weight: 0.1
67
+ text_mum_weight: 0.5
68
+
69
+ optimization:
70
+ max_update: 400000
71
+ lr: [0.0005]
72
+ clip_norm: 1.0
73
+
74
+ optimizer:
75
+ _name: adam
76
+ adam_betas: (0.9,0.98)
77
+ adam_eps: 1e-06
78
+ weight_decay: 0.01
79
+
80
+ lr_scheduler:
81
+ _name: polynomial_decay
82
+ warmup_updates: 32000
83
+ end_learning_rate: 0.00015 # for future longger pre-training, e.g. 600K step
84
+
85
+ model:
86
+ _name: speechut
87
+ label_rate: ???
88
+ encoder_embed_dim: 1024
89
+ encoder_ffn_embed_dim: 4096
90
+ skip_masked: false
91
+ skip_nomask: false
92
+ mask_prob: 0.80
93
+ extractor_mode: layer_norm
94
+ conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
95
+ final_dim: 768
96
+ activation_fn: "gelu"
97
+ encoder_layers: 12
98
+ encoder_attention_heads: 16
99
+ encoder_layerdrop: 0.0
100
+ dropout_input: 0.0
101
+ dropout_features: 0.0
102
+ dropout: 0.0
103
+ attention_dropout: 0.0
104
+ layer_norm_first: true
105
+ feature_grad_mult: 1.0
106
+ untie_final_proj: true
107
+ activation_dropout: 0.0
108
+ use_rel_pos_enc: true
109
+ add_unit_encoder: true
110
+ add_text_ctc: true
111
+ mask_u2t: false
112
+ mix_with_unit: true
113
+ add_decoder: true
114
+ reset_decoder_embedding_config: true
115
+ scaling_for_att: 32 # alleviate fp16 overflow issue
116
+ text_transformer:
117
+ activation_fn: ${model.activation_fn}
118
+ dropout: ${model.dropout}
119
+ attention_dropout: ${model.attention_dropout}
120
+ activation_dropout: ${model.activation_dropout}
121
+ max_source_positions: 3000
122
+ max_target_positions: 3000
123
+ no_scale_embedding: true
124
+ layernorm_embedding: true
125
+ no_token_positional_embeddings: true
126
+ share_decoder_input_output_embed: false
127
+ encoder:
128
+ embed_dim: 1024
129
+ ffn_embed_dim: 4096
130
+ layers: 12
131
+ attention_heads: 16
132
+ normalize_before: false
133
+ learned_pos: true
134
+ layerdrop: ${model.encoder_layerdrop}
135
+ decoder:
136
+ layerdrop: 0.1
137
+ embed_dim: 768
138
+ ffn_embed_dim: 3072
139
+ layers: 6
140
+ attention_heads: 12
141
+ normalize_before: false
142
+ learned_pos: false
143
+ output_dim: 768
144
+
145
+ hydra:
146
+ job:
147
+ config:
148
+ override_dirname:
149
+ kv_sep: '-'
150
+ item_sep: '__'
151
+ exclude_keys:
152
+ - run
153
+ - task.data
154
+ - task.label_dir
155
+ run:
156
+ dir: ???
157
+ sweep:
158
+ dir: ???
159
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
SpeechT5/Speech2S/speech2s/criterions/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+
4
+ for file in os.listdir(os.path.dirname(__file__)):
5
+ if file.endswith(".py") and not file.startswith("_"):
6
+ criterion_name = file[: file.find(".py")]
7
+ importlib.import_module(
8
+ "speechut.criterions." + criterion_name
9
+ )
SpeechT5/Speech2S/speech2s/criterions/ctc_ce.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----------------------------------------------------------------------------
2
+ # SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
4
+ # Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
5
+ #
6
+ # Copyright (c) 2022 Microsoft
7
+ # Licensed under The MIT License [see LICENSE for details]
8
+ # ----------------------------------------------------------------------------
9
+
10
+ import math
11
+ from argparse import Namespace
12
+ from dataclasses import dataclass, field
13
+ from omegaconf import II
14
+ from typing import Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from fairseq import metrics, utils
19
+ from fairseq.criterions import FairseqCriterion, register_criterion
20
+ from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
21
+ from fairseq.dataclass import FairseqDataclass
22
+ from fairseq.data.data_utils import post_process
23
+ from fairseq.tasks import FairseqTask
24
+ from fairseq.logging.meters import safe_round
25
+
26
+
27
+ @dataclass
28
+ class CtcCeCriterionConfig(FairseqDataclass):
29
+ zero_infinity: bool = field(
30
+ default=False,
31
+ metadata={"help": "zero inf loss when source length <= target length"},
32
+ )
33
+ sentence_avg: bool = II("optimization.sentence_avg")
34
+ post_process: str = field(
35
+ default="letter",
36
+ metadata={
37
+ "help": "how to post process predictions into words. can be letter, "
38
+ "wordpiece, BPE symbols, etc. "
39
+ "See fairseq.data.data_utils.post_process() for full list of options"
40
+ },
41
+ )
42
+ wer_kenlm_model: Optional[str] = field(
43
+ default=None,
44
+ metadata={
45
+ "help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
46
+ },
47
+ )
48
+ wer_lexicon: Optional[str] = field(
49
+ default=None,
50
+ metadata={"help": "lexicon to use with wer_kenlm_model"},
51
+ )
52
+ wer_lm_weight: float = field(
53
+ default=2.0,
54
+ metadata={"help": "lm weight to use with wer_kenlm_model"},
55
+ )
56
+ wer_word_score: float = field(
57
+ default=-1.0,
58
+ metadata={"help": "lm word score to use with wer_kenlm_model"},
59
+ )
60
+
61
+ wer_args: Optional[str] = field(
62
+ default=None,
63
+ metadata={
64
+ "help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
65
+ },
66
+ )
67
+
68
+ dec_weight: float = field(
69
+ default=0.5,
70
+ metadata={"help": "weights for decoder CE Loss, loss will be ((1 - dec_weight) * hubert_loss + dec_weight * CE_Loss)"},
71
+ )
72
+ report_accuracy: bool = field(
73
+ default=True,
74
+ metadata={"help": "report decoder accuracy metric"},
75
+ )
76
+ ignore_prefix_size: int = field(
77
+ default=0,
78
+ metadata={"help": "Ignore first N tokens"},
79
+ )
80
+ label_smoothing: float = field(
81
+ default=0.1,
82
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
83
+ )
84
+
85
+
86
+ @register_criterion("ctc_ce", dataclass=CtcCeCriterionConfig)
87
+ class CtcCeCriterion(FairseqCriterion):
88
+ def __init__(self, cfg: CtcCeCriterionConfig, task: FairseqTask):
89
+ super().__init__(task)
90
+ self.blank_idx = (
91
+ task.target_dictionary.index(task.blank_symbol)
92
+ if hasattr(task, "blank_symbol")
93
+ else 0
94
+ )
95
+ self.pad_idx = task.target_dictionary.pad()
96
+ self.eos_idx = task.target_dictionary.eos()
97
+ self.post_process = cfg.post_process
98
+
99
+ if cfg.wer_args is not None:
100
+ (
101
+ cfg.wer_kenlm_model,
102
+ cfg.wer_lexicon,
103
+ cfg.wer_lm_weight,
104
+ cfg.wer_word_score,
105
+ ) = eval(cfg.wer_args)
106
+
107
+ if cfg.wer_kenlm_model is not None:
108
+ from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
109
+
110
+ dec_args = Namespace()
111
+ dec_args.nbest = 1
112
+ dec_args.criterion = "ctc"
113
+ dec_args.kenlm_model = cfg.wer_kenlm_model
114
+ dec_args.lexicon = cfg.wer_lexicon
115
+ dec_args.beam = 50
116
+ dec_args.beam_size_token = min(50, len(task.target_dictionary))
117
+ dec_args.beam_threshold = min(50, len(task.target_dictionary))
118
+ dec_args.lm_weight = cfg.wer_lm_weight
119
+ dec_args.word_score = cfg.wer_word_score
120
+ dec_args.unk_weight = -math.inf
121
+ dec_args.sil_weight = 0
122
+
123
+ self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
124
+ else:
125
+ self.w2l_decoder = None
126
+
127
+ self.zero_infinity = cfg.zero_infinity
128
+ self.sentence_avg = cfg.sentence_avg
129
+
130
+ self.dec_weight = cfg.dec_weight
131
+ self.report_accuracy = cfg.report_accuracy
132
+ self.ignore_prefix_size = cfg.ignore_prefix_size
133
+ self.eps = cfg.label_smoothing
134
+
135
+ def forward(self, model, sample, reduce=True):
136
+ net_output = model(**sample["net_input"])
137
+ lprobs = model.get_normalized_probs(
138
+ net_output, log_probs=True
139
+ ).contiguous() # (T, B, C) from the encoder
140
+
141
+ if "src_lengths" in sample["net_input"]:
142
+ input_lengths = sample["net_input"]["src_lengths"]
143
+ else:
144
+ if net_output["padding_mask"] is not None:
145
+ non_padding_mask = ~net_output["padding_mask"]
146
+ input_lengths = non_padding_mask.long().sum(-1)
147
+ else:
148
+ input_lengths = lprobs.new_full(
149
+ (lprobs.size(1),), lprobs.size(0), dtype=torch.long
150
+ )
151
+
152
+ pad_mask = (sample["target"] != self.pad_idx) & (
153
+ sample["target"] != self.eos_idx
154
+ )
155
+ targets_flat = sample["target"].masked_select(pad_mask)
156
+ if "target_lengths" in sample:
157
+ target_lengths = sample["target_lengths"]
158
+ else:
159
+ target_lengths = pad_mask.sum(-1)
160
+
161
+ with torch.backends.cudnn.flags(enabled=False):
162
+ loss = F.ctc_loss(
163
+ lprobs,
164
+ targets_flat,
165
+ input_lengths,
166
+ target_lengths,
167
+ blank=self.blank_idx,
168
+ reduction="sum",
169
+ zero_infinity=self.zero_infinity,
170
+ )
171
+
172
+ ntokens = (
173
+ sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
174
+ )
175
+
176
+ sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
177
+
178
+ logging_output = {}
179
+ if "decoder_target" in sample:
180
+ if net_output["decoder_out"] is not None:
181
+ dec_sample_size = sample["target"].size(0) if self.sentence_avg else sample["dec_ntokens"]
182
+ dec_loss, dec_nll_loss = self.compute_ce_loss(model, net_output["decoder_out"], sample, reduce=reduce)
183
+ logging_output["ctc_loss"] = loss.item()
184
+ loss = (1 - self.dec_weight) * loss + (self.dec_weight * dec_loss * sample_size / dec_sample_size)
185
+ logging_output["dec_loss"] = dec_loss.item()
186
+ logging_output["dec_nll_loss"] = dec_nll_loss.item()
187
+ logging_output["dec_sample_size"] = dec_sample_size
188
+
189
+ if self.report_accuracy:
190
+ n_correct, total = self.compute_accuracy(model, net_output["decoder_out"], sample)
191
+ logging_output["dec_n_correct"] = utils.item(n_correct.data)
192
+ logging_output["total"] = utils.item(total.data)
193
+ else:
194
+ logging_output["ctc_loss"] = loss.item()
195
+ loss = (1 - self.dec_weight) * loss
196
+ logging_output["dec_loss"] = 0
197
+ logging_output["dec_nll_loss"] = 0
198
+ logging_output["dec_sample_size"] = 1
199
+ if self.report_accuracy:
200
+ logging_output["dec_n_correct"] = 0
201
+ logging_output["total"] = 1
202
+
203
+ logging_output = {
204
+ "loss": utils.item(loss.data), # * sample['ntokens'],
205
+ "ntokens": ntokens,
206
+ "nsentences": sample["id"].numel(),
207
+ "sample_size": sample_size,
208
+ **logging_output,
209
+ }
210
+
211
+ if not model.training and self.dec_weight < 1.0:
212
+ import editdistance
213
+
214
+ with torch.no_grad():
215
+ lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
216
+
217
+ c_err = 0
218
+ c_len = 0
219
+ w_errs = 0
220
+ w_len = 0
221
+ wv_errs = 0
222
+ for lp, t, inp_l in zip(
223
+ lprobs_t,
224
+ sample["target_label"]
225
+ if "target_label" in sample
226
+ else sample["target"],
227
+ input_lengths,
228
+ ):
229
+ lp = lp[:inp_l].unsqueeze(0)
230
+
231
+ decoded = None
232
+ if self.w2l_decoder is not None:
233
+ decoded = self.w2l_decoder.decode(lp)
234
+ if len(decoded) < 1:
235
+ decoded = None
236
+ else:
237
+ decoded = decoded[0]
238
+ if len(decoded) < 1:
239
+ decoded = None
240
+ else:
241
+ decoded = decoded[0]
242
+
243
+ p = (t != self.task.target_dictionary.pad()) & (
244
+ t != self.task.target_dictionary.eos()
245
+ )
246
+ targ = t[p]
247
+ targ_units = self.task.target_dictionary.string(targ)
248
+ targ_units_arr = targ.tolist()
249
+
250
+ toks = lp.argmax(dim=-1).unique_consecutive()
251
+ pred_units_arr = toks[toks != self.blank_idx].tolist()
252
+
253
+ c_err += editdistance.eval(pred_units_arr, targ_units_arr)
254
+ c_len += len(targ_units_arr)
255
+
256
+ targ_words = post_process(targ_units, self.post_process).split()
257
+
258
+ pred_units = self.task.target_dictionary.string(pred_units_arr)
259
+ pred_words_raw = post_process(pred_units, self.post_process).split()
260
+
261
+ if decoded is not None and "words" in decoded:
262
+ pred_words = decoded["words"]
263
+ w_errs += editdistance.eval(pred_words, targ_words)
264
+ wv_errs += editdistance.eval(pred_words_raw, targ_words)
265
+ else:
266
+ dist = editdistance.eval(pred_words_raw, targ_words)
267
+ w_errs += dist
268
+ wv_errs += dist
269
+
270
+ w_len += len(targ_words)
271
+
272
+ logging_output["wv_errors"] = wv_errs
273
+ logging_output["w_errors"] = w_errs
274
+ logging_output["w_total"] = w_len
275
+ logging_output["c_errors"] = c_err
276
+ logging_output["c_total"] = c_len
277
+
278
+ return loss, sample_size, logging_output
279
+
280
+ def compute_ce_loss(self, model, net_output, sample, reduce=True):
281
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
282
+ loss, nll_loss = label_smoothed_nll_loss(
283
+ lprobs,
284
+ target,
285
+ self.eps,
286
+ ignore_index=self.pad_idx,
287
+ reduce=reduce,
288
+ )
289
+ return loss, nll_loss
290
+
291
+ def compute_accuracy(self, model, net_output, sample):
292
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
293
+ mask = target.ne(self.pad_idx)
294
+ n_correct = torch.sum(
295
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
296
+ )
297
+ total = torch.sum(mask)
298
+ return n_correct, total
299
+
300
+ def get_lprobs_and_target(self, model, net_output, sample):
301
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
302
+ target = sample["decoder_target"]
303
+ if self.ignore_prefix_size > 0:
304
+ if getattr(lprobs, "batch_first", False):
305
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
306
+ target = target[:, self.ignore_prefix_size :].contiguous()
307
+ else:
308
+ lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
309
+ target = target[self.ignore_prefix_size :, :].contiguous()
310
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
311
+
312
+
313
+ @staticmethod
314
+ def reduce_metrics(logging_outputs) -> None:
315
+ """Aggregate logging outputs from data parallel training."""
316
+
317
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
318
+ ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
319
+ nsentences = utils.item(
320
+ sum(log.get("nsentences", 0) for log in logging_outputs)
321
+ )
322
+ sample_size = utils.item(
323
+ sum(log.get("sample_size", 0) for log in logging_outputs)
324
+ )
325
+
326
+ metrics.log_scalar(
327
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
328
+ )
329
+ metrics.log_scalar("ntokens", ntokens)
330
+ metrics.log_scalar("nsentences", nsentences)
331
+ if sample_size != ntokens:
332
+ metrics.log_scalar(
333
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
334
+ )
335
+
336
+ c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
337
+ metrics.log_scalar("_c_errors", c_errors)
338
+ c_total = sum(log.get("c_total", 0) for log in logging_outputs)
339
+ metrics.log_scalar("_c_total", c_total)
340
+ w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
341
+ metrics.log_scalar("_w_errors", w_errors)
342
+ wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
343
+ metrics.log_scalar("_wv_errors", wv_errors)
344
+ w_total = sum(log.get("w_total", 0) for log in logging_outputs)
345
+ metrics.log_scalar("_w_total", w_total)
346
+
347
+ if c_total > 0:
348
+ metrics.log_derived(
349
+ "uer",
350
+ lambda meters: safe_round(
351
+ meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
352
+ )
353
+ if meters["_c_total"].sum > 0
354
+ else float("nan"),
355
+ )
356
+ if w_total > 0:
357
+ metrics.log_derived(
358
+ "wer",
359
+ lambda meters: safe_round(
360
+ meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
361
+ )
362
+ if meters["_w_total"].sum > 0
363
+ else float("nan"),
364
+ )
365
+ metrics.log_derived(
366
+ "raw_wer",
367
+ lambda meters: safe_round(
368
+ meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
369
+ )
370
+ if meters["_w_total"].sum > 0
371
+ else float("nan"),
372
+ )
373
+
374
+ if "dec_loss" in logging_outputs[0]:
375
+ ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
376
+ dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
377
+ dec_nll_loss_sum = sum(log.get("dec_nll_loss", 0) for log in logging_outputs)
378
+ dec_sample_size = sum(log.get("dec_sample_size", 0) for log in logging_outputs)
379
+ metrics.log_scalar(
380
+ "dec_loss", dec_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
381
+ )
382
+ metrics.log_scalar(
383
+ "ctc_loss", ctc_loss_sum / sample_size / math.log(2), sample_size, round=3
384
+ )
385
+ metrics.log_scalar(
386
+ "dec_nll_loss", dec_nll_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
387
+ )
388
+ metrics.log_derived(
389
+ "dec_ppl", lambda meters: utils.get_perplexity(meters["dec_nll_loss"].avg)
390
+ )
391
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
392
+ if total > 0:
393
+ metrics.log_scalar("total", total)
394
+ n_correct = utils.item(
395
+ sum(log.get("dec_n_correct", 0) for log in logging_outputs)
396
+ )
397
+ metrics.log_scalar("dec_n_correct", n_correct)
398
+ metrics.log_derived(
399
+ "dec_accuracy",
400
+ lambda meters: round(
401
+ meters["dec_n_correct"].sum * 100.0 / meters["total"].sum, 3
402
+ )
403
+ if meters["total"].sum > 0
404
+ else float("nan"),
405
+ )
406
+
407
+ @staticmethod
408
+ def logging_outputs_can_be_summed() -> bool:
409
+ """
410
+ Whether the logging outputs returned by `forward` can be summed
411
+ across workers prior to calling `reduce_metrics`. Setting this
412
+ to True will improves distributed training speed.
413
+ """
414
+ return True
SpeechT5/Speech2S/speech2s/criterions/speechut_criterion.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----------------------------------------------------------------------------
2
+ # SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
4
+ # Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
5
+ #
6
+ # Copyright (c) 2022 Microsoft
7
+ # Licensed under The MIT License [see LICENSE for details]
8
+ # ----------------------------------------------------------------------------
9
+
10
+ import logging
11
+ import math
12
+ import re
13
+ from dataclasses import dataclass, field
14
+ from typing import List, Optional
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from fairseq import metrics, utils
20
+ from fairseq.criterions import FairseqCriterion, register_criterion
21
+ from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
22
+ from fairseq.dataclass import FairseqDataclass
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ @dataclass
27
+ class SpeechUTCriterionConfig(FairseqDataclass):
28
+ pred_masked_weight: float = field(
29
+ default=1.0,
30
+ metadata={"help": "weight for predictive loss for masked frames"},
31
+ )
32
+ pred_nomask_weight: float = field(
33
+ default=0.0,
34
+ metadata={"help": "weight for predictive loss for unmasked frames"},
35
+ )
36
+ loss_weights: Optional[List[float]] = field(
37
+ default=None,
38
+ metadata={"help": "weights for additional loss terms (not first one)"},
39
+ )
40
+ log_keys: List[str] = field(
41
+ default_factory=lambda: [],
42
+ metadata={"help": "output keys to log"},
43
+ )
44
+ u2t_ed_weight: float = field(
45
+ default=0.1,
46
+ metadata={"help": "weights for text ED Loss, loss will be (hubert_loss + text_mum_weight * MUM_Loss + u2t_ed_weight * CE_Loss + u2t_ctc_weight * CTC_loss)"},
47
+ )
48
+ u2t_ctc_weight: float = field(
49
+ default=0.0,
50
+ metadata={"help": "weights for text ED Loss, loss will be (hubert_loss + text_mum_weight * MUM_Loss + u2t_ed_weight * CE_Loss + u2t_ctc_weight * CTC_loss)"},
51
+ )
52
+ text_mum_weight: float = field(
53
+ default=0.0,
54
+ metadata={"help": "masked unit modeling weight from the text end"},
55
+ )
56
+ report_accuracy: bool = field(
57
+ default=True,
58
+ metadata={"help": "report decoder accuracy metric"},
59
+ )
60
+ ignore_prefix_size: int = field(
61
+ default=0,
62
+ metadata={"help": "Ignore first N tokens"},
63
+ )
64
+ label_smoothing: float = field(
65
+ default=0.0,
66
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
67
+ )
68
+ no_ctc_blank: bool = field(
69
+ default=False,
70
+ metadata={"help": "mask out the blank of ctc, only when dec_loss_type=ctc"},
71
+ )
72
+ label_smoothing: float = field(
73
+ default=0.0,
74
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
75
+ )
76
+
77
+ @register_criterion("speechut_criterion", dataclass=SpeechUTCriterionConfig)
78
+ class SpeechUTCriterion(FairseqCriterion):
79
+ def __init__(
80
+ self,
81
+ task,
82
+ pred_masked_weight,
83
+ pred_nomask_weight,
84
+ loss_weights=None,
85
+ log_keys=None,
86
+ u2t_ed_weight=0.1,
87
+ u2t_ctc_weight=0,
88
+ text_mum_weight=0,
89
+ report_accuracy=False,
90
+ ignore_prefix_size=0,
91
+ label_smoothing=0,
92
+ no_ctc_blank=False,
93
+ ):
94
+ super().__init__(task)
95
+ self.pred_masked_weight = pred_masked_weight
96
+ self.pred_nomask_weight = pred_nomask_weight
97
+ self.loss_weights = loss_weights
98
+ self.log_keys = [] if log_keys is None else log_keys
99
+ self.u2t_ed_weight = u2t_ed_weight
100
+ self.u2t_ctc_weight = u2t_ctc_weight
101
+ self.text_mum_weight = text_mum_weight
102
+ self.report_accuracy = report_accuracy
103
+ self.ignore_prefix_size = ignore_prefix_size
104
+ self.eps = label_smoothing
105
+ self.no_ctc_blank = no_ctc_blank
106
+ self.padding_idx = task.dictionaries[0].pad()
107
+ self.eos_idx = task.dictionaries[0].eos()
108
+ self.blank_idx = task.dictionaries[0].bos()
109
+
110
+ def compute_hubert_loss(self, model, net_output, reduction, preffix='', suffix=''):
111
+ loss = 0
112
+ sample_size = []
113
+ logging_output = {}
114
+ loss_m_list = []
115
+ logp_m_list = model.get_logits(net_output, True)
116
+ targ_m_list = model.get_targets(net_output, True)
117
+ assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
118
+ for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
119
+ loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
120
+ loss_m_list.append(loss_m)
121
+ logging_output[f"{preffix}loss_m_{i}"] = loss_m.detach().item()
122
+ if self.pred_masked_weight > 0:
123
+ loss += self.pred_masked_weight * sum(loss_m_list)
124
+ sample_size.append(targ_m_list[0].numel())
125
+
126
+ loss_u_list = []
127
+ logp_u_list = model.get_logits(net_output, False)
128
+ targ_u_list = model.get_targets(net_output, False)
129
+ assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
130
+ for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
131
+ loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
132
+ loss_u_list.append(loss_u)
133
+ logging_output[f"{preffix}loss_u_{i}"] = loss_u.detach().item()
134
+ if self.pred_nomask_weight > 0:
135
+ loss += self.pred_nomask_weight * sum(loss_u_list)
136
+ sample_size.append(targ_u_list[0].numel())
137
+
138
+ sample_size = np.mean(sample_size)
139
+
140
+ def compute_correct(logits, targets):
141
+ if logits.numel() == 0:
142
+ return 0, 0
143
+ else:
144
+ assert logits.dim() > 1, logits.shape
145
+ max = logits.argmax(-1) == targets
146
+ min = logits.argmin(-1) == targets
147
+ both = max & min
148
+ corr = max.long().sum().item() - both.long().sum().item()
149
+ count = max.numel()
150
+ return corr, count
151
+
152
+ with torch.no_grad():
153
+ for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
154
+ corr_m, count_m = compute_correct(logp_m, targ_m)
155
+ logging_output[f"correct_m_{i}{suffix}"] = corr_m
156
+ logging_output[f"count_m_{i}{suffix}"] = count_m
157
+
158
+ for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
159
+ corr_u, count_u = compute_correct(logp_u, targ_u)
160
+ logging_output[f"correct_u_{i}{suffix}"] = corr_u
161
+ logging_output[f"count_u_{i}{suffix}"] = count_u
162
+
163
+ return loss, sample_size, logging_output
164
+
165
+
166
+ def forward(self, model, sample, reduce=True, log_pred=False):
167
+ """Compute the loss for the given sample.
168
+ Returns a tuple with three elements:
169
+ 1) the loss
170
+ 2) the sample size, which is used as the denominator for the gradient
171
+ 3) logging outputs to display while training
172
+ """
173
+ reduction = "sum" if reduce else "none"
174
+
175
+ if "net_input" in sample:
176
+ unit_sample = text_sample = None
177
+ else:
178
+ unit_sample = sample.get("text_mono", None)
179
+ text_sample = sample.get("text_paired", None)
180
+ assert unit_sample is not None or text_sample is not None
181
+ sample = sample.get("speech")
182
+
183
+ ### 1. S2U: do hubert forward and loss computation
184
+ sample["modality"] = "speech"
185
+ net_output = model(target_list=sample["target_list"], **sample["net_input"])
186
+ loss, sample_size, logging_output = self.compute_hubert_loss(
187
+ model,
188
+ net_output,
189
+ reduction,
190
+ )
191
+ if self.loss_weights is not None:
192
+ assert hasattr(model, "get_extra_losses")
193
+ extra_losses, names = model.get_extra_losses(net_output)
194
+ if torch.is_tensor(extra_losses):
195
+ extra_losses = [extra_losses]
196
+ names = [names]
197
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
198
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
199
+ assert len(extra_losses) == len(
200
+ self.loss_weights
201
+ ), f"{len(extra_losses)}, {len(self.loss_weights)}"
202
+ for p, n, coef in zip(extra_losses, names, self.loss_weights):
203
+ if coef != 0 and p is not None:
204
+ p = coef * p.float() * sample_size
205
+ loss += p
206
+ logging_output[f"loss_{n}"] = p.item()
207
+ for lk in self.log_keys:
208
+ if lk in net_output:
209
+ logging_output[lk] = float((net_output[lk]))
210
+
211
+ ### 2. do text U2T forward and loss computation
212
+ if text_sample is not None and (self.u2t_ctc_weight + self.u2t_ed_weight) > 0:
213
+ ## 2.1 re-loading "target_list", in default case, target_list = [src_tokens],
214
+ ## while in case of using "unit-phone-char" structure, target_list will be [ref_tokens]
215
+ text_sample["net_input"]["target_list"] = [
216
+ text_sample.get("ref_tokens", text_sample["net_input"]["src_tokens"].clone()),
217
+ ]
218
+ text_net_output = model(**text_sample["net_input"])
219
+ text_sample_size = text_sample["ntokens"]
220
+
221
+ ### 2.1 U2T_UCTC
222
+ if self.u2t_ctc_weight > 0:
223
+ text_ctc_loss = self.compute_ctc_loss(model, text_net_output, text_sample["target"], reduction=reduction)
224
+ loss += self.u2t_ctc_weight * text_ctc_loss * sample_size / text_sample_size
225
+ logging_output["text_ctc_loss"] = utils.item(text_ctc_loss)
226
+ logging_output["text_sample_size"] = text_sample_size
227
+
228
+ ### 2.2 U2T_ED
229
+ if self.u2t_ed_weight > 0:
230
+ text_dec_loss, text_dec_nll_loss = self.compute_ce_loss(model, text_net_output["decoder_out"], text_sample, reduce=reduce)
231
+ loss += self.u2t_ed_weight * text_dec_loss * sample_size / text_sample_size
232
+ logging_output["text_dec_loss"] = utils.item(text_dec_loss)
233
+ logging_output["text_dec_nll_loss"] = utils.item(text_dec_nll_loss)
234
+ logging_output["text_sample_size"] = text_sample_size
235
+ if self.report_accuracy:
236
+ n_correct, total = self.compute_accuracy(model, text_net_output["decoder_out"], text_sample)
237
+ logging_output["correct_text_dec"] = utils.item(n_correct.data)
238
+ logging_output["count_text_dec"] = utils.item(total.data)
239
+
240
+ ### 3. do unit MUM forward and loss computation
241
+ if unit_sample is not None and self.text_mum_weight > 0:
242
+ src_tokens = unit_sample["net_input"]["src_tokens"]
243
+ target = unit_sample.get("target", None)
244
+ target = src_tokens.clone() if target is None else target
245
+ unit_net_output = model.forward_mum(src_tokens, target)
246
+ loss_num, sample_size_mum, logging_output_mum = self.compute_hubert_loss(
247
+ model,
248
+ unit_net_output,
249
+ reduction,
250
+ preffix="mum_",
251
+ suffix="_mum",
252
+ )
253
+ loss += self.text_mum_weight * loss_num * sample_size / sample_size_mum
254
+ logging_output["unit_sample_size"] = sample_size_mum
255
+ logging_output.update(logging_output_mum)
256
+
257
+ logging_output = {
258
+ "loss": utils.item(loss) if reduce else loss,
259
+ "ntokens": sample_size,
260
+ "nsentences": sample["id"].numel() + (text_sample["id"].numel() if text_sample is not None else 0),
261
+ "sample_size": sample_size,
262
+ **logging_output,
263
+ }
264
+
265
+ return loss, sample_size, logging_output
266
+
267
+ def compute_ctc_loss(self, model, net_output, target, reduction):
268
+ logits = net_output["encoder_out_ctc"][0] # (T, B, C) from the code-encoder
269
+ if self.no_ctc_blank:
270
+ ## set prob of <blank> to -inf
271
+ logits = logits.float()
272
+ logits[:, :, self.blank_idx] = -1000000.0
273
+
274
+ lprobs = F.log_softmax(logits.float(), dim=-1)
275
+
276
+ encoder_padding_mask = net_output["encoder_padding_mask"][0]
277
+ non_padding_mask = ~encoder_padding_mask
278
+ input_lengths = non_padding_mask.long().sum(-1)
279
+ pad_mask = (target != self.padding_idx) & (target != self.eos_idx)
280
+ targets_flat = target.masked_select(pad_mask)
281
+ target_lengths = pad_mask.sum(-1)
282
+
283
+ with torch.backends.cudnn.flags(enabled=False):
284
+ loss = F.ctc_loss(
285
+ lprobs,
286
+ targets_flat,
287
+ input_lengths,
288
+ target_lengths,
289
+ blank=self.blank_idx,
290
+ reduction=reduction,
291
+ zero_infinity=True,
292
+ )
293
+ return loss
294
+
295
+ def compute_ce_loss(self, model, net_output, sample, reduce=True):
296
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
297
+ loss, nll_loss = label_smoothed_nll_loss(
298
+ lprobs,
299
+ target,
300
+ self.eps,
301
+ ignore_index=self.padding_idx,
302
+ reduce=reduce,
303
+ )
304
+ return loss, nll_loss
305
+
306
+ def compute_accuracy(self, model, net_output, sample):
307
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
308
+ mask = target.ne(self.padding_idx)
309
+ n_correct = torch.sum(
310
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
311
+ )
312
+ total = torch.sum(mask)
313
+ return n_correct, total
314
+
315
+ def get_lprobs_and_target(self, model, net_output, sample):
316
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
317
+ target = sample["target"]
318
+
319
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
320
+
321
+ @staticmethod
322
+ def reduce_metrics(logging_outputs) -> None:
323
+ """Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
324
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
325
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
326
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
327
+
328
+ metrics.log_scalar(
329
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
330
+ )
331
+ if sample_size != ntokens:
332
+ metrics.log_scalar(
333
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
334
+ )
335
+ metrics.log_derived(
336
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
337
+ )
338
+ else:
339
+ metrics.log_derived(
340
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
341
+ )
342
+
343
+ counts = {}
344
+ for lk in logging_outputs[0].keys():
345
+ if lk.startswith("count_"):
346
+ val = sum(log.get(lk, 0) for log in logging_outputs)
347
+ metrics.log_scalar(lk, val)
348
+ counts[lk] = val
349
+
350
+ for lk in logging_outputs[0].keys():
351
+ if lk.startswith("loss_"):
352
+ val = sum(log.get(lk, 0) for log in logging_outputs)
353
+ metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
354
+ elif lk.startswith("correct_"):
355
+ val = sum(log.get(lk, 0) for log in logging_outputs)
356
+ metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
357
+
358
+ if "text_sample_size" in logging_outputs[0]:
359
+ text_sample_size = sum(log.get("text_sample_size", 0) for log in logging_outputs)
360
+ for lk in logging_outputs[0].keys():
361
+ if lk.startswith("text_") and lk.endswith("_loss"):
362
+ val = sum(log.get(lk, 0) for log in logging_outputs)
363
+ metrics.log_scalar(lk, val / text_sample_size / math.log(2), round=3)
364
+
365
+ if "unit_sample_size" in logging_outputs[0]:
366
+ unit_sample_size = sum(log.get("unit_sample_size", 0) for log in logging_outputs)
367
+ for lk in logging_outputs[0].keys():
368
+ if lk.startswith("mum_loss_"):
369
+ val = sum(log.get(lk, 0) for log in logging_outputs)
370
+ metrics.log_scalar(lk, val / unit_sample_size / math.log(2), round=3)
371
+
372
+ @staticmethod
373
+ def aggregate_logging_outputs(logging_outputs):
374
+ """Aggregate logging outputs from data parallel training."""
375
+ raise NotImplementedError()
376
+
377
+ @staticmethod
378
+ def logging_outputs_can_be_summed() -> bool:
379
+ """
380
+ Whether the logging outputs returned by `forward` can be summed
381
+ across workers prior to calling `reduce_metrics`. Setting this
382
+ to True will improves distributed training speed.
383
+ """
384
+ return False
SpeechT5/Speech2S/speech2s/data/concat_dataset.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Copyright (c) 2022 Microsoft
3
+ # Licensed under The MIT License [see LICENSE for details]
4
+ # Based on fairseq code bases
5
+ # https://github.com/facebookresearch/fairseq
6
+ # --------------------------------------------------------
7
+
8
+ import bisect
9
+
10
+ import numpy as np
11
+ from torch.utils.data.dataloader import default_collate
12
+
13
+ from fairseq.data import FairseqDataset
14
+
15
+
16
+ class ConcatDataset(FairseqDataset):
17
+ @staticmethod
18
+ def cumsum(sequence, sample_ratios):
19
+ r, s = [], 0
20
+ for e, ratio in zip(sequence, sample_ratios):
21
+ curr_len = int(ratio * len(e))
22
+ r.append(curr_len + s)
23
+ s += curr_len
24
+ return r
25
+
26
+ def __init__(self, datasets, sample_ratios=1):
27
+ super(ConcatDataset, self).__init__()
28
+ assert len(datasets) > 0, "datasets should not be an empty iterable"
29
+ self.datasets = list(datasets)
30
+ if isinstance(sample_ratios, int):
31
+ sample_ratios = [sample_ratios] * len(self.datasets)
32
+ self.sample_ratios = sample_ratios
33
+ self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
34
+ self.real_sizes = [len(d) for d in self.datasets]
35
+
36
+ def __len__(self):
37
+ return self.cumulative_sizes[-1]
38
+
39
+ def __getitem__(self, idx):
40
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
41
+ return self.datasets[dataset_idx][sample_idx]
42
+
43
+ def _get_dataset_and_sample_index(self, idx: int):
44
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
45
+ if dataset_idx == 0:
46
+ sample_idx = idx
47
+ else:
48
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
49
+ sample_idx = sample_idx % self.real_sizes[dataset_idx]
50
+ return dataset_idx, sample_idx
51
+
52
+ def collater(self, samples, **extra_args):
53
+ # For now only supports datasets with same underlying collater implementations
54
+ if hasattr(self.datasets[0], "collater"):
55
+ return self.datasets[0].collater(samples, **extra_args)
56
+ else:
57
+ return default_collate(samples, **extra_args)
58
+
59
+ def size(self, idx: int):
60
+ """
61
+ Return an example's size as a float or tuple.
62
+ """
63
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
64
+ return self.datasets[dataset_idx].size(sample_idx)
65
+
66
+ def num_tokens(self, index: int):
67
+ return np.max(self.size(index))
68
+
69
+ def attr(self, attr: str, index: int):
70
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
71
+ return getattr(self.datasets[dataset_idx], attr, None)
72
+
73
+ @property
74
+ def sizes(self):
75
+ _dataset_sizes = []
76
+ for ds, sr in zip(self.datasets, self.sample_ratios):
77
+ if isinstance(ds.sizes, np.ndarray):
78
+ _dataset_sizes.append(np.tile(ds.sizes, sr))
79
+ else:
80
+ # Only support underlying dataset with single size array.
81
+ assert isinstance(ds.sizes, list)
82
+ _dataset_sizes.append(np.tile(ds.sizes[0], sr))
83
+ return np.concatenate(_dataset_sizes)
84
+
85
+ @property
86
+ def supports_prefetch(self):
87
+ return all(d.supports_prefetch for d in self.datasets)
88
+
89
+ def ordered_indices(self):
90
+ """
91
+ Returns indices sorted by length. So less padding is needed.
92
+ """
93
+ if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
94
+ # special handling for concatenating lang_pair_datasets
95
+ if getattr(self.datasets[0], "shuffle", False):
96
+ indices = np.random.permutation(len(self)).astype(np.int64)
97
+ else:
98
+ indices = np.arange(len(self), dtype=np.int64)
99
+ sizes = self.sizes
100
+ tgt_sizes = (
101
+ sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
102
+ )
103
+ src_sizes = (
104
+ sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
105
+ )
106
+ # sort by target length, then source length
107
+ if tgt_sizes is not None:
108
+ indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
109
+ return indices[np.argsort(src_sizes[indices], kind="mergesort")]
110
+ else:
111
+ return np.argsort(self.sizes)
112
+
113
+ def prefetch(self, indices):
114
+ frm = 0
115
+ for to, ds in zip(self.cumulative_sizes, self.datasets):
116
+ real_size = len(ds)
117
+ if getattr(ds, "supports_prefetch", False):
118
+ ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
119
+ frm = to
120
+
121
+ @property
122
+ def can_reuse_epoch_itr_across_epochs(self):
123
+ return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
124
+
125
+ def set_epoch(self, epoch):
126
+ super().set_epoch(epoch)
127
+ for ds in self.datasets:
128
+ if hasattr(ds, "set_epoch"):
129
+ ds.set_epoch(epoch)
SpeechT5/Speech2S/speech2s/data/hubert_dataset.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Copyright (c) 2022 Microsoft
3
+ # Licensed under The MIT License [see LICENSE for details]
4
+ # Based on fairseq code bases
5
+ # https://github.com/facebookresearch/fairseq
6
+ # --------------------------------------------------------
7
+
8
+ import itertools
9
+ import logging
10
+ import io
11
+ import os
12
+ import sys
13
+ import time
14
+ from pathlib import Path
15
+ from typing import Any, List, Optional, Union, Tuple
16
+
17
+ import numpy as np
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from fairseq.data import data_utils, Dictionary
22
+ from fairseq.data.fairseq_dataset import FairseqDataset
23
+ from fairseq.data.audio.audio_utils import (
24
+ read_from_stored_zip,
25
+ is_sf_audio_data,
26
+ )
27
+
28
+ FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ def parse_path(path: str) -> Tuple[str, List[int]]:
33
+ """Parse data path which is either a path to
34
+ 1. a .npy/.wav/.flac/.ogg file
35
+ 2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
36
+
37
+ Args:
38
+ path (str): the data path to parse
39
+
40
+ Returns:
41
+ file_path (str): the file path
42
+ slice_ptr (list of int): empty in case 1;
43
+ byte offset and length for the slice in case 2
44
+ """
45
+
46
+ if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
47
+ _path, slice_ptr = path, []
48
+ else:
49
+ _path, *slice_ptr = path.split(":")
50
+ if not Path(_path).is_file():
51
+ raise FileNotFoundError(f"File not found: {_path}")
52
+ assert len(slice_ptr) in {0, 1, 2}, f"Invalid path: {path}"
53
+ slice_ptr = [int(i) for i in slice_ptr]
54
+ return _path, slice_ptr
55
+
56
+ def load_audio(manifest_path, max_keep, min_keep, retry_times=5):
57
+ n_long, n_short = 0, 0
58
+ names, inds, sizes, chunk_names, chunk_indices = [], [], [], [], []
59
+ for i in range(retry_times):
60
+ with open(manifest_path) as f:
61
+ root = f.readline().strip()
62
+ for ind, line in enumerate(f):
63
+ items = line.strip().split("\t")
64
+ assert len(items) == 2, line
65
+ sz = int(items[1])
66
+ if min_keep is not None and sz < min_keep:
67
+ n_short += 1
68
+ elif max_keep is not None and sz > max_keep:
69
+ n_long += 1
70
+ else:
71
+ fname = items[0].split(":")
72
+ if len(fname) > 2:
73
+ if len(chunk_names) == 0 or fname[0] != chunk_names[-1]:
74
+ chunk_names.append(fname[0])
75
+ chunk_indices.append(len(names))
76
+ names.append(items[0])
77
+ inds.append(ind)
78
+ sizes.append(sz)
79
+ if len(names) == 0:
80
+ logger.warn(f"Fail to load manifest for the {i} time")
81
+ time.sleep(1)
82
+ continue
83
+ else:
84
+ break
85
+ tot = ind + 1
86
+ logger.info(
87
+ (
88
+ f"max_keep={max_keep}, min_keep={min_keep}, "
89
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
90
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
91
+ )
92
+ )
93
+ return root, names, inds, tot, sizes, chunk_names, chunk_indices
94
+
95
+
96
+ def load_label(label_path, inds, tot, retry_times=5):
97
+ for i in range(retry_times):
98
+ with open(label_path) as f:
99
+ labels = [line.rstrip() for line in f]
100
+ if len(labels) == 0:
101
+ logger.warn(f"Fail to load label for the {i} time")
102
+ time.sleep(1)
103
+ continue
104
+ else:
105
+ break
106
+ assert (
107
+ len(labels) == tot
108
+ ), f"number of labels does not match ({len(labels)} != {tot})"
109
+ labels = [labels[i] for i in inds]
110
+ return labels
111
+
112
+
113
+ def load_label_offset(label_path, inds, tot, retry_times=5):
114
+ for i in range(retry_times):
115
+ with open(label_path) as f:
116
+ code_lengths = [len(line.encode("utf-8")) for line in f]
117
+ if len(code_lengths) == 0:
118
+ logger.warn(f"Fail to load label for the {i} time")
119
+ time.sleep(1)
120
+ continue
121
+ else:
122
+ break
123
+ assert (
124
+ len(code_lengths) == tot
125
+ ), f"number of labels does not match ({len(code_lengths)} != {tot})"
126
+ offsets = list(itertools.accumulate([0] + code_lengths))
127
+ offsets = [(offsets[i], offsets[i + 1]) for i in inds]
128
+ return offsets
129
+
130
+
131
+ def verify_label_lengths(
132
+ audio_sizes,
133
+ audio_rate,
134
+ label_path,
135
+ label_rate,
136
+ inds,
137
+ tot,
138
+ tol=0.1, # tolerance in seconds
139
+ ):
140
+ if label_rate < 0:
141
+ logger.info(f"{label_path} is sequence label. skipped")
142
+ return
143
+
144
+ with open(label_path) as f:
145
+ lengths = [len(line.rstrip().split()) for line in f]
146
+ assert len(lengths) == tot
147
+ lengths = [lengths[i] for i in inds]
148
+ num_invalid = 0
149
+ for i, ind in enumerate(inds):
150
+ dur_from_audio = audio_sizes[i] / audio_rate
151
+ dur_from_label = lengths[i] / label_rate
152
+ if abs(dur_from_audio - dur_from_label) > tol:
153
+ logger.warning(
154
+ (
155
+ f"audio and label duration differ too much "
156
+ f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
157
+ f"in line {ind+1} of {label_path}. Check if `label_rate` "
158
+ f"is correctly set (currently {label_rate}). "
159
+ f"num. of samples = {audio_sizes[i]}; "
160
+ f"label length = {lengths[i]}"
161
+ )
162
+ )
163
+ num_invalid += 1
164
+ if num_invalid > 0:
165
+ logger.warning(
166
+ f"total {num_invalid} (audio, label) pairs with mismatched lengths"
167
+ )
168
+
169
+
170
+ class HubertDataset(FairseqDataset):
171
+ def __init__(
172
+ self,
173
+ manifest_path: str,
174
+ sample_rate: float,
175
+ label_paths: List[str],
176
+ label_rates: Union[List[float], float], # -1 for sequence labels
177
+ pad_list: List[str],
178
+ eos_list: List[str],
179
+ label_processors: Optional[List[Any]] = None,
180
+ max_keep_sample_size: Optional[int] = None,
181
+ min_keep_sample_size: Optional[int] = None,
182
+ max_sample_size: Optional[int] = None,
183
+ shuffle: bool = True,
184
+ pad_audio: bool = False,
185
+ normalize: bool = False,
186
+ store_labels: bool = True,
187
+ random_crop: bool = False,
188
+ single_target: bool = False,
189
+ tgt_dict: Optional[Dictionary] = None,
190
+ add_decoder_target: bool = False,
191
+ fine_tuning: bool = False,
192
+ tgt_lang_idx: int = None,
193
+ tokenizer = None,
194
+ mbart_style_lang_id: bool = False,
195
+ retry_times: int = 5,
196
+ reduce_label_for_dec: bool = True,
197
+ ):
198
+ self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.chunk_names, self.chunk_indices = load_audio(
199
+ manifest_path, max_keep_sample_size, min_keep_sample_size, retry_times
200
+ )
201
+ self.sample_rate = sample_rate
202
+ self.shuffle = shuffle
203
+ self.random_crop = random_crop
204
+ self.tgt_dict = tgt_dict
205
+ self.add_decoder_target = add_decoder_target
206
+ self.fine_tuning = fine_tuning
207
+
208
+ self.num_labels = len(label_paths)
209
+ self.pad_list = pad_list
210
+ self.eos_list = eos_list
211
+ self.label_processors = label_processors
212
+ self.single_target = single_target
213
+ self.epoch = 0
214
+
215
+ self.label_rates = (
216
+ [label_rates for _ in range(len(label_paths))]
217
+ if isinstance(label_rates, int)
218
+ else label_rates
219
+ )
220
+ self.store_labels = store_labels
221
+ if store_labels:
222
+ self.label_list = [load_label(p, inds, tot, retry_times) for p in label_paths]
223
+ else:
224
+ self.label_paths = label_paths
225
+ self.label_offsets_list = [
226
+ load_label_offset(p, inds, tot, retry_times) for p in label_paths
227
+ ]
228
+ assert label_processors is None or len(label_processors) == self.num_labels
229
+ for label_path, label_rate in zip(label_paths, self.label_rates):
230
+ verify_label_lengths(
231
+ self.wav_sizes, sample_rate, label_path, label_rate, inds, tot
232
+ )
233
+
234
+ self.max_sample_size = (
235
+ max_sample_size if max_sample_size is not None else sys.maxsize
236
+ )
237
+ self.pad_audio = pad_audio
238
+ self.normalize = normalize
239
+ self.tgt_lang_idx = tgt_lang_idx
240
+ self.tokenizer = tokenizer
241
+ self.mbart_style_lang_id = mbart_style_lang_id
242
+ self.retry_times = retry_times
243
+ self.reduce_label_for_dec = reduce_label_for_dec
244
+ logger.info(
245
+ f"pad_audio={pad_audio}, random_crop={random_crop}, tgt_lang_idx={self.tgt_lang_idx}, reduce_label_for_dec={reduce_label_for_dec}, "
246
+ f"mbart_style_lang_id={mbart_style_lang_id}, normalize={normalize}, max_sample_size={self.max_sample_size}"
247
+ )
248
+
249
+ def set_epoch(self, epoch):
250
+ self.epoch = epoch
251
+
252
+ def batch_by_size(self, indices, max_tokens=None, max_sentences=None, required_batch_size_multiple=1):
253
+ self.max_tokens = max_tokens
254
+ self.max_sentences = max_sentences
255
+ self.required_batch_size_multiple = required_batch_size_multiple
256
+ if isinstance(indices[0], np.ndarray):
257
+ batch_list = []
258
+ for indice in indices:
259
+ batch = super(HubertDataset, self).batch_by_size(indice, max_tokens, max_sentences, required_batch_size_multiple)
260
+ batch_list.append(batch)
261
+ return batch_list
262
+ else:
263
+ return super(HubertDataset, self).batch_by_size(indices, max_tokens, max_sentences, required_batch_size_multiple)
264
+ def shuffle_batches(self, batches, seed):
265
+ if isinstance(batches[0], list):
266
+ new_batches = []
267
+ with data_utils.numpy_seed(seed):
268
+ np.random.shuffle(batches)
269
+ for batch in batches:
270
+ np.random.shuffle(batch)
271
+ new_batches.extend(batch)
272
+ return new_batches
273
+ else:
274
+ with data_utils.numpy_seed(seed):
275
+ np.random.shuffle(batches)
276
+ return batches
277
+
278
+ def get_audio(self, index):
279
+ import soundfile as sf
280
+
281
+ wav_path = os.path.join(self.audio_root, self.audio_names[index])
282
+ _path, slice_ptr = parse_path(wav_path)
283
+ if len(slice_ptr) == 1:
284
+ import kaldiio
285
+ feat = kaldiio.load_mat(wav_path)
286
+ feat = torch.from_numpy(feat).float()
287
+ if self.normalize:
288
+ with torch.no_grad():
289
+ feat = F.layer_norm(feat, feat.shape[-1])
290
+ return feat
291
+ else:
292
+ if len(slice_ptr) == 2:
293
+ byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
294
+ assert is_sf_audio_data(byte_data)
295
+ wav_path = io.BytesIO(byte_data)
296
+ for i in range(self.retry_times):
297
+ if i < self.retry_times - 1:
298
+ try:
299
+ wav, cur_sample_rate = sf.read(wav_path)
300
+ break
301
+ except Exception as e:
302
+ logger.warn(f"Fail to load wav for the {i} time")
303
+ logger.warn(e)
304
+ time.sleep(1)
305
+ continue
306
+ else:
307
+ wav, cur_sample_rate = sf.read(wav_path)
308
+
309
+ wav = torch.from_numpy(wav).float()
310
+ wav = self.postprocess(wav, cur_sample_rate)
311
+ return wav
312
+
313
+ def get_label(self, index, label_idx):
314
+ if self.store_labels:
315
+ label = self.label_list[label_idx][index]
316
+ else:
317
+ with open(self.label_paths[label_idx]) as f:
318
+ offset_s, offset_e = self.label_offsets_list[label_idx][index]
319
+ f.seek(offset_s)
320
+ label = f.read(offset_e - offset_s)
321
+
322
+ if self.tokenizer is not None and self.fine_tuning:
323
+ label = self.tokenizer.encode(label)
324
+
325
+ if self.label_processors is not None:
326
+ label = self.label_processors[label_idx](label)
327
+ return label
328
+
329
+ def get_labels(self, index):
330
+ return [self.get_label(index, i) for i in range(self.num_labels)]
331
+
332
+ def __getitem__(self, index):
333
+ wav = self.get_audio(index)
334
+ labels = self.get_labels(index)
335
+ return {"id": index, "source": wav, "label_list": labels}
336
+
337
+ def __len__(self):
338
+ return len(self.wav_sizes)
339
+
340
+ def crop_to_max_size(self, wav, target_size):
341
+ size = len(wav)
342
+ diff = size - target_size
343
+ if diff <= 0:
344
+ return wav, 0
345
+
346
+ start, end = 0, target_size
347
+ if self.random_crop:
348
+ start = np.random.randint(0, diff + 1)
349
+ end = size - diff + start
350
+ return wav[start:end], start
351
+
352
+ def collater(self, samples):
353
+ # target = max(sizes) -> random_crop not used
354
+ # target = max_sample_size -> random_crop used for long
355
+ samples = [s for s in samples if s["source"] is not None]
356
+ if len(samples) == 0:
357
+ return {}
358
+
359
+ audios = [s["source"] for s in samples]
360
+ audio_sizes = [len(s) for s in audios]
361
+ if self.pad_audio:
362
+ audio_size = min(max(audio_sizes), self.max_sample_size)
363
+ else:
364
+ audio_size = min(min(audio_sizes), self.max_sample_size)
365
+ feat_dim = audios[0].size(-1) if audios[0].dim() > 1 else 1
366
+ collated_audios, padding_mask, audio_starts = self.collater_audio(
367
+ audios, audio_size, feat_dim,
368
+ )
369
+
370
+ targets_by_label = [
371
+ [s["label_list"][i] for s in samples] for i in range(self.num_labels)
372
+ ]
373
+ targets_list, lengths_list, ntokens_list = self.collater_label(
374
+ targets_by_label, audio_size, audio_starts
375
+ )
376
+
377
+ if self.add_decoder_target:
378
+ if self.fine_tuning:
379
+ decoder_label = [
380
+ torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
381
+ for i in range(targets_list[0].size(0))
382
+ ]
383
+ else:
384
+ if self.tokenizer is not None:
385
+ decoder_label = [
386
+ # Set 48 for translate int to char and avoid \n
387
+ torch.cat(
388
+ (
389
+ torch.tensor(
390
+ self.tokenizer.sp.Encode(
391
+ "".join(
392
+ [chr(j + 48) for j in (
393
+ targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]]
394
+ ).tolist()]
395
+ ), out_type=int
396
+ )
397
+ ),
398
+ torch.tensor([self.tgt_dict.eos()])
399
+ ), dim=0
400
+ ).long()
401
+ for i in range(targets_list[0].size(0))
402
+ ]
403
+ else:
404
+ decoder_label = [
405
+ torch.cat((targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
406
+ for i in range(targets_list[0].size(0))
407
+ ]
408
+
409
+ if self.mbart_style_lang_id:
410
+ decoder_label = [
411
+ torch.cat((decoder_label[i], torch.tensor([self.tgt_lang_idx])), 0).long()
412
+ for i in range(targets_list[0].size(0))
413
+ ]
414
+
415
+ dec_ntokens = sum(x.size(0) for x in decoder_label)
416
+ decoder_target = data_utils.collate_tokens(
417
+ decoder_label,
418
+ self.tgt_dict.pad(),
419
+ self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx,
420
+ left_pad=False,
421
+ move_eos_to_beginning=False,
422
+ )
423
+ decoder_target_lengths = torch.tensor(
424
+ [x.size(0) for x in decoder_label], dtype=torch.long
425
+ )
426
+ prev_output_tokens = data_utils.collate_tokens(
427
+ decoder_label,
428
+ self.tgt_dict.pad(),
429
+ self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx,
430
+ left_pad=False,
431
+ move_eos_to_beginning=True,
432
+ )
433
+
434
+ if self.tgt_lang_idx is not None and not self.mbart_style_lang_id:
435
+ assert (prev_output_tokens[:, 0] != self.tgt_dict.eos()).sum() == 0
436
+ prev_output_tokens[:, 0] = self.tgt_lang_idx
437
+
438
+ net_input = {
439
+ "source": collated_audios,
440
+ "padding_mask": padding_mask,
441
+ "prev_output_tokens": prev_output_tokens,
442
+ }
443
+ batch = {
444
+ "id": torch.LongTensor([s["id"] for s in samples]),
445
+ "net_input": net_input,
446
+ "decoder_target": decoder_target,
447
+ "decoder_target_lengths": decoder_target_lengths,
448
+ "dec_ntokens": dec_ntokens,
449
+ "lang_idx": self.tgt_lang_idx,
450
+ }
451
+ else:
452
+ net_input = {"source": collated_audios, "padding_mask": padding_mask}
453
+ batch = {
454
+ "id": torch.LongTensor([s["id"] for s in samples]),
455
+ "net_input": net_input,
456
+ }
457
+
458
+ if self.single_target:
459
+ batch["target_lengths"] = lengths_list[0]
460
+ batch["ntokens"] = ntokens_list[0]
461
+ batch["target"] = targets_list[0]
462
+ else:
463
+ batch["target_lengths_list"] = lengths_list
464
+ batch["ntokens_list"] = ntokens_list
465
+ batch["target_list"] = targets_list
466
+ return batch
467
+
468
+ def collater_audio(self, audios, audio_size, feat_dim=1):
469
+ collated_audios = audios[0].new_zeros(len(audios), audio_size, feat_dim)
470
+ padding_mask = (
471
+ torch.BoolTensor(collated_audios.shape[0:2]).fill_(False)
472
+ # if self.pad_audio else None
473
+ )
474
+ audio_starts = [0 for _ in audios]
475
+ for i, audio in enumerate(audios):
476
+ audio = audio.view(-1, feat_dim)
477
+ diff = len(audio) - audio_size
478
+ if diff == 0:
479
+ collated_audios[i] = audio
480
+ elif diff < 0:
481
+ assert self.pad_audio
482
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff, feat_dim), 0.0)])
483
+ padding_mask[i, diff:] = True
484
+ else:
485
+ collated_audios[i], audio_starts[i] = self.crop_to_max_size(
486
+ audio, audio_size
487
+ )
488
+ return collated_audios.squeeze(-1), padding_mask, audio_starts
489
+
490
+ def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
491
+ assert label_rate > 0
492
+ s2f = label_rate / self.sample_rate
493
+ frm_starts = [int(round(s * s2f)) for s in audio_starts]
494
+ frm_size = int(round(audio_size * s2f))
495
+ if not self.pad_audio:
496
+ rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
497
+ frm_size = min(frm_size, *rem_size)
498
+ targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
499
+ logger.debug(f"audio_starts={audio_starts}")
500
+ logger.debug(f"frame_starts={frm_starts}")
501
+ logger.debug(f"frame_size={frm_size}")
502
+
503
+ lengths = torch.LongTensor([len(t) for t in targets])
504
+ ntokens = lengths.sum().item()
505
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
506
+ return targets, lengths, ntokens
507
+
508
+ def collater_seq_label(self, targets, pad):
509
+ lengths = torch.LongTensor([len(t) for t in targets])
510
+ ntokens = lengths.sum().item()
511
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
512
+ return targets, lengths, ntokens
513
+
514
+ def collater_label(self, targets_by_label, audio_size, audio_starts):
515
+ targets_list, lengths_list, ntokens_list = [], [], []
516
+ itr = zip(targets_by_label, self.label_rates, self.pad_list)
517
+ for targets, label_rate, pad in itr:
518
+ if label_rate == -1:
519
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
520
+ else:
521
+ targets, lengths, ntokens = self.collater_frm_label(
522
+ targets, audio_size, audio_starts, label_rate, pad
523
+ )
524
+ targets_list.append(targets)
525
+ lengths_list.append(lengths)
526
+ ntokens_list.append(ntokens)
527
+ return targets_list, lengths_list, ntokens_list
528
+
529
+ def num_tokens(self, index):
530
+ return self.size(index)
531
+
532
+ def size(self, index):
533
+ if self.pad_audio:
534
+ return self.wav_sizes[index]
535
+ return min(self.wav_sizes[index], self.max_sample_size)
536
+
537
+ @property
538
+ def sizes(self):
539
+ return np.array(self.wav_sizes)
540
+
541
+ def ordered_indices(self):
542
+ """Return an ordered list of indices. Batches will be constructed based
543
+ on this order."""
544
+
545
+ if self.shuffle:
546
+ if len(self.chunk_names) > 0:
547
+ logger.info(f"ordered indices for epoch {self.epoch}")
548
+ with data_utils.numpy_seed(self.epoch):
549
+ self.chunk_order = np.random.permutation(len(self.chunk_names))
550
+ chunk_count = 0
551
+ tmp_sizes = []
552
+ tmp_indices = []
553
+ indice = []
554
+ for i in self.chunk_order:
555
+ chunk_count += 1
556
+ start = self.chunk_indices[i]
557
+ end = self.chunk_indices[i+1] if i < len(self.chunk_names) - 1 else len(self)
558
+ size = list(self.sizes[start:end])
559
+ tmp_indices.extend(list(np.arange(start, end)))
560
+ tmp_sizes.extend(size)
561
+ if chunk_count % 10 == 0 or i == self.chunk_order[0]:
562
+ order = [np.random.permutation(len(tmp_indices))]
563
+ order.append(
564
+ np.minimum(
565
+ np.array(tmp_sizes),
566
+ self.max_sample_size,
567
+ )
568
+ )
569
+ sort_idx = np.lexsort(order)[::-1]
570
+ indice.append(np.array([tmp_indices[k] for k in sort_idx]))
571
+ tmp_indices = []
572
+ tmp_sizes =[]
573
+ return indice
574
+ else:
575
+ order = [np.random.permutation(len(self))]
576
+ order.append(
577
+ np.minimum(
578
+ np.array(self.sizes),
579
+ self.max_sample_size,
580
+ )
581
+ )
582
+ return np.lexsort(order)[::-1]
583
+ else:
584
+ return np.arange(len(self))
585
+
586
+ def postprocess(self, wav, cur_sample_rate):
587
+ if wav.dim() == 2:
588
+ wav = wav.mean(-1)
589
+ assert wav.dim() == 1, wav.dim()
590
+
591
+ if cur_sample_rate != self.sample_rate:
592
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
593
+
594
+ if self.normalize:
595
+ with torch.no_grad():
596
+ wav = F.layer_norm(wav, wav.shape)
597
+ return wav
SpeechT5/Speech2S/speech2s/data/language_trible_dataset.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Copyright (c) 2022 Microsoft
3
+ # Licensed under The MIT License [see LICENSE for details]
4
+ # Based on fairseq code bases
5
+ # https://github.com/facebookresearch/fairseq
6
+ # --------------------------------------------------------
7
+
8
+ import logging
9
+ import numpy as np
10
+ import torch
11
+ import os
12
+ import itertools
13
+
14
+ from fairseq.data import FairseqDataset, data_utils
15
+ from fairseq.data import (
16
+ AppendTokenDataset,
17
+ ConcatDataset,
18
+ PrependTokenDataset,
19
+ data_utils,
20
+ indexed_dataset,
21
+ )
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ def load_langtriple_dataset(
26
+ data_path,
27
+ split,
28
+ src,
29
+ src_dict,
30
+ ref,
31
+ ref_dict,
32
+ tgt,
33
+ tgt_dict,
34
+ combine,
35
+ dataset_impl,
36
+ upsample_primary,
37
+ left_pad_source,
38
+ left_pad_target,
39
+ max_source_positions,
40
+ max_target_positions,
41
+ prepend_bos=False,
42
+ load_alignments=False,
43
+ truncate_source=False,
44
+ append_source_id=False,
45
+ num_buckets=0,
46
+ shuffle=True,
47
+ pad_to_multiple=1,
48
+ prepend_bos_src=None,
49
+ lang_format="[{}]",
50
+ ):
51
+ assert not truncate_source
52
+ def split_exists(split, src, ref, tgt, lang, data_path):
53
+ filename = os.path.join(data_path, "{}.{}-{}-{}.{}".format(split, src, ref, tgt, lang))
54
+ return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
55
+
56
+ src_datasets = []
57
+ ref_datasets = []
58
+ tgt_datasets = []
59
+
60
+ for k in itertools.count():
61
+ split_k = split + (str(k) if k > 0 else "")
62
+
63
+ # infer langcode
64
+ if split_exists(split_k, src, ref, tgt, src, data_path):
65
+ prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, src, ref, tgt))
66
+ elif split_exists(split_k, tgt, ref, src, src, data_path):
67
+ prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, tgt, ref, src))
68
+ else:
69
+ if k > 0:
70
+ break
71
+ else:
72
+ raise FileNotFoundError(
73
+ "Dataset not found: {} ({})".format(split, data_path)
74
+ )
75
+
76
+ src_dataset = data_utils.load_indexed_dataset(
77
+ prefix + src, src_dict, dataset_impl
78
+ )
79
+ src_datasets.append(src_dataset)
80
+
81
+ ref_dataset = data_utils.load_indexed_dataset(
82
+ prefix + ref, ref_dict, dataset_impl
83
+ )
84
+ ref_datasets.append(ref_dataset)
85
+
86
+ tgt_dataset = data_utils.load_indexed_dataset(
87
+ prefix + tgt, tgt_dict, dataset_impl
88
+ )
89
+ if tgt_dataset is not None:
90
+ tgt_datasets.append(tgt_dataset)
91
+
92
+ logger.info(
93
+ "{} {} {}-{}-{} {} examples".format(
94
+ data_path, split_k, src, ref, tgt, len(src_datasets[-1])
95
+ )
96
+ )
97
+
98
+ if not combine:
99
+ break
100
+
101
+ assert len(src_datasets) == len(ref_datasets)
102
+ assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
103
+
104
+ if len(src_datasets) == 1:
105
+ src_dataset = src_datasets[0]
106
+ ref_dataset = ref_datasets[0]
107
+ tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
108
+ else:
109
+ sample_ratios = [1] * len(src_datasets)
110
+ sample_ratios[0] = upsample_primary
111
+ src_dataset = ConcatDataset(src_datasets, sample_ratios)
112
+ ref_dataset = ConcatDataset(ref_datasets, sample_ratios)
113
+ if len(tgt_datasets) > 0:
114
+ tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
115
+ else:
116
+ tgt_dataset = None
117
+
118
+ if prepend_bos:
119
+ assert hasattr(src_dict, "bos_index") and hasattr(ref_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
120
+ src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
121
+ ref_dataset = PrependTokenDataset(ref_dataset, ref_dict.bos())
122
+ if tgt_dataset is not None:
123
+ tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
124
+ elif prepend_bos_src is not None:
125
+ logger.info(f"prepending src bos: {prepend_bos_src}")
126
+ src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
127
+ ref_dataset = PrependTokenDataset(ref_dataset, prepend_bos_src)
128
+
129
+ eos = None
130
+ if append_source_id:
131
+ src_dataset = AppendTokenDataset(
132
+ src_dataset, src_dict.index(lang_format.format(src))
133
+ )
134
+ ref_dataset = AppendTokenDataset(
135
+ ref_dataset, ref_dict.index(lang_format.format(ref))
136
+ )
137
+ if tgt_dataset is not None:
138
+ tgt_dataset = AppendTokenDataset(
139
+ tgt_dataset, tgt_dict.index(lang_format.format(tgt))
140
+ )
141
+ eos = tgt_dict.index(lang_format.format(tgt))
142
+
143
+ align_dataset = None
144
+ if load_alignments:
145
+ align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
146
+ if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
147
+ align_dataset = data_utils.load_indexed_dataset(
148
+ align_path, None, dataset_impl
149
+ )
150
+
151
+ tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
152
+ return LanguageTripleDataset(
153
+ src_dataset,
154
+ src_dataset.sizes,
155
+ src_dict,
156
+ ref_dataset,
157
+ ref_dataset.sizes,
158
+ ref_dict,
159
+ tgt_dataset,
160
+ tgt_dataset_sizes,
161
+ tgt_dict,
162
+ left_pad_source=left_pad_source,
163
+ left_pad_target=left_pad_target,
164
+ align_dataset=align_dataset,
165
+ eos=eos,
166
+ num_buckets=num_buckets,
167
+ shuffle=shuffle,
168
+ pad_to_multiple=pad_to_multiple,
169
+ )
170
+
171
+
172
+ def collate(
173
+ samples,
174
+ pad_idx,
175
+ eos_idx,
176
+ left_pad_source=True,
177
+ left_pad_target=False,
178
+ input_feeding=True,
179
+ pad_to_length=None,
180
+ pad_to_multiple=1,
181
+ ):
182
+ if len(samples) == 0:
183
+ return {}
184
+
185
+ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
186
+ return data_utils.collate_tokens(
187
+ [s[key] for s in samples],
188
+ pad_idx,
189
+ None,
190
+ left_pad,
191
+ move_eos_to_beginning,
192
+ pad_to_length=pad_to_length,
193
+ pad_to_multiple=pad_to_multiple,
194
+ )
195
+
196
+ def check_alignment(alignment, src_len, tgt_len):
197
+ if alignment is None or len(alignment) == 0:
198
+ return False
199
+ if (
200
+ alignment[:, 0].max().item() >= src_len - 1
201
+ or alignment[:, 1].max().item() >= tgt_len - 1
202
+ ):
203
+ logger.warning("alignment size mismatch found, skipping alignment!")
204
+ return False
205
+ return True
206
+
207
+ def compute_alignment_weights(alignments):
208
+ """
209
+ Given a tensor of shape [:, 2] containing the source-target indices
210
+ corresponding to the alignments, a weight vector containing the
211
+ inverse frequency of each target index is computed.
212
+ For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
213
+ a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
214
+ index 3 is repeated twice)
215
+ """
216
+ align_tgt = alignments[:, 1]
217
+ _, align_tgt_i, align_tgt_c = torch.unique(
218
+ align_tgt, return_inverse=True, return_counts=True
219
+ )
220
+ align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
221
+ return 1.0 / align_weights.float()
222
+
223
+ id = torch.LongTensor([s["id"] for s in samples])
224
+ src_tokens = merge(
225
+ "source",
226
+ left_pad=left_pad_source,
227
+ pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
228
+ )
229
+ ref_tokens = merge(
230
+ "reference",
231
+ left_pad=left_pad_source,
232
+ pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
233
+ )
234
+ # sort by descending source length
235
+ src_lengths = torch.LongTensor(
236
+ [s["source"].ne(pad_idx).long().sum() for s in samples]
237
+ )
238
+ ref_lengths = torch.LongTensor(
239
+ [s["reference"].ne(pad_idx).long().sum() for s in samples]
240
+ )
241
+ src_lengths, sort_order = src_lengths.sort(descending=True)
242
+ id = id.index_select(0, sort_order)
243
+ src_tokens = src_tokens.index_select(0, sort_order)
244
+ ref_lengths = ref_lengths.index_select(0, sort_order)
245
+ ref_tokens = ref_tokens.index_select(0, sort_order)
246
+
247
+ prev_output_tokens = None
248
+ target = None
249
+ if samples[0].get("target", None) is not None:
250
+ target = merge(
251
+ "target",
252
+ left_pad=left_pad_target,
253
+ pad_to_length=pad_to_length["target"]
254
+ if pad_to_length is not None
255
+ else None,
256
+ )
257
+ target = target.index_select(0, sort_order)
258
+ tgt_lengths = torch.LongTensor(
259
+ [s["target"].ne(pad_idx).long().sum() for s in samples]
260
+ ).index_select(0, sort_order)
261
+ ntokens = tgt_lengths.sum().item()
262
+
263
+ if samples[0].get("prev_output_tokens", None) is not None:
264
+ prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
265
+ elif input_feeding:
266
+ # we create a shifted version of targets for feeding the
267
+ # previous output token(s) into the next decoder step
268
+ prev_output_tokens = merge(
269
+ "target",
270
+ left_pad=left_pad_target,
271
+ move_eos_to_beginning=True,
272
+ pad_to_length=pad_to_length["target"]
273
+ if pad_to_length is not None
274
+ else None,
275
+ )
276
+ else:
277
+ ntokens = src_lengths.sum().item()
278
+
279
+ batch = {
280
+ "id": id,
281
+ "nsentences": len(samples),
282
+ "ntokens": ntokens,
283
+ "net_input": {
284
+ "src_tokens": src_tokens,
285
+ "src_lengths": src_lengths,
286
+ },
287
+ "target": target,
288
+ "ref_tokens": ref_tokens,
289
+ "ref_lengths": ref_lengths,
290
+ }
291
+ if prev_output_tokens is not None:
292
+ batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
293
+ 0, sort_order
294
+ )
295
+
296
+ if samples[0].get("alignment", None) is not None:
297
+ bsz, tgt_sz = batch["target"].shape
298
+ src_sz = batch["net_input"]["src_tokens"].shape[1]
299
+
300
+ offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
301
+ offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz
302
+ if left_pad_source:
303
+ offsets[:, 0] += src_sz - src_lengths
304
+ if left_pad_target:
305
+ offsets[:, 1] += tgt_sz - tgt_lengths
306
+
307
+ alignments = [
308
+ alignment + offset
309
+ for align_idx, offset, src_len, tgt_len in zip(
310
+ sort_order, offsets, src_lengths, tgt_lengths
311
+ )
312
+ for alignment in [samples[align_idx]["alignment"].view(-1, 2)]
313
+ if check_alignment(alignment, src_len, tgt_len)
314
+ ]
315
+
316
+ if len(alignments) > 0:
317
+ alignments = torch.cat(alignments, dim=0)
318
+ align_weights = compute_alignment_weights(alignments)
319
+
320
+ batch["alignments"] = alignments
321
+ batch["align_weights"] = align_weights
322
+
323
+ if samples[0].get("constraints", None) is not None:
324
+ # Collate the packed constraints across the samples, padding to
325
+ # the length of the longest sample.
326
+ lens = [sample.get("constraints").size(0) for sample in samples]
327
+ max_len = max(lens)
328
+ constraints = torch.zeros((len(samples), max(lens))).long()
329
+ for i, sample in enumerate(samples):
330
+ constraints[i, 0 : lens[i]] = samples[i].get("constraints")
331
+ batch["constraints"] = constraints.index_select(0, sort_order)
332
+
333
+ return batch
334
+
335
+
336
+ class LanguageTripleDataset(FairseqDataset):
337
+ """
338
+ A pair of torch.utils.data.Datasets.
339
+
340
+ Args:
341
+ src (torch.utils.data.Dataset): source dataset to wrap
342
+ src_sizes (List[int]): source sentence lengths
343
+ src_dict (~fairseq.data.Dictionary): source vocabulary
344
+ tgt (torch.utils.data.Dataset, optional): target dataset to wrap
345
+ tgt_sizes (List[int], optional): target sentence lengths
346
+ tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
347
+ left_pad_source (bool, optional): pad source tensors on the left side
348
+ (default: True).
349
+ left_pad_target (bool, optional): pad target tensors on the left side
350
+ (default: False).
351
+ shuffle (bool, optional): shuffle dataset elements before batching
352
+ (default: True).
353
+ input_feeding (bool, optional): create a shifted version of the targets
354
+ to be passed into the model for teacher forcing (default: True).
355
+ remove_eos_from_source (bool, optional): if set, removes eos from end
356
+ of source if it's present (default: False).
357
+ append_eos_to_target (bool, optional): if set, appends eos to end of
358
+ target if it's absent (default: False).
359
+ align_dataset (torch.utils.data.Dataset, optional): dataset
360
+ containing alignments.
361
+ constraints (Tensor, optional): 2d tensor with a concatenated, zero-
362
+ delimited list of constraints for each sentence.
363
+ append_bos (bool, optional): if set, appends bos to the beginning of
364
+ source/target sentence.
365
+ num_buckets (int, optional): if set to a value greater than 0, then
366
+ batches will be bucketed into the given number of batch shapes.
367
+ src_lang_id (int, optional): source language ID, if set, the collated batch
368
+ will contain a field 'src_lang_id' in 'net_input' which indicates the
369
+ source language of the samples.
370
+ tgt_lang_id (int, optional): target language ID, if set, the collated batch
371
+ will contain a field 'tgt_lang_id' which indicates the target language
372
+ of the samples.
373
+ """
374
+
375
+ def __init__(
376
+ self,
377
+ src,
378
+ src_sizes,
379
+ src_dict,
380
+ ref,
381
+ ref_sizes,
382
+ ref_dict,
383
+ tgt=None,
384
+ tgt_sizes=None,
385
+ tgt_dict=None,
386
+ left_pad_source=True,
387
+ left_pad_target=False,
388
+ shuffle=True,
389
+ input_feeding=True,
390
+ remove_eos_from_source=False,
391
+ append_eos_to_target=False,
392
+ align_dataset=None,
393
+ constraints=None,
394
+ append_bos=False,
395
+ eos=None,
396
+ num_buckets=0,
397
+ src_lang_id=None,
398
+ tgt_lang_id=None,
399
+ pad_to_multiple=1,
400
+ ):
401
+ if tgt_dict is not None:
402
+ assert src_dict.pad() == tgt_dict.pad()
403
+ assert src_dict.eos() == tgt_dict.eos()
404
+ assert src_dict.unk() == tgt_dict.unk()
405
+ if tgt is not None:
406
+ assert len(src) == len(
407
+ tgt
408
+ ), "Source and target must contain the same number of examples"
409
+ assert len(src) == len(
410
+ ref
411
+ ), "Source and reference must contain the same number of examples"
412
+ self.src = src
413
+ self.ref = ref
414
+ self.tgt = tgt
415
+ self.src_sizes = np.array(src_sizes)
416
+ self.ref_sizes = np.array(ref_sizes)
417
+ self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
418
+ self.sizes = (
419
+ np.vstack((self.src_sizes, self.tgt_sizes)).T
420
+ if self.tgt_sizes is not None
421
+ else self.src_sizes
422
+ )
423
+ self.src_dict = src_dict
424
+ self.ref_dict = ref_dict
425
+ self.tgt_dict = tgt_dict
426
+ self.left_pad_source = left_pad_source
427
+ self.left_pad_target = left_pad_target
428
+ self.shuffle = shuffle
429
+ self.input_feeding = input_feeding
430
+ self.remove_eos_from_source = remove_eos_from_source
431
+ self.append_eos_to_target = append_eos_to_target
432
+ self.align_dataset = align_dataset
433
+ if self.align_dataset is not None:
434
+ assert (
435
+ self.tgt_sizes is not None
436
+ ), "Both source and target needed when alignments are provided"
437
+ self.constraints = constraints
438
+ self.append_bos = append_bos
439
+ self.eos = eos if eos is not None else src_dict.eos()
440
+ self.src_lang_id = src_lang_id
441
+ self.tgt_lang_id = tgt_lang_id
442
+ if num_buckets > 0:
443
+ from fairseq.data import BucketPadLengthDataset
444
+
445
+ self.src = BucketPadLengthDataset(
446
+ self.src,
447
+ sizes=self.src_sizes,
448
+ num_buckets=num_buckets,
449
+ pad_idx=self.src_dict.pad(),
450
+ left_pad=self.left_pad_source,
451
+ )
452
+ self.src_sizes = self.src.sizes
453
+ logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
454
+ self.ref = BucketPadLengthDataset(
455
+ self.ref,
456
+ sizes=self.ref_sizes,
457
+ num_buckets=num_buckets,
458
+ pad_idx=self.ref_dict.pad(),
459
+ left_pad=self.left_pad_source,
460
+ )
461
+ self.ref_sizes = self.ref.sizes
462
+ logger.info("bucketing reference lengths: {}".format(list(self.src.buckets)))
463
+ if self.tgt is not None:
464
+ self.tgt = BucketPadLengthDataset(
465
+ self.tgt,
466
+ sizes=self.tgt_sizes,
467
+ num_buckets=num_buckets,
468
+ pad_idx=self.tgt_dict.pad(),
469
+ left_pad=self.left_pad_target,
470
+ )
471
+ self.tgt_sizes = self.tgt.sizes
472
+ logger.info(
473
+ "bucketing target lengths: {}".format(list(self.tgt.buckets))
474
+ )
475
+
476
+ # determine bucket sizes using self.num_tokens, which will return
477
+ # the padded lengths (thanks to BucketPadLengthDataset)
478
+ num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long])
479
+ self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
480
+ self.buckets = [
481
+ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
482
+ ]
483
+ else:
484
+ self.buckets = None
485
+ self.pad_to_multiple = pad_to_multiple
486
+
487
+ def get_batch_shapes(self):
488
+ return self.buckets
489
+
490
+ def __getitem__(self, index):
491
+ tgt_item = self.tgt[index] if self.tgt is not None else None
492
+ src_item = self.src[index]
493
+ ref_item = self.ref[index]
494
+ # Append EOS to end of tgt sentence if it does not have an EOS and remove
495
+ # EOS from end of src sentence if it exists. This is useful when we use
496
+ # use existing datasets for opposite directions i.e., when we want to
497
+ # use tgt_dataset as src_dataset and vice versa
498
+ if self.append_eos_to_target:
499
+ eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
500
+ if self.tgt and self.tgt[index][-1] != eos:
501
+ tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
502
+
503
+ if self.append_bos:
504
+ bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
505
+ if self.tgt and self.tgt[index][0] != bos:
506
+ tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
507
+
508
+ bos = self.src_dict.bos()
509
+ if self.src[index][0] != bos:
510
+ src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
511
+ if self.ref[index][0] != bos:
512
+ ref_item = torch.cat([torch.LongTensor([bos]), self.ref[index]])
513
+
514
+ if self.remove_eos_from_source:
515
+ eos = self.src_dict.eos()
516
+ if self.src[index][-1] == eos:
517
+ src_item = self.src[index][:-1]
518
+ if self.ref[index][-1] == eos:
519
+ ref_item = self.ref[index][:-1]
520
+
521
+ example = {
522
+ "id": index,
523
+ "source": src_item,
524
+ "reference": ref_item,
525
+ "target": tgt_item,
526
+ }
527
+ if self.align_dataset is not None:
528
+ example["alignment"] = self.align_dataset[index]
529
+ if self.constraints is not None:
530
+ example["constraints"] = self.constraints[index]
531
+ return example
532
+
533
+ def __len__(self):
534
+ return len(self.src)
535
+
536
+ def collater(self, samples, pad_to_length=None):
537
+ """Merge a list of samples to form a mini-batch.
538
+
539
+ Args:
540
+ samples (List[dict]): samples to collate
541
+ pad_to_length (dict, optional): a dictionary of
542
+ {'source': source_pad_to_length, 'target': target_pad_to_length}
543
+ to indicate the max length to pad to in source and target respectively.
544
+
545
+ Returns:
546
+ dict: a mini-batch with the following keys:
547
+
548
+ - `id` (LongTensor): example IDs in the original input order
549
+ - `ntokens` (int): total number of tokens in the batch
550
+ - `net_input` (dict): the input to the Model, containing keys:
551
+
552
+ - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
553
+ the source sentence of shape `(bsz, src_len)`. Padding will
554
+ appear on the left if *left_pad_source* is ``True``.
555
+ - `src_lengths` (LongTensor): 1D Tensor of the unpadded
556
+ lengths of each source sentence of shape `(bsz)`
557
+ - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
558
+ tokens in the target sentence, shifted right by one
559
+ position for teacher forcing, of shape `(bsz, tgt_len)`.
560
+ This key will not be present if *input_feeding* is
561
+ ``False``. Padding will appear on the left if
562
+ *left_pad_target* is ``True``.
563
+ - `src_lang_id` (LongTensor): a long Tensor which contains source
564
+ language IDs of each sample in the batch
565
+
566
+ - `target` (LongTensor): a padded 2D Tensor of tokens in the
567
+ target sentence of shape `(bsz, tgt_len)`. Padding will appear
568
+ on the left if *left_pad_target* is ``True``.
569
+ - `tgt_lang_id` (LongTensor): a long Tensor which contains target language
570
+ IDs of each sample in the batch
571
+ """
572
+ res = collate(
573
+ samples,
574
+ pad_idx=self.src_dict.pad(),
575
+ eos_idx=self.eos,
576
+ left_pad_source=self.left_pad_source,
577
+ left_pad_target=self.left_pad_target,
578
+ input_feeding=self.input_feeding,
579
+ pad_to_length=pad_to_length,
580
+ pad_to_multiple=self.pad_to_multiple,
581
+ )
582
+ if self.src_lang_id is not None or self.tgt_lang_id is not None:
583
+ src_tokens = res["net_input"]["src_tokens"]
584
+ bsz = src_tokens.size(0)
585
+ if self.src_lang_id is not None:
586
+ res["net_input"]["src_lang_id"] = (
587
+ torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
588
+ )
589
+ if self.tgt_lang_id is not None:
590
+ res["tgt_lang_id"] = (
591
+ torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
592
+ )
593
+ return res
594
+
595
+ def num_tokens(self, index):
596
+ """Return the number of tokens in a sample. This value is used to
597
+ enforce ``--max-tokens`` during batching."""
598
+ return max(
599
+ self.src_sizes[index],
600
+ self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
601
+ )
602
+
603
+ def num_tokens_vec(self, indices):
604
+ """Return the number of tokens for a set of positions defined by indices.
605
+ This value is used to enforce ``--max-tokens`` during batching."""
606
+ sizes = self.src_sizes[indices]
607
+ if self.tgt_sizes is not None:
608
+ sizes = np.maximum(sizes, self.tgt_sizes[indices])
609
+ return sizes
610
+
611
+ def size(self, index):
612
+ """Return an example's size as a float or tuple. This value is used when
613
+ filtering a dataset with ``--max-positions``."""
614
+ return (
615
+ self.src_sizes[index],
616
+ self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
617
+ )
618
+
619
+ def ordered_indices(self):
620
+ """Return an ordered list of indices. Batches will be constructed based
621
+ on this order."""
622
+ if self.shuffle:
623
+ indices = np.random.permutation(len(self)).astype(np.int64)
624
+ else:
625
+ indices = np.arange(len(self), dtype=np.int64)
626
+ if self.buckets is None:
627
+ # sort by target length, then source length
628
+ if self.tgt_sizes is not None:
629
+ indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
630
+ return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
631
+ else:
632
+ # sort by bucketed_num_tokens, which is:
633
+ # max(padded_src_len, padded_tgt_len)
634
+ return indices[
635
+ np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
636
+ ]
637
+
638
+ @property
639
+ def supports_prefetch(self):
640
+ return getattr(self.src, "supports_prefetch", False) and (
641
+ getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
642
+ )
643
+
644
+ def prefetch(self, indices):
645
+ self.src.prefetch(indices)
646
+ if self.tgt is not None:
647
+ self.tgt.prefetch(indices)
648
+ if self.align_dataset is not None:
649
+ self.align_dataset.prefetch(indices)
650
+
651
+ def filter_indices_by_size(self, indices, max_sizes):
652
+ """Filter a list of sample indices. Remove those that are longer
653
+ than specified in max_sizes.
654
+
655
+ Args:
656
+ indices (np.array): original array of sample indices
657
+ max_sizes (int or list[int] or tuple[int]): max sample size,
658
+ can be defined separately for src and tgt (then list or tuple)
659
+
660
+ Returns:
661
+ np.array: filtered sample array
662
+ list: list of removed indices
663
+ """
664
+ return data_utils.filter_paired_dataset_indices_by_size(
665
+ self.src_sizes,
666
+ self.tgt_sizes,
667
+ indices,
668
+ max_sizes,
669
+ )
SpeechT5/Speech2S/speech2s/data/load_langpair_dataset.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Copyright (c) 2022 Microsoft
3
+ # Licensed under The MIT License [see LICENSE for details]
4
+ # Based on fairseq code bases
5
+ # https://github.com/facebookresearch/fairseq
6
+ # --------------------------------------------------------
7
+
8
+ """
9
+ Modified from https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/fairseq/tasks/translation.py
10
+ 1. Add custom lang_format in function load_langpair_dataset
11
+ 2. If truncate_source (default no), use RandomCropDataset instead of TruncateDataset
12
+ """
13
+
14
+ import itertools
15
+ import logging
16
+ import os
17
+
18
+ from fairseq.data import (
19
+ AppendTokenDataset,
20
+ LanguagePairDataset,
21
+ PrependTokenDataset,
22
+ StripTokenDataset,
23
+ TruncateDataset,
24
+ RandomCropDataset,
25
+ data_utils,
26
+ indexed_dataset,
27
+ )
28
+
29
+ from speechut.data.concat_dataset import ConcatDataset
30
+
31
+
32
+ EVAL_BLEU_ORDER = 4
33
+
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ def load_langpair_dataset(
39
+ data_path,
40
+ split,
41
+ src,
42
+ src_dict,
43
+ tgt,
44
+ tgt_dict,
45
+ combine,
46
+ dataset_impl,
47
+ upsample_primary,
48
+ left_pad_source,
49
+ left_pad_target,
50
+ max_source_positions,
51
+ max_target_positions,
52
+ prepend_bos=False,
53
+ load_alignments=False,
54
+ truncate_source=False,
55
+ append_source_id=False,
56
+ num_buckets=0,
57
+ shuffle=True,
58
+ pad_to_multiple=1,
59
+ prepend_bos_src=None,
60
+ lang_format="[{}]",
61
+ input_feeding=True,
62
+ ):
63
+ def split_exists(split, src, tgt, lang, data_path):
64
+ filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
65
+ return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
66
+
67
+ src_datasets = []
68
+ tgt_datasets = []
69
+
70
+ for k in itertools.count():
71
+ split_k = split + (str(k) if k > 0 else "")
72
+
73
+ # infer langcode
74
+ if split_exists(split_k, src, tgt, src, data_path):
75
+ prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
76
+ elif split_exists(split_k, tgt, src, src, data_path):
77
+ prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
78
+ else:
79
+ if k > 0:
80
+ break
81
+ else:
82
+ raise FileNotFoundError(
83
+ "Dataset not found: {} ({})".format(split, data_path)
84
+ )
85
+
86
+ src_dataset = data_utils.load_indexed_dataset(
87
+ prefix + src, src_dict, dataset_impl
88
+ )
89
+ if truncate_source:
90
+ src_dataset = AppendTokenDataset(
91
+ RandomCropDataset(
92
+ StripTokenDataset(src_dataset, src_dict.eos()),
93
+ max_source_positions - 1,
94
+ ),
95
+ src_dict.eos(),
96
+ )
97
+ src_datasets.append(src_dataset)
98
+
99
+ tgt_dataset = data_utils.load_indexed_dataset(
100
+ prefix + tgt, tgt_dict, dataset_impl
101
+ )
102
+ if tgt_dataset is not None:
103
+ tgt_datasets.append(tgt_dataset)
104
+
105
+ logger.info(
106
+ "{} {} {}-{} {} examples".format(
107
+ data_path, split_k, src, tgt, len(src_datasets[-1])
108
+ )
109
+ )
110
+
111
+ if not combine:
112
+ break
113
+
114
+ assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
115
+
116
+ if len(src_datasets) == 1:
117
+ src_dataset = src_datasets[0]
118
+ tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
119
+ else:
120
+ sample_ratios = [1] * len(src_datasets)
121
+ sample_ratios[0] = upsample_primary
122
+ src_dataset = ConcatDataset(src_datasets, sample_ratios)
123
+ if len(tgt_datasets) > 0:
124
+ tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
125
+ else:
126
+ tgt_dataset = None
127
+
128
+ if prepend_bos:
129
+ assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
130
+ src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
131
+ if tgt_dataset is not None:
132
+ tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
133
+ elif prepend_bos_src is not None:
134
+ logger.info(f"prepending src bos: {prepend_bos_src}")
135
+ src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
136
+
137
+ eos = None
138
+ if append_source_id:
139
+ src_dataset = AppendTokenDataset(
140
+ src_dataset, src_dict.index(lang_format.format(src))
141
+ )
142
+ if tgt_dataset is not None:
143
+ tgt_dataset = AppendTokenDataset(
144
+ tgt_dataset, tgt_dict.index(lang_format.format(tgt))
145
+ )
146
+ eos = tgt_dict.index(lang_format.format(tgt))
147
+
148
+ align_dataset = None
149
+ if load_alignments:
150
+ align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
151
+ if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
152
+ align_dataset = data_utils.load_indexed_dataset(
153
+ align_path, None, dataset_impl
154
+ )
155
+
156
+ tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
157
+ return LanguagePairDataset(
158
+ src_dataset,
159
+ src_dataset.sizes,
160
+ src_dict,
161
+ tgt_dataset,
162
+ tgt_dataset_sizes,
163
+ tgt_dict,
164
+ left_pad_source=left_pad_source,
165
+ left_pad_target=left_pad_target,
166
+ align_dataset=align_dataset,
167
+ eos=eos,
168
+ num_buckets=num_buckets,
169
+ shuffle=shuffle,
170
+ pad_to_multiple=pad_to_multiple,
171
+ input_feeding=input_feeding,
172
+ )
SpeechT5/Speech2S/speech2s/data/multimodal_corpus_dataset.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Copyright (c) 2022 Microsoft
3
+ # Licensed under The MIT License [see LICENSE for details]
4
+ # Based on fairseq code bases
5
+ # https://github.com/facebookresearch/fairseq
6
+ # --------------------------------------------------------
7
+
8
+ import logging
9
+ from os import replace
10
+ import time
11
+ from collections import OrderedDict
12
+ from typing import Any, Dict, List, Optional
13
+
14
+ import numpy as np
15
+ from fairseq.data import data_utils
16
+
17
+ from fairseq.data import FairseqDataset
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class MultiCorpusDataset(FairseqDataset):
23
+ """
24
+ see fairseq/fairseq/data/multi_corpus_dataset.__doc__
25
+
26
+ Args:
27
+ datasets: a OrderedDict of FairseqDataset instances.
28
+ distribution: a List containing the probability of getting an utterance from
29
+ corresponding dataset
30
+ seed: random seed for sampling the datsets
31
+ sort_indices: if true, will sort the ordered indices by size
32
+ batch_sample: if true, will ensure each batch is from a single dataset
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ datasets: Dict[str, FairseqDataset],
38
+ max_positions: Dict,
39
+ distribution: List[float],
40
+ max_tokens_ratio: List[float],
41
+ seed: int = 1234,
42
+ sort_indices: bool = False,
43
+ check_length: bool = False,
44
+ ):
45
+ super().__init__()
46
+ assert isinstance(datasets, OrderedDict)
47
+ assert len(datasets) == len(distribution)
48
+ # assert sum(distribution) == 1
49
+ self.datasets = datasets
50
+ self.distribution = distribution
51
+ self.max_tokens_ratio = max_tokens_ratio
52
+ self.seed = seed
53
+ self.sort_indices = sort_indices
54
+ self.max_positions = max_positions
55
+ self.check_length = check_length
56
+
57
+ # Avoid repeated conversions to list later
58
+ self.dataset_list = list(datasets.values())
59
+ self.total_num_instances = 0
60
+
61
+ # first_dataset = self.dataset_list[0]
62
+
63
+ self.num_instances_per_dataset = []
64
+ self.dataset_offsets = []
65
+ for i, dataset in enumerate(self.dataset_list):
66
+ assert isinstance(dataset, FairseqDataset)
67
+ # assert type(dataset) is type(first_dataset)
68
+ self.num_instances_per_dataset.append(
69
+ 0 if self.distribution[i] == 0 else len(dataset)
70
+ )
71
+ self.dataset_offsets.append(self.total_num_instances)
72
+ self.total_num_instances += self.num_instances_per_dataset[i]
73
+
74
+ def ordered_indices(self):
75
+ start = time.time()
76
+ with data_utils.numpy_seed(self.seed, self.epoch):
77
+ logger.info(f"sampling new dataset with seed {self.seed} epoch {self.epoch}")
78
+ sampled_indices = {}
79
+
80
+ # For each dataset i, sample self.distribution[i] * self.total_num_instances
81
+ for i, key in enumerate(self.datasets):
82
+ tp = time.time()
83
+ if self.distribution[i] == 0:
84
+ # skip dataset if sampling probability is 0
85
+ continue
86
+
87
+ if i < len(self.datasets) - 1:
88
+ num_instances = int(self.distribution[i] * self.total_num_instances)
89
+ high = self.dataset_offsets[i + 1]
90
+ else:
91
+ num_instances = int(self.distribution[i] * self.total_num_instances)
92
+ high = self.total_num_instances
93
+
94
+ logger.info(f"sampling {num_instances} from {key} dataset")
95
+
96
+ # First, add k copies of the dataset where k = num_instances // len(dataset).
97
+ # This ensures an equal distribution of the data points as much as possible.
98
+ # For the remaining entries randomly sample them
99
+ dataset_size = len(self.datasets[key])
100
+ num_copies = num_instances // dataset_size
101
+ dataset_indices = np.random.permutation(high - self.dataset_offsets[i])[: num_instances - num_copies * dataset_size]
102
+ if num_copies > 0:
103
+ dataset_indices = np.concatenate(
104
+ (
105
+ np.repeat(
106
+ np.arange(high - self.dataset_offsets[i]), num_copies
107
+ ),
108
+ dataset_indices,
109
+ )
110
+ )
111
+ # filter by size, we should ignore it by setting check_length=False
112
+ # , as it is very time-consuming on large dadaset
113
+ if self.max_positions[key] is not None and self.check_length:
114
+ dataset_indices, ignored = self.datasets[key].filter_indices_by_size(
115
+ dataset_indices,
116
+ self.max_positions[key],
117
+ )
118
+ if len(ignored) > 0:
119
+ logger.warning(
120
+ (
121
+ "{:,} samples have invalid sizes and will be skipped, "
122
+ "max_positions={}, first few sample ids={}"
123
+ ).format(len(ignored), self.max_positions[key], ignored[:10])
124
+ )
125
+
126
+ if self.sort_indices:
127
+ logger.info(" - sampled indices took {}s".format(time.time() - tp))
128
+ tp = time.time()
129
+ dataset_indices = np.sort(dataset_indices)
130
+ ordered_indices = self.datasets[key].ordered_indices()
131
+ if isinstance(ordered_indices[0], np.ndarray): # chunked audio data
132
+ dataset_indices = [order_idx + self.dataset_offsets[i] for order_idx in ordered_indices]
133
+ assert self.dataset_offsets[i] == 0
134
+ # TODO for chunked audio data, now assume len(dataset_indices) == len(dataset). Don't filter any data.
135
+ else:
136
+ dataset_indices = ordered_indices[dataset_indices] + self.dataset_offsets[i]
137
+ logger.info(" - ordered_indices took {}s".format(time.time() - tp))
138
+ else:
139
+ np.random.shuffle(dataset_indices)
140
+
141
+ sampled_indices[key] = dataset_indices
142
+
143
+ logger.info(
144
+ "multi_corpus_dataset ordered_indices took {}s".format(
145
+ time.time() - start
146
+ )
147
+ )
148
+ return sampled_indices
149
+
150
+ def _map_index(self, index: int):
151
+ """
152
+ If dataset A has length N and dataset B has length M
153
+ then index 1 maps to index 1 of dataset A, and index N + 1
154
+ maps to index 1 of B.
155
+ """
156
+ counter = 0
157
+ for num_instances, key in zip(self.num_instances_per_dataset, self.datasets):
158
+ if index < counter + num_instances:
159
+ return index - counter, key
160
+ counter += num_instances
161
+ raise ValueError(
162
+ "Invalid index: {}, max: {}".format(index, self.total_num_instances)
163
+ )
164
+
165
+ def __len__(self):
166
+ """
167
+ Length of this dataset is the sum of individual datasets
168
+ """
169
+ return self.total_num_instances
170
+
171
+ def __getitem__(self, index):
172
+ new_index, key = self._map_index(index)
173
+ try:
174
+ item = self.datasets[key][new_index]
175
+ item["full_id"] = index
176
+ return item
177
+ except Exception as e:
178
+ e.args = (f"Error from {key} dataset", *e.args)
179
+ raise
180
+
181
+ def collater(self, samples):
182
+ """
183
+ If we are doing batch sampling, then pick the right collater to use.
184
+
185
+ Otherwise we assume all collaters are the same.
186
+ """
187
+ if len(samples) == 0:
188
+ return None
189
+
190
+ samples_dict = {key: [] for key in self.datasets}
191
+ for s in samples:
192
+ _, key = self._map_index(s["full_id"])
193
+ samples_dict[key].append(s)
194
+
195
+ batch = {}
196
+ for key in samples_dict:
197
+ if len(samples_dict[key]) == 0:
198
+ continue
199
+ batch[key] = self.datasets[key].collater(samples_dict[key])
200
+
201
+ return batch
202
+
203
+
204
+ def num_tokens(self, index: int):
205
+ index, key = self._map_index(index)
206
+ return self.datasets[key].num_tokens(index)
207
+
208
+ def size(self, index: int):
209
+ index, key = self._map_index(index)
210
+ return self.datasets[key].size(index)
211
+
212
+ @property
213
+ def can_reuse_epoch_itr_across_epochs(self):
214
+ return False
215
+
216
+ def set_epoch(self, epoch, **unused):
217
+ super().set_epoch(epoch)
218
+ logger.info(f"setting epoch of multi_corpus_dataset to {epoch}")
219
+ for ds in self.dataset_list:
220
+ if hasattr(ds, "set_epoch"):
221
+ ds.set_epoch(epoch)
222
+ self.epoch = epoch
223
+
224
+ @property
225
+ def supports_prefetch(self):
226
+ return False
227
+
228
+ @property
229
+ def supports_fetch_outside_dataloader(self):
230
+ return all(
231
+ self.datasets[key].supports_fetch_outside_dataloader
232
+ for key in self.datasets
233
+ )
234
+
235
+
236
+ def batch_by_size(
237
+ self,
238
+ indices,
239
+ max_tokens=None,
240
+ max_sentences=None,
241
+ required_batch_size_multiple=1,
242
+ ):
243
+ dataset_indices = indices
244
+ batches_dict = {}
245
+ for n, key in enumerate(dataset_indices):
246
+ max_tokens_ratio = self.max_tokens_ratio[n]
247
+ if isinstance(dataset_indices[key][0], np.ndarray): # chunked audio data
248
+ cur_batches = self.datasets[key].batch_by_size(
249
+ dataset_indices[key],
250
+ round(max_tokens * max_tokens_ratio),
251
+ max_sentences,
252
+ required_batch_size_multiple,
253
+ )
254
+ logger.info(f"Created {sum([len(b) for b in cur_batches])} [{len(cur_batches)}] batches for dataset {key}")
255
+ else:
256
+ cur_batches = super().batch_by_size(
257
+ np.array(dataset_indices[key], dtype=np.int64),
258
+ round(max_tokens * max_tokens_ratio),
259
+ max_sentences,
260
+ required_batch_size_multiple,
261
+ )
262
+ logger.info(f"Created {len(cur_batches)} batches for dataset {key}")
263
+ batches_dict[key] = cur_batches
264
+
265
+ return batches_dict
266
+
267
+
268
+ def get_batch_sampler(
269
+ self,
270
+ indices,
271
+ num_shards,
272
+ seed,
273
+ max_tokens=None,
274
+ max_sentences=None,
275
+ required_batch_size_multiple=1,
276
+ split_modality_batch=False,
277
+ ):
278
+
279
+ def batch_sampler(dataset, epoch):
280
+ start = time.time()
281
+ batches_dict = dataset.batch_by_size(
282
+ indices,
283
+ max_tokens=max_tokens,
284
+ max_sentences=max_sentences,
285
+ required_batch_size_multiple=required_batch_size_multiple,
286
+ )
287
+ logger.info(f"multi_corpus_dataset, batch_by_size took {time.time() - start}s")
288
+ start = time.time()
289
+ new_batches = []
290
+
291
+ ### shuffle inner group size, split into speech/text batches
292
+ shuffled_batches_list = []
293
+ speech_batches = []
294
+ ### we should specify the speech_batches because: we need concatenate different speech datasets
295
+ # (e.g. ltr or km) instead of loading them parellelly.
296
+ for name, batches in batches_dict.items():
297
+ if name.startswith("speech"):
298
+ if isinstance(batches[0], list): # chunked audio data
299
+ batches = self.datasets[name].shuffle_batches(list(batches), seed + epoch)
300
+ shuffled_batches_list.append(batches)
301
+ else:
302
+ batches = inner_bucket_shuffle(batches, seed+epoch, num_shards*10)
303
+ batches = batches[: (len(batches) // num_shards) * num_shards]
304
+ if len(batches) == 0:
305
+ logger.warning(f"Sample 0 batch for {name}, you should ensure that no {name} data provided.")
306
+ else:
307
+ speech_batches += batches
308
+ else:
309
+ batches = inner_bucket_shuffle(batches, seed+epoch, num_shards*10)
310
+ batches = batches[: (len(batches) // num_shards) * num_shards]
311
+ if len(batches) == 0:
312
+ logger.warning(f"Sample 0 batch for {name}, you should ensure that no {name} data provided.")
313
+ else:
314
+ batches = shuffle_buckets(batches, seed=seed+epoch, inner_shuf=False)
315
+ shuffled_batches_list.append(batches)
316
+ if len(speech_batches) > 0:
317
+ speech_batches = shuffle_buckets(speech_batches, seed=seed+epoch, inner_shuf=False)
318
+ shuffled_batches_list.append(speech_batches)
319
+
320
+ ### create the final new_batches
321
+ num_batch = min(len(batches) for batches in shuffled_batches_list)
322
+ if split_modality_batch:
323
+ for i in range(0, num_batch, num_shards):
324
+ for batches in shuffled_batches_list:
325
+ new_batches += batches[i: i + num_shards]
326
+ else:
327
+ for i in range(num_batch):
328
+ new_batches.append(np.concatenate([batches[i] for batches in shuffled_batches_list]))
329
+
330
+ logger.info(f"multi_corpus_dataset sample {len(new_batches)} batches, took {time.time() - start}s")
331
+ return new_batches
332
+
333
+ def inner_bucket_shuffle(batches, seed, bucket_size=10, thr=0):
334
+ """we assert batches is sorted form long to short.
335
+ shuffle samples in a buctet(e.g. 10 batches).
336
+ batches: a list of numpy array"""
337
+ num_batch = len(batches)
338
+ new_batches = []
339
+ num_buckets = len(batches) // bucket_size
340
+ i = 0
341
+ while i < num_batch:
342
+ if (i < bucket_size * thr or
343
+ i >= bucket_size * (num_buckets - thr)
344
+ ):
345
+ new_batches.append(batches[i])
346
+ i += 1
347
+ else:
348
+ group = np.concatenate(batches[i: i+bucket_size])
349
+ with data_utils.numpy_seed(seed):
350
+ np.random.shuffle(group)
351
+ new_batches += np.array_split(group, bucket_size)
352
+ i += bucket_size
353
+ assert all([len(batch) > 0 for batch in new_batches])
354
+ return new_batches
355
+
356
+ def shuffle_buckets(batches, seed, inner_shuf=True):
357
+ if inner_shuf:
358
+ batches = inner_bucket_shuffle(batches, seed, num_shards*10)
359
+ batches = [batches[i: i + num_shards] for i in range(0, len(batches)-num_shards+1, num_shards)]
360
+ assert len(batches[-1]) == num_shards
361
+ new_batches = []
362
+ with data_utils.numpy_seed(seed):
363
+ np.random.shuffle(batches)
364
+ for group in batches:
365
+ new_batches += group
366
+ return new_batches
367
+
368
+ return batch_sampler
SpeechT5/Speech2S/speech2s/models/__init__.py ADDED
File without changes
SpeechT5/Speech2S/speech2s/models/speechut.py ADDED
@@ -0,0 +1,785 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----------------------------------------------------------------------------
2
+ # SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
4
+ # Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
5
+ #
6
+ # Copyright (c) 2022 Microsoft
7
+ # Licensed under The MIT License [see LICENSE for details]
8
+ # ----------------------------------------------------------------------------
9
+
10
+ import logging
11
+ from dataclasses import dataclass, field
12
+ from typing import Dict, List, Optional, Tuple
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ from fairseq import utils, checkpoint_utils
20
+ from fairseq.data.data_utils import compute_mask_indices
21
+ from fairseq.data.dictionary import Dictionary
22
+ from fairseq.dataclass import ChoiceEnum
23
+ from fairseq.models import BaseFairseqModel, register_model
24
+ from fairseq.models.transformer import Embedding
25
+ from fairseq.file_io import PathManager
26
+ from torch import Tensor
27
+ from fairseq.models.wav2vec.wav2vec2 import ConvFeatureExtractionModel
28
+ from fairseq.modules import GradMultiply, LayerNorm
29
+ from fairseq.tasks.hubert_pretraining import (
30
+ HubertPretrainingConfig,
31
+ HubertPretrainingTask,
32
+ )
33
+ from fairseq.models.hubert import HubertConfig
34
+ from fairseq.models.transformer import TransformerConfig
35
+ from speechut.modules import TransformerEncoder
36
+ from speechut.modules import TransformerEncoderBase
37
+ from speechut.modules import TransformerDecoderBaseScriptable
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
42
+ MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
43
+
44
+
45
+ @dataclass
46
+
47
+ class SpeechutConfig(HubertConfig):
48
+ use_rel_pos_enc: bool = field(
49
+ default=False,
50
+ metadata={"help": "whether to use relative positional encoding"},
51
+ )
52
+ scaling_for_att: float = field(
53
+ default=1.0,
54
+ metadata={"help": "scaling for attention weights to prevent overflow issue (for large model)"},
55
+ )
56
+
57
+ # unit encoder-decoder
58
+ text_transformer: TransformerConfig = TransformerConfig()
59
+ reset_decoder_embedding_config: bool = field(
60
+ default=False,
61
+ metadata={"help": "reset the no_scale_embedding/layernorm_embedding to default for the decoder"},
62
+ )
63
+ add_unit_encoder: bool = field(
64
+ default=False,
65
+ metadata={"help": "add unit encoder"},
66
+ )
67
+ add_decoder: bool = field(
68
+ default=True,
69
+ metadata={"help": "add decoder"},
70
+ )
71
+ add_text_ctc: bool = field(
72
+ default=False,
73
+ metadata={"help": "add_text_ctc head"},
74
+ )
75
+ text_ctc_conv_kernel: int = field(
76
+ default=2,
77
+ metadata={"help": "text_ctc_conv kernel size"},
78
+ )
79
+ mask_u2t: bool = field(
80
+ default=True,
81
+ metadata={"help": "mask the unit input in unit-to-text task"},
82
+ )
83
+
84
+ # embedding mixing
85
+ mix_with_unit: bool = field(
86
+ default=True,
87
+ metadata={"help": "mix with the unit embeddings"},
88
+ )
89
+ use_pred_unit: bool = field(
90
+ default=False,
91
+ metadata={"help": "use the embeddings of predicted units"},
92
+ )
93
+ l2_embedding: bool = field(
94
+ default=False,
95
+ metadata={"help": "compute l2 loss between unit embedding and unit hidden state"},
96
+ )
97
+
98
+ # Finetune related
99
+ encoder_dict_size: int = field(
100
+ default=-1,
101
+ metadata={"help": "text encoder dictionary dimension"},
102
+ )
103
+
104
+ decoder_dict_size: int = field(
105
+ default=-1,
106
+ metadata={"help": "decoder dictionary dimension"},
107
+ )
108
+
109
+
110
+ @register_model("speechut", dataclass=SpeechutConfig)
111
+ class SpeechutModel(BaseFairseqModel):
112
+ def __init__(
113
+ self,
114
+ cfg: SpeechutConfig,
115
+ task_cfg: HubertPretrainingConfig,
116
+ dictionaries: List[Dictionary],
117
+ unit_dictionary: Dictionary = None,
118
+ text_tgt_dictionary: Dictionary = None,
119
+ ) -> None:
120
+ super().__init__()
121
+ logger.info(f"SpeechutModel Config: {cfg}")
122
+
123
+ feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
124
+ self.embed = feature_enc_layers[-1][0]
125
+
126
+ self.feature_extractor = ConvFeatureExtractionModel(
127
+ conv_layers=feature_enc_layers,
128
+ dropout=0.0,
129
+ mode=cfg.extractor_mode,
130
+ conv_bias=cfg.conv_bias,
131
+ )
132
+ feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
133
+ self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
134
+
135
+ self.post_extract_proj = (
136
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
137
+ if self.embed != cfg.encoder_embed_dim
138
+ else None
139
+ )
140
+
141
+ self.mask_prob = cfg.mask_prob
142
+ self.mask_selection = cfg.mask_selection
143
+ self.mask_other = cfg.mask_other
144
+ self.mask_length = cfg.mask_length
145
+ self.no_mask_overlap = cfg.no_mask_overlap
146
+ self.mask_min_space = cfg.mask_min_space
147
+
148
+ self.mask_channel_prob = cfg.mask_channel_prob
149
+ self.mask_channel_selection = cfg.mask_channel_selection
150
+ self.mask_channel_other = cfg.mask_channel_other
151
+ self.mask_channel_length = cfg.mask_channel_length
152
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
153
+ self.mask_channel_min_space = cfg.mask_channel_min_space
154
+
155
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
156
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
157
+
158
+ self.feature_grad_mult = cfg.feature_grad_mult
159
+ self.logit_temp = cfg.logit_temp
160
+ self.skip_masked = cfg.skip_masked
161
+ self.skip_nomask = cfg.skip_nomask
162
+
163
+ final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
164
+
165
+ self.mask_emb = nn.Parameter(
166
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
167
+ )
168
+
169
+ self.encoder = TransformerEncoder(cfg)
170
+ self.layer_norm = LayerNorm(self.embed)
171
+
172
+ self.target_glu = None
173
+ if cfg.target_glu:
174
+ self.target_glu = nn.Sequential(
175
+ nn.Linear(final_dim, final_dim * 2), nn.GLU()
176
+ )
177
+
178
+ self.final_dim = final_dim
179
+ assert len(dictionaries) <= 2, f"Only support <=2 kinds of targets, get {len(dictionaries)} dictionaries"
180
+ if len(dictionaries) == 1:
181
+ dictionaries = [dictionaries[0], dictionaries[0]]
182
+ self.num_classes = [len(d) for d in dictionaries]
183
+
184
+ self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
185
+ self.code_encoder_proj = nn.Linear(cfg.text_transformer.encoder.embed_dim, self.num_classes[-1])
186
+ self.final_proj_list = [self.final_proj, self.code_encoder_proj]
187
+
188
+ self.label_embs_concat = nn.Parameter(torch.FloatTensor(self.num_classes[0], final_dim))
189
+ self.label_embs_list = [self.label_embs_concat]
190
+ for p in self.label_embs_list:
191
+ nn.init.uniform_(p)
192
+
193
+ ### build unit encoder:
194
+ self.mask_u2t = cfg.mask_u2t
195
+ self.add_text_ctc = cfg.add_text_ctc
196
+ self.text_ctc_conv_kernel = cfg.text_ctc_conv_kernel
197
+ self.padding_idx = unit_dictionary.pad()
198
+ self.unit_mask_idx = unit_dictionary.index("<mask>")
199
+
200
+ self.add_unit_encoder = cfg.add_unit_encoder
201
+ self.mix_with_unit = cfg.mix_with_unit
202
+ self.use_pred_unit = cfg.use_pred_unit
203
+ self.l2_embedding = cfg.l2_embedding
204
+ if self.add_unit_encoder:
205
+ assert len(unit_dictionary) == self.num_classes[0], f"unit_dictionary: {len(unit_dictionary)}, self.num_classes[0]: {self.num_classes[0]}"
206
+ ### build unit pre-net, and shared with hubert label_embs if needed (default: False)
207
+ self.unit_embed_tokens = self.build_embedding(
208
+ unit_dictionary,
209
+ cfg.text_transformer.encoder.embed_dim,
210
+ )
211
+ if self.final_dim == cfg.text_transformer.encoder.embed_dim:
212
+ logger.info("Share label_embs[0] with unit_embed_tokens ...")
213
+ nn.init.uniform_(self.unit_embed_tokens.weight)
214
+ self.label_embs_list[0] = self.unit_embed_tokens.weight
215
+
216
+ ### build unit encoder
217
+ self.unit_encoder = TransformerEncoderBase(
218
+ cfg.text_transformer,
219
+ unit_dictionary,
220
+ self.unit_embed_tokens,
221
+ use_rel_pos_enc=cfg.use_rel_pos_enc,
222
+ scaling_for_att=cfg.scaling_for_att,
223
+ )
224
+
225
+ ### build text ctc head
226
+ if self.add_text_ctc:
227
+ conv = nn.Conv1d(
228
+ cfg.text_transformer.encoder.embed_dim, cfg.text_transformer.encoder.embed_dim,
229
+ self.text_ctc_conv_kernel,
230
+ stride=self.text_ctc_conv_kernel // 2,
231
+ bias=False,
232
+ padding=self.text_ctc_conv_kernel // 2,
233
+ )
234
+ nn.init.kaiming_normal_(conv.weight)
235
+ self.unit_encoder_ctc_head = nn.Sequential(
236
+ Rotate3D(),
237
+ conv,
238
+ nn.Dropout(p=0.1),
239
+ nn.Sequential(
240
+ Rotate3D(),
241
+ Rotate3D(),
242
+ LayerNorm(cfg.text_transformer.encoder.embed_dim),
243
+ ),
244
+ nn.GELU(),
245
+ nn.Linear(cfg.text_transformer.encoder.embed_dim, len(text_tgt_dictionary)),
246
+ )
247
+
248
+ ### build unit2text decoder, not available for now
249
+ self.add_decoder = cfg.add_decoder
250
+ self.text_transformer_cfg = cfg.text_transformer
251
+ if self.add_decoder:
252
+ # To make sure that the decoder dict size is the same as the fine-tuning tgt_dict size or bpe code dict size
253
+ dec_dictionary = self.cutting_dictionary(text_tgt_dictionary, cfg.decoder_dict_size)
254
+ decoder_embed_tokens = self.build_embedding(
255
+ dec_dictionary, cfg.text_transformer.decoder.embed_dim
256
+ )
257
+ if cfg.reset_decoder_embedding_config:
258
+ cfg.text_transformer.no_scale_embedding = False
259
+ cfg.text_transformer.layernorm_embedding = False
260
+ cfg.text_transformer.no_token_positional_embeddings = False
261
+ self.decoder = TransformerDecoderBaseScriptable(cfg.text_transformer, dec_dictionary, decoder_embed_tokens, use_rel_pos_enc=cfg.use_rel_pos_enc)
262
+
263
+
264
+ def cutting_dictionary(self, dictionary, dict_size):
265
+ if dictionary is None or dict_size <= 0:
266
+ return dictionary
267
+ else:
268
+ import copy
269
+ cut_dictionary = copy.deepcopy(dictionary)
270
+ if dict_size > len(cut_dictionary):
271
+ for i in range(dict_size - len(cut_dictionary)):
272
+ cut_dictionary.symbols.append(f'_{i}_')
273
+ else:
274
+ cut_dictionary.symbols = cut_dictionary.symbols[:dict_size]
275
+ return cut_dictionary
276
+
277
+ def build_embedding(self, dictionary, embed_dim):
278
+ num_embeddings = len(dictionary)
279
+ padding_idx = dictionary.pad()
280
+ return Embedding(num_embeddings, embed_dim, padding_idx)
281
+
282
+ def upgrade_state_dict_named(self, state_dict, name):
283
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
284
+
285
+ super().upgrade_state_dict_named(state_dict, name)
286
+ return state_dict
287
+
288
+ @classmethod
289
+ def build_model(cls, cfg: SpeechutConfig, task: HubertPretrainingTask):
290
+ """Build a new model instance."""
291
+ unit_dictionary = getattr(task, "text_src_dictionary", None)
292
+ text_tgt_dictionary = getattr(task, "text_dictionary", None)
293
+ model = SpeechutModel(cfg, task.cfg, task.dictionaries, unit_dictionary, text_tgt_dictionary)
294
+ return model
295
+
296
+ def apply_mask(self, x, padding_mask, target_list):
297
+ B, T, C = x.shape
298
+ if self.mask_prob > 0:
299
+ mask_indices = compute_mask_indices(
300
+ (B, T),
301
+ padding_mask,
302
+ self.mask_prob,
303
+ self.mask_length,
304
+ self.mask_selection,
305
+ self.mask_other,
306
+ min_masks=2,
307
+ no_overlap=self.no_mask_overlap,
308
+ min_space=self.mask_min_space,
309
+ )
310
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
311
+ x[mask_indices] = self.mask_emb
312
+ else:
313
+ mask_indices = None
314
+
315
+ if self.mask_channel_prob > 0:
316
+ mask_channel_indices = compute_mask_indices(
317
+ (B, C),
318
+ None,
319
+ self.mask_channel_prob,
320
+ self.mask_channel_length,
321
+ self.mask_channel_selection,
322
+ self.mask_channel_other,
323
+ no_overlap=self.no_mask_channel_overlap,
324
+ min_space=self.mask_channel_min_space,
325
+ )
326
+ mask_channel_indices = (
327
+ torch.from_numpy(mask_channel_indices)
328
+ .to(x.device)
329
+ .unsqueeze(1)
330
+ .expand(-1, T, -1)
331
+ )
332
+ x[mask_channel_indices] = 0
333
+
334
+ return x, mask_indices
335
+
336
+ def forward_features(self, source: torch.Tensor) -> torch.Tensor:
337
+ if self.feature_grad_mult > 0:
338
+ features = self.feature_extractor(source)
339
+ if self.feature_grad_mult != 1.0:
340
+ features = GradMultiply.apply(features, self.feature_grad_mult)
341
+ else:
342
+ with torch.no_grad():
343
+ features = self.feature_extractor(source)
344
+ return features
345
+
346
+ def forward_targets(
347
+ self,
348
+ features: torch.Tensor,
349
+ target_list: List[torch.Tensor],
350
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
351
+ # Trim features to ensure labels exist and then get aligned labels
352
+ feat_tsz = features.size(2)
353
+ targ_tsz = min([t.size(1) for t in target_list])
354
+ if self.feat2tar_ratio * feat_tsz > targ_tsz:
355
+ feat_tsz = int(targ_tsz / self.feat2tar_ratio)
356
+ features = features[..., :feat_tsz]
357
+ target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
358
+ target_inds += np.random.choice(int(self.feat2tar_ratio))
359
+ target_list = [t[:, target_inds.long()] for t in target_list]
360
+ return features, target_list
361
+
362
+ def forward_padding_mask(
363
+ self,
364
+ features: torch.Tensor,
365
+ padding_mask: torch.Tensor,
366
+ ) -> torch.Tensor:
367
+ extra = padding_mask.size(1) % features.size(1)
368
+ if extra > 0:
369
+ padding_mask = padding_mask[:, :-extra]
370
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
371
+ padding_mask = padding_mask.all(-1)
372
+ return padding_mask
373
+
374
+ def get_normalized_probs(
375
+ self,
376
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
377
+ log_probs: bool,
378
+ sample: Optional[Dict[str, Tensor]] = None,
379
+ ):
380
+ lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
381
+ lprobs.batch_first = True
382
+ return lprobs
383
+
384
+ def downsample_ctc_padding_mask(self, padding_mask):
385
+ """
386
+ padding_mask: (B, T)
387
+ """
388
+ stride = self.text_ctc_conv_kernel // 2
389
+ return padding_mask[:, ::stride]
390
+
391
+ def compute_pred(self, proj_x, label_embs):
392
+ if self.target_glu:
393
+ label_embs = self.target_glu(label_embs)
394
+ x = F.normalize(proj_x.float(), dim=-1) # (S, D)
395
+ label_embs = F.normalize(label_embs.float(), dim=-1) # (C, D)
396
+ logits = torch.matmul(x, label_embs.T).type_as(proj_x) # (S, C)
397
+ logits /= self.logit_temp
398
+ return logits
399
+
400
+ def compute_hubert_logits(self, x, target, proj, label_embs, padding_mask, mask_indices):
401
+ if not self.skip_masked:
402
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
403
+ proj_x_m = proj(x[masked_indices])
404
+ logit_m_list = [(self.compute_pred(proj_x_m, label_embs), target[masked_indices])]
405
+ else:
406
+ logit_m_list = [None]
407
+
408
+ if not self.skip_nomask:
409
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
410
+ proj_x_u = proj(x[nomask_indices])
411
+ logit_u_list = [(self.compute_pred(proj_x_u, label_embs), target[nomask_indices])]
412
+ else:
413
+ logit_u_list = [None]
414
+
415
+ return logit_m_list, logit_u_list
416
+
417
+ def compute_ce_logits(self, x, target, proj, padding_mask, mask_indices):
418
+ if not self.skip_masked:
419
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
420
+ logit_m_list = [(proj(x[masked_indices]), target[masked_indices])]
421
+ else:
422
+ logit_m_list = [None]
423
+
424
+ if not self.skip_nomask:
425
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
426
+ logit_u_list = [(proj(x[nomask_indices]), target[nomask_indices])]
427
+ else:
428
+ logit_u_list = [None]
429
+
430
+ return logit_m_list, logit_u_list
431
+
432
+ def convert_embeddings(self,
433
+ x,
434
+ padding_mask,
435
+ target=None,
436
+ mask_indices=None,
437
+ mix_with_unit=False,
438
+ use_pred_unit=False,
439
+ l2_embedding=False,
440
+ remask=False
441
+ ):
442
+ """
443
+ 1. Mix with units if needed (default: True)
444
+ 2. Prepare for unit_encoder inputs
445
+ Inputs:
446
+ x, (B, T, D)
447
+ Return:
448
+ src_tokens, (B, T)
449
+ soft_embeddings, (B, T, D)
450
+ l2_loss, a loss
451
+ """
452
+ soft_embeddings = self.final_proj_list[0](x) if x.size(-1) == self.final_dim else x
453
+ if padding_mask is None:
454
+ padding_mask = soft_embeddings.new_zeros(soft_embeddings.size(0), soft_embeddings.size(1), dtype=torch.long)
455
+ if use_pred_unit:
456
+ src_tokens = self.compute_pred(self.final_proj_list[0](x), self.label_embs_list[0]).argmax(dim=-1)
457
+ src_tokens[padding_mask] = self.padding_idx
458
+ elif target is not None:
459
+ src_tokens = target
460
+ else:
461
+ src_tokens = padding_mask.long()
462
+
463
+ if l2_embedding | mix_with_unit:
464
+ unit_embeddings = self.unit_embed_tokens(src_tokens) # (B, T, D)
465
+
466
+ l2_loss = 0
467
+ if l2_embedding:
468
+ if mask_indices is not None:
469
+ l2_loss = (soft_embeddings - unit_embeddings)[mask_indices].float().pow(2).mean(dim=-1)
470
+ scale = unit_embeddings[mask_indices].float().pow(2).sum(dim=-1)
471
+ else:
472
+ l2_loss = (soft_embeddings - unit_embeddings).float().pow(2).mean(dim=-1)
473
+ scale = unit_embeddings.float().pow(2).sum(dim=-1)
474
+ l2_loss = (l2_loss / scale).mean()
475
+
476
+ if mix_with_unit:
477
+ B, T, D = x.shape
478
+ selected_indices = compute_mask_indices(
479
+ (B, T),
480
+ padding_mask,
481
+ self.mask_prob / 2,
482
+ self.mask_length // 2,
483
+ self.mask_selection,
484
+ self.mask_other,
485
+ min_masks=2,
486
+ no_overlap=self.no_mask_overlap,
487
+ min_space=self.mask_min_space,
488
+ )
489
+ selected_indices = torch.from_numpy(selected_indices).to(x.device)
490
+ if mask_indices is not None:
491
+ if remask:
492
+ remask_indices = torch.logical_and(selected_indices, mask_indices)
493
+ soft_embeddings[remask_indices] = self.mask_emb
494
+ swap_indices = torch.logical_and(selected_indices, ~mask_indices)
495
+ else:
496
+ swap_indices = selected_indices
497
+ soft_embeddings[swap_indices] = unit_embeddings[swap_indices]
498
+
499
+ soft_embeddings = soft_embeddings * (1 - padding_mask.unsqueeze(-1).type_as(x))
500
+ return src_tokens, soft_embeddings, l2_loss
501
+
502
+ def forward(
503
+ self,
504
+ source: torch.Tensor = None,
505
+ src_tokens: torch.Tensor = None,
506
+ src_lengths: torch.Tensor = None,
507
+ prev_output_tokens: torch.Tensor = None,
508
+ target_list: Optional[List[torch.Tensor]] = None,
509
+ padding_mask: Optional[torch.Tensor] = None,
510
+ mask: bool = True,
511
+ features_only: bool = False,
512
+ output_layer: Optional[int] = None,
513
+ ) -> Dict[str, torch.Tensor]:
514
+ assert source is not None or src_tokens is not None
515
+ if source is not None:
516
+ return self.forward_speech(
517
+ source=source,
518
+ target_list=target_list,
519
+ padding_mask=padding_mask,
520
+ mask=mask,
521
+ features_only=features_only,
522
+ output_layer=output_layer,
523
+ )
524
+ else:
525
+ return self.forward_text(
526
+ src_tokens=src_tokens,
527
+ src_lengths=src_lengths,
528
+ prev_output_tokens=prev_output_tokens,
529
+ mask=self.mask_u2t,
530
+ features_only=features_only,
531
+ output_layer=output_layer,
532
+ )
533
+
534
+ def forward_speech(
535
+ self,
536
+ source: torch.Tensor = None,
537
+ target_list: Optional[List[torch.Tensor]] = None,
538
+ padding_mask: Optional[torch.Tensor] = None,
539
+ mask: bool = True,
540
+ features_only: bool = False,
541
+ output_layer: Optional[int] = None,
542
+ ) -> Dict[str, torch.Tensor]:
543
+ """output layer is 1-based"""
544
+ features = self.forward_features(source)
545
+ if target_list is not None:
546
+ features, target_list = self.forward_targets(features, target_list)
547
+
548
+ features_pen = features.float().pow(2).mean()
549
+
550
+ features = features.transpose(1, 2)
551
+ features = self.layer_norm(features)
552
+ unmasked_features = features.clone()
553
+
554
+ if padding_mask is not None:
555
+ padding_mask = self.forward_padding_mask(features, padding_mask)
556
+
557
+ if self.post_extract_proj is not None:
558
+ features = self.post_extract_proj(features)
559
+
560
+ features = self.dropout_input(features)
561
+ unmasked_features = self.dropout_features(unmasked_features)
562
+
563
+ if mask:
564
+ x, mask_indices = self.apply_mask(features, padding_mask, target_list)
565
+ else:
566
+ x = features
567
+ mask_indices = None
568
+
569
+ # feature: (B, T, D), float
570
+ # target: (B, T), long
571
+ # x: (B, T, D), float
572
+ # padding_mask: (B, T), bool
573
+ # mask_indices: (B, T), bool
574
+ x, _ = self.encoder(
575
+ x,
576
+ padding_mask=padding_mask,
577
+ layer=None if output_layer is None else output_layer - 1,
578
+ )
579
+
580
+ if features_only:
581
+ return {"x": x, "padding_mask": padding_mask, "features": features}
582
+
583
+ logit_m_list, logit_u_list = self.compute_hubert_logits(
584
+ x,
585
+ target_list[0],
586
+ self.final_proj_list[0],
587
+ self.label_embs_list[0],
588
+ padding_mask,
589
+ mask_indices,
590
+ )
591
+
592
+ result = {
593
+ "logit_m_list": logit_m_list,
594
+ "logit_u_list": logit_u_list,
595
+ "padding_mask": padding_mask,
596
+ "features_pen": features_pen,
597
+ }
598
+
599
+ if self.add_unit_encoder:
600
+ src_tokens, x_emb, l2_loss = self.convert_embeddings(
601
+ x,
602
+ padding_mask, target_list[0],
603
+ mask_indices=mask_indices,
604
+ mix_with_unit=self.mix_with_unit,
605
+ use_pred_unit=self.use_pred_unit,
606
+ l2_embedding=self.l2_embedding,
607
+ )
608
+ encoder_out = self.unit_encoder(src_tokens, token_embeddings=x_emb)
609
+
610
+ result['encoder_out'] = encoder_out['encoder_out'] # [(T, B, D)]
611
+ result['encoder_padding_mask'] = encoder_out['encoder_padding_mask'] # [(B, T)]
612
+ if self.l2_embedding:
613
+ result['embedding_l2_loss'] = l2_loss
614
+
615
+ code_logit_m_list, code_logit_u_list = self.compute_ce_logits(
616
+ encoder_out['encoder_out'][0].transpose(0, 1), # -> (B, T, C)
617
+ target_list[-1],
618
+ self.final_proj_list[1],
619
+ padding_mask,
620
+ mask_indices,
621
+ )
622
+ result['logit_m_list'] += code_logit_m_list
623
+ result['logit_u_list'] += code_logit_u_list
624
+ return result
625
+
626
+ def forward_text(
627
+ self,
628
+ src_tokens: torch.Tensor = None,
629
+ src_lengths: torch.Tensor = None,
630
+ prev_output_tokens: torch.Tensor = None,
631
+ target_list: Optional[List[torch.Tensor]] = None,
632
+ mask: bool = True,
633
+ features_only: bool = False,
634
+ output_layer: Optional[int] = None,
635
+ ) -> Dict[str, torch.Tensor]:
636
+ assert self.add_unit_encoder, f"Can not forward unit-text branch without unit_encoder!"
637
+
638
+ padding_mask = src_tokens == self.padding_idx
639
+ unit_embeddings = self.unit_embed_tokens(src_tokens)
640
+ if mask:
641
+ unit_embeddings, mask_indices = self.apply_mask(unit_embeddings, padding_mask, [src_tokens])
642
+
643
+ encoder_out = self.unit_encoder(
644
+ src_tokens,
645
+ token_embeddings=unit_embeddings,
646
+ return_all_hiddens=output_layer is not None,
647
+ )
648
+
649
+ result = {}
650
+ result["encoder_out"] = encoder_out["encoder_out"]
651
+ result["encoder_states"] = encoder_out["encoder_states"]
652
+ result["padding_mask"] = padding_mask
653
+
654
+ if self.add_text_ctc:
655
+ result["encoder_out_ctc"] = [self.unit_encoder_ctc_head(x) for x in encoder_out['encoder_out']]
656
+ result["encoder_padding_mask"] = [
657
+ self.downsample_ctc_padding_mask(padding_mask) for padding_mask in encoder_out['encoder_padding_mask']
658
+ ]
659
+
660
+ if features_only:
661
+ return result
662
+ if self.add_decoder:
663
+ assert prev_output_tokens is not None
664
+ decoder_out = self.decoder(
665
+ prev_output_tokens=prev_output_tokens, encoder_out=encoder_out,
666
+ )
667
+ result['decoder_out'] = decoder_out
668
+ return result
669
+
670
+ def forward_mum(self, src_tokens, target, mask=True):
671
+ target_list = [target]
672
+ padding_mask = src_tokens.eq(self.unit_encoder.padding_idx)
673
+ unit_embeddings = self.unit_embed_tokens(src_tokens)
674
+ if mask:
675
+ unit_embeddings, mask_indices = self.apply_mask(unit_embeddings, padding_mask, target_list)
676
+ else:
677
+ ### If already applied mask on src_tokens, then the target_list should contains many padding_idx
678
+ mask_indices = target_list[-1] != self.padding_idx
679
+ unit_embeddings[mask_indices] = self.mask_emb
680
+
681
+ encoder_out = self.unit_encoder(
682
+ src_tokens,
683
+ token_embeddings=unit_embeddings,
684
+ )
685
+ code_logit_m_list, code_logit_u_list = self.compute_ce_logits(
686
+ encoder_out["encoder_out"][0].transpose(0, 1),
687
+ target_list[-1],
688
+ self.final_proj_list[1],
689
+ padding_mask,
690
+ mask_indices,
691
+ )
692
+ result = {}
693
+ result["logit_m_list"] = code_logit_m_list
694
+ result["logit_u_list"] = code_logit_u_list
695
+ result["padding_mask"] = padding_mask
696
+ return result
697
+
698
+ def extract_features(
699
+ self,
700
+ source: torch.Tensor,
701
+ padding_mask: Optional[torch.Tensor] = None,
702
+ mask: bool = False,
703
+ ret_conv: bool = False,
704
+ output_layer: Optional[int] = None,
705
+ **kwargs,
706
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
707
+ """Extract encoder features for only speech input"""
708
+ res = self.forward(
709
+ source,
710
+ padding_mask=padding_mask,
711
+ mask=mask,
712
+ features_only=True,
713
+ output_layer=output_layer,
714
+ )
715
+ x = res["x"] # B x T x D
716
+ padding_mask = res["padding_mask"]
717
+
718
+ if self.add_unit_encoder:
719
+ src_tokens, x, _ = self.convert_embeddings(
720
+ x,
721
+ padding_mask,
722
+ mix_with_unit=False,
723
+ use_pred_unit=False,
724
+ )
725
+ encoder_out = self.unit_encoder(
726
+ src_tokens,
727
+ token_embeddings=x,
728
+ return_all_hiddens=output_layer is not None
729
+ )
730
+ res["x"] = encoder_out['encoder_out'][0].transpose(0, 1) # (B, T, D)
731
+
732
+ feature = res["features"] if ret_conv else res["x"]
733
+ if output_layer is not None:
734
+ feature = encoder_out['encoder_states']
735
+
736
+ return feature, padding_mask
737
+
738
+ def get_logits(self, net_output, is_masked=True):
739
+ if is_masked:
740
+ logits_list = net_output["logit_m_list"]
741
+ else:
742
+ logits_list = net_output["logit_u_list"]
743
+ logits_list = [x[0].float() for x in logits_list if x is not None]
744
+ return logits_list
745
+
746
+ def get_targets(self, net_output, is_masked=True):
747
+ if is_masked:
748
+ logits_list = net_output["logit_m_list"]
749
+ else:
750
+ logits_list = net_output["logit_u_list"]
751
+ targets_list = [x[1].long() for x in logits_list if x is not None]
752
+ return targets_list
753
+
754
+ def get_extra_losses(self, net_output):
755
+ extra_losses = []
756
+ names = []
757
+
758
+ if "features_pen" in net_output:
759
+ extra_losses.append(net_output["features_pen"])
760
+ names.append("features_pen")
761
+
762
+ if "embedding_l2_loss" in net_output:
763
+ extra_losses.append(net_output["embedding_l2_loss"])
764
+ names.append("embedding_l2_loss")
765
+
766
+ return extra_losses, names
767
+
768
+ def remove_pretraining_modules(self, step2=False):
769
+ self.target_glu = None
770
+
771
+ def load_checkpoint(self, checkpoint: str):
772
+ if not PathManager.exists(checkpoint):
773
+ raise IOError("Model file not found: {}".format(checkpoint))
774
+ state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint)
775
+ return state
776
+
777
+ class Rotate3D(nn.Module):
778
+ """
779
+ (T, B, D) --> (B, D, T) --> (D, T, B) --> (T, B, D)
780
+ """
781
+ def __init__(self):
782
+ super().__init__()
783
+
784
+ def forward(self, x):
785
+ return x.permute(1, 2, 0)
SpeechT5/Speech2S/speech2s/models/speechut_asr.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----------------------------------------------------------------------------
2
+ # SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
4
+ # Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
5
+ #
6
+ # Copyright (c) 2022 Microsoft
7
+ # Licensed under The MIT License [see LICENSE for details]
8
+ # ----------------------------------------------------------------------------
9
+
10
+ import contextlib
11
+ import torch
12
+ from dataclasses import dataclass, field
13
+ from fairseq import utils
14
+ from fairseq.models import BaseFairseqModel, register_model
15
+ from fairseq.models.fairseq_encoder import FairseqEncoder
16
+ from fairseq.models.hubert import HubertAsrConfig, HubertEncoder
17
+ from fairseq.tasks import FairseqTask
18
+
19
+ @dataclass
20
+ class SpeechUTASRConfig(HubertAsrConfig):
21
+ add_decoder: bool = field(
22
+ default=True,
23
+ metadata={"help": "add decoder for fine-tune"},
24
+ )
25
+
26
+ @register_model("speechut_asr", dataclass=SpeechUTASRConfig)
27
+ class SpeechUTASR(BaseFairseqModel):
28
+ """
29
+ A encoder-ctc-decoder model if cfg.add_decoder is True, or a encoder-ctc model
30
+ """
31
+ def __init__(self, cfg: SpeechUTASRConfig, encoder: FairseqEncoder):
32
+ super().__init__()
33
+ self.cfg = cfg
34
+ self.encoder = encoder
35
+ if not cfg.add_decoder:
36
+ self.encoder.w2v_model.decoder = None
37
+
38
+ def upgrade_state_dict_named(self, state_dict, name):
39
+ super().upgrade_state_dict_named(state_dict, name)
40
+ return state_dict
41
+
42
+ @classmethod
43
+ def build_model(cls, cfg: SpeechUTASRConfig, task: FairseqTask):
44
+ """Build a new model instance."""
45
+ encoder = SpeechUTEncoder(cfg, task)
46
+ return cls(cfg, encoder)
47
+
48
+ def forward(self, source, padding_mask, prev_output_tokens, **kwargs):
49
+ encoder_out = self.encoder(source, padding_mask, **kwargs)
50
+
51
+ x = self.encoder.final_dropout(encoder_out['encoder_out'][0]) # (T, B, C)
52
+ if self.encoder.proj:
53
+ x = self.encoder.proj(x)
54
+ if self.encoder.conv_ctc_proj:
55
+ padding_mask = self.encoder.w2v_model.downsample_ctc_padding_mask(encoder_out["encoder_padding_mask"][0])
56
+ else:
57
+ padding_mask = encoder_out["encoder_padding_mask"]
58
+
59
+ decoder_out = self.decoder(
60
+ prev_output_tokens, encoder_out=encoder_out, **kwargs
61
+ ) if self.cfg.add_decoder else None
62
+
63
+ return {
64
+ "encoder_out_ctc": x, # (T, B, C), for CTC loss
65
+ "padding_mask": padding_mask, # (B, T), for CTC loss
66
+ "decoder_out": decoder_out, # for ED loss
67
+ }
68
+
69
+ def forward_decoder(self, prev_output_tokens, **kwargs):
70
+ return self.decoder(prev_output_tokens, **kwargs)
71
+
72
+ def get_logits(self, net_output):
73
+ """For CTC decoding"""
74
+ logits = net_output["encoder_out"]
75
+ padding = net_output["encoder_padding_mask"]
76
+ if padding is not None and padding.any():
77
+ padding = padding.T
78
+ logits[padding][..., 0] = 0
79
+ logits[padding][..., 1:] = float("-inf")
80
+
81
+ return logits
82
+
83
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
84
+ """For 1) computing CTC loss, 2) decoder decoding."""
85
+
86
+ if "encoder_out_ctc" in net_output:
87
+ logits = net_output["encoder_out_ctc"]
88
+ else:
89
+ return self.decoder.get_normalized_probs(net_output, log_probs, sample)
90
+
91
+ if isinstance(logits, list):
92
+ logits = logits[0]
93
+
94
+ if log_probs:
95
+ return utils.log_softmax(logits.float(), dim=-1)
96
+ else:
97
+ return utils.softmax(logits.float(), dim=-1)
98
+
99
+ @property
100
+ def decoder(self):
101
+ return self.encoder.w2v_model.decoder
102
+
103
+
104
+ class SpeechUTEncoder(HubertEncoder):
105
+ """
106
+ Modified from fairseq.models.hubert.hubert_asr.HubertEncoder
107
+ 1. make it compatible with encoder-decoder model
108
+ """
109
+ def __init__(self, cfg: HubertAsrConfig, task):
110
+ super().__init__(cfg, task)
111
+
112
+ if (task.target_dictionary is not None) and (
113
+ hasattr(self.w2v_model, "unit_encoder_ctc_head")
114
+ ):
115
+ self.proj = self.w2v_model.unit_encoder_ctc_head
116
+ self.conv_ctc_proj = True
117
+ else:
118
+ self.conv_ctc_proj = False
119
+
120
+ def forward(self, source, padding_mask, tbc=True, **kwargs):
121
+ w2v_args = {
122
+ "source": source,
123
+ "padding_mask": padding_mask,
124
+ "mask": self.apply_mask and self.training,
125
+ }
126
+ ft = self.freeze_finetune_updates <= self.num_updates
127
+ with torch.no_grad() if not ft else contextlib.ExitStack():
128
+ x, padding_mask = self.w2v_model.extract_features(**w2v_args)
129
+ if tbc:
130
+ # B x T x C -> T x B x C
131
+ x = x.transpose(0, 1)
132
+ return {
133
+ "encoder_out": [x], # T x B x C
134
+ "encoder_padding_mask": [padding_mask], # B x T
135
+ }
136
+
137
+ def forward_torchscript(self, net_input):
138
+ """A TorchScript-compatible version of forward.
139
+
140
+ Forward the encoder out.
141
+ """
142
+ x, padding_mask = self.w2v_model.extract_features(**net_input, mask=False)
143
+ # B x T x C -> T x B x C
144
+ x = x.transpose(0, 1)
145
+
146
+ encoder_out = {
147
+ "encoder_out" : [x],
148
+ "encoder_padding_mask" : [padding_mask],
149
+ }
150
+ if self.proj:
151
+ x = self.proj(x)
152
+ encoder_out["encoder_out_ctc"] = x
153
+
154
+ return encoder_out
155
+
156
+ def reorder_encoder_out(self, encoder_out, new_order):
157
+ if encoder_out["encoder_out"] is not None:
158
+ encoder_out["encoder_out"] = [
159
+ x.index_select(1, new_order) for x in encoder_out["encoder_out"]
160
+ ]
161
+ if encoder_out["encoder_padding_mask"] is not None:
162
+ encoder_out["encoder_padding_mask"] = [
163
+ x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]
164
+ ]
165
+ return encoder_out
SpeechT5/Speech2S/speech2s/models/speechut_st.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----------------------------------------------------------------------------
2
+ # SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
4
+ # Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
5
+ #
6
+ # Copyright (c) 2022 Microsoft
7
+ # Licensed under The MIT License [see LICENSE for details]
8
+ # ----------------------------------------------------------------------------
9
+
10
+ import logging
11
+ import contextlib
12
+ import torch
13
+ import torch.nn as nn
14
+ from argparse import Namespace
15
+ from dataclasses import dataclass
16
+ from typing import Any
17
+ from fairseq import checkpoint_utils, tasks
18
+ from fairseq.models import BaseFairseqModel, register_model
19
+ from fairseq.models.fairseq_encoder import FairseqEncoder
20
+ from fairseq.tasks import FairseqTask
21
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
22
+ from fairseq.data.data_utils import lengths_to_padding_mask
23
+
24
+ from fairseq.models.hubert import HubertAsrConfig
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ @dataclass
29
+ class SpeechUTS2TConfig(HubertAsrConfig):
30
+ ### the following config is only for the compatibility to fairseq speech_to_text task
31
+ input_feat_per_channel: Any = None
32
+ input_channels: Any = None
33
+ speaker_to_id: Any = None
34
+
35
+ @register_model("speechut_st_legacy", dataclass=SpeechUTS2TConfig)
36
+ class SpeechUTS2T(BaseFairseqModel):
37
+ """An encoder-decoder model."""
38
+ def __init__(self, cfg: SpeechUTS2TConfig, encoder: FairseqEncoder):
39
+ super().__init__()
40
+ self.cfg = cfg
41
+ self.encoder = encoder
42
+
43
+ def upgrade_state_dict_named(self, state_dict, name):
44
+ super().upgrade_state_dict_named(state_dict, name)
45
+ return state_dict
46
+
47
+ @classmethod
48
+ def build_model(cls, cfg: SpeechUTS2TConfig, task: FairseqTask):
49
+ """Build a new model instance."""
50
+ encoder = SpeechUTEncoder(cfg, task)
51
+ return cls(cfg, encoder)
52
+
53
+ def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
54
+ encoder_out = self.encoder(src_tokens, src_lengths, **kwargs)
55
+ decoder_out = self.encoder.w2v_model.decoder(
56
+ prev_output_tokens, encoder_out=encoder_out, **kwargs
57
+ )
58
+ return decoder_out
59
+
60
+ def forward_decoder(self, prev_output_tokens, **kwargs):
61
+ return self.encoder.w2v_model.decoder(prev_output_tokens, **kwargs)
62
+
63
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
64
+ """For decoder decoding."""
65
+ return self.encoder.w2v_model.decoder.get_normalized_probs(net_output, log_probs, sample)
66
+
67
+ @property
68
+ def decoder(self):
69
+ return self.encoder.w2v_model.decoder
70
+
71
+
72
+ class SpeechUTEncoder(FairseqEncoder):
73
+ """
74
+ Modified from fairseq.models.hubert.hubert_asr.HubertEncoder
75
+ 1. make it compatible with fairseq speech_to_text task
76
+ 2. make it compatible with encoder-decoder model
77
+ """
78
+ def __init__(self, cfg: SpeechUTS2TConfig, task):
79
+ self.apply_mask = cfg.apply_mask
80
+
81
+ arg_overrides = {
82
+ "dropout": cfg.dropout,
83
+ "activation_dropout": cfg.activation_dropout,
84
+ "dropout_input": cfg.dropout_input,
85
+ "attention_dropout": cfg.attention_dropout,
86
+ "mask_length": cfg.mask_length,
87
+ "mask_prob": cfg.mask_prob,
88
+ "mask_selection": cfg.mask_selection,
89
+ "mask_other": cfg.mask_other,
90
+ "no_mask_overlap": cfg.no_mask_overlap,
91
+ "mask_channel_length": cfg.mask_channel_length,
92
+ "mask_channel_prob": cfg.mask_channel_prob,
93
+ "mask_channel_selection": cfg.mask_channel_selection,
94
+ "mask_channel_other": cfg.mask_channel_other,
95
+ "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
96
+ "encoder_layerdrop": cfg.layerdrop,
97
+ "feature_grad_mult": cfg.feature_grad_mult,
98
+ }
99
+
100
+ if cfg.w2v_args is None:
101
+ state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
102
+ w2v_args = state.get("cfg", None)
103
+ if w2v_args is None:
104
+ w2v_args = convert_namespace_to_omegaconf(state["args"])
105
+ cfg.w2v_args = w2v_args
106
+ else:
107
+ state = None
108
+ w2v_args = cfg.w2v_args
109
+ if isinstance(w2v_args, Namespace):
110
+ cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
111
+
112
+ assert task.data_cfg.standardize_audio() == w2v_args.task.normalize, (
113
+ "Fine-tuning works best when data normalization is the same. "
114
+ "Please check that --normalize is set or unset for "
115
+ "both pre-training and here"
116
+ )
117
+
118
+ pretrain_task = tasks.setup_task(w2v_args.task, load_local_states=False)
119
+ assert state is not None and "task_state" in state, f"the stored dictionaries not found in checkpoint!"
120
+ # This will load the stored "dictionaries" object
121
+ pretrain_task.load_state_dict(state["task_state"])
122
+
123
+ model = pretrain_task.build_model(w2v_args.model, from_checkpoint=True)
124
+ if state is not None and not cfg.no_pretrained_weights:
125
+ try:
126
+ model.load_state_dict(state["model"], strict=True)
127
+ except Exception as e:
128
+ logger.warn(e)
129
+ model.load_state_dict(state["model"], strict=False)
130
+
131
+ model.remove_pretraining_modules()
132
+
133
+ super().__init__(pretrain_task.source_dictionary)
134
+
135
+ d = w2v_args.model.encoder_embed_dim
136
+
137
+ self.w2v_model = model
138
+
139
+ self.final_dropout = nn.Dropout(cfg.final_dropout)
140
+ self.freeze_finetune_updates = cfg.freeze_finetune_updates
141
+ self.num_updates = 0
142
+
143
+ def set_num_updates(self, num_updates):
144
+ """Set the number of parameters updates."""
145
+ super().set_num_updates(num_updates)
146
+ self.num_updates = num_updates
147
+
148
+ def forward(self, src_tokens=None, src_lengths=None, **kwargs):
149
+
150
+ w2v_args = {
151
+ "source": src_tokens,
152
+ "padding_mask": lengths_to_padding_mask(src_lengths),
153
+ "mask": self.apply_mask and self.training,
154
+ }
155
+
156
+ ft = self.freeze_finetune_updates <= self.num_updates
157
+
158
+ with torch.no_grad() if not ft else contextlib.ExitStack():
159
+ x, padding_mask = self.w2v_model.extract_features(**w2v_args)
160
+ # B x T x C -> T x B x C
161
+ x = x.transpose(0, 1)
162
+
163
+ return {
164
+ "encoder_out": [x], # T x B x C
165
+ "encoder_padding_mask": [padding_mask], # B x T
166
+ "padding_mask": [padding_mask],
167
+ }
168
+
169
+ def forward_torchscript(self, net_input):
170
+ """A TorchScript-compatible version of forward.
171
+
172
+ Forward the encoder out.
173
+ """
174
+ _net_input = {
175
+ "source": net_input["src_tokens"],
176
+ "padding_mask": lengths_to_padding_mask(net_input["src_lengths"]),
177
+ "mask": False,
178
+ }
179
+
180
+ x, padding_mask = self.w2v_model.extract_features(**_net_input)
181
+ # B x T x C -> T x B x C
182
+ x = x.transpose(0, 1)
183
+
184
+ encoder_out = {
185
+ "encoder_out" : [x],
186
+ "encoder_padding_mask" : [padding_mask],
187
+ }
188
+ return encoder_out
189
+
190
+ def reorder_encoder_out(self, encoder_out, new_order):
191
+ if encoder_out["encoder_out"] is not None:
192
+ encoder_out["encoder_out"] = [
193
+ x.index_select(1, new_order) for x in encoder_out["encoder_out"]
194
+ ]
195
+ if encoder_out["encoder_padding_mask"] is not None:
196
+ encoder_out["encoder_padding_mask"] = [
197
+ x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]
198
+ ]
199
+ return encoder_out
200
+
201
+ def max_positions(self):
202
+ """Maximum input length supported by the encoder."""
203
+ return None
204
+
205
+ def upgrade_state_dict_named(self, state_dict, name):
206
+ return state_dict
207
+
208
+
209
+ def Embedding(num_embeddings, embedding_dim, padding_idx):
210
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
211
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
212
+ nn.init.constant_(m.weight[padding_idx], 0)
213
+ return m
214
+
215
+
216
+ def Linear(in_features, out_features, bias=True):
217
+ m = nn.Linear(in_features, out_features, bias)
218
+ nn.init.xavier_uniform_(m.weight)
219
+ if bias:
220
+ nn.init.constant_(m.bias, 0.0)
221
+ return m
SpeechT5/Speech2S/speech2s/models/t5_transformer_lm.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
3
+ # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ from fairseq.models import (
11
+ register_model_architecture,
12
+ )
13
+ from fairseq.models.transformer_lm import base_lm_architecture
14
+
15
+
16
+ @register_model_architecture(model_name="transformer_lm", arch_name="transformer_lm_t5")
17
+ def transformer_lm_t5(args):
18
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280)
19
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6144)
20
+ args.decoder_layers = getattr(args, "decoder_layers", 20)
21
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
22
+ args.dropout = getattr(args, "dropout", 0.1)
23
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
24
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
25
+ base_lm_architecture(args)
SpeechT5/Speech2S/speech2s/modules/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Copyright (c) 2022 Microsoft
3
+ # Licensed under The MIT License [see LICENSE for details]
4
+ # Based on fairseq code bases
5
+ # https://github.com/facebookresearch/fairseq
6
+ # --------------------------------------------------------
7
+
8
+ from .learned_positional_embedding import LearnedPositionalEmbedding
9
+ from .multihead_attention import MultiheadAttention
10
+ from .relative_pos_enc import RelativePositionalEncoding
11
+ from .transformer_layer import TransformerEncoderLayerBase, TransformerDecoderLayerBase
12
+ from .w2v_encoder import TransformerEncoder, TransformerSentenceEncoderLayer
13
+ from .transformer_encoder import TransformerEncoderBase
14
+ from .transformer_decoder import TransformerDecoderScriptable, TransformerDecoderBaseScriptable
15
+
16
+ __all__ = [
17
+ "MultiheadAttention",
18
+ "RelativePositionalEncoding",
19
+ "LearnedPositionalEmbedding",
20
+ "TransformerEncoderLayerBase",
21
+ "TransformerDecoderLayerBase",
22
+ "TransformerEncoder",
23
+ "TransformerSentenceEncoderLayer",
24
+ "TransformerEncoderBase",
25
+ "TransformerDecoderScriptable",
26
+ "TransformerDecoderBaseScriptable",
27
+ ]
SpeechT5/Speech2S/speech2s/modules/ctc_prefix_score.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+ import numpy as np
7
+ import six
8
+
9
+
10
+ class CTCPrefixScore(object):
11
+ """Compute CTC label sequence scores
12
+ which is based on Algorithm 2 in WATANABE et al.
13
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
14
+ but extended to efficiently compute the probablities of multiple labels
15
+ simultaneously
16
+ """
17
+
18
+ def __init__(self, x, blank, eos, xp):
19
+ self.xp = xp
20
+ self.logzero = -10000000000.0
21
+ self.blank = blank
22
+ self.eos = eos
23
+ self.input_length = len(x)
24
+ self.x = x
25
+
26
+ def initial_state(self):
27
+ """Obtain an initial CTC state
28
+ :return: CTC state
29
+ """
30
+ # initial CTC state is made of a frame x 2 tensor that corresponds to
31
+ # r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
32
+ # superscripts n and b (non-blank and blank), respectively.
33
+ r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
34
+ r[0, 1] = self.x[0, self.blank]
35
+ for i in six.moves.range(1, self.input_length):
36
+ r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
37
+ return r
38
+
39
+ def __call__(self, y, cs, r_prev):
40
+ """Compute CTC prefix scores for next labels
41
+ :param y : prefix label sequence
42
+ :param cs : array of next labels
43
+ :param r_prev: previous CTC state
44
+ :return ctc_scores, ctc_states
45
+ """
46
+ # initialize CTC states
47
+ output_length = len(y) - 1 # ignore sos
48
+ # new CTC states are prepared as a frame x (n or b) x n_labels tensor
49
+ # that corresponds to r_t^n(h) and r_t^b(h).
50
+ r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
51
+ xs = self.x[:, cs]
52
+ if output_length == 0:
53
+ r[0, 0] = xs[0]
54
+ r[0, 1] = self.logzero
55
+ else:
56
+ r[output_length - 1] = self.logzero
57
+
58
+ # prepare forward probabilities for the last label
59
+ r_sum = self.xp.logaddexp(
60
+ r_prev[:, 0], r_prev[:, 1]
61
+ ) # log(r_t^n(g) + r_t^b(g))
62
+ last = y[-1]
63
+ if output_length > 0 and last in cs:
64
+ log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
65
+ for i in six.moves.range(len(cs)):
66
+ log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
67
+ else:
68
+ log_phi = r_sum
69
+
70
+ # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
71
+ # and log prefix probabilities log(psi)
72
+ start = max(output_length, 1)
73
+ log_psi = r[start - 1, 0]
74
+ for t in six.moves.range(start, self.input_length):
75
+ r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
76
+ r[t, 1] = (
77
+ self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
78
+ )
79
+ log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
80
+
81
+ # get P(...eos|X) that ends with the prefix itself
82
+ eos_pos = self.xp.where(cs == self.eos)[0]
83
+ if len(eos_pos) > 0:
84
+ log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
85
+
86
+ # exclude blank probs
87
+ blank_pos = self.xp.where(cs == self.blank)[0]
88
+ if len(blank_pos) > 0:
89
+ log_psi[blank_pos] = self.logzero
90
+
91
+ # return the log prefix probability and CTC states, where the label axis
92
+ # of the CTC states is moved to the first axis to slice it easily
93
+ return log_psi, self.xp.rollaxis(r, 2)
SpeechT5/Speech2S/speech2s/modules/learned_positional_embedding.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Copyright (c) 2022 Microsoft
3
+ # Licensed under The MIT License [see LICENSE for details]
4
+ # Based on fairseq code bases
5
+ # https://github.com/facebookresearch/fairseq
6
+ # --------------------------------------------------------
7
+
8
+ """
9
+ Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/learned_positional_embedding.py
10
+ 1. Add clamping if the input length exceeds the max-source-tokens
11
+ """
12
+
13
+ from typing import Dict, Optional
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from fairseq import utils
19
+ from torch import Tensor
20
+
21
+
22
+ class LearnedPositionalEmbedding(nn.Embedding):
23
+ """
24
+ This module learns positional embeddings up to a fixed maximum size.
25
+ Padding ids are ignored by either offsetting based on padding_idx
26
+ or by setting padding_idx to None and ensuring that the appropriate
27
+ position ids are passed to the forward function.
28
+ """
29
+
30
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
31
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
32
+ self.onnx_trace = False
33
+ if self.padding_idx is not None:
34
+ self.max_positions = self.num_embeddings - self.padding_idx - 1
35
+ else:
36
+ self.max_positions = self.num_embeddings
37
+
38
+ def forward(
39
+ self,
40
+ input: Tensor,
41
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
42
+ positions: Optional[Tensor] = None,
43
+ ):
44
+ """Input is expected to be of size [bsz x seqlen]."""
45
+ assert (positions is None) or (
46
+ self.padding_idx is None
47
+ ), "If positions is pre-computed then padding_idx should not be set."
48
+
49
+ if positions is None:
50
+ if incremental_state is not None:
51
+ # positions is the same for every token when decoding a single step
52
+ # Without the int() cast, it doesn't work in some cases when exporting to ONNX
53
+ positions = torch.zeros(
54
+ (1, 1), device=input.device, dtype=input.dtype
55
+ ).fill_(int(self.padding_idx + input.size(1)))
56
+ else:
57
+ positions = utils.make_positions(
58
+ input, self.padding_idx, onnx_trace=self.onnx_trace
59
+ )
60
+ positions = torch.clamp(positions, max=self.padding_idx + self.max_positions)
61
+ return F.embedding(
62
+ positions,
63
+ self.weight,
64
+ self.padding_idx,
65
+ self.max_norm,
66
+ self.norm_type,
67
+ self.scale_grad_by_freq,
68
+ self.sparse,
69
+ )
SpeechT5/Speech2S/speech2s/modules/multihead_attention.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Copyright (c) 2022 Microsoft
3
+ # Licensed under The MIT License [see LICENSE for details]
4
+ # Based on fairseq code bases
5
+ # https://github.com/facebookresearch/fairseq
6
+ # --------------------------------------------------------
7
+
8
+ from typing import Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from fairseq import utils
13
+ from torch import Tensor
14
+
15
+ from fairseq.modules import MultiheadAttention as FairseqMultiheadAttention
16
+
17
+
18
+ class MultiheadAttention(FairseqMultiheadAttention):
19
+ """Multi-headed attention.
20
+
21
+ See "Attention Is All You Need" for more details.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ embed_dim,
27
+ num_heads,
28
+ kdim=None,
29
+ vdim=None,
30
+ dropout=0.0,
31
+ bias=True,
32
+ add_bias_kv=False,
33
+ add_zero_attn=False,
34
+ self_attention=False,
35
+ encoder_decoder_attention=False,
36
+ q_noise=0.0,
37
+ qn_block_size=8,
38
+ scaling_for_att=1.0
39
+ ):
40
+ super().__init__(
41
+ embed_dim,
42
+ num_heads,
43
+ kdim,
44
+ vdim,
45
+ dropout,
46
+ bias,
47
+ add_bias_kv,
48
+ add_zero_attn,
49
+ self_attention,
50
+ encoder_decoder_attention,
51
+ q_noise,
52
+ qn_block_size,
53
+ )
54
+ self.scaling_for_att = scaling_for_att
55
+
56
+ def forward(
57
+ self,
58
+ query,
59
+ key: Optional[Tensor],
60
+ value: Optional[Tensor],
61
+ key_padding_mask: Optional[Tensor] = None,
62
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
63
+ need_weights: bool = True,
64
+ static_kv: bool = False,
65
+ attn_mask: Optional[Tensor] = None,
66
+ before_softmax: bool = False,
67
+ need_head_weights: bool = False,
68
+ position_bias: Optional[Tensor] = None,
69
+ ) -> Tuple[Tensor, Optional[Tensor]]:
70
+ """Input shape: Time x Batch x Channel
71
+
72
+ Args:
73
+ key_padding_mask (ByteTensor, optional): mask to exclude
74
+ keys that are pads, of shape `(batch, src_len)`, where
75
+ padding elements are indicated by 1s.
76
+ need_weights (bool, optional): return the attention weights,
77
+ averaged over heads (default: False).
78
+ attn_mask (ByteTensor, optional): typically used to
79
+ implement causal attention, where the mask prevents the
80
+ attention from looking forward in time (default: None).
81
+ before_softmax (bool, optional): return the raw attention
82
+ weights and values before the attention softmax.
83
+ need_head_weights (bool, optional): return the attention
84
+ weights for each head. Implies *need_weights*. Default:
85
+ return the average attention weights over all heads.
86
+ """
87
+ if need_head_weights:
88
+ need_weights = True
89
+
90
+ is_tpu = query.device.type == "xla"
91
+
92
+ tgt_len, bsz, embed_dim = query.size()
93
+ src_len = tgt_len
94
+ assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
95
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
96
+ if key is not None:
97
+ src_len, key_bsz, _ = key.size()
98
+ if not torch.jit.is_scripting():
99
+ assert key_bsz == bsz
100
+ assert value is not None
101
+ assert src_len, bsz == value.shape[:2]
102
+
103
+ if (
104
+ not self.onnx_trace
105
+ and not is_tpu # don't use PyTorch version on TPUs
106
+ and incremental_state is None
107
+ and not static_kv
108
+ # A workaround for quantization to work. Otherwise JIT compilation
109
+ # treats bias in linear module as method.
110
+ and not torch.jit.is_scripting()
111
+ and position_bias is None
112
+ ):
113
+ assert key is not None and value is not None
114
+ return F.multi_head_attention_forward(
115
+ query,
116
+ key,
117
+ value,
118
+ self.embed_dim,
119
+ self.num_heads,
120
+ torch.empty([0]),
121
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
122
+ self.bias_k,
123
+ self.bias_v,
124
+ self.add_zero_attn,
125
+ self.dropout_module.p,
126
+ self.out_proj.weight,
127
+ self.out_proj.bias,
128
+ self.training or self.dropout_module.apply_during_inference,
129
+ key_padding_mask,
130
+ need_weights,
131
+ attn_mask,
132
+ use_separate_proj_weight=True,
133
+ q_proj_weight=self.q_proj.weight,
134
+ k_proj_weight=self.k_proj.weight,
135
+ v_proj_weight=self.v_proj.weight,
136
+ )
137
+
138
+ if incremental_state is not None:
139
+ saved_state = self._get_input_buffer(incremental_state)
140
+ if saved_state is not None and "prev_key" in saved_state:
141
+ # previous time steps are cached - no need to recompute
142
+ # key and value if they are static
143
+ if static_kv:
144
+ assert self.encoder_decoder_attention and not self.self_attention
145
+ key = value = None
146
+ else:
147
+ saved_state = None
148
+
149
+ if self.self_attention:
150
+ q = self.q_proj(query)
151
+ k = self.k_proj(query)
152
+ v = self.v_proj(query)
153
+ elif self.encoder_decoder_attention:
154
+ # encoder-decoder attention
155
+ q = self.q_proj(query)
156
+ if key is None:
157
+ assert value is None
158
+ k = v = None
159
+ else:
160
+ k = self.k_proj(key)
161
+ v = self.v_proj(key)
162
+
163
+ else:
164
+ assert key is not None and value is not None
165
+ q = self.q_proj(query)
166
+ k = self.k_proj(key)
167
+ v = self.v_proj(value)
168
+ q *= self.scaling
169
+ q *= (1 / self.scaling_for_att)
170
+
171
+ if self.bias_k is not None:
172
+ assert self.bias_v is not None
173
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
174
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
175
+ if attn_mask is not None:
176
+ attn_mask = torch.cat(
177
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
178
+ )
179
+ if key_padding_mask is not None:
180
+ key_padding_mask = torch.cat(
181
+ [
182
+ key_padding_mask,
183
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
184
+ ],
185
+ dim=1,
186
+ )
187
+
188
+ q = (
189
+ q.contiguous()
190
+ .view(tgt_len, bsz * self.num_heads, self.head_dim)
191
+ .transpose(0, 1)
192
+ )
193
+ if k is not None:
194
+ k = (
195
+ k.contiguous()
196
+ .view(-1, bsz * self.num_heads, self.head_dim)
197
+ .transpose(0, 1)
198
+ )
199
+ if v is not None:
200
+ v = (
201
+ v.contiguous()
202
+ .view(-1, bsz * self.num_heads, self.head_dim)
203
+ .transpose(0, 1)
204
+ )
205
+
206
+ if saved_state is not None:
207
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
208
+ if "prev_key" in saved_state:
209
+ _prev_key = saved_state["prev_key"]
210
+ assert _prev_key is not None
211
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
212
+ if static_kv:
213
+ k = prev_key
214
+ else:
215
+ assert k is not None
216
+ k = torch.cat([prev_key, k], dim=1)
217
+ src_len = k.size(1)
218
+ if "prev_value" in saved_state:
219
+ _prev_value = saved_state["prev_value"]
220
+ assert _prev_value is not None
221
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
222
+ if static_kv:
223
+ v = prev_value
224
+ else:
225
+ assert v is not None
226
+ v = torch.cat([prev_value, v], dim=1)
227
+ prev_key_padding_mask: Optional[Tensor] = None
228
+ if "prev_key_padding_mask" in saved_state:
229
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
230
+ assert k is not None and v is not None
231
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
232
+ key_padding_mask=key_padding_mask,
233
+ prev_key_padding_mask=prev_key_padding_mask,
234
+ batch_size=bsz,
235
+ src_len=k.size(1),
236
+ static_kv=static_kv,
237
+ )
238
+
239
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
240
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
241
+ saved_state["prev_key_padding_mask"] = key_padding_mask
242
+ # In this branch incremental_state is never None
243
+ assert incremental_state is not None
244
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
245
+ assert k is not None
246
+ assert k.size(1) == src_len
247
+
248
+ # This is part of a workaround to get around fork/join parallelism
249
+ # not supporting Optional types.
250
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
251
+ key_padding_mask = None
252
+
253
+ if key_padding_mask is not None:
254
+ assert key_padding_mask.size(0) == bsz
255
+ assert key_padding_mask.size(1) == src_len
256
+
257
+ if self.add_zero_attn:
258
+ assert v is not None
259
+ src_len += 1
260
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
261
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
262
+ if attn_mask is not None:
263
+ attn_mask = torch.cat(
264
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
265
+ )
266
+ if key_padding_mask is not None:
267
+ key_padding_mask = torch.cat(
268
+ [
269
+ key_padding_mask,
270
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
271
+ key_padding_mask
272
+ ),
273
+ ],
274
+ dim=1,
275
+ )
276
+
277
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
278
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
279
+
280
+ if position_bias is not None: ## first order
281
+ ## position_bias: [241, 241, 64]
282
+ #print ("attn_weights: ", attn_weights.size()) # [492, 241, 241]
283
+ reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64]
284
+ #print ("reshape_q: ", reshape_q.size())
285
+ B = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
286
+ #print ("B: ", B.size()) ## [241, 492, 241]
287
+ #B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1))
288
+ B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1))
289
+ #print ("B 2: ", B.size())
290
+ attn_weights += B
291
+
292
+ attn_weights *= self.scaling_for_att
293
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
294
+
295
+ if attn_mask is not None:
296
+ attn_mask = attn_mask.unsqueeze(0)
297
+ if self.onnx_trace:
298
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
299
+ attn_weights += attn_mask
300
+
301
+ if key_padding_mask is not None:
302
+ # don't attend to padding symbols
303
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
304
+ if not is_tpu:
305
+ attn_weights = attn_weights.masked_fill(
306
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
307
+ float("-inf"),
308
+ )
309
+ else:
310
+ attn_weights = attn_weights.transpose(0, 2)
311
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
312
+ attn_weights = attn_weights.transpose(0, 2)
313
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
314
+
315
+ if self.scaling_for_att > 1.0:
316
+ attn_weights = attn_weights - attn_weights.detach().max(dim=-1, keepdim=True)[0]
317
+
318
+ if before_softmax:
319
+ return attn_weights, v
320
+
321
+ attn_weights_float = utils.softmax(
322
+ attn_weights, dim=-1, onnx_trace=self.onnx_trace
323
+ )
324
+ attn_weights = attn_weights_float.type_as(attn_weights)
325
+ attn_probs = self.dropout_module(attn_weights)
326
+
327
+ assert v is not None
328
+ attn = torch.bmm(attn_probs, v)
329
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
330
+ if self.onnx_trace and attn.size(1) == 1:
331
+ # when ONNX tracing a single decoder step (sequence length == 1)
332
+ # the transpose is a no-op copy before view, thus unnecessary
333
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
334
+ else:
335
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
336
+ attn = self.out_proj(attn)
337
+ attn_weights: Optional[Tensor] = None
338
+ if need_weights:
339
+ attn_weights = attn_weights_float.view(
340
+ bsz, self.num_heads, tgt_len, src_len
341
+ ).transpose(1, 0)
342
+ if not need_head_weights:
343
+ # average attention weights over heads
344
+ attn_weights = attn_weights.mean(dim=0)
345
+
346
+ return attn, attn_weights
SpeechT5/Speech2S/speech2s/modules/relative_pos_enc.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Copyright (c) 2022 Microsoft
3
+ # Licensed under The MIT License [see LICENSE for details]
4
+ # Based on fairseq code bases
5
+ # https://github.com/facebookresearch/fairseq
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+ class RelativePositionalEncoding(torch.nn.Module):
11
+ def __init__(self, d_model, maxlen=1000, embed_v=False):
12
+ super(RelativePositionalEncoding, self).__init__()
13
+
14
+ self.d_model = d_model
15
+ self.maxlen = maxlen
16
+ self.pe_k = torch.nn.Embedding(2*maxlen, d_model)
17
+ if embed_v:
18
+ self.pe_v = torch.nn.Embedding(2*maxlen, d_model)
19
+ self.embed_v = embed_v
20
+
21
+
22
+ def forward(self, pos_seq, incremental_state=None):
23
+ pos_seq[pos_seq < -self.maxlen] = -self.maxlen
24
+ pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1
25
+ pos_seq = pos_seq + self.maxlen
26
+
27
+ if incremental_state is not None:
28
+ pos_seq = pos_seq[-1:]
29
+
30
+ if self.embed_v:
31
+ return self.pe_k(pos_seq), self.pe_v(pos_seq)
32
+ else:
33
+ return self.pe_k(pos_seq), None