Thefrudi78 commited on
Commit
b0be382
1 Parent(s): 68b29a9

Upload 552 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .editorconfig +11 -0
  2. .gitignore +140 -0
  3. LICENSE +24 -0
  4. README.md +548 -10
  5. api_key.txt +1 -0
  6. constants.py +49 -0
  7. data/models/coqui/.placeholder +2 -0
  8. data/models/rvc/.placeholder +3 -0
  9. data/tmp/.placeholder +2 -0
  10. docker/Dockerfile +35 -0
  11. docker/docker-compose.yml +23 -0
  12. docker/readme.md +10 -0
  13. modules/classify/classify_module.py +41 -0
  14. modules/speech_recognition/streaming_module.py +121 -0
  15. modules/speech_recognition/vosk_module.py +77 -0
  16. modules/speech_recognition/whisper_module.py +56 -0
  17. modules/text_to_speech/coqui/coqui_module.py +333 -0
  18. modules/utils.py +15 -0
  19. modules/voice_conversion/fairseq/LICENSE +21 -0
  20. modules/voice_conversion/fairseq/__init__.py +45 -0
  21. modules/voice_conversion/fairseq/binarizer.py +381 -0
  22. modules/voice_conversion/fairseq/checkpoint_utils.py +905 -0
  23. modules/voice_conversion/fairseq/data/__init__.py +130 -0
  24. modules/voice_conversion/fairseq/data/add_target_dataset.py +83 -0
  25. modules/voice_conversion/fairseq/data/append_token_dataset.py +41 -0
  26. modules/voice_conversion/fairseq/data/audio/__init__.py +93 -0
  27. modules/voice_conversion/fairseq/data/audio/audio_utils.py +389 -0
  28. modules/voice_conversion/fairseq/data/audio/data_cfg.py +387 -0
  29. modules/voice_conversion/fairseq/data/audio/dataset_transforms/__init__.py +53 -0
  30. modules/voice_conversion/fairseq/data/audio/dataset_transforms/concataugment.py +61 -0
  31. modules/voice_conversion/fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py +105 -0
  32. modules/voice_conversion/fairseq/data/audio/feature_transforms/__init__.py +43 -0
  33. modules/voice_conversion/fairseq/data/audio/feature_transforms/delta_deltas.py +37 -0
  34. modules/voice_conversion/fairseq/data/audio/feature_transforms/global_cmvn.py +29 -0
  35. modules/voice_conversion/fairseq/data/audio/feature_transforms/specaugment.py +131 -0
  36. modules/voice_conversion/fairseq/data/audio/feature_transforms/utterance_cmvn.py +41 -0
  37. modules/voice_conversion/fairseq/data/audio/frm_text_to_speech_dataset.py +205 -0
  38. modules/voice_conversion/fairseq/data/audio/hubert_dataset.py +356 -0
  39. modules/voice_conversion/fairseq/data/audio/multi_modality_dataset.py +284 -0
  40. modules/voice_conversion/fairseq/data/audio/raw_audio_dataset.py +393 -0
  41. modules/voice_conversion/fairseq/data/audio/speech_to_speech_dataset.py +379 -0
  42. modules/voice_conversion/fairseq/data/audio/speech_to_text_dataset.py +733 -0
  43. modules/voice_conversion/fairseq/data/audio/speech_to_text_joint_dataset.py +359 -0
  44. modules/voice_conversion/fairseq/data/audio/text_to_speech_dataset.py +250 -0
  45. modules/voice_conversion/fairseq/data/audio/waveform_transforms/__init__.py +48 -0
  46. modules/voice_conversion/fairseq/data/audio/waveform_transforms/noiseaugment.py +201 -0
  47. modules/voice_conversion/fairseq/data/backtranslation_dataset.py +165 -0
  48. modules/voice_conversion/fairseq/data/base_wrapper_dataset.py +78 -0
  49. modules/voice_conversion/fairseq/data/bucket_pad_length_dataset.py +78 -0
  50. modules/voice_conversion/fairseq/data/codedataset.py +576 -0
.editorconfig ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ root = true
2
+
3
+ [*]
4
+ end_of_line = lf
5
+ insert_final_newline = true
6
+ trim_trailing_whitespace = true
7
+
8
+ [*.{py,js,html,css,scss,md}]
9
+ charset = utf-8
10
+ indent_style = space
11
+ indent_size = 4
.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ data/
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ .python-version
87
+
88
+ # pipenv
89
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
91
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
92
+ # install all needed dependencies.
93
+ #Pipfile.lock
94
+
95
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96
+ __pypackages__/
97
+
98
+ # Celery stuff
99
+ celerybeat-schedule
100
+ celerybeat.pid
101
+
102
+ # SageMath parsed files
103
+ *.sage.py
104
+
105
+ # Environments
106
+ .env
107
+ .venv
108
+ env/
109
+ venv/
110
+ ENV/
111
+ env.bak/
112
+ venv.bak/
113
+
114
+ # Spyder project settings
115
+ .spyderproject
116
+ .spyproject
117
+
118
+ # Rope project settings
119
+ .ropeproject
120
+
121
+ # mkdocs documentation
122
+ /site
123
+
124
+ # mypy
125
+ .mypy_cache/
126
+ .dmypy.json
127
+ dmypy.json
128
+
129
+ # Pyre type checker
130
+ .pyre/
131
+
132
+ debug.png
133
+ test.wav
134
+ /tts_samples
135
+ model.pt
136
+ .DS_Store
137
+ .chroma
138
+ /.chroma_db
139
+ api_key.txt
140
+ .vscode
LICENSE ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This is free and unencumbered software released into the public domain.
2
+
3
+ Anyone is free to copy, modify, publish, use, compile, sell, or
4
+ distribute this software, either in source code form or as a compiled
5
+ binary, for any purpose, commercial or non-commercial, and by any
6
+ means.
7
+
8
+ In jurisdictions that recognize copyright laws, the author or authors
9
+ of this software dedicate any and all copyright interest in the
10
+ software to the public domain. We make this dedication for the benefit
11
+ of the public at large and to the detriment of our heirs and
12
+ successors. We intend this dedication to be an overt act of
13
+ relinquishment in perpetuity of all present and future rights to this
14
+ software under copyright law.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19
+ IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20
+ OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21
+ ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22
+ OTHER DEALINGS IN THE SOFTWARE.
23
+
24
+ For more information, please refer to <https://unlicense.org>
README.md CHANGED
@@ -1,10 +1,548 @@
1
- ---
2
- title: Extra
3
- emoji: 🔥
4
- colorFrom: green
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SillyTavern - Extras
2
+
3
+ ## Recent news
4
+
5
+ * July 25 2023 - Now extras require Python 3.11 to run, some of the modules new will be incompatible with old Python 3.10 installs. To migrate using conda, please remove old environment using `conda remove --name extras --all` and reinstall using the instructions below.
6
+
7
+ ## What is this
8
+ A set of APIs for various SillyTavern extensions.
9
+
10
+ **You need to run the latest version of SillyTavern. Grab it here: [How to install](https://docs.sillytavern.app/installation/windows/), [Git repository](https://github.com/SillyTavern/SillyTavern)**
11
+
12
+ All modules, except for Stable Diffusion, run on the CPU by default. However, they can alternatively be configured to use CUDA (with `--cuda` command line option). When running all modules simultaneously, you can expect a usage of approximately 6 GB of RAM. Loading Stable Diffusion adds an additional couple of GB to the memory usage.
13
+
14
+ Try on Colab (will give you a link to Extras API): <a target="_blank" href="https://colab.research.google.com/github/SillyTavern/SillyTavern/blob/release/colab/GPU.ipynb">
15
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
16
+ </a>
17
+
18
+ Colab link:
19
+ https://colab.research.google.com/github/SillyTavern/SillyTavern/blob/release/colab/GPU.ipynb
20
+
21
+ Documentation:
22
+ https://docs.sillytavern.app/
23
+
24
+ ## How to run
25
+ ### :exclamation: **IMPORTANT!**
26
+ Default **requirements.txt** contains only basic packages for text processing
27
+
28
+ If you want to use the most advanced features (like Stable Diffusion, TTS), change that to **requirements-complete.txt** in commands below. See [Modules](#modules) section for more details.
29
+
30
+ If you run on Apple Silicon (M1/M2), use the **requirements-silicon.txt** file instead.
31
+
32
+ ### Getting an error when installing from requirements-complete.txt?
33
+
34
+ > ERROR: Could not build wheels for hnswlib, which is required to install pyproject.toml-based projects
35
+
36
+ Installing chromadb package requires one of the following:
37
+
38
+ 1. Have Visual C++ build tools installed: https://visualstudio.microsoft.com/visual-cpp-build-tools/
39
+ 2. Installing hnswlib from conda: `conda install -c conda-forge hnswlib`
40
+
41
+ ### Missing modules reported by SillyTavern extensions menu?
42
+
43
+ You must specify a list of module names to be run in the `--enable-modules` command (`caption` provided as an example). See [Modules](#modules) section.
44
+
45
+ ### ☁️ Colab
46
+ * Open colab link
47
+ * Select desired "extra" options and start the cell
48
+ * Wait for it to finish
49
+ * Get an API URL link from colab output under the `### SillyTavern Extensions LINK ###` title
50
+ * Start SillyTavern with extensions support: set `enableExtensions` to `true` in config.conf
51
+ * Navigate to SillyTavern extensions menu and put in an API URL and tap "Connect" to load the extensions
52
+
53
+ ### What about mobile/Android/Termux? 🤔
54
+
55
+ There are some folks in the community having success running Extras on their phones via Ubuntu on Termux. This project wasn't made with mobile support in mind, so this guide is provided strictly for your information only: https://rentry.org/STAI-Termux#downloading-and-running-tai-extras
56
+
57
+ #### ❗ IMPORTANT!
58
+
59
+ We will NOT provide any support for running this on Android. Direct all your questions to the creator of this guide.
60
+
61
+ #### Talkinghead module on Linux
62
+
63
+ It requires the installation of an additional package because it's not installed automatically due to incompatibility with Colab. Run this after you install other requirements:
64
+
65
+ `pip install wxpython==4.2.1`
66
+
67
+ ### 💻 Locally
68
+ #### Option 1 - Conda (recommended) 🐍
69
+
70
+ **PREREQUISITES**
71
+ * Install Miniconda: https://docs.conda.io/en/latest/miniconda.html
72
+ * _(Important!) Read how to use Conda: https://conda.io/projects/conda/en/latest/user-guide/getting-started.html_
73
+ * Install git: https://git-scm.com/downloads
74
+
75
+ **EXECUTE THESE COMMANDS ONE BY ONE IN THE _CONDA COMMAND PROMPT_.**
76
+
77
+ **TYPE/PASTE EACH COMMAND INTO THE PROMPT, HIT ENTER AND WAIT FOR IT TO FINISH!**
78
+
79
+ * Before the first run, create an environment (let's call it `extras`):
80
+ ```
81
+ conda create -n extras
82
+ ```
83
+ * Now activate the newly created env
84
+ ```
85
+ conda activate extras
86
+ ```
87
+ * Install Python 3.11
88
+ ```
89
+ conda install python=3.11
90
+ ```
91
+ * Install the required system packages
92
+ ```
93
+ conda install git
94
+ ```
95
+ * Clone this repository
96
+ ```
97
+ git clone https://github.com/SillyTavern/SillyTavern-extras
98
+ ```
99
+ * Navigated to the freshly cloned repository
100
+ ```
101
+ cd SillyTavern-extras
102
+ ```
103
+ * Install the project requirements
104
+ ```
105
+ pip install -r requirements.txt
106
+ ```
107
+ * Run the Extensions API server
108
+ ```
109
+ python server.py --enable-modules=caption,summarize,classify
110
+ ```
111
+ * Copy the Extra's server API URL listed in the console window after it finishes loading up. On local installs, this defaults to `http://localhost:5100`.
112
+ * Open your SillyTavern config.conf file (located in the base install folder), and look for a line "`const enableExtensions`". Make sure that line has "`= true`", and not "`= false`".
113
+ * Start your SillyTavern server
114
+ * Open the Extensions panel (via the 'Stacked Blocks' icon at the top of the page), paste the API URL into the input box, and click "Connect" to connect to the Extras extension server.
115
+ * To run again, simply activate the environment and run these commands. Be sure to the additional options for server.py (see below) that your setup requires.
116
+ ```
117
+ conda activate extras
118
+ python server.py
119
+ ```
120
+
121
+ #### Option 2 - Vanilla 🍦
122
+ * Install Python 3.11: https://www.python.org/downloads/release/python-3114/
123
+ * Install git: https://git-scm.com/downloads
124
+ * Clone the repo:
125
+ ```
126
+ git clone https://github.com/SillyTavern/SillyTavern-extras
127
+ cd SillyTavern-extras
128
+ ```
129
+ * Run `python -m pip install -r requirements.txt`
130
+ * Run `python server.py --enable-modules=caption,summarize,classify`
131
+ * Get the API URL. Defaults to `http://localhost:5100` if you run locally.
132
+ * Start SillyTavern with extensions support: set `enableExtensions` to `true` in config.conf
133
+ * Navigate to the SillyTavern extensions menu and put in an API URL and tap "Connect" to load the extensions
134
+
135
+ ## Modules
136
+
137
+ | Name | Description | Included in default requirements.txt |
138
+ | ----------- | --------------------------------- | ------ |
139
+ | `caption` | Image captioning | ✔️ Yes |
140
+ | `summarize` | Text summarization | ✔️ Yes |
141
+ | `classify` | Text sentiment classification | ✔️ Yes |
142
+ | `sd` | Stable Diffusion image generation | :x: No (✔️ remote) |
143
+ | `silero-tts` | [Silero TTS server](https://github.com/ouoertheo/silero-api-server) | :x: No |
144
+ | `edge-tts` | [Microsoft Edge TTS client](https://github.com/rany2/edge-tts) | ✔️ Yes |
145
+ | `coqui-tts` | [Coqui TTS server](https://github.com/coqui-ai/TTS) | :x: No |
146
+ | `chromadb` | Infinity context server | :x: No |
147
+ | `talkinghead` | Talking Head Sprites | :x: No |
148
+
149
+ ## Additional options
150
+ | Flag | Description |
151
+ | ------------------------ | ---------------------------------------------------------------------- |
152
+ | `--enable-modules` | **Required option**. Provide a list of enabled modules.<br>Expects a comma-separated list of module names. See [Modules](#modules)<br>Example: `--enable-modules=caption,sd` |
153
+ | `--port` | Specify the port on which the application is hosted. Default: **5100** |
154
+ | `--listen` | Host the app on the local network |
155
+ | `--share` | Share the app on CloudFlare tunnel |
156
+ | `--secure` | Adds API key authentication requirements. Highly recommended when paired with share! |
157
+ | `--cpu` | Run the models on the CPU instead of CUDA. Enabled by default. |
158
+ | `--mps` or `--m1` | Run the models on Apple Silicon. Only for M1 and M2 processors. |
159
+ | `--cuda` | Uses CUDA (GPU+VRAM) to run modules if it is available. Otherwise, falls back to using CPU. |
160
+ | `--cuda-device` | Specifies a CUDA device to use. Defaults to `cuda:0` (first available GPU). |
161
+ | `--talkinghead-gpu` | Uses GPU for talkinghead (10x FPS increase in animation). |
162
+ | `--coqui-gpu` | Uses GPU for coqui TTS (if available). |
163
+ | `--coqui-model` | If provided, downloads and preloads a coqui TTS model. Default: none.<br>Example: `tts_models/multilingual/multi-dataset/bark` |
164
+ | `--summarization-model` | Load a custom summarization model.<br>Expects a HuggingFace model ID.<br>Default: [Qiliang/bart-large-cnn-samsum-ChatGPT_v3](https://huggingface.co/Qiliang/bart-large-cnn-samsum-ChatGPT_v3) |
165
+ | `--classification-model` | Load a custom sentiment classification model.<br>Expects a HuggingFace model ID.<br>Default (6 emotions): [nateraw/bert-base-uncased-emotion](https://huggingface.co/nateraw/bert-base-uncased-emotion)<br>Other solid option is (28 emotions): [joeddav/distilbert-base-uncased-go-emotions-student](https://huggingface.co/joeddav/distilbert-base-uncased-go-emotions-student)<br>For Chinese language: [touch20032003/xuyuan-trial-sentiment-bert-chinese](https://huggingface.co/touch20032003/xuyuan-trial-sentiment-bert-chinese) |
166
+ | `--captioning-model` | Load a custom captioning model.<br>Expects a HuggingFace model ID.<br>Default: [Salesforce/blip-image-captioning-large](https://huggingface.co/Salesforce/blip-image-captioning-large) |
167
+ | `--embedding-model` | Load a custom text embedding model.<br>Expects a HuggingFace model ID.<br>Default: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) |
168
+ | `--chroma-host` | Specifies a host IP for a remote ChromaDB server. |
169
+ | `--chroma-port` | Specifies an HTTP port for a remote ChromaDB server.<br>Default: `8000` |
170
+ | `--sd-model` | Load a custom Stable Diffusion image generation model.<br>Expects a HuggingFace model ID.<br>Default: [ckpt/anything-v4.5-vae-swapped](https://huggingface.co/ckpt/anything-v4.5-vae-swapped)<br>*Must have VAE pre-baked in PyTorch format or the output will look drab!* |
171
+ | `--sd-cpu` | Force the Stable Diffusion generation pipeline to run on the CPU.<br>**SLOW!** |
172
+ | `--sd-remote` | Use a remote SD backend.<br>**Supported APIs: [sd-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)** |
173
+ | `--sd-remote-host` | Specify the host of the remote SD backend<br>Default: **127.0.0.1** |
174
+ | `--sd-remote-port` | Specify the port of the remote SD backend<br>Default: **7860** |
175
+ | `--sd-remote-ssl` | Use SSL for the remote SD backend<br>Default: **False** |
176
+ | `--sd-remote-auth` | Specify the `username:password` for the remote SD backend (if required) |
177
+
178
+ ## Coqui TTS
179
+
180
+ ### Running on Mac M1
181
+
182
+ #### ImportError: symbol not found
183
+
184
+ If you're getting the following error when running coqui-tts module on M1 Mac:
185
+
186
+ ```
187
+ ImportError: dlopen(/Users/user/.../lib/python3.11/site-packages/MeCab/_MeCab.cpython-311-darwin.so, 0x0002): symbol not found in flat namespace '__ZN5MeCab11createModelEPKc'
188
+ ```
189
+
190
+ Do the following:
191
+
192
+ 1. Install homebrew: https://brew.sh/
193
+ 2. Build and install the `mecab` package
194
+
195
+ ```
196
+ brew install --build-from-source mecab
197
+ ARCHFLAGS='-arch arm64' pip install --no-binary :all: --compile --use-pep517 --no-cache-dir --force mecab-python3
198
+ ```
199
+
200
+ ## ChromaDB
201
+ ChromaDB is a blazing fast and open source database that is used for long-term memory when chatting with characters. It can be run in-memory or on a local server on your LAN.
202
+
203
+ NOTE: You should **NOT** run ChromaDB on a cloud server. There are no methods for authentication (yet), so unless you want to expose an unauthenticated ChromaDB to the world, run this on a local server in your LAN.
204
+
205
+ ### In-memory setup
206
+
207
+ Run the extras server with the `chromadb` module enabled (recommended).
208
+
209
+ ### Remote setup
210
+
211
+ Use this if you want to use ChromaDB with docker or host it remotely. If you don't know what that means and only want to use ChromaDB with ST on your local device, use the 'in-memory' instructions instead.
212
+
213
+ Prerequisites: Docker, Docker compose (make sure you're running in rootless mode with the systemd service enabled if on Linux).
214
+
215
+ Steps:
216
+
217
+ 1. Run `git clone https://github.com/chroma-core/chroma chromadb` and `cd chromadb`
218
+ 2. Run `docker-compose up -d --build` to build ChromaDB. This may take a long time depending on your system
219
+ 3. Once the build process is finished, ChromaDB should be running in the background. You can check with the command `docker ps`
220
+ 4. On your client machine, specify your local server ip in the `--chroma-host` argument (ex. `--chroma-host=192.168.1.10`)
221
+
222
+
223
+ If you are running ChromaDB on the same machine as SillyTavern, you will have to change the port of one of the services. To do this for ChromaDB:
224
+
225
+ 1. Run `docker ps` to get the container ID and then `docker container stop <container ID>`
226
+ 2. Enter the ChromaDB git repository `cd chromadb`
227
+ 3. Open `docker-compose.yml` and look for the line starting with `uvicorn chromadb.app:app`
228
+ 4. Change the `--port` argument to whatever port you want.
229
+ 5. Look for the `ports` category and change the occurrences of `8000` to whatever port you chose in step 4.
230
+ 6. Save and exit. Then run `docker-compose up --detach`
231
+ 7. On your client machine, make sure to specity the `--chroma-port` argument (ex. `--chroma-port=<your-port-here>`) along with the `--chroma-host` argument.
232
+
233
+ ## API Endpoints
234
+ ### Get active list
235
+ `GET /api/modules`
236
+ #### **Input**
237
+ None
238
+ #### **Output**
239
+ ```
240
+ {"modules":["caption", "classify", "summarize"]}
241
+ ```
242
+
243
+ ### Image captioning
244
+ `POST /api/caption`
245
+ #### **Input**
246
+ ```
247
+ { "image": "base64 encoded image" }
248
+ ```
249
+ #### **Output**
250
+ ```
251
+ { "caption": "caption of the posted image" }
252
+ ```
253
+
254
+ ### Text summarization
255
+ `POST /api/summarize`
256
+ #### **Input**
257
+ ```
258
+ { "text": "text to be summarize", "params": {} }
259
+ ```
260
+ #### **Output**
261
+ ```
262
+ { "summary": "summarized text" }
263
+ ```
264
+ #### Optional: `params` object for control over summarization:
265
+ | Name | Default value |
266
+ | --------------------- | ------------------------------------------------------------- |
267
+ | `temperature` | 1.0 |
268
+ | `repetition_penalty` | 1.0 |
269
+ | `max_length` | 500 |
270
+ | `min_length` | 200 |
271
+ | `length_penalty` | 1.5 |
272
+ | `bad_words` | ["\n", '"', "*", "[", "]", "{", "}", ":", "(", ")", "<", ">"] |
273
+
274
+ ### Text sentiment classification
275
+ `POST /api/classify`
276
+ #### **Input**
277
+ ```
278
+ { "text": "text to classify sentiment of" }
279
+ ```
280
+ #### **Output**
281
+ ```
282
+ {
283
+ "classification": [
284
+ {
285
+ "label": "joy",
286
+ "score": 1.0
287
+ },
288
+ {
289
+ "label": "anger",
290
+ "score": 0.7
291
+ },
292
+ {
293
+ "label": "love",
294
+ "score": 0.6
295
+ },
296
+ {
297
+ "label": "sadness",
298
+ "score": 0.5
299
+ },
300
+ {
301
+ "label": "fear",
302
+ "score": 0.4
303
+ },
304
+ {
305
+ "label": "surprise",
306
+ "score": 0.3
307
+ }
308
+ ]
309
+ }
310
+ ```
311
+ > **NOTES**
312
+ > 1. Sorted by descending score order
313
+ > 2. List of categories defined by the summarization model
314
+ > 3. Value range from 0.0 to 1.0
315
+
316
+ ### Stable Diffusion image generation
317
+ `POST /api/image`
318
+ #### **Input**
319
+ ```
320
+ { "prompt": "prompt to be generated", "sampler": "DDIM", "steps": 20, "scale": 6, "model": "model_name" }
321
+ ```
322
+ #### **Output**
323
+ ```
324
+ { "image": "base64 encoded image" }
325
+ ```
326
+ > **NOTES**
327
+ > 1. Only the "prompt" parameter is required
328
+ > 2. Both "sampler" and "model" parameters only work when using a remote SD backend
329
+
330
+ ### Get available Stable Diffusion models
331
+ `GET /api/image/models`
332
+ #### **Output**
333
+ ```
334
+ { "models": [list of all available model names] }
335
+ ```
336
+
337
+ ### Get available Stable Diffusion samplers
338
+ `GET /api/image/samplers`
339
+ #### **Output**
340
+ ```
341
+ { "samplers": [list of all available sampler names] }
342
+ ```
343
+
344
+ ### Get currently loaded Stable Diffusion model
345
+ `GET /api/image/model`
346
+ #### **Output**
347
+ ```
348
+ { "model": "name of the current loaded model" }
349
+ ```
350
+
351
+ ### Load a Stable Diffusion model (remote)
352
+ `POST /api/image/model`
353
+ #### **Input**
354
+ ```
355
+ { "model": "name of the model to load" }
356
+ ```
357
+ #### **Output**
358
+ ```
359
+ { "previous_model": "name of the previous model", "current_model": "name of the newly loaded model" }
360
+ ```
361
+
362
+ ### Generate Silero TTS voice
363
+ `POST /api/tts/generate`
364
+ #### **Input**
365
+ ```
366
+ { "speaker": "speaker voice_id", "text": "text to narrate" }
367
+ ```
368
+ #### **Output**
369
+ WAV audio file.
370
+
371
+ ### Get Silero TTS voices
372
+ `GET /api/tts/speakers`
373
+ #### **Output**
374
+ ```
375
+ [
376
+ {
377
+ "name": "en_0",
378
+ "preview_url": "http://127.0.0.1:5100/api/tts/sample/en_0",
379
+ "voice_id": "en_0"
380
+ }
381
+ ]
382
+ ```
383
+
384
+ ### Get Silero TTS voice sample
385
+ `GET /api/tts/sample/<voice_id>`
386
+ #### **Output**
387
+ WAV audio file.
388
+
389
+ ### Add messages to chromadb
390
+ `POST /api/chromadb`
391
+ #### **Input**
392
+ ```
393
+ {
394
+ "chat_id": "chat1 - 2023-12-31",
395
+ "messages": [
396
+ {
397
+ "id": "633a4bd1-8350-46b5-9ef2-f5d27acdecb7",
398
+ "date": 1684164339877,
399
+ "role": "user",
400
+ "content": "Hello, AI world!",
401
+ "meta": "this is meta"
402
+ },
403
+ {
404
+ "id": "8a2ed36b-c212-4a1b-84a3-0ffbe0896506",
405
+ "date": 1684164411759,
406
+ "role": "assistant",
407
+ "content": "Hello, Hooman!"
408
+ },
409
+ ]
410
+ }
411
+ ```
412
+ #### **Output**
413
+ ```
414
+ { "count": 2 }
415
+ ```
416
+
417
+ ### Query chromadb
418
+ `POST /api/chromadb/query`
419
+ #### **Input**
420
+ ```
421
+ {
422
+ "chat_id": "chat1 - 2023-12-31",
423
+ "query": "Hello",
424
+ "n_results": 2,
425
+ }
426
+ ```
427
+ #### **Output**
428
+ ```
429
+ [
430
+ {
431
+ "id": "633a4bd1-8350-46b5-9ef2-f5d27acdecb7",
432
+ "date": 1684164339877,
433
+ "role": "user",
434
+ "content": "Hello, AI world!",
435
+ "distance": 0.31,
436
+ "meta": "this is meta"
437
+ },
438
+ {
439
+ "id": "8a2ed36b-c212-4a1b-84a3-0ffbe0896506",
440
+ "date": 1684164411759,
441
+ "role": "assistant",
442
+ "content": "Hello, Hooman!",
443
+ "distance": 0.29
444
+ },
445
+ ]
446
+ ```
447
+
448
+ ### Delete the messages from chromadb
449
+ `POST /api/chromadb/purge`
450
+ #### **Input**
451
+ ```
452
+ { "chat_id": "chat1 - 2023-04-12" }
453
+ ```
454
+
455
+ ### Get a list of Edge TTS voices
456
+ `GET /api/edge-tts/list`
457
+ #### **Output**
458
+ ```
459
+ [{'Name': 'Microsoft Server Speech Text to Speech Voice (af-ZA, AdriNeural)', 'ShortName': 'af-ZA-AdriNeural', 'Gender': 'Female', 'Locale': 'af-ZA', 'SuggestedCodec': 'audio-24khz-48kbitrate-mono-mp3', 'FriendlyName': 'Microsoft Adri Online (Natural) - Afrikaans (South Africa)', 'Status': 'GA', 'VoiceTag': {'ContentCategories': ['General'], 'VoicePersonalities': ['Friendly', 'Positive']}}]
460
+ ```
461
+
462
+ ### Generate Edge TTS voice
463
+ `POST /api/edge-tts/generate`
464
+ #### **Input**
465
+ ```
466
+ { "text": "Text to narrate", "voice": "af-ZA-AdriNeural", "rate": 0 }
467
+ ```
468
+ #### **Output**
469
+ MP3 audio file.
470
+
471
+ ### Load a Coqui TTS model
472
+ `GET /api/coqui-tts/load`
473
+ #### **Input**
474
+ _model (string, required): The name of the Coqui TTS model to load.
475
+ _gpu (string, Optional): Use the GPU to load model.
476
+ _progress (string, Optional): Show progress bar in terminal.
477
+ ```
478
+ { "_model": "tts_models--en--jenny--jenny\model.pth" }
479
+ { "_gpu": "False" }
480
+ { "_progress": "True" }
481
+ ```
482
+ #### **Output**
483
+ "Loaded"
484
+
485
+ ### Get a list of Coqui TTS voices
486
+ `GET /api/coqui-tts/list`
487
+ #### **Output**
488
+ ```
489
+ ["tts_models--en--jenny--jenny\\model.pth", "tts_models--en--ljspeech--fast_pitch\\model_file.pth", "tts_models--en--ljspeech--glow-tts\\model_file.pth", "tts_models--en--ljspeech--neural_hmm\\model_file.pth", "tts_models--en--ljspeech--speedy-speech\\model_file.pth", "tts_models--en--ljspeech--tacotron2-DDC\\model_file.pth", "tts_models--en--ljspeech--vits\\model_file.pth", "tts_models--en--ljspeech--vits--neon\\model_file.pth.tar", "tts_models--en--multi-dataset--tortoise-v2", "tts_models--en--vctk--vits\\model_file.pth", "tts_models--et--cv--vits\\model_file.pth.tar", "tts_models--multilingual--multi-dataset--bark", "tts_models--multilingual--multi-dataset--your_tts\\model_file.pth", "tts_models--multilingual--multi-dataset--your_tts\\model_se.pth"]
490
+ ```
491
+
492
+ ### Get a list of the loaded Coqui model speakers
493
+ `GET /api/coqui-tts/multspeaker`
494
+ #### **Output**
495
+ ```
496
+ {"0": "female-en-5", "1": "female-en-5\n", "2": "female-pt-4\n", "3": "male-en-2", "4": "male-en-2\n", "5": "male-pt-3\n"}
497
+ ```
498
+
499
+ ### Get a list of the loaded Coqui model lanagauges
500
+ `GET /api/coqui-tts/multlang`
501
+ #### **Output**
502
+ ```
503
+ {"0": "en", "1": "fr-fr", "2": "pt-br"}
504
+ ```
505
+
506
+ ### Generate Coqui TTS voice
507
+ `POST /api/edge-tts/generate`
508
+ #### **Input**
509
+ ```
510
+ {
511
+ "text": "Text to narrate",
512
+ "speaker_id": "0",
513
+ "mspker": null,
514
+ "language_id": null,
515
+ "style_wav": null
516
+ }
517
+ ```
518
+ #### **Output**
519
+ MP3 audio file.
520
+
521
+ ### Loads a talkinghead character by specifying the character's image URL.
522
+ `GET /api/talkinghead/load`
523
+ #### **Parameters**
524
+ loadchar (string, required): The URL of the character's image. The URL should point to a PNG image.
525
+ { "loadchar": "http://localhost:8000/characters/Aqua.png" }
526
+ #### **Example**
527
+ 'http://localhost:5100/api/talkinghead/load?loadchar=http://localhost:8000/characters/Aqua.png'
528
+ #### **Output**
529
+ 'OK'
530
+
531
+ ### Animates the talkinghead sprite to start talking.
532
+ `GET /api/talkinghead/start_talking`
533
+ #### **Example**
534
+ 'http://localhost:5100/api/talkinghead/start_talking'
535
+ #### **Output**
536
+ "started"
537
+
538
+ ### Animates the talkinghead sprite to stop talking.
539
+ `GET /api/talkinghead/stop_talking`
540
+ #### **Example**
541
+ 'http://localhost:5100/api/talkinghead/stop_talking'
542
+ #### **Output**
543
+ "stopped"
544
+
545
+ ### Outputs the animated talkinghead sprite.
546
+ `GET /api/talkinghead/result_feed`
547
+ #### **Output**
548
+ Animated transparent image
api_key.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ CHANGEME
constants.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Constants
2
+ DEFAULT_CUDA_DEVICE = "cuda:0"
3
+ # Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
4
+ DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
5
+ # Also try: 'joeddav/distilbert-base-uncased-go-emotions-student'
6
+ DEFAULT_CLASSIFICATION_MODEL = "nateraw/bert-base-uncased-emotion"
7
+ # Also try: 'Salesforce/blip-image-captioning-base'
8
+ DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
9
+ DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
10
+ DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
11
+ DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
12
+ DEFAULT_REMOTE_SD_PORT = 7860
13
+ DEFAULT_CHROMA_PORT = 8000
14
+ SILERO_SAMPLES_PATH = "tts_samples"
15
+ SILERO_SAMPLE_TEXT = "The quick brown fox jumps over the lazy dog"
16
+ DEFAULT_SUMMARIZE_PARAMS = {
17
+ "temperature": 1.0,
18
+ "repetition_penalty": 1.0,
19
+ "max_length": 500,
20
+ "min_length": 200,
21
+ "length_penalty": 1.5,
22
+ "bad_words": [
23
+ "\n",
24
+ '"',
25
+ "*",
26
+ "[",
27
+ "]",
28
+ "{",
29
+ "}",
30
+ ":",
31
+ "(",
32
+ ")",
33
+ "<",
34
+ ">",
35
+ "Â",
36
+ "The text ends",
37
+ "The story ends",
38
+ "The text is",
39
+ "The story is",
40
+ ],
41
+ }
42
+
43
+ PROMPT_PREFIX = "best quality, absurdres, "
44
+ NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
45
+ error hands, bad hands, error fingers, bad fingers, missing fingers
46
+ error legs, bad legs, multiple legs, missing legs, error lighting,
47
+ error shadow, error reflection, text, error, extra digit, fewer digits,
48
+ cropped, worst quality, low quality, normal quality, jpeg artifacts,
49
+ signature, watermark, username, blurry"""
data/models/coqui/.placeholder ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Put Coqui models folders here.
2
+ Must contains both a "model.pth" and "config.json" file.
data/models/rvc/.placeholder ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Put RVC models folder here.
2
+ Must have ".pth" file in it
3
+ .index file is optional but could help improve the processing time/quality.
data/tmp/.placeholder ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ This is a temporary file folder.
2
+ May contain RVC input/output file for research purpose.
docker/Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
2
+
3
+ EXPOSE 5100
4
+
5
+ ENV PATH="/root/miniconda3/bin:${PATH}"
6
+ ARG PATH="/root/miniconda3/bin:${PATH}"
7
+
8
+ ENV DEBIAN_FRONTEND noninteractive
9
+ RUN apt-get update && apt-get install -y --no-install-recommends \
10
+ python3 python3-venv wget build-essential
11
+
12
+ RUN wget \
13
+ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
14
+ && mkdir /root/.conda \
15
+ && bash Miniconda3-latest-Linux-x86_64.sh -b \
16
+ && rm -f Miniconda3-latest-Linux-x86_64.sh
17
+
18
+ RUN conda --version
19
+
20
+ RUN conda init
21
+
22
+ RUN conda create -n extras
23
+
24
+ RUN /bin/bash -c "source activate extras"
25
+
26
+ RUN conda install pytorch torchvision torchaudio pytorch-cuda=11.7 git -c pytorch -c nvidia -c conda-forge
27
+
28
+ WORKDIR /sillytavern-extras/
29
+ COPY . .
30
+
31
+ ARG REQUIREMENTS
32
+ RUN pip install -r $REQUIREMENTS
33
+
34
+ ARG MODULES
35
+ CMD ["python","server.py","--enable-modules=$MODULES"]
docker/docker-compose.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3"
2
+ services:
3
+ sillytavern-extras:
4
+ runtime: nvidia
5
+ image: cohee1207/sillytavern-extras
6
+ build:
7
+ context: ../
8
+ dockerfile: docker/Dockerfile
9
+ args:
10
+ REQUIREMENTS: requirements.txt
11
+ MODULES: caption,summarize,classify
12
+ # REQUIREMENTS: requirements-complete.txt
13
+ # MODULES: caption,summarize,classify,sd,silero-tts,edge-tts,chromadb
14
+ volumes:
15
+ #- "./chromadb:/chromadb"
16
+ - "./cache:/root/.cache"
17
+ - "./api_key.txt:/sillytavern-extras/api_key.txt:rw"
18
+ ports:
19
+ - "5100:5100"
20
+ environment:
21
+ - NVIDIA_VISIBLE_DEVICES=all
22
+ command: python server.py --enable-modules=caption,summarize,classify
23
+ # command: python server.py --enable-modules=caption,summarize,classify,sd,silero-tts,edge-tts,chromadb
docker/readme.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker Usage
2
+
3
+ ## Building the image
4
+
5
+ *This is assuming you have docker and docker compose installed and running.*
6
+
7
+ 1. Open a terminal and set your current directory to the "docker" directory in your clone of this repo.
8
+ 2. Adjust the "docker-compose.yml" file to match your needs. The default selection and the selection with all modules are provided as examples.
9
+ 3. Once ready, run the command "docker compose build" to build the "cohee1207/sillytavern-extras" docker image.
10
+
modules/classify/classify_module.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Classify module for SillyTavern Extras
3
+
4
+ Authors:
5
+ - Tony Ribeiro (https://github.com/Tony-sama)
6
+ - Cohee (https://github.com/Cohee1207)
7
+
8
+ Provides classification features for text
9
+
10
+ References:
11
+ - https://huggingface.co/tasks/text-classification
12
+ """
13
+
14
+ from transformers import pipeline
15
+
16
+ DEBUG_PREFIX = "<Classify module>"
17
+
18
+ # Models init
19
+
20
+ text_emotion_pipe = None
21
+
22
+ def init_text_emotion_classifier(model_name: str, device: str, torch_dtype: str) -> None:
23
+ global text_emotion_pipe
24
+
25
+ print(DEBUG_PREFIX,"Initializing text classification pipeline with model",model_name)
26
+ text_emotion_pipe = pipeline(
27
+ "text-classification",
28
+ model=model_name,
29
+ top_k=None,
30
+ device=device,
31
+ torch_dtype=torch_dtype,
32
+ )
33
+
34
+
35
+ def classify_text_emotion(text: str) -> list:
36
+ output = text_emotion_pipe(
37
+ text,
38
+ truncation=True,
39
+ max_length=text_emotion_pipe.model.config.max_position_embeddings,
40
+ )[0]
41
+ return sorted(output, key=lambda x: x["score"], reverse=True)
modules/speech_recognition/streaming_module.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speech-to-text module based on Vosk and Whisper for SillyTavern Extras
3
+ - Vosk website: https://alphacephei.com/vosk/
4
+ - Vosk api: https://github.com/alphacep/vosk-api
5
+ - Whisper github: https://github.com/openai/whisper
6
+
7
+ Authors:
8
+ - Tony Ribeiro (https://github.com/Tony-sama)
9
+
10
+ Models are saved into user cache folder, example: C:/Users/toto/.cache/whisper and C:/Users/toto/.cache/vosk
11
+
12
+ References:
13
+ - Code adapted from:
14
+ - whisper github: https://github.com/openai/whisper
15
+ - oobabooga text-generation-webui github: https://github.com/oobabooga/text-generation-webui
16
+ - vosk github: https://github.com/alphacep/vosk-api/blob/master/python/example/test_microphone.py
17
+ """
18
+ from flask import jsonify, abort
19
+
20
+ import queue
21
+ import sys
22
+ import sounddevice as sd
23
+ import soundfile as sf
24
+ import io
25
+ import numpy as np
26
+ from scipy.io.wavfile import write
27
+
28
+ import vosk
29
+ import whisper
30
+
31
+ DEBUG_PREFIX = "<stt streaming module>"
32
+ RECORDING_FILE_PATH = "stt_test.wav"
33
+
34
+ whisper_model = None
35
+ vosk_model = None
36
+ device = None
37
+
38
+ def load_model(file_path=None):
39
+ """
40
+ Load given vosk model from file or default to en-us model.
41
+ Download model to user cache folder, example: C:/Users/toto/.cache/vosk
42
+ """
43
+
44
+ if file_path is None:
45
+ return (whisper.load_model("base.en"), vosk.Model(lang="en-us"))
46
+ else:
47
+ return (whisper.load_model(file_path), vosk.Model(lang="en-us"))
48
+
49
+ def convert_bytearray_to_wav_ndarray(input_bytearray: bytes, sampling_rate=16000):
50
+ """
51
+ Convert a bytearray to wav format to output in a file for quality check debuging
52
+ """
53
+ bytes_wav = bytes()
54
+ byte_io = io.BytesIO(bytes_wav)
55
+ write(byte_io, sampling_rate, np.frombuffer(input_bytearray, dtype=np.int16))
56
+ output_wav = byte_io.read()
57
+ output, _ = sf.read(io.BytesIO(output_wav))
58
+ return output
59
+
60
+ def record_and_transcript():
61
+ """
62
+ Continuously record from mic and transcript voice.
63
+ Return the transcript once no more voice is detected.
64
+ """
65
+ if whisper_model is None:
66
+ print(DEBUG_PREFIX,"Whisper model not initialized yet.")
67
+ return ""
68
+
69
+ q = queue.Queue()
70
+ stream_errors = list()
71
+
72
+ def callback(indata, frames, time, status):
73
+ """This is called (from a separate thread) for each audio block."""
74
+ if status:
75
+ print(status, file=sys.stderr)
76
+ stream_errors.append(status)
77
+ q.put(bytes(indata))
78
+
79
+ try:
80
+ device_info = sd.query_devices(device, "input")
81
+ # soundfile expects an int, sounddevice provides a float:
82
+ samplerate = int(device_info["default_samplerate"])
83
+
84
+ print(DEBUG_PREFIX, "Start recording from:", device_info["name"], "with samplerate", samplerate)
85
+
86
+ with sd.RawInputStream(samplerate=samplerate, blocksize = 8000, device=device, dtype="int16", channels=1, callback=callback):
87
+
88
+ rec = vosk.KaldiRecognizer(vosk_model, samplerate)
89
+ full_recording = bytearray()
90
+ while True:
91
+ data = q.get()
92
+ if len(stream_errors) > 0:
93
+ raise Exception(DEBUG_PREFIX+" Stream errors: "+str(stream_errors))
94
+
95
+ full_recording.extend(data)
96
+
97
+ if rec.AcceptWaveform(data):
98
+ # Extract transcript string
99
+ transcript = rec.Result()[14:-3]
100
+ print(DEBUG_PREFIX, "Transcripted from microphone stream (vosk):", transcript)
101
+
102
+ # ----------------------------------
103
+ # DEBUG: save recording to wav file
104
+ # ----------------------------------
105
+ output_file = convert_bytearray_to_wav_ndarray(input_bytearray=full_recording, sampling_rate=samplerate)
106
+ sf.write(file=RECORDING_FILE_PATH, data=output_file, samplerate=samplerate)
107
+ print(DEBUG_PREFIX, "Recorded message saved to", RECORDING_FILE_PATH)
108
+
109
+ # Whisper HACK
110
+ result = whisper_model.transcribe(RECORDING_FILE_PATH)
111
+ transcript = result["text"]
112
+ print(DEBUG_PREFIX, "Transcripted from audio file (whisper):", transcript)
113
+ # ----------------------------------
114
+
115
+ return jsonify({"transcript": transcript})
116
+ #else:
117
+ # print(rec.PartialResult())
118
+
119
+ except Exception as e: # No exception observed during test but we never know
120
+ print(e)
121
+ abort(500, DEBUG_PREFIX+" Exception occurs while recording")
modules/speech_recognition/vosk_module.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speech-to-text module based on Vosk for SillyTavern Extras
3
+ - Vosk website: https://alphacephei.com/vosk/
4
+ - Vosk api: https://github.com/alphacep/vosk-api
5
+
6
+ Authors:
7
+ - Tony Ribeiro (https://github.com/Tony-sama)
8
+
9
+ Models are saved into user cache folder, example: C:/Users/toto/.cache/vosk
10
+
11
+ References:
12
+ - Code adapted from: https://github.com/alphacep/vosk-api/blob/master/python/example/test_simple.py
13
+ """
14
+ from flask import jsonify, abort, request
15
+
16
+ import wave
17
+ from vosk import Model, KaldiRecognizer, SetLogLevel
18
+ import soundfile
19
+
20
+ DEBUG_PREFIX = "<stt vosk module>"
21
+ RECORDING_FILE_PATH = "stt_test.wav"
22
+
23
+ model = None
24
+
25
+ SetLogLevel(-1)
26
+
27
+ def load_model(file_path=None):
28
+ """
29
+ Load given vosk model from file or default to en-us model.
30
+ Download model to user cache folder, example: C:/Users/toto/.cache/vosk
31
+ """
32
+
33
+ if file_path is None:
34
+ return Model(lang="en-us")
35
+ else:
36
+ return Model(file_path)
37
+
38
+ def process_audio():
39
+ """
40
+ Transcript request audio file to text using Whisper
41
+ """
42
+
43
+ if model is None:
44
+ print(DEBUG_PREFIX,"Vosk model not initialized yet.")
45
+ return ""
46
+
47
+ try:
48
+ file = request.files.get('AudioFile')
49
+ file.save(RECORDING_FILE_PATH)
50
+
51
+ # Read and rewrite the file with soundfile
52
+ data, samplerate = soundfile.read(RECORDING_FILE_PATH)
53
+ soundfile.write(RECORDING_FILE_PATH, data, samplerate)
54
+
55
+ wf = wave.open(RECORDING_FILE_PATH, "rb")
56
+ if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE":
57
+ print("Audio file must be WAV format mono PCM.")
58
+ abort(500, DEBUG_PREFIX+" Audio file must be WAV format mono PCM.")
59
+
60
+ rec = KaldiRecognizer(model, wf.getframerate())
61
+ #rec.SetWords(True)
62
+ #rec.SetPartialWords(True)
63
+
64
+ while True:
65
+ data = wf.readframes(4000)
66
+ if len(data) == 0:
67
+ break
68
+ if rec.AcceptWaveform(data):
69
+ break
70
+
71
+ transcript = rec.Result()[14:-3]
72
+ print(DEBUG_PREFIX, "Transcripted from request audio file:", transcript)
73
+ return jsonify({"transcript": transcript})
74
+
75
+ except Exception as e: # No exception observed during test but we never know
76
+ print(e)
77
+ abort(500, DEBUG_PREFIX+" Exception occurs while processing audio")
modules/speech_recognition/whisper_module.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speech-to-text module based on Whisper for SillyTavern Extras
3
+ - Whisper github: https://github.com/openai/whisper
4
+
5
+ Authors:
6
+ - Tony Ribeiro (https://github.com/Tony-sama)
7
+
8
+ Models are saved into user cache folder, example: C:/Users/toto/.cache/whisper
9
+
10
+ References:
11
+ - Code adapted from:
12
+ - whisper github: https://github.com/openai/whisper
13
+ - oobabooga text-generation-webui github: https://github.com/oobabooga/text-generation-webui
14
+ """
15
+ from flask import jsonify, abort, request
16
+
17
+ import whisper
18
+
19
+ DEBUG_PREFIX = "<stt whisper module>"
20
+ RECORDING_FILE_PATH = "stt_test.wav"
21
+
22
+ model = None
23
+
24
+ def load_model(file_path=None):
25
+ """
26
+ Load given vosk model from file or default to en-us model.
27
+ Download model to user cache folder, example: C:/Users/toto/.cache/vosk
28
+ """
29
+
30
+ if file_path is None:
31
+ return whisper.load_model("base.en")
32
+ else:
33
+ return whisper.load_model(file_path)
34
+
35
+ def process_audio():
36
+ """
37
+ Transcript request audio file to text using Whisper
38
+ """
39
+
40
+ if model is None:
41
+ print(DEBUG_PREFIX,"Whisper model not initialized yet.")
42
+ return ""
43
+
44
+ try:
45
+ file = request.files.get('AudioFile')
46
+ file.save(RECORDING_FILE_PATH)
47
+
48
+ result = model.transcribe(RECORDING_FILE_PATH)
49
+ transcript = result["text"]
50
+ print(DEBUG_PREFIX, "Transcripted from audio file (whisper):", transcript)
51
+
52
+ return jsonify({"transcript": transcript})
53
+
54
+ except Exception as e: # No exception observed during test but we never know
55
+ print(e)
56
+ abort(500, DEBUG_PREFIX+" Exception occurs while processing audio")
modules/text_to_speech/coqui/coqui_module.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Coqui module for SillyTavern Extras
3
+
4
+ Authors:
5
+ - Pyrater (https://github.com/pyrater)
6
+ - Tony Ribeiro (https://github.com/Tony-sama)
7
+
8
+ Models are saved into user cache folder: "C:/Users/<username>/AppData/Local/tts"
9
+
10
+ References:
11
+ - Code adapted from:
12
+ - Coqui TTS https://tts.readthedocs.io/en/latest/
13
+ - Audio-webui: https://github.com/gitmylo/audio-webui
14
+ """
15
+ import json
16
+ import os
17
+ import io
18
+ import shutil
19
+
20
+ from flask import abort, request, send_file, jsonify
21
+
22
+ from TTS.api import TTS
23
+ from TTS.utils.manage import ModelManager
24
+
25
+ from modules.utils import silence_log
26
+
27
+ DEBUG_PREFIX = "<Coqui-TTS module>"
28
+ COQUI_MODELS_PATH = "data/models/coqui/"
29
+ IGNORED_FILES = [".placeholder"]
30
+ COQUI_LOCAL_MODEL_FILE_NAME = "model.pth"
31
+ COQUI_LOCAL_CONFIG_FILE_NAME = "config.json"
32
+
33
+ gpu_mode = False
34
+ is_downloading = False
35
+
36
+ def install_model(model_id):
37
+ global gpu_mode
38
+ audio_buffer = io.BytesIO()
39
+ speaker_id = None
40
+ language_id = None
41
+
42
+ print(DEBUG_PREFIX,"Loading model",model_id)
43
+ try:
44
+ tts = TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
45
+
46
+ if tts.is_multi_lingual:
47
+ language_id = tts.languages[0]
48
+
49
+ if tts.is_multi_speaker:
50
+ speaker_id =tts.speakers[0]
51
+
52
+ tts.tts_to_file(text="this is a test message", file_path=audio_buffer, speaker=speaker_id, language=language_id)
53
+ except Exception as e:
54
+ print(DEBUG_PREFIX,"ERROR:", e)
55
+ print("Model", model_id, "cannot be loaded, maybe wrong model name? Must be one of")
56
+ for i in TTS.list_models():
57
+ print(i)
58
+ return False
59
+
60
+ print(DEBUG_PREFIX,"Success")
61
+ return True
62
+
63
+ def coqui_check_model_state():
64
+ """
65
+ Check if the requested model is installed on the server machine
66
+ """
67
+ try:
68
+ model_state = "absent"
69
+ request_json = request.get_json()
70
+ model_id = request_json["model_id"]
71
+
72
+ print(DEBUG_PREFIX,"Search for model", model_id)
73
+
74
+ coqui_models_folder = ModelManager().output_prefix # models location
75
+
76
+ # Check if tts folder exist
77
+ if os.path.isdir(coqui_models_folder):
78
+
79
+ installed_models = os.listdir(coqui_models_folder)
80
+
81
+ model_folder_exists = False
82
+ model_folder = None
83
+
84
+ for i in installed_models:
85
+ if model_id == i.replace("--","/",3): # Error with model wrong name
86
+ model_folder_exists = True
87
+ model_folder = i
88
+ print(DEBUG_PREFIX,"Folder found:",model_folder)
89
+
90
+ # Check failed download
91
+ if model_folder_exists:
92
+ content = os.listdir(os.path.join(coqui_models_folder,model_folder))
93
+ print(DEBUG_PREFIX,"Checking content:",content)
94
+ for i in content:
95
+ if i == model_folder+".zip":
96
+ print("Corrupt installed found, model download must have failed previously")
97
+ model_state = "corrupted"
98
+ break
99
+
100
+ if model_state != "corrupted":
101
+ model_state = "installed"
102
+
103
+ response = json.dumps({"model_state":model_state})
104
+ return response
105
+
106
+ except Exception as e:
107
+ print(e)
108
+ abort(500, DEBUG_PREFIX + " Exception occurs while trying to search for installed model")
109
+
110
+ def coqui_install_model():
111
+ """
112
+ Install requested model is installed on the server machine
113
+ """
114
+ global gpu_mode
115
+ global is_downloading
116
+
117
+ try:
118
+ model_installed = False
119
+ request_json = request.get_json()
120
+ model_id = request_json["model_id"]
121
+ action = request_json["action"]
122
+
123
+ print(DEBUG_PREFIX,"Received request",action,"for model",model_id)
124
+
125
+ if (is_downloading):
126
+ print(DEBUG_PREFIX,"Rejected, already downloading a model")
127
+ return json.dumps({"status":"downloading"})
128
+
129
+ coqui_models_folder = ModelManager().output_prefix # models location
130
+
131
+ # Check if tts folder exist
132
+ if os.path.isdir(coqui_models_folder):
133
+ installed_models = os.listdir(coqui_models_folder)
134
+ model_path = None
135
+
136
+ print(DEBUG_PREFIX,"Found",len(installed_models),"models in",coqui_models_folder)
137
+
138
+ for i in installed_models:
139
+ if model_id == i.replace("--","/"):
140
+ model_installed = True
141
+ model_path = os.path.join(coqui_models_folder,i)
142
+
143
+ if model_installed:
144
+ print(DEBUG_PREFIX,"model found:", model_id)
145
+ else:
146
+ print(DEBUG_PREFIX,"model not found")
147
+
148
+ if action == "download":
149
+ if model_installed:
150
+ abort(500, DEBUG_PREFIX + "Bad request, model already installed.")
151
+
152
+ is_downloading = True
153
+ TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
154
+ is_downloading = False
155
+
156
+ if action == "repare":
157
+ if not model_installed:
158
+ abort(500, DEBUG_PREFIX + " bad request: requesting repare of model not installed")
159
+
160
+
161
+ print(DEBUG_PREFIX,"Deleting corrupted model folder:",model_path)
162
+ shutil.rmtree(model_path, ignore_errors=True)
163
+
164
+ is_downloading = True
165
+ TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
166
+ is_downloading = False
167
+
168
+ response = json.dumps({"status":"done"})
169
+ return response
170
+
171
+ except Exception as e:
172
+ is_downloading = False
173
+ print(e)
174
+ abort(500, DEBUG_PREFIX + " Exception occurs while trying to search for installed model")
175
+
176
+ def coqui_get_local_models():
177
+ """
178
+ Return user local models list in the following format: [language][dataset][name] = TTS_string_id
179
+ """
180
+ try:
181
+ print(DEBUG_PREFIX, "Received request for list of RVC models")
182
+
183
+ folder_names = os.listdir(COQUI_MODELS_PATH)
184
+
185
+ print(DEBUG_PREFIX,"Searching model in",COQUI_MODELS_PATH)
186
+
187
+ model_list = []
188
+ for folder_name in folder_names:
189
+ folder_path = COQUI_MODELS_PATH+folder_name
190
+
191
+ if folder_name in IGNORED_FILES:
192
+ continue
193
+
194
+ # Must be a folder
195
+ if not os.path.isdir(folder_path):
196
+ print("> WARNING:",folder_name,"is not a folder, it should not be there, ignored")
197
+ continue
198
+
199
+ print("> Found model folder",folder_name)
200
+
201
+ # Check pth
202
+ valid_folder = False
203
+ for file_name in os.listdir(folder_path):
204
+ if file_name.endswith(".pth"):
205
+ print(" > pth:",file_name)
206
+ valid_folder = True
207
+ if file_name.endswith(".config"):
208
+ print(" > config:",file_name)
209
+
210
+ if valid_folder:
211
+ print(" > Valid folder added to list")
212
+ model_list.append(folder_name)
213
+ else:
214
+ print(" > WARNING: Missing pth or config file, ignored folder")
215
+
216
+ # Return the list of valid folders
217
+ response = json.dumps({"models_list":model_list})
218
+ return response
219
+
220
+ except Exception as e:
221
+ print(e)
222
+ abort(500, DEBUG_PREFIX + " Exception occurs while searching for Coqui models.")
223
+
224
+
225
+
226
+ def coqui_generate_tts():
227
+ """
228
+ Process request text with the loaded RVC model
229
+ - expected request: {
230
+ "text": text,
231
+ "model_id": voiceId,
232
+ "language_id": language,
233
+ "speaker_id": speaker
234
+ }
235
+
236
+ - model_id formats:
237
+ - model_type/language/dataset/model_name
238
+ - model_type/language/dataset/model_name[spearker_id]
239
+ - model_type/language/dataset/model_name[spearker_id][language_id]
240
+ - examples:
241
+ - tts_models/ja/kokoro/tacotron2-DDC
242
+ - tts_models/en/vctk/vits[0]
243
+ - tts_models/multilingual/multi-dataset/your_tts[2][1]
244
+ """
245
+ global gpu_mode
246
+ global is_downloading
247
+ audio_buffer = io.BytesIO()
248
+
249
+ try:
250
+ request_json = request.get_json()
251
+ #print(request_json)
252
+
253
+ print(DEBUG_PREFIX,"Received TTS request for ", request_json)
254
+
255
+ if (is_downloading):
256
+ print(DEBUG_PREFIX,"Rejected, currently downloading a model, cannot perform TTS")
257
+ abort(500, DEBUG_PREFIX + " Requested TTS while downloading a model")
258
+
259
+ text = request_json["text"]
260
+ model_name = request_json["model_id"]
261
+ language_id = None
262
+ speaker_id = None
263
+
264
+ # Local model
265
+ model_type = model_name.split("/")[0]
266
+ if model_type == "local":
267
+ return generate_tts_local(model_name.split("/")[1], text)
268
+
269
+
270
+ if request_json["language_id"] != "none":
271
+ language_id = request_json["language_id"]
272
+
273
+ if request_json["speaker_id"] != "none":
274
+ speaker_id = request_json["speaker_id"]
275
+
276
+ print(DEBUG_PREFIX,"Loading tts \n- model", model_name, "\n - speaker_id: ",speaker_id,"\n - language_id: ",language_id, "\n - using",("GPU" if gpu_mode else "CPU"))
277
+
278
+ is_downloading = True
279
+ tts = TTS(model_name=model_name, progress_bar=True, gpu=gpu_mode)
280
+ is_downloading = False
281
+
282
+ if tts.is_multi_lingual:
283
+ if language_id is None:
284
+ abort(400, DEBUG_PREFIX + " Requested model "+model_name+" is multi-lingual but no language id provided")
285
+ language_id = tts.languages[int(language_id)]
286
+
287
+ if tts.is_multi_speaker:
288
+ if speaker_id is None:
289
+ abort(400, DEBUG_PREFIX + " Requested model "+model_name+" is multi-speaker but no speaker id provided")
290
+ speaker_id =tts.speakers[int(speaker_id)]
291
+
292
+ tts.tts_to_file(text=text, file_path=audio_buffer, speaker=speaker_id, language=language_id)
293
+
294
+ print(DEBUG_PREFIX, "Success, saved to",audio_buffer)
295
+
296
+ # Return the output_audio_path object as a response
297
+ response = send_file(audio_buffer, mimetype="audio/x-wav")
298
+ audio_buffer = io.BytesIO()
299
+
300
+ return response
301
+
302
+ except Exception as e:
303
+ print(e)
304
+ abort(500, DEBUG_PREFIX + " Exception occurs while trying to process request "+str(request_json))
305
+
306
+ def generate_tts_local(model_folder, text):
307
+ """
308
+ Generate tts using local coqui model
309
+ """
310
+ audio_buffer = io.BytesIO()
311
+
312
+ print(DEBUG_PREFIX,"Request for tts from local coqui model",model_folder)
313
+
314
+ model_path = os.path.join(COQUI_MODELS_PATH,model_folder,COQUI_LOCAL_MODEL_FILE_NAME)
315
+ config_path = os.path.join(COQUI_MODELS_PATH,model_folder,COQUI_LOCAL_CONFIG_FILE_NAME)
316
+
317
+ if not os.path.exists(model_path):
318
+ raise ValueError("File does not exists:",model_path)
319
+
320
+ if not os.path.exists(config_path):
321
+ raise ValueError("File does not exists:",config_path)
322
+
323
+ print(DEBUG_PREFIX,"Loading local tts model", model_path,"using",("GPU" if gpu_mode else "CPU"))
324
+ tts = TTS(model_path=model_path, config_path=config_path, progress_bar=True, gpu=gpu_mode)
325
+ tts.tts_to_file(text=text, file_path=audio_buffer)
326
+
327
+ print(DEBUG_PREFIX, "Success, saved to",audio_buffer)
328
+
329
+ # Return the output_audio_path object as a response
330
+ response = send_file(audio_buffer, mimetype="audio/x-wav")
331
+ audio_buffer = io.BytesIO()
332
+
333
+ return response
modules/utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from contextlib import contextmanager
3
+ import sys
4
+
5
+ @contextmanager
6
+ def silence_log():
7
+ old_stdout = sys.stdout
8
+ old_stderr = sys.stderr
9
+ try:
10
+ with open(os.devnull, "w") as new_target:
11
+ sys.stdout = new_target
12
+ yield new_target
13
+ finally:
14
+ sys.stdout = old_stdout
15
+ sys.stderr = old_stderr
modules/voice_conversion/fairseq/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Facebook, Inc. and its affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
modules/voice_conversion/fairseq/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ import os
8
+ import sys
9
+
10
+ try:
11
+ from .version import __version__ # noqa
12
+ except ImportError:
13
+ version_txt = os.path.join(os.path.dirname(__file__), "version.txt")
14
+ with open(version_txt) as f:
15
+ __version__ = f.read().strip()
16
+
17
+ __all__ = ["pdb"]
18
+
19
+ # backwards compatibility to support `from fairseq.X import Y`
20
+ from fairseq.distributed import utils as distributed_utils
21
+ from fairseq.logging import meters, metrics, progress_bar # noqa
22
+
23
+ sys.modules["fairseq.distributed_utils"] = distributed_utils
24
+ sys.modules["fairseq.meters"] = meters
25
+ sys.modules["fairseq.metrics"] = metrics
26
+ sys.modules["fairseq.progress_bar"] = progress_bar
27
+
28
+ # initialize hydra
29
+ #from fairseq.dataclass.initialize import hydra_init
30
+
31
+ #hydra_init()
32
+
33
+ #import fairseq.criterions # noqa
34
+ #import fairseq.distributed # noqa
35
+ #import fairseq.models # noqa
36
+ #import fairseq.modules # noqa
37
+ #import fairseq.optim # noqa
38
+ #import fairseq.optim.lr_scheduler # noqa
39
+ #import fairseq.pdb # noqa
40
+ #import fairseq.scoring # noqa
41
+ #import fairseq.tasks # noqa
42
+ #import fairseq.token_generation_constraints # noqa
43
+
44
+ #import fairseq.benchmark # noqa
45
+ #import fairseq.model_parallel # noqa
modules/voice_conversion/fairseq/binarizer.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import typing as tp
9
+ from abc import ABC, abstractmethod
10
+ from collections import Counter
11
+ from dataclasses import dataclass
12
+ from multiprocessing import Pool
13
+
14
+ import torch
15
+
16
+ from fairseq.data import Dictionary, indexed_dataset
17
+ from fairseq.file_chunker_utils import Chunker, find_offsets
18
+ from fairseq.file_io import PathManager
19
+ from fairseq.tokenizer import tokenize_line
20
+
21
+ logger = logging.getLogger("binarizer")
22
+
23
+
24
+ @dataclass
25
+ class BinarizeSummary:
26
+ """
27
+ Keep track of what's going on in the binarizer
28
+ """
29
+
30
+ num_seq: int = 0
31
+ replaced: tp.Optional[Counter] = None
32
+ num_tok: int = 0
33
+
34
+ @property
35
+ def num_replaced(self) -> int:
36
+ if self.replaced is None:
37
+ return 0
38
+ return sum(self.replaced.values())
39
+
40
+ @property
41
+ def replaced_percent(self) -> float:
42
+ return 100 * self.num_replaced / self.num_tok
43
+
44
+ def __str__(self) -> str:
45
+ base = f"{self.num_seq} sents, {self.num_tok} tokens"
46
+ if self.replaced is None:
47
+ return base
48
+
49
+ return f"{base}, {self.replaced_percent:.3}% replaced"
50
+
51
+ def merge(self, other: "BinarizeSummary"):
52
+ replaced = None
53
+ if self.replaced is not None:
54
+ replaced = self.replaced
55
+ if other.replaced is not None:
56
+ if replaced is None:
57
+ replaced = other.replaced
58
+ else:
59
+ replaced += other.replaced
60
+ self.replaced = replaced
61
+ self.num_seq += other.num_seq
62
+ self.num_tok += other.num_tok
63
+
64
+
65
+ class Binarizer(ABC):
66
+ """
67
+ a binarizer describes how to take a string and build a tensor out of it
68
+ """
69
+
70
+ @abstractmethod
71
+ def binarize_line(
72
+ self,
73
+ line: str,
74
+ summary: BinarizeSummary,
75
+ ) -> torch.IntTensor:
76
+ ...
77
+
78
+
79
+ def _worker_prefix(output_prefix: str, worker_id: int):
80
+ return f"{output_prefix}.pt{worker_id}"
81
+
82
+
83
+ class FileBinarizer:
84
+ """
85
+ An file binarizer can take a file, tokenize it, and binarize each line to a tensor
86
+ """
87
+
88
+ @classmethod
89
+ def multiprocess_dataset(
90
+ cls,
91
+ input_file: str,
92
+ dataset_impl: str,
93
+ binarizer: Binarizer,
94
+ output_prefix: str,
95
+ vocab_size=None,
96
+ num_workers=1,
97
+ ) -> BinarizeSummary:
98
+ final_summary = BinarizeSummary()
99
+
100
+ offsets = find_offsets(input_file, num_workers)
101
+ # find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs:
102
+ # [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info
103
+ # we zip the list with itself shifted by one to get all the pairs.
104
+ (first_chunk, *more_chunks) = zip(offsets, offsets[1:])
105
+ pool = None
106
+ if num_workers > 1:
107
+ pool = Pool(processes=num_workers - 1)
108
+ worker_results = [
109
+ pool.apply_async(
110
+ cls._binarize_chunk_and_finalize,
111
+ args=(
112
+ binarizer,
113
+ input_file,
114
+ start_offset,
115
+ end_offset,
116
+ _worker_prefix(
117
+ output_prefix,
118
+ worker_id,
119
+ ),
120
+ dataset_impl,
121
+ ),
122
+ kwds={
123
+ "vocab_size": vocab_size,
124
+ }
125
+ if vocab_size is not None
126
+ else {},
127
+ )
128
+ for worker_id, (start_offset, end_offset) in enumerate(
129
+ more_chunks, start=1
130
+ )
131
+ ]
132
+
133
+ pool.close()
134
+ pool.join()
135
+ for r in worker_results:
136
+ summ = r.get()
137
+ final_summary.merge(summ)
138
+
139
+ # do not close the bin file as we need to merge the worker results in
140
+ final_ds, summ = cls._binarize_file_chunk(
141
+ binarizer,
142
+ input_file,
143
+ offset_start=first_chunk[0],
144
+ offset_end=first_chunk[1],
145
+ output_prefix=output_prefix,
146
+ dataset_impl=dataset_impl,
147
+ vocab_size=vocab_size if vocab_size is not None else None,
148
+ )
149
+ final_summary.merge(summ)
150
+
151
+ if num_workers > 1:
152
+ for worker_id in range(1, num_workers):
153
+ # merge the worker outputs
154
+ worker_output_prefix = _worker_prefix(
155
+ output_prefix,
156
+ worker_id,
157
+ )
158
+ final_ds.merge_file_(worker_output_prefix)
159
+ try:
160
+ os.remove(indexed_dataset.data_file_path(worker_output_prefix))
161
+ os.remove(indexed_dataset.index_file_path(worker_output_prefix))
162
+ except Exception as e:
163
+ logger.error(
164
+ f"couldn't remove {worker_output_prefix}.*", exc_info=e
165
+ )
166
+
167
+ # now we can close the file
168
+ idx_file = indexed_dataset.index_file_path(output_prefix)
169
+ final_ds.finalize(idx_file)
170
+ return final_summary
171
+
172
+ @staticmethod
173
+ def _binarize_file_chunk(
174
+ binarizer: Binarizer,
175
+ filename: str,
176
+ offset_start: int,
177
+ offset_end: int,
178
+ output_prefix: str,
179
+ dataset_impl: str,
180
+ vocab_size=None,
181
+ ) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary)
182
+ """
183
+ creates a dataset builder and append binarized items to it. This function does not
184
+ finalize the builder, this is useful if you want to do other things with your bin file
185
+ like appending/merging other files
186
+ """
187
+ bin_file = indexed_dataset.data_file_path(output_prefix)
188
+ ds = indexed_dataset.make_builder(
189
+ bin_file,
190
+ impl=dataset_impl,
191
+ vocab_size=vocab_size,
192
+ )
193
+ summary = BinarizeSummary()
194
+
195
+ with Chunker(
196
+ PathManager.get_local_path(filename), offset_start, offset_end
197
+ ) as line_iterator:
198
+ for line in line_iterator:
199
+ ds.add_item(binarizer.binarize_line(line, summary))
200
+
201
+ return ds, summary
202
+
203
+ @classmethod
204
+ def _binarize_chunk_and_finalize(
205
+ cls,
206
+ binarizer: Binarizer,
207
+ filename: str,
208
+ offset_start: int,
209
+ offset_end: int,
210
+ output_prefix: str,
211
+ dataset_impl: str,
212
+ vocab_size=None,
213
+ ):
214
+ """
215
+ same as above, but also finalizes the builder
216
+ """
217
+ ds, summ = cls._binarize_file_chunk(
218
+ binarizer,
219
+ filename,
220
+ offset_start,
221
+ offset_end,
222
+ output_prefix,
223
+ dataset_impl,
224
+ vocab_size=vocab_size,
225
+ )
226
+
227
+ idx_file = indexed_dataset.index_file_path(output_prefix)
228
+ ds.finalize(idx_file)
229
+
230
+ return summ
231
+
232
+
233
+ class VocabularyDatasetBinarizer(Binarizer):
234
+ """
235
+ Takes a Dictionary/Vocabulary, assign ids to each
236
+ token using the dictionary encode_line function.
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ dict: Dictionary,
242
+ tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
243
+ append_eos: bool = True,
244
+ reverse_order: bool = False,
245
+ already_numberized: bool = False,
246
+ ) -> None:
247
+ self.dict = dict
248
+ self.tokenize = tokenize
249
+ self.append_eos = append_eos
250
+ self.reverse_order = reverse_order
251
+ self.already_numberized = already_numberized
252
+ super().__init__()
253
+
254
+ def binarize_line(
255
+ self,
256
+ line: str,
257
+ summary: BinarizeSummary,
258
+ ):
259
+ if summary.replaced is None:
260
+ summary.replaced = Counter()
261
+
262
+ def replaced_consumer(word, idx):
263
+ if idx == self.dict.unk_index and word != self.dict.unk_word:
264
+ summary.replaced.update([word])
265
+
266
+ if self.already_numberized:
267
+ id_strings = line.strip().split()
268
+ id_list = [int(id_string) for id_string in id_strings]
269
+ if self.reverse_order:
270
+ id_list.reverse()
271
+ if self.append_eos:
272
+ id_list.append(self.dict.eos())
273
+ ids = torch.IntTensor(id_list)
274
+ else:
275
+ ids = self.dict.encode_line(
276
+ line=line,
277
+ line_tokenizer=self.tokenize,
278
+ add_if_not_exist=False,
279
+ consumer=replaced_consumer,
280
+ append_eos=self.append_eos,
281
+ reverse_order=self.reverse_order,
282
+ )
283
+
284
+ summary.num_seq += 1
285
+ summary.num_tok += len(ids)
286
+ return ids
287
+
288
+
289
+ class AlignmentDatasetBinarizer(Binarizer):
290
+ """
291
+ binarize by parsing a set of alignments and packing
292
+ them in a tensor (see utils.parse_alignment)
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ alignment_parser: tp.Callable[[str], torch.IntTensor],
298
+ ) -> None:
299
+ super().__init__()
300
+ self.alignment_parser = alignment_parser
301
+
302
+ def binarize_line(
303
+ self,
304
+ line: str,
305
+ summary: BinarizeSummary,
306
+ ):
307
+ ids = self.alignment_parser(line)
308
+ summary.num_seq += 1
309
+ summary.num_tok += len(ids)
310
+ return ids
311
+
312
+
313
+ class LegacyBinarizer:
314
+ @classmethod
315
+ def binarize(
316
+ cls,
317
+ filename: str,
318
+ dico: Dictionary,
319
+ consumer: tp.Callable[[torch.IntTensor], None],
320
+ tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
321
+ append_eos: bool = True,
322
+ reverse_order: bool = False,
323
+ offset: int = 0,
324
+ end: int = -1,
325
+ already_numberized: bool = False,
326
+ ) -> tp.Dict[str, int]:
327
+ binarizer = VocabularyDatasetBinarizer(
328
+ dict=dico,
329
+ tokenize=tokenize,
330
+ append_eos=append_eos,
331
+ reverse_order=reverse_order,
332
+ already_numberized=already_numberized,
333
+ )
334
+ return cls._consume_file(
335
+ filename,
336
+ binarizer,
337
+ consumer,
338
+ offset_start=offset,
339
+ offset_end=end,
340
+ )
341
+
342
+ @classmethod
343
+ def binarize_alignments(
344
+ cls,
345
+ filename: str,
346
+ alignment_parser: tp.Callable[[str], torch.IntTensor],
347
+ consumer: tp.Callable[[torch.IntTensor], None],
348
+ offset: int = 0,
349
+ end: int = -1,
350
+ ) -> tp.Dict[str, int]:
351
+ binarizer = AlignmentDatasetBinarizer(alignment_parser)
352
+ return cls._consume_file(
353
+ filename,
354
+ binarizer,
355
+ consumer,
356
+ offset_start=offset,
357
+ offset_end=end,
358
+ )
359
+
360
+ @staticmethod
361
+ def _consume_file(
362
+ filename: str,
363
+ binarizer: Binarizer,
364
+ consumer: tp.Callable[[torch.IntTensor], None],
365
+ offset_start: int,
366
+ offset_end: int,
367
+ ) -> tp.Dict[str, int]:
368
+ summary = BinarizeSummary()
369
+
370
+ with Chunker(
371
+ PathManager.get_local_path(filename), offset_start, offset_end
372
+ ) as line_iterator:
373
+ for line in line_iterator:
374
+ consumer(binarizer.binarize_line(line, summary))
375
+
376
+ return {
377
+ "nseq": summary.num_seq,
378
+ "nunk": summary.num_replaced,
379
+ "ntok": summary.num_tok,
380
+ "replaced": summary.replaced,
381
+ }
modules/voice_conversion/fairseq/checkpoint_utils.py ADDED
@@ -0,0 +1,905 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import ast
7
+ import collections
8
+ import contextlib
9
+ import inspect
10
+ import logging
11
+ import os
12
+ import re
13
+ import time
14
+ import traceback
15
+ from collections import OrderedDict
16
+ from pathlib import Path
17
+ from typing import Any, Dict, Optional, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ from fairseq.data import data_utils
22
+ from fairseq.dataclass.configs import CheckpointConfig
23
+ from fairseq.dataclass.utils import (
24
+ convert_namespace_to_omegaconf,
25
+ overwrite_args_by_name,
26
+ )
27
+ from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
28
+ from fairseq.file_io import PathManager
29
+ from fairseq.models import FairseqDecoder, FairseqEncoder
30
+ from omegaconf import DictConfig, OmegaConf, open_dict
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
36
+ from fairseq import meters
37
+
38
+ # only one worker should attempt to create the required dir
39
+ if trainer.data_parallel_rank == 0:
40
+ os.makedirs(cfg.save_dir, exist_ok=True)
41
+
42
+ prev_best = getattr(save_checkpoint, "best", val_loss)
43
+ if val_loss is not None:
44
+ best_function = max if cfg.maximize_best_checkpoint_metric else min
45
+ save_checkpoint.best = best_function(val_loss, prev_best)
46
+
47
+ if cfg.no_save:
48
+ return
49
+
50
+ trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
51
+
52
+ if not trainer.should_save_checkpoint_on_current_rank:
53
+ if trainer.always_call_state_dict_during_save_checkpoint:
54
+ trainer.state_dict()
55
+ return
56
+
57
+ write_timer = meters.StopwatchMeter()
58
+ write_timer.start()
59
+
60
+ epoch = epoch_itr.epoch
61
+ end_of_epoch = epoch_itr.end_of_epoch()
62
+ updates = trainer.get_num_updates()
63
+
64
+ logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
65
+
66
+ def is_better(a, b):
67
+ return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
68
+
69
+ suffix = trainer.checkpoint_suffix
70
+ checkpoint_conds = collections.OrderedDict()
71
+ checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
72
+ end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
73
+ )
74
+ checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
75
+ not end_of_epoch
76
+ and cfg.save_interval_updates > 0
77
+ and updates % cfg.save_interval_updates == 0
78
+ )
79
+ checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
80
+ not hasattr(save_checkpoint, "best")
81
+ or is_better(val_loss, save_checkpoint.best)
82
+ )
83
+ if val_loss is not None and cfg.keep_best_checkpoints > 0:
84
+ worst_best = getattr(save_checkpoint, "best", None)
85
+ chkpts = checkpoint_paths(
86
+ cfg.save_dir,
87
+ pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
88
+ cfg.best_checkpoint_metric, suffix
89
+ ),
90
+ )
91
+ if len(chkpts) > 0:
92
+ p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
93
+ worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
94
+ # add random digits to resolve ties
95
+ with data_utils.numpy_seed(epoch, updates, val_loss):
96
+ rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
97
+
98
+ checkpoint_conds[
99
+ "checkpoint.best_{}_{:.3f}{}{}.pt".format(
100
+ cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
101
+ )
102
+ ] = worst_best is None or is_better(val_loss, worst_best)
103
+ checkpoint_conds[
104
+ "checkpoint_last{}.pt".format(suffix)
105
+ ] = not cfg.no_last_checkpoints
106
+
107
+ extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
108
+ if hasattr(save_checkpoint, "best"):
109
+ extra_state.update({"best": save_checkpoint.best})
110
+
111
+ checkpoints = [
112
+ os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
113
+ ]
114
+ if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
115
+ trainer.save_checkpoint(checkpoints[0], extra_state)
116
+ for cp in checkpoints[1:]:
117
+ if cfg.write_checkpoints_asynchronously:
118
+ # TODO[ioPath]: Need to implement a delayed asynchronous
119
+ # file copying/moving feature.
120
+ logger.warning(
121
+ f"ioPath is not copying {checkpoints[0]} to {cp} "
122
+ "since async write mode is on."
123
+ )
124
+ else:
125
+ assert PathManager.copy(
126
+ checkpoints[0], cp, overwrite=True
127
+ ), f"Failed to copy {checkpoints[0]} to {cp}"
128
+
129
+ write_timer.stop()
130
+ logger.info(
131
+ "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
132
+ checkpoints[0], epoch, updates, val_loss, write_timer.sum
133
+ )
134
+ )
135
+
136
+ if not end_of_epoch and cfg.keep_interval_updates > 0:
137
+ # remove old checkpoints; checkpoints are sorted in descending order
138
+ if cfg.keep_interval_updates_pattern == -1:
139
+ checkpoints = checkpoint_paths(
140
+ cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
141
+ )
142
+ else:
143
+ checkpoints = checkpoint_paths(
144
+ cfg.save_dir,
145
+ pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
146
+ keep_match=True,
147
+ )
148
+ checkpoints = [
149
+ x[0]
150
+ for x in checkpoints
151
+ if x[1] % cfg.keep_interval_updates_pattern != 0
152
+ ]
153
+
154
+ for old_chk in checkpoints[cfg.keep_interval_updates :]:
155
+ if os.path.lexists(old_chk):
156
+ os.remove(old_chk)
157
+ elif PathManager.exists(old_chk):
158
+ PathManager.rm(old_chk)
159
+
160
+ if cfg.keep_last_epochs > 0:
161
+ # remove old epoch checkpoints; checkpoints are sorted in descending order
162
+ checkpoints = checkpoint_paths(
163
+ cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
164
+ )
165
+ for old_chk in checkpoints[cfg.keep_last_epochs :]:
166
+ if os.path.lexists(old_chk):
167
+ os.remove(old_chk)
168
+ elif PathManager.exists(old_chk):
169
+ PathManager.rm(old_chk)
170
+
171
+ if cfg.keep_best_checkpoints > 0:
172
+ # only keep the best N checkpoints according to validation metric
173
+ checkpoints = checkpoint_paths(
174
+ cfg.save_dir,
175
+ pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
176
+ cfg.best_checkpoint_metric, suffix
177
+ ),
178
+ )
179
+ if not cfg.maximize_best_checkpoint_metric:
180
+ checkpoints = checkpoints[::-1]
181
+ for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
182
+ if os.path.lexists(old_chk):
183
+ os.remove(old_chk)
184
+ elif PathManager.exists(old_chk):
185
+ PathManager.rm(old_chk)
186
+
187
+
188
+ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
189
+ """
190
+ Load a checkpoint and restore the training iterator.
191
+
192
+ *passthrough_args* will be passed through to
193
+ ``trainer.get_train_iterator``.
194
+ """
195
+
196
+ reset_optimizer = cfg.reset_optimizer
197
+ reset_lr_scheduler = cfg.reset_lr_scheduler
198
+ optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
199
+ reset_meters = cfg.reset_meters
200
+ reset_dataloader = cfg.reset_dataloader
201
+
202
+ if cfg.finetune_from_model is not None and (
203
+ reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
204
+ ):
205
+ raise ValueError(
206
+ "--finetune-from-model can not be set together with either --reset-optimizer"
207
+ " or reset_lr_scheduler or reset_meters or reset_dataloader"
208
+ )
209
+
210
+ suffix = trainer.checkpoint_suffix
211
+ if (
212
+ cfg.restore_file == "checkpoint_last.pt"
213
+ ): # default value of restore_file is 'checkpoint_last.pt'
214
+ checkpoint_path = os.path.join(
215
+ cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
216
+ )
217
+ first_launch = not PathManager.exists(checkpoint_path)
218
+ if first_launch and getattr(cfg, "continue_once", None) is not None:
219
+ checkpoint_path = cfg.continue_once
220
+ elif cfg.finetune_from_model is not None and first_launch:
221
+ # if there is no last checkpoint to restore, start the finetune from pretrained model
222
+ # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
223
+ if PathManager.exists(cfg.finetune_from_model):
224
+ checkpoint_path = cfg.finetune_from_model
225
+ reset_optimizer = True
226
+ reset_lr_scheduler = True
227
+ reset_meters = True
228
+ reset_dataloader = True
229
+ logger.info(
230
+ f"loading pretrained model from {checkpoint_path}: "
231
+ "optimizer, lr scheduler, meters, dataloader will be reset"
232
+ )
233
+ else:
234
+ raise ValueError(
235
+ f"--finetune-from-model {cfg.finetune_from_model} does not exist"
236
+ )
237
+ elif suffix is not None:
238
+ checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
239
+ else:
240
+ checkpoint_path = cfg.restore_file
241
+
242
+ if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
243
+ raise ValueError(
244
+ "--finetune-from-model and --restore-file (non-default value) "
245
+ "can not be specified together: " + str(cfg)
246
+ )
247
+
248
+ extra_state = trainer.load_checkpoint(
249
+ checkpoint_path,
250
+ reset_optimizer,
251
+ reset_lr_scheduler,
252
+ optimizer_overrides,
253
+ reset_meters=reset_meters,
254
+ )
255
+
256
+ if (
257
+ extra_state is not None
258
+ and "best" in extra_state
259
+ and not reset_optimizer
260
+ and not reset_meters
261
+ ):
262
+ save_checkpoint.best = extra_state["best"]
263
+
264
+ if extra_state is not None and not reset_dataloader:
265
+ # restore iterator from checkpoint
266
+ itr_state = extra_state["train_iterator"]
267
+ epoch_itr = trainer.get_train_iterator(
268
+ epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
269
+ )
270
+ epoch_itr.load_state_dict(itr_state)
271
+ else:
272
+ epoch_itr = trainer.get_train_iterator(
273
+ epoch=1, load_dataset=True, **passthrough_args
274
+ )
275
+
276
+ trainer.lr_step(epoch_itr.epoch)
277
+
278
+ return extra_state, epoch_itr
279
+
280
+
281
+ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
282
+ """Loads a checkpoint to CPU (with upgrading for backward compatibility).
283
+
284
+ If doing single-GPU training or if the checkpoint is only being loaded by at
285
+ most one process on each node (current default behavior is for only rank 0
286
+ to read the checkpoint from disk), load_on_all_ranks should be False to
287
+ avoid errors from torch.distributed not having been initialized or
288
+ torch.distributed.barrier() hanging.
289
+
290
+ If all processes on each node may be loading the checkpoint
291
+ simultaneously, load_on_all_ranks should be set to True to avoid I/O
292
+ conflicts.
293
+
294
+ There's currently no support for > 1 but < all processes loading the
295
+ checkpoint on each node.
296
+ """
297
+ local_path = PathManager.get_local_path(path)
298
+ # The locally cached file returned by get_local_path() may be stale for
299
+ # remote files that are periodically updated/overwritten (ex:
300
+ # checkpoint_last.pt) - so we remove the local copy, sync across processes
301
+ # (if needed), and then download a fresh copy.
302
+ if local_path != path and PathManager.path_requires_pathmanager(path):
303
+ try:
304
+ os.remove(local_path)
305
+ except FileNotFoundError:
306
+ # With potentially multiple processes removing the same file, the
307
+ # file being missing is benign (missing_ok isn't available until
308
+ # Python 3.8).
309
+ pass
310
+ if load_on_all_ranks:
311
+ torch.distributed.barrier()
312
+ local_path = PathManager.get_local_path(path)
313
+
314
+ with open(local_path, "rb") as f:
315
+ state = torch.load(f, map_location=torch.device("cpu"))
316
+
317
+ if "args" in state and state["args"] is not None and arg_overrides is not None:
318
+ args = state["args"]
319
+ for arg_name, arg_val in arg_overrides.items():
320
+ setattr(args, arg_name, arg_val)
321
+
322
+ if "cfg" in state and state["cfg"] is not None:
323
+
324
+ # hack to be able to set Namespace in dict config. this should be removed when we update to newer
325
+ # omegaconf version that supports object flags, or when we migrate all existing models
326
+ from omegaconf import __version__ as oc_version
327
+ from omegaconf import _utils
328
+
329
+ if oc_version < "2.2":
330
+ old_primitive = _utils.is_primitive_type
331
+ _utils.is_primitive_type = lambda _: True
332
+
333
+ state["cfg"] = OmegaConf.create(state["cfg"])
334
+
335
+ _utils.is_primitive_type = old_primitive
336
+ OmegaConf.set_struct(state["cfg"], True)
337
+ else:
338
+ state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
339
+
340
+ if arg_overrides is not None:
341
+ overwrite_args_by_name(state["cfg"], arg_overrides)
342
+
343
+ state = _upgrade_state_dict(state)
344
+ return state
345
+
346
+
347
+ def load_model_ensemble(
348
+ filenames,
349
+ arg_overrides: Optional[Dict[str, Any]] = None,
350
+ task=None,
351
+ strict=True,
352
+ suffix="",
353
+ num_shards=1,
354
+ state=None,
355
+ ):
356
+ """Loads an ensemble of models.
357
+
358
+ Args:
359
+ filenames (List[str]): checkpoint files to load
360
+ arg_overrides (Dict[str,Any], optional): override model args that
361
+ were used during model training
362
+ task (fairseq.tasks.FairseqTask, optional): task to use for loading
363
+ """
364
+ assert not (
365
+ strict and num_shards > 1
366
+ ), "Cannot load state dict with strict=True and checkpoint shards > 1"
367
+ ensemble, args, _task = load_model_ensemble_and_task(
368
+ filenames,
369
+ arg_overrides,
370
+ task,
371
+ strict,
372
+ suffix,
373
+ num_shards,
374
+ state,
375
+ )
376
+ return ensemble, args
377
+
378
+
379
+ def get_maybe_sharded_checkpoint_filename(
380
+ filename: str, suffix: str, shard_idx: int, num_shards: int
381
+ ) -> str:
382
+ orig_filename = filename
383
+ filename = filename.replace(".pt", suffix + ".pt")
384
+ fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
385
+ model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
386
+ if PathManager.exists(fsdp_filename):
387
+ return fsdp_filename
388
+ elif num_shards > 1:
389
+ return model_parallel_filename
390
+ else:
391
+ return filename
392
+
393
+
394
+ def load_model_ensemble_and_task(
395
+ filenames,
396
+ arg_overrides: Optional[Dict[str, Any]] = None,
397
+ task=None,
398
+ strict=True,
399
+ suffix="",
400
+ num_shards=1,
401
+ state=None,
402
+ ):
403
+ assert state is None or len(filenames) == 1
404
+
405
+ from fairseq import tasks
406
+
407
+ assert not (
408
+ strict and num_shards > 1
409
+ ), "Cannot load state dict with strict=True and checkpoint shards > 1"
410
+ ensemble = []
411
+ cfg = None
412
+ for filename in filenames:
413
+ orig_filename = filename
414
+ model_shard_state = {"shard_weights": [], "shard_metadata": []}
415
+ assert num_shards > 0
416
+ st = time.time()
417
+ for shard_idx in range(num_shards):
418
+ filename = get_maybe_sharded_checkpoint_filename(
419
+ orig_filename, suffix, shard_idx, num_shards
420
+ )
421
+
422
+ if not PathManager.exists(filename):
423
+ raise IOError("Model file not found: {}".format(filename))
424
+ if state is None:
425
+ state = load_checkpoint_to_cpu(filename, arg_overrides)
426
+ if "args" in state and state["args"] is not None:
427
+ cfg = convert_namespace_to_omegaconf(state["args"])
428
+ elif "cfg" in state and state["cfg"] is not None:
429
+ cfg = state["cfg"]
430
+ else:
431
+ raise RuntimeError(
432
+ f"Neither args nor cfg exist in state keys = {state.keys()}"
433
+ )
434
+
435
+ if task is None:
436
+ task = tasks.setup_task(cfg.task)
437
+
438
+ if "task_state" in state:
439
+ task.load_state_dict(state["task_state"])
440
+
441
+ if "fsdp_metadata" in state and num_shards > 1:
442
+ model_shard_state["shard_weights"].append(state["model"])
443
+ model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
444
+ # check FSDP import before the code goes too far
445
+ if not has_FSDP:
446
+ raise ImportError(
447
+ "Cannot find FullyShardedDataParallel. "
448
+ "Please install fairscale with: pip install fairscale"
449
+ )
450
+ if shard_idx == num_shards - 1:
451
+ consolidated_model_state = FSDP.consolidate_shard_weights(
452
+ shard_weights=model_shard_state["shard_weights"],
453
+ shard_metadata=model_shard_state["shard_metadata"],
454
+ )
455
+ model = task.build_model(cfg.model)
456
+ if (
457
+ "optimizer_history" in state
458
+ and len(state["optimizer_history"]) > 0
459
+ and "num_updates" in state["optimizer_history"][-1]
460
+ ):
461
+ model.set_num_updates(
462
+ state["optimizer_history"][-1]["num_updates"]
463
+ )
464
+ model.load_state_dict(
465
+ consolidated_model_state, strict=strict, model_cfg=cfg.model
466
+ )
467
+ else:
468
+ # model parallel checkpoint or unsharded checkpoint
469
+ # support old external tasks
470
+
471
+ argspec = inspect.getfullargspec(task.build_model)
472
+ if "from_checkpoint" in argspec.args:
473
+ model = task.build_model(cfg.model, from_checkpoint=True)
474
+ else:
475
+ model = task.build_model(cfg.model)
476
+ if (
477
+ "optimizer_history" in state
478
+ and len(state["optimizer_history"]) > 0
479
+ and "num_updates" in state["optimizer_history"][-1]
480
+ ):
481
+ model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
482
+ model.load_state_dict(
483
+ state["model"], strict=strict, model_cfg=cfg.model
484
+ )
485
+
486
+ # reset state so it gets loaded for the next model in ensemble
487
+ state = None
488
+ if shard_idx % 10 == 0 and shard_idx > 0:
489
+ elapsed = time.time() - st
490
+ logger.info(
491
+ f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
492
+ )
493
+
494
+ # build model for ensemble
495
+ ensemble.append(model)
496
+ return ensemble, cfg, task
497
+
498
+
499
+ def load_model_ensemble_and_task_from_hf_hub(
500
+ model_id,
501
+ cache_dir: Optional[str] = None,
502
+ arg_overrides: Optional[Dict[str, Any]] = None,
503
+ **kwargs: Any,
504
+ ):
505
+ try:
506
+ from huggingface_hub import snapshot_download
507
+ except ImportError:
508
+ raise ImportError(
509
+ "You need to install huggingface_hub to use `load_from_hf_hub`. "
510
+ "See https://pypi.org/project/huggingface-hub/ for installation."
511
+ )
512
+
513
+ library_name = "fairseq"
514
+ cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
515
+ cache_dir = snapshot_download(
516
+ model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
517
+ )
518
+
519
+ _arg_overrides = arg_overrides or {}
520
+ _arg_overrides["data"] = cache_dir
521
+ return load_model_ensemble_and_task(
522
+ [p.as_posix() for p in Path(cache_dir).glob("*.pt")],
523
+ arg_overrides=_arg_overrides,
524
+ )
525
+
526
+
527
+ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
528
+ """Retrieves all checkpoints found in `path` directory.
529
+
530
+ Checkpoints are identified by matching filename to the specified pattern. If
531
+ the pattern contains groups, the result will be sorted by the first group in
532
+ descending order.
533
+ """
534
+ pt_regexp = re.compile(pattern)
535
+ files = PathManager.ls(path)
536
+
537
+ entries = []
538
+ for i, f in enumerate(files):
539
+ m = pt_regexp.fullmatch(f)
540
+ if m is not None:
541
+ idx = float(m.group(1)) if len(m.groups()) > 0 else i
542
+ entries.append((idx, m.group(0)))
543
+ if keep_match:
544
+ return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
545
+ else:
546
+ return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
547
+
548
+
549
+ def torch_persistent_save(obj, filename, async_write: bool = False):
550
+ if async_write:
551
+ with PathManager.opena(filename, "wb") as f:
552
+ _torch_persistent_save(obj, f)
553
+ else:
554
+ if PathManager.supports_rename(filename):
555
+ # do atomic save
556
+ with PathManager.open(filename + ".tmp", "wb") as f:
557
+ _torch_persistent_save(obj, f)
558
+ PathManager.rename(filename + ".tmp", filename)
559
+ else:
560
+ # fallback to non-atomic save
561
+ with PathManager.open(filename, "wb") as f:
562
+ _torch_persistent_save(obj, f)
563
+
564
+
565
+ def _torch_persistent_save(obj, f):
566
+ if isinstance(f, str):
567
+ with PathManager.open(f, "wb") as h:
568
+ torch_persistent_save(obj, h)
569
+ return
570
+ for i in range(3):
571
+ try:
572
+ return torch.save(obj, f)
573
+ except Exception:
574
+ if i == 2:
575
+ logger.error(traceback.format_exc())
576
+ raise
577
+
578
+
579
+ def _upgrade_state_dict(state):
580
+ """Helper for upgrading old model checkpoints."""
581
+
582
+ # add optimizer_history
583
+ if "optimizer_history" not in state:
584
+ state["optimizer_history"] = [
585
+ {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
586
+ ]
587
+ state["last_optimizer_state"] = state["optimizer"]
588
+ del state["optimizer"]
589
+ del state["best_loss"]
590
+ # move extra_state into sub-dictionary
591
+ if "epoch" in state and "extra_state" not in state:
592
+ state["extra_state"] = {
593
+ "epoch": state["epoch"],
594
+ "batch_offset": state["batch_offset"],
595
+ "val_loss": state["val_loss"],
596
+ }
597
+ del state["epoch"]
598
+ del state["batch_offset"]
599
+ del state["val_loss"]
600
+ # reduce optimizer history's memory usage (only keep the last state)
601
+ if "optimizer" in state["optimizer_history"][-1]:
602
+ state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
603
+ for optim_hist in state["optimizer_history"]:
604
+ del optim_hist["optimizer"]
605
+ # record the optimizer class name
606
+ if "optimizer_name" not in state["optimizer_history"][-1]:
607
+ state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
608
+ # move best_loss into lr_scheduler_state
609
+ if "lr_scheduler_state" not in state["optimizer_history"][-1]:
610
+ state["optimizer_history"][-1]["lr_scheduler_state"] = {
611
+ "best": state["optimizer_history"][-1]["best_loss"]
612
+ }
613
+ del state["optimizer_history"][-1]["best_loss"]
614
+ # keep track of number of updates
615
+ if "num_updates" not in state["optimizer_history"][-1]:
616
+ state["optimizer_history"][-1]["num_updates"] = 0
617
+ # use stateful training data iterator
618
+ if "train_iterator" not in state["extra_state"]:
619
+ state["extra_state"]["train_iterator"] = {
620
+ "epoch": state["extra_state"].get("epoch", 0),
621
+ "iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
622
+ }
623
+
624
+ # backward compatibility, cfg updates
625
+ if "args" in state and state["args"] is not None:
626
+ # old model checkpoints may not have separate source/target positions
627
+ if hasattr(state["args"], "max_positions") and not hasattr(
628
+ state["args"], "max_source_positions"
629
+ ):
630
+ state["args"].max_source_positions = state["args"].max_positions
631
+ state["args"].max_target_positions = state["args"].max_positions
632
+ # default to translation task
633
+ if not hasattr(state["args"], "task"):
634
+ state["args"].task = "translation"
635
+ # --raw-text and --lazy-load are deprecated
636
+ if getattr(state["args"], "raw_text", False):
637
+ state["args"].dataset_impl = "raw"
638
+ elif getattr(state["args"], "lazy_load", False):
639
+ state["args"].dataset_impl = "lazy"
640
+ # epochs start at 1
641
+ if state["extra_state"]["train_iterator"] is not None:
642
+ state["extra_state"]["train_iterator"]["epoch"] = max(
643
+ state["extra_state"]["train_iterator"].get("epoch", 1), 1
644
+ )
645
+ # --remove-bpe ==> --postprocess
646
+ if hasattr(state["args"], "remove_bpe"):
647
+ state["args"].post_process = state["args"].remove_bpe
648
+ # --min-lr ==> --stop-min-lr
649
+ if hasattr(state["args"], "min_lr"):
650
+ state["args"].stop_min_lr = state["args"].min_lr
651
+ del state["args"].min_lr
652
+ # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
653
+ if hasattr(state["args"], "criterion") and state["args"].criterion in [
654
+ "binary_cross_entropy",
655
+ "kd_binary_cross_entropy",
656
+ ]:
657
+ state["args"].criterion = "wav2vec"
658
+ # remove log_keys if it's None (criteria will supply a default value of [])
659
+ if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
660
+ delattr(state["args"], "log_keys")
661
+ # speech_pretraining => audio pretraining
662
+ if (
663
+ hasattr(state["args"], "task")
664
+ and state["args"].task == "speech_pretraining"
665
+ ):
666
+ state["args"].task = "audio_pretraining"
667
+ # audio_cpc => wav2vec
668
+ if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
669
+ state["args"].arch = "wav2vec"
670
+ # convert legacy float learning rate to List[float]
671
+ if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
672
+ state["args"].lr = [state["args"].lr]
673
+ # convert task data arg to a string instead of List[string]
674
+ if (
675
+ hasattr(state["args"], "data")
676
+ and isinstance(state["args"].data, list)
677
+ and len(state["args"].data) > 0
678
+ ):
679
+ state["args"].data = state["args"].data[0]
680
+
681
+ state["cfg"] = convert_namespace_to_omegaconf(state["args"])
682
+
683
+ if "cfg" in state and state["cfg"] is not None:
684
+ cfg = state["cfg"]
685
+ with open_dict(cfg):
686
+ # any upgrades for Hydra-based configs
687
+ if (
688
+ "task" in cfg
689
+ and "eval_wer_config" in cfg.task
690
+ and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
691
+ ):
692
+ cfg.task.eval_wer_config.print_alignment = "hard"
693
+ if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
694
+ cfg.generation.print_alignment = (
695
+ "hard" if cfg.generation.print_alignment else None
696
+ )
697
+ if (
698
+ "model" in cfg
699
+ and "w2v_args" in cfg.model
700
+ and cfg.model.w2v_args is not None
701
+ and (
702
+ hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
703
+ )
704
+ and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
705
+ and cfg.model.w2v_args.task.eval_wer_config is not None
706
+ and isinstance(
707
+ cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
708
+ )
709
+ ):
710
+ cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
711
+
712
+ return state
713
+
714
+
715
+ def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
716
+ """Prune the given state_dict if desired for LayerDrop
717
+ (https://arxiv.org/abs/1909.11556).
718
+
719
+ Training with LayerDrop allows models to be robust to pruning at inference
720
+ time. This function prunes state_dict to allow smaller models to be loaded
721
+ from a larger model and re-maps the existing state_dict for this to occur.
722
+
723
+ It's called by functions that load models from checkpoints and does not
724
+ need to be called directly.
725
+ """
726
+ arch = None
727
+ if model_cfg is not None:
728
+ arch = (
729
+ model_cfg._name
730
+ if isinstance(model_cfg, DictConfig)
731
+ else getattr(model_cfg, "arch", None)
732
+ )
733
+
734
+ if not model_cfg or arch is None or arch == "ptt_transformer":
735
+ # args should not be none, but don't crash if it is.
736
+ return state_dict
737
+
738
+ encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
739
+ decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
740
+
741
+ if not encoder_layers_to_keep and not decoder_layers_to_keep:
742
+ return state_dict
743
+
744
+ # apply pruning
745
+ logger.info(
746
+ "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
747
+ )
748
+
749
+ def create_pruning_pass(layers_to_keep, layer_name):
750
+ keep_layers = sorted(
751
+ int(layer_string) for layer_string in layers_to_keep.split(",")
752
+ )
753
+ mapping_dict = {}
754
+ for i in range(len(keep_layers)):
755
+ mapping_dict[str(keep_layers[i])] = str(i)
756
+
757
+ regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
758
+ return {"substitution_regex": regex, "mapping_dict": mapping_dict}
759
+
760
+ pruning_passes = []
761
+ if encoder_layers_to_keep:
762
+ pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
763
+ if decoder_layers_to_keep:
764
+ pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
765
+
766
+ new_state_dict = {}
767
+ for layer_name in state_dict.keys():
768
+ match = re.search(r"\.layers\.(\d+)\.", layer_name)
769
+ # if layer has no number in it, it is a supporting layer, such as an
770
+ # embedding
771
+ if not match:
772
+ new_state_dict[layer_name] = state_dict[layer_name]
773
+ continue
774
+
775
+ # otherwise, layer should be pruned.
776
+ original_layer_number = match.group(1)
777
+ # figure out which mapping dict to replace from
778
+ for pruning_pass in pruning_passes:
779
+ if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
780
+ "substitution_regex"
781
+ ].search(layer_name):
782
+ new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
783
+ substitution_match = pruning_pass["substitution_regex"].search(
784
+ layer_name
785
+ )
786
+ new_state_key = (
787
+ layer_name[: substitution_match.start(1)]
788
+ + new_layer_number
789
+ + layer_name[substitution_match.end(1) :]
790
+ )
791
+ new_state_dict[new_state_key] = state_dict[layer_name]
792
+
793
+ # Since layers are now pruned, *_layers_to_keep are no longer needed.
794
+ # This is more of "It would make it work fix" rather than a proper fix.
795
+ if isinstance(model_cfg, DictConfig):
796
+ context = open_dict(model_cfg)
797
+ else:
798
+ context = contextlib.ExitStack()
799
+ with context:
800
+ if hasattr(model_cfg, "encoder_layers_to_keep"):
801
+ model_cfg.encoder_layers_to_keep = None
802
+ if hasattr(model_cfg, "decoder_layers_to_keep"):
803
+ model_cfg.decoder_layers_to_keep = None
804
+
805
+ return new_state_dict
806
+
807
+
808
+ def load_pretrained_component_from_model(
809
+ component: Union[FairseqEncoder, FairseqDecoder],
810
+ checkpoint: str,
811
+ strict: bool = True,
812
+ ):
813
+ """
814
+ Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
815
+ provided `component` object. If state_dict fails to load, there may be a
816
+ mismatch in the architecture of the corresponding `component` found in the
817
+ `checkpoint` file.
818
+ """
819
+ if not PathManager.exists(checkpoint):
820
+ raise IOError("Model file not found: {}".format(checkpoint))
821
+ state = load_checkpoint_to_cpu(checkpoint)
822
+ if isinstance(component, FairseqEncoder):
823
+ component_type = "encoder"
824
+ elif isinstance(component, FairseqDecoder):
825
+ component_type = "decoder"
826
+ else:
827
+ raise ValueError(
828
+ "component to load must be either a FairseqEncoder or "
829
+ "FairseqDecoder. Loading other component types are not supported."
830
+ )
831
+ component_state_dict = OrderedDict()
832
+ for key in state["model"].keys():
833
+ if key.startswith(component_type):
834
+ # encoder.input_layers.0.0.weight --> input_layers.0.0.weight
835
+ component_subkey = key[len(component_type) + 1 :]
836
+ component_state_dict[component_subkey] = state["model"][key]
837
+ component.load_state_dict(component_state_dict, strict=strict)
838
+ return component
839
+
840
+
841
+ def verify_checkpoint_directory(save_dir: str) -> None:
842
+ if not os.path.exists(save_dir):
843
+ os.makedirs(save_dir, exist_ok=True)
844
+ temp_file_path = os.path.join(save_dir, "dummy")
845
+ try:
846
+ with open(temp_file_path, "w"):
847
+ pass
848
+ except OSError as e:
849
+ logger.warning(
850
+ "Unable to access checkpoint save directory: {}".format(save_dir)
851
+ )
852
+ raise e
853
+ else:
854
+ os.remove(temp_file_path)
855
+
856
+
857
+ def save_ema_as_checkpoint(src_path, dst_path):
858
+ state = load_ema_from_checkpoint(src_path)
859
+ torch_persistent_save(state, dst_path)
860
+
861
+
862
+ def load_ema_from_checkpoint(fpath):
863
+ """Loads exponential moving averaged (EMA) checkpoint from input and
864
+ returns a model with ema weights.
865
+
866
+ Args:
867
+ fpath: A string path of checkpoint to load from.
868
+
869
+ Returns:
870
+ A dict of string keys mapping to various values. The 'model' key
871
+ from the returned dict should correspond to an OrderedDict mapping
872
+ string parameter names to torch Tensors.
873
+ """
874
+ params_dict = collections.OrderedDict()
875
+ new_state = None
876
+
877
+ with PathManager.open(fpath, "rb") as f:
878
+ new_state = torch.load(
879
+ f,
880
+ map_location=(
881
+ lambda s, _: torch.serialization.default_restore_location(s, "cpu")
882
+ ),
883
+ )
884
+
885
+ # EMA model is stored in a separate "extra state"
886
+ model_params = new_state["extra_state"]["ema"]
887
+
888
+ for key in list(model_params.keys()):
889
+ p = model_params[key]
890
+ if isinstance(p, torch.HalfTensor):
891
+ p = p.float()
892
+ if key not in params_dict:
893
+ params_dict[key] = p.clone()
894
+ # NOTE: clone() is needed in case of p is a shared parameter
895
+ else:
896
+ raise ValueError("Key {} is repeated in EMA model params.".format(key))
897
+
898
+ if len(params_dict) == 0:
899
+ raise ValueError(
900
+ f"Input checkpoint path '{fpath}' does not contain "
901
+ "ema model weights, is this model trained with EMA?"
902
+ )
903
+
904
+ new_state["model"] = params_dict
905
+ return new_state
modules/voice_conversion/fairseq/data/__init__.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ from .dictionary import Dictionary, TruncatedDictionary
8
+
9
+ from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
10
+
11
+ from .base_wrapper_dataset import BaseWrapperDataset
12
+
13
+ from .add_target_dataset import AddTargetDataset
14
+ from .append_token_dataset import AppendTokenDataset
15
+ from .audio.raw_audio_dataset import BinarizedAudioDataset, FileAudioDataset
16
+ from .audio.hubert_dataset import HubertDataset
17
+ from .backtranslation_dataset import BacktranslationDataset
18
+ from .bucket_pad_length_dataset import BucketPadLengthDataset
19
+ from .colorize_dataset import ColorizeDataset
20
+ from .concat_dataset import ConcatDataset
21
+ from .concat_sentences_dataset import ConcatSentencesDataset
22
+ from .denoising_dataset import DenoisingDataset
23
+ from .id_dataset import IdDataset
24
+ from .indexed_dataset import (
25
+ IndexedCachedDataset,
26
+ IndexedDataset,
27
+ IndexedRawTextDataset,
28
+ MMapIndexedDataset,
29
+ )
30
+ from .language_pair_dataset import LanguagePairDataset
31
+ from .list_dataset import ListDataset
32
+ from .lm_context_window_dataset import LMContextWindowDataset
33
+ from .lru_cache_dataset import LRUCacheDataset
34
+ from .mask_tokens_dataset import MaskTokensDataset
35
+ from .monolingual_dataset import MonolingualDataset
36
+ from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset
37
+ from .nested_dictionary_dataset import NestedDictionaryDataset
38
+ from .noising import NoisingDataset
39
+ from .numel_dataset import NumelDataset
40
+ from .num_samples_dataset import NumSamplesDataset
41
+ from .offset_tokens_dataset import OffsetTokensDataset
42
+ from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
43
+ from .prepend_dataset import PrependDataset
44
+ from .prepend_token_dataset import PrependTokenDataset
45
+ from .raw_label_dataset import RawLabelDataset
46
+ from .replace_dataset import ReplaceDataset
47
+ from .resampling_dataset import ResamplingDataset
48
+ from .roll_dataset import RollDataset
49
+ from .round_robin_zip_datasets import RoundRobinZipDatasets
50
+ from .sort_dataset import SortDataset
51
+ from .strip_token_dataset import StripTokenDataset
52
+ from .subsample_dataset import SubsampleDataset
53
+ from .token_block_dataset import TokenBlockDataset
54
+ from .transform_eos_dataset import TransformEosDataset
55
+ from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
56
+ from .shorten_dataset import TruncateDataset, RandomCropDataset
57
+ from .multilingual.sampled_multi_dataset import SampledMultiDataset
58
+ from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset
59
+ from .fasta_dataset import FastaDataset, EncodedFastaDataset
60
+ from .transform_eos_concat_langpair_dataset import TransformEosConcatLangPairDataset
61
+
62
+ from .iterators import (
63
+ CountingIterator,
64
+ EpochBatchIterator,
65
+ GroupedIterator,
66
+ ShardedIterator,
67
+ )
68
+
69
+ __all__ = [
70
+ "AddTargetDataset",
71
+ "AppendTokenDataset",
72
+ "BacktranslationDataset",
73
+ "BaseWrapperDataset",
74
+ "BinarizedAudioDataset",
75
+ "BucketPadLengthDataset",
76
+ "ColorizeDataset",
77
+ "ConcatDataset",
78
+ "ConcatSentencesDataset",
79
+ "CountingIterator",
80
+ "DenoisingDataset",
81
+ "Dictionary",
82
+ "EncodedFastaDataset",
83
+ "EpochBatchIterator",
84
+ "FairseqDataset",
85
+ "FairseqIterableDataset",
86
+ "FastaDataset",
87
+ "FileAudioDataset",
88
+ "GroupedIterator",
89
+ "HubertDataset",
90
+ "IdDataset",
91
+ "IndexedCachedDataset",
92
+ "IndexedDataset",
93
+ "IndexedRawTextDataset",
94
+ "LanguagePairDataset",
95
+ "LeftPadDataset",
96
+ "ListDataset",
97
+ "LMContextWindowDataset",
98
+ "LRUCacheDataset",
99
+ "MaskTokensDataset",
100
+ "MMapIndexedDataset",
101
+ "MonolingualDataset",
102
+ "MultiCorpusSampledDataset",
103
+ "NestedDictionaryDataset",
104
+ "NoisingDataset",
105
+ "NumelDataset",
106
+ "NumSamplesDataset",
107
+ "OffsetTokensDataset",
108
+ "PadDataset",
109
+ "PrependDataset",
110
+ "PrependTokenDataset",
111
+ "RandomCropDataset",
112
+ "RawLabelDataset",
113
+ "ResamplingDataset",
114
+ "ReplaceDataset",
115
+ "RightPadDataset",
116
+ "RollDataset",
117
+ "RoundRobinZipDatasets",
118
+ "SampledMultiDataset",
119
+ "SampledMultiEpochDataset",
120
+ "ShardedIterator",
121
+ "SortDataset",
122
+ "StripTokenDataset",
123
+ "SubsampleDataset",
124
+ "TokenBlockDataset",
125
+ "TransformEosDataset",
126
+ "TransformEosLangPairDataset",
127
+ "TransformEosConcatLangPairDataset",
128
+ "TruncateDataset",
129
+ "TruncatedDictionary",
130
+ ]
modules/voice_conversion/fairseq/data/add_target_dataset.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import BaseWrapperDataset, data_utils
9
+ from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
10
+
11
+
12
+ class AddTargetDataset(BaseWrapperDataset):
13
+ def __init__(
14
+ self,
15
+ dataset,
16
+ labels,
17
+ pad,
18
+ eos,
19
+ batch_targets,
20
+ process_label=None,
21
+ label_len_fn=None,
22
+ add_to_input=False,
23
+ text_compression_level=TextCompressionLevel.none,
24
+ ):
25
+ super().__init__(dataset)
26
+ self.labels = labels
27
+ self.batch_targets = batch_targets
28
+ self.pad = pad
29
+ self.eos = eos
30
+ self.process_label = process_label
31
+ self.label_len_fn = label_len_fn
32
+ self.add_to_input = add_to_input
33
+ self.text_compressor = TextCompressor(level=text_compression_level)
34
+
35
+ def get_label(self, index, process_fn=None):
36
+ lbl = self.labels[index]
37
+ lbl = self.text_compressor.decompress(lbl)
38
+ return lbl if process_fn is None else process_fn(lbl)
39
+
40
+ def __getitem__(self, index):
41
+ item = self.dataset[index]
42
+ item["label"] = self.get_label(index, process_fn=self.process_label)
43
+ return item
44
+
45
+ def size(self, index):
46
+ sz = self.dataset.size(index)
47
+ own_sz = self.label_len_fn(self.get_label(index))
48
+ return sz, own_sz
49
+
50
+ def collater(self, samples):
51
+ collated = self.dataset.collater(samples)
52
+ if len(collated) == 0:
53
+ return collated
54
+ indices = set(collated["id"].tolist())
55
+ target = [s["label"] for s in samples if s["id"] in indices]
56
+
57
+ if self.add_to_input:
58
+ eos = torch.LongTensor([self.eos])
59
+ prev_output_tokens = [torch.cat([eos, t], axis=-1) for t in target]
60
+ target = [torch.cat([t, eos], axis=-1) for t in target]
61
+ collated["net_input"]["prev_output_tokens"] = prev_output_tokens
62
+
63
+ if self.batch_targets:
64
+ collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
65
+ target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
66
+ collated["ntokens"] = collated["target_lengths"].sum().item()
67
+ if getattr(collated["net_input"], "prev_output_tokens", None):
68
+ collated["net_input"]["prev_output_tokens"] = data_utils.collate_tokens(
69
+ collated["net_input"]["prev_output_tokens"],
70
+ pad_idx=self.pad,
71
+ left_pad=False,
72
+ )
73
+ else:
74
+ collated["ntokens"] = sum([len(t) for t in target])
75
+
76
+ collated["target"] = target
77
+ return collated
78
+
79
+ def filter_indices_by_size(self, indices, max_sizes):
80
+ indices, ignored = data_utils._filter_by_size_dynamic(
81
+ indices, self.size, max_sizes
82
+ )
83
+ return indices, ignored
modules/voice_conversion/fairseq/data/append_token_dataset.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from . import BaseWrapperDataset
10
+
11
+
12
+ class AppendTokenDataset(BaseWrapperDataset):
13
+ def __init__(self, dataset, token=None):
14
+ super().__init__(dataset)
15
+ self.token = token
16
+ if token is not None:
17
+ self._sizes = np.array(dataset.sizes) + 1
18
+ else:
19
+ self._sizes = dataset.sizes
20
+
21
+ def __getitem__(self, idx):
22
+ item = self.dataset[idx]
23
+ if self.token is not None:
24
+ item = torch.cat([item, item.new([self.token])])
25
+ return item
26
+
27
+ @property
28
+ def sizes(self):
29
+ return self._sizes
30
+
31
+ def num_tokens(self, index):
32
+ n = self.dataset.num_tokens(index)
33
+ if self.token is not None:
34
+ n += 1
35
+ return n
36
+
37
+ def size(self, index):
38
+ n = self.dataset.size(index)
39
+ if self.token is not None:
40
+ n += 1
41
+ return n
modules/voice_conversion/fairseq/data/audio/__init__.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Optional
3
+ import importlib
4
+ import os
5
+ import numpy as np
6
+
7
+
8
+ class AudioTransform(ABC):
9
+ @classmethod
10
+ @abstractmethod
11
+ def from_config_dict(cls, config: Optional[Dict] = None):
12
+ pass
13
+
14
+
15
+ class CompositeAudioTransform(AudioTransform):
16
+ def _from_config_dict(
17
+ cls,
18
+ transform_type,
19
+ get_audio_transform,
20
+ composite_cls,
21
+ config=None,
22
+ return_empty=False,
23
+ ):
24
+ _config = {} if config is None else config
25
+ _transforms = _config.get(f"{transform_type}_transforms")
26
+
27
+ if _transforms is None:
28
+ if return_empty:
29
+ _transforms = []
30
+ else:
31
+ return None
32
+
33
+ transforms = [
34
+ get_audio_transform(_t).from_config_dict(_config.get(_t))
35
+ for _t in _transforms
36
+ ]
37
+ return composite_cls(transforms)
38
+
39
+ def __init__(self, transforms):
40
+ self.transforms = [t for t in transforms if t is not None]
41
+
42
+ def __call__(self, x):
43
+ for t in self.transforms:
44
+ x = t(x)
45
+ return x
46
+
47
+ def __repr__(self):
48
+ format_string = (
49
+ [self.__class__.__name__ + "("]
50
+ + [f" {t.__repr__()}" for t in self.transforms]
51
+ + [")"]
52
+ )
53
+ return "\n".join(format_string)
54
+
55
+
56
+ def register_audio_transform(name, cls_type, registry, class_names):
57
+ def register_audio_transform_cls(cls):
58
+ if name in registry:
59
+ raise ValueError(f"Cannot register duplicate transform ({name})")
60
+ if not issubclass(cls, cls_type):
61
+ raise ValueError(
62
+ f"Transform ({name}: {cls.__name__}) must extend "
63
+ f"{cls_type.__name__}"
64
+ )
65
+ if cls.__name__ in class_names:
66
+ raise ValueError(
67
+ f"Cannot register audio transform with duplicate "
68
+ f"class name ({cls.__name__})"
69
+ )
70
+ registry[name] = cls
71
+ class_names.add(cls.__name__)
72
+ return cls
73
+
74
+ return register_audio_transform_cls
75
+
76
+
77
+ def import_transforms(transforms_dir, transform_type):
78
+ for file in os.listdir(transforms_dir):
79
+ path = os.path.join(transforms_dir, file)
80
+ if (
81
+ not file.startswith("_")
82
+ and not file.startswith(".")
83
+ and (file.endswith(".py") or os.path.isdir(path))
84
+ ):
85
+ name = file[: file.find(".py")] if file.endswith(".py") else file
86
+ importlib.import_module(
87
+ f"fairseq.data.audio.{transform_type}_transforms." + name
88
+ )
89
+
90
+
91
+ # Utility fn for uniform numbers in transforms
92
+ def rand_uniform(a, b):
93
+ return np.random.uniform() * (b - a) + a
modules/voice_conversion/fairseq/data/audio/audio_utils.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import mmap
8
+ from pathlib import Path
9
+ import io
10
+ from typing import BinaryIO, List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
17
+
18
+ SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"}
19
+ FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
20
+
21
+
22
+ def convert_waveform(
23
+ waveform: Union[np.ndarray, torch.Tensor],
24
+ sample_rate: int,
25
+ normalize_volume: bool = False,
26
+ to_mono: bool = False,
27
+ to_sample_rate: Optional[int] = None,
28
+ ) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
29
+ """convert a waveform:
30
+ - to a target sample rate
31
+ - from multi-channel to mono channel
32
+ - volume normalization
33
+
34
+ Args:
35
+ waveform (numpy.ndarray or torch.Tensor): 2D original waveform
36
+ (channels x length)
37
+ sample_rate (int): original sample rate
38
+ normalize_volume (bool): perform volume normalization
39
+ to_mono (bool): convert to mono channel if having multiple channels
40
+ to_sample_rate (Optional[int]): target sample rate
41
+ Returns:
42
+ waveform (numpy.ndarray): converted 2D waveform (channels x length)
43
+ sample_rate (float): target sample rate
44
+ """
45
+ try:
46
+ import torchaudio.sox_effects as ta_sox
47
+ except ImportError:
48
+ raise ImportError("Please install torchaudio: pip install torchaudio")
49
+
50
+ effects = []
51
+ if normalize_volume:
52
+ effects.append(["gain", "-n"])
53
+ if to_sample_rate is not None and to_sample_rate != sample_rate:
54
+ effects.append(["rate", f"{to_sample_rate}"])
55
+ if to_mono and waveform.shape[0] > 1:
56
+ effects.append(["channels", "1"])
57
+ if len(effects) > 0:
58
+ is_np_input = isinstance(waveform, np.ndarray)
59
+ _waveform = torch.from_numpy(waveform) if is_np_input else waveform
60
+ converted, converted_sample_rate = ta_sox.apply_effects_tensor(
61
+ _waveform, sample_rate, effects
62
+ )
63
+ if is_np_input:
64
+ converted = converted.numpy()
65
+ return converted, converted_sample_rate
66
+ return waveform, sample_rate
67
+
68
+
69
+ def get_waveform(
70
+ path_or_fp: Union[str, BinaryIO],
71
+ normalization: bool = True,
72
+ mono: bool = True,
73
+ frames: int = -1,
74
+ start: int = 0,
75
+ always_2d: bool = True,
76
+ output_sample_rate: Optional[int] = None,
77
+ normalize_volume: bool = False,
78
+ waveform_transforms: Optional[CompositeAudioWaveformTransform] = None,
79
+ ) -> Tuple[np.ndarray, int]:
80
+ """Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
81
+
82
+ Args:
83
+ path_or_fp (str or BinaryIO): the path or file-like object
84
+ normalization (bool): normalize values to [-1, 1] (Default: True)
85
+ mono (bool): convert multi-channel audio to mono-channel one
86
+ frames (int): the number of frames to read. (-1 for reading all)
87
+ start (int): Where to start reading. A negative value counts from the end.
88
+ always_2d (bool): always return 2D array even for mono-channel audios
89
+ output_sample_rate (Optional[int]): output sample rate
90
+ normalize_volume (bool): normalize volume
91
+ Returns:
92
+ waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
93
+ sample_rate (float): sample rate
94
+ """
95
+ if isinstance(path_or_fp, str):
96
+ ext = Path(path_or_fp).suffix
97
+ if ext not in SF_AUDIO_FILE_EXTENSIONS:
98
+ raise ValueError(f"Unsupported audio format: {ext}")
99
+
100
+ try:
101
+ import soundfile as sf
102
+ except ImportError:
103
+ raise ImportError("Please install soundfile: pip install soundfile")
104
+
105
+ waveform, sample_rate = sf.read(
106
+ path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start
107
+ )
108
+ waveform = waveform.T # T x C -> C x T
109
+ waveform, sample_rate = convert_waveform(
110
+ waveform,
111
+ sample_rate,
112
+ normalize_volume=normalize_volume,
113
+ to_mono=mono,
114
+ to_sample_rate=output_sample_rate,
115
+ )
116
+
117
+ if not normalization:
118
+ waveform *= 2**15 # denormalized to 16-bit signed integers
119
+
120
+ if waveform_transforms is not None:
121
+ waveform, sample_rate = waveform_transforms(waveform, sample_rate)
122
+
123
+ if not always_2d:
124
+ waveform = waveform.squeeze(axis=0)
125
+
126
+ return waveform, sample_rate
127
+
128
+
129
+ def get_features_from_npy_or_audio(path, waveform_transforms=None):
130
+ ext = Path(path).suffix
131
+ if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
132
+ raise ValueError(f'Unsupported file format for "{path}"')
133
+ return (
134
+ np.load(path)
135
+ if ext == ".npy"
136
+ else get_fbank(path, waveform_transforms=waveform_transforms)
137
+ )
138
+
139
+
140
+ def get_features_or_waveform_from_stored_zip(
141
+ path,
142
+ byte_offset,
143
+ byte_size,
144
+ need_waveform=False,
145
+ use_sample_rate=None,
146
+ waveform_transforms=None,
147
+ ):
148
+ assert path.endswith(".zip")
149
+ data = read_from_stored_zip(path, byte_offset, byte_size)
150
+ f = io.BytesIO(data)
151
+ if is_npy_data(data):
152
+ features_or_waveform = np.load(f)
153
+ elif is_sf_audio_data(data):
154
+ features_or_waveform = (
155
+ get_waveform(
156
+ f,
157
+ always_2d=False,
158
+ output_sample_rate=use_sample_rate,
159
+ waveform_transforms=waveform_transforms,
160
+ )[0]
161
+ if need_waveform
162
+ else get_fbank(f, waveform_transforms=waveform_transforms)
163
+ )
164
+ else:
165
+ raise ValueError(f'Unknown file format for "{path}"')
166
+ return features_or_waveform
167
+
168
+
169
+ def get_features_or_waveform(
170
+ path: str, need_waveform=False, use_sample_rate=None, waveform_transforms=None
171
+ ):
172
+ """Get speech features from .npy file or waveform from .wav/.flac file.
173
+ The file may be inside an uncompressed ZIP file and is accessed via byte
174
+ offset and length.
175
+
176
+ Args:
177
+ path (str): File path in the format of "<.npy/.wav/.flac path>" or
178
+ "<zip path>:<byte offset>:<byte length>".
179
+ need_waveform (bool): return waveform instead of features.
180
+ use_sample_rate (int): change sample rate for the input wave file
181
+
182
+ Returns:
183
+ features_or_waveform (numpy.ndarray): speech features or waveform.
184
+ """
185
+ _path, slice_ptr = parse_path(path)
186
+ if len(slice_ptr) == 0:
187
+ if need_waveform:
188
+ return get_waveform(
189
+ _path,
190
+ always_2d=False,
191
+ output_sample_rate=use_sample_rate,
192
+ waveform_transforms=waveform_transforms,
193
+ )[0]
194
+ return get_features_from_npy_or_audio(
195
+ _path, waveform_transforms=waveform_transforms
196
+ )
197
+ elif len(slice_ptr) == 2:
198
+ features_or_waveform = get_features_or_waveform_from_stored_zip(
199
+ _path,
200
+ slice_ptr[0],
201
+ slice_ptr[1],
202
+ need_waveform=need_waveform,
203
+ use_sample_rate=use_sample_rate,
204
+ waveform_transforms=waveform_transforms,
205
+ )
206
+ else:
207
+ raise ValueError(f"Invalid path: {path}")
208
+
209
+ return features_or_waveform
210
+
211
+
212
+ def _get_kaldi_fbank(
213
+ waveform: np.ndarray, sample_rate: int, n_bins=80
214
+ ) -> Optional[np.ndarray]:
215
+ """Get mel-filter bank features via PyKaldi."""
216
+ try:
217
+ from kaldi.feat.fbank import Fbank, FbankOptions
218
+ from kaldi.feat.mel import MelBanksOptions
219
+ from kaldi.feat.window import FrameExtractionOptions
220
+ from kaldi.matrix import Vector
221
+
222
+ mel_opts = MelBanksOptions()
223
+ mel_opts.num_bins = n_bins
224
+ frame_opts = FrameExtractionOptions()
225
+ frame_opts.samp_freq = sample_rate
226
+ opts = FbankOptions()
227
+ opts.mel_opts = mel_opts
228
+ opts.frame_opts = frame_opts
229
+ fbank = Fbank(opts=opts)
230
+ features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy()
231
+ return features
232
+ except ImportError:
233
+ return None
234
+
235
+
236
+ def _get_torchaudio_fbank(
237
+ waveform: np.ndarray, sample_rate, n_bins=80
238
+ ) -> Optional[np.ndarray]:
239
+ """Get mel-filter bank features via TorchAudio."""
240
+ try:
241
+ import torchaudio.compliance.kaldi as ta_kaldi
242
+
243
+ waveform = torch.from_numpy(waveform)
244
+ features = ta_kaldi.fbank(
245
+ waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
246
+ )
247
+ return features.numpy()
248
+ except ImportError:
249
+ return None
250
+
251
+
252
+ def get_fbank(
253
+ path_or_fp: Union[str, BinaryIO], n_bins=80, waveform_transforms=None
254
+ ) -> np.ndarray:
255
+ """Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
256
+ (faster CPP implementation) to TorchAudio (Python implementation). Note that
257
+ Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
258
+ waveform should not be normalized."""
259
+ waveform, sample_rate = get_waveform(
260
+ path_or_fp, normalization=False, waveform_transforms=waveform_transforms
261
+ )
262
+
263
+ features = _get_kaldi_fbank(waveform, sample_rate, n_bins)
264
+ if features is None:
265
+ features = _get_torchaudio_fbank(waveform, sample_rate, n_bins)
266
+ if features is None:
267
+ raise ImportError(
268
+ "Please install pyKaldi or torchaudio to enable "
269
+ "online filterbank feature extraction"
270
+ )
271
+
272
+ return features
273
+
274
+
275
+ def is_npy_data(data: bytes) -> bool:
276
+ return data[0] == 147 and data[1] == 78
277
+
278
+
279
+ def is_sf_audio_data(data: bytes) -> bool:
280
+ is_wav = data[0] == 82 and data[1] == 73 and data[2] == 70
281
+ is_flac = data[0] == 102 and data[1] == 76 and data[2] == 97
282
+ is_ogg = data[0] == 79 and data[1] == 103 and data[2] == 103
283
+ return is_wav or is_flac or is_ogg
284
+
285
+
286
+ def mmap_read(path: str, offset: int, length: int) -> bytes:
287
+ with open(path, "rb") as f:
288
+ with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o:
289
+ data = mmap_o[offset : offset + length]
290
+ return data
291
+
292
+
293
+ def read_from_stored_zip(zip_path: str, offset: int, length: int) -> bytes:
294
+ return mmap_read(zip_path, offset, length)
295
+
296
+
297
+ def parse_path(path: str) -> Tuple[str, List[int]]:
298
+ """Parse data path which is either a path to
299
+ 1. a .npy/.wav/.flac/.ogg file
300
+ 2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
301
+
302
+ Args:
303
+ path (str): the data path to parse
304
+
305
+ Returns:
306
+ file_path (str): the file path
307
+ slice_ptr (list of int): empty in case 1;
308
+ byte offset and length for the slice in case 2
309
+ """
310
+
311
+ if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
312
+ _path, slice_ptr = path, []
313
+ else:
314
+ _path, *slice_ptr = path.split(":")
315
+ if not Path(_path).is_file():
316
+ raise FileNotFoundError(f"File not found: {_path}")
317
+ assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}"
318
+ slice_ptr = [int(i) for i in slice_ptr]
319
+ return _path, slice_ptr
320
+
321
+
322
+ def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor:
323
+ padding = n_fft - win_length
324
+ assert padding >= 0
325
+ return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2))
326
+
327
+
328
+ def get_fourier_basis(n_fft: int) -> torch.Tensor:
329
+ basis = np.fft.fft(np.eye(n_fft))
330
+ basis = np.vstack(
331
+ [np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])]
332
+ )
333
+ return torch.from_numpy(basis).float()
334
+
335
+
336
+ def get_mel_filters(
337
+ sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
338
+ ) -> torch.Tensor:
339
+ try:
340
+ import librosa
341
+ except ImportError:
342
+ raise ImportError("Please install librosa: pip install librosa")
343
+ basis = librosa.filters.mel(sample_rate, n_fft, n_mels, f_min, f_max)
344
+ return torch.from_numpy(basis).float()
345
+
346
+
347
+ class TTSSpectrogram(torch.nn.Module):
348
+ def __init__(
349
+ self,
350
+ n_fft: int,
351
+ win_length: int,
352
+ hop_length: int,
353
+ window_fn: callable = torch.hann_window,
354
+ return_phase: bool = False,
355
+ ) -> None:
356
+ super(TTSSpectrogram, self).__init__()
357
+ self.n_fft = n_fft
358
+ self.hop_length = hop_length
359
+ self.return_phase = return_phase
360
+
361
+ basis = get_fourier_basis(n_fft).unsqueeze(1)
362
+ basis *= get_window(window_fn, n_fft, win_length)
363
+ self.register_buffer("basis", basis)
364
+
365
+ def forward(
366
+ self, waveform: torch.Tensor
367
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
368
+ padding = (self.n_fft // 2, self.n_fft // 2)
369
+ x = F.pad(waveform.unsqueeze(1), padding, mode="reflect")
370
+ x = F.conv1d(x, self.basis, stride=self.hop_length)
371
+ real_part = x[:, : self.n_fft // 2 + 1, :]
372
+ imag_part = x[:, self.n_fft // 2 + 1 :, :]
373
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
374
+ if self.return_phase:
375
+ phase = torch.atan2(imag_part, real_part)
376
+ return magnitude, phase
377
+ return magnitude
378
+
379
+
380
+ class TTSMelScale(torch.nn.Module):
381
+ def __init__(
382
+ self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int
383
+ ) -> None:
384
+ super(TTSMelScale, self).__init__()
385
+ basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
386
+ self.register_buffer("basis", basis)
387
+
388
+ def forward(self, specgram: torch.Tensor) -> torch.Tensor:
389
+ return torch.matmul(self.basis, specgram)
modules/voice_conversion/fairseq/data/audio/data_cfg.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from argparse import Namespace
8
+ from copy import deepcopy
9
+ from pathlib import Path
10
+ from typing import Dict, Optional
11
+
12
+ from fairseq.data import Dictionary
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def get_config_from_yaml(yaml_path: Path):
18
+ try:
19
+ import yaml
20
+ except ImportError:
21
+ print("Please install PyYAML: pip install PyYAML")
22
+ config = {}
23
+ if yaml_path.is_file():
24
+ try:
25
+ with open(yaml_path) as f:
26
+ config = yaml.load(f, Loader=yaml.FullLoader)
27
+ except Exception as e:
28
+ raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}")
29
+ else:
30
+ raise FileNotFoundError(f"{yaml_path.as_posix()} not found")
31
+
32
+ return config
33
+
34
+
35
+ class S2TDataConfig(object):
36
+ """Wrapper class for data config YAML"""
37
+
38
+ def __init__(self, yaml_path: Path):
39
+ self.config = get_config_from_yaml(yaml_path)
40
+ self.root = yaml_path.parent
41
+
42
+ def _auto_convert_to_abs_path(self, x):
43
+ if isinstance(x, str):
44
+ if not Path(x).exists() and (self.root / x).exists():
45
+ return (self.root / x).as_posix()
46
+ elif isinstance(x, dict):
47
+ return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()}
48
+ return x
49
+
50
+ @property
51
+ def vocab_filename(self):
52
+ """fairseq vocabulary file under data root"""
53
+ return self.config.get("vocab_filename", "dict.txt")
54
+
55
+ @property
56
+ def speaker_set_filename(self):
57
+ """speaker set file under data root"""
58
+ return self.config.get("speaker_set_filename", None)
59
+
60
+ @property
61
+ def shuffle(self) -> bool:
62
+ """Shuffle dataset samples before batching"""
63
+ return self.config.get("shuffle", False)
64
+
65
+ @property
66
+ def pre_tokenizer(self) -> Dict:
67
+ """Pre-tokenizer to apply before subword tokenization. Returning
68
+ a dictionary with `tokenizer` providing the tokenizer name and
69
+ the other items providing the tokenizer-specific arguments.
70
+ Tokenizers are defined in `fairseq.data.encoders.*`"""
71
+ tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None})
72
+ return self._auto_convert_to_abs_path(tokenizer)
73
+
74
+ @property
75
+ def bpe_tokenizer(self) -> Dict:
76
+ """Subword tokenizer to apply after pre-tokenization. Returning
77
+ a dictionary with `bpe` providing the tokenizer name and
78
+ the other items providing the tokenizer-specific arguments.
79
+ Tokenizers are defined in `fairseq.data.encoders.*`"""
80
+ tokenizer = self.config.get("bpe_tokenizer", {"bpe": None})
81
+ return self._auto_convert_to_abs_path(tokenizer)
82
+
83
+ @property
84
+ def prepend_tgt_lang_tag(self) -> bool:
85
+ """Prepend target lang ID token as the target BOS (e.g. for to-many
86
+ multilingual setting). During inference, this requires `--prefix-size 1`
87
+ to force BOS to be lang ID token."""
88
+ return self.config.get("prepend_tgt_lang_tag", False)
89
+
90
+ @property
91
+ def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
92
+ """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
93
+ return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
94
+
95
+ @property
96
+ def input_feat_per_channel(self):
97
+ """The dimension of input features (per audio channel)"""
98
+ return self.config.get("input_feat_per_channel", 80)
99
+
100
+ @property
101
+ def input_channels(self):
102
+ """The number of channels in the input audio"""
103
+ return self.config.get("input_channels", 1)
104
+
105
+ @property
106
+ def sample_rate(self):
107
+ return self.config.get("sample_rate", 16_000)
108
+
109
+ @property
110
+ def sampling_alpha(self):
111
+ """Hyper-parameter alpha = 1/T for temperature-based resampling.
112
+ (alpha = 1 for no resampling)"""
113
+ return self.config.get("sampling_alpha", 1.0)
114
+
115
+ @property
116
+ def use_audio_input(self):
117
+ """Needed by the dataset loader to see if the model requires
118
+ raw audio as inputs."""
119
+ return self.config.get("use_audio_input", False)
120
+
121
+ def standardize_audio(self) -> bool:
122
+ return self.use_audio_input and self.config.get("standardize_audio", False)
123
+
124
+ @property
125
+ def use_sample_rate(self):
126
+ """Needed by the dataset loader to see if the model requires
127
+ raw audio with specific sample rate as inputs."""
128
+ return self.config.get("use_sample_rate", 16000)
129
+
130
+ @property
131
+ def audio_root(self):
132
+ """Audio paths in the manifest TSV can be relative and this provides
133
+ the root path. Set this to empty string when using absolute paths."""
134
+ return self.config.get("audio_root", "")
135
+
136
+ def get_transforms(self, transform_type, split, is_train):
137
+ """Split-specific feature transforms. Allowing train set
138
+ wildcard `_train`, evaluation set wildcard `_eval` and general
139
+ wildcard `*` for matching."""
140
+ from copy import deepcopy
141
+
142
+ cfg = deepcopy(self.config)
143
+ _cur = cfg.get(f"{transform_type}transforms", {})
144
+ cur = _cur.get(split)
145
+ cur = _cur.get("_train") if cur is None and is_train else cur
146
+ cur = _cur.get("_eval") if cur is None and not is_train else cur
147
+ cur = _cur.get("*") if cur is None else cur
148
+ return cur
149
+
150
+ def get_feature_transforms(self, split, is_train):
151
+ cfg = deepcopy(self.config)
152
+ # TODO: deprecate transforms
153
+ cur = self.get_transforms("", split, is_train)
154
+ if cur is not None:
155
+ logger.warning(
156
+ "Auto converting transforms into feature_transforms, "
157
+ "but transforms will be deprecated in the future. Please "
158
+ "update this in the config."
159
+ )
160
+ ft_transforms = self.get_transforms("feature_", split, is_train)
161
+ if ft_transforms:
162
+ cur.extend(ft_transforms)
163
+ else:
164
+ cur = self.get_transforms("feature_", split, is_train)
165
+ cfg["feature_transforms"] = cur
166
+ return cfg
167
+
168
+ def get_waveform_transforms(self, split, is_train):
169
+ cfg = deepcopy(self.config)
170
+ cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
171
+ return cfg
172
+
173
+ def get_dataset_transforms(self, split, is_train):
174
+ cfg = deepcopy(self.config)
175
+ cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
176
+ return cfg
177
+
178
+ @property
179
+ def global_cmvn_stats_npz(self) -> Optional[str]:
180
+ path = self.config.get("global_cmvn", {}).get("stats_npz_path", None)
181
+ return self._auto_convert_to_abs_path(path)
182
+
183
+ @property
184
+ def vocoder(self) -> Dict[str, str]:
185
+ vocoder = self.config.get("vocoder", {"type": "griffin_lim"})
186
+ return self._auto_convert_to_abs_path(vocoder)
187
+
188
+ @property
189
+ def hub(self) -> Dict[str, str]:
190
+ return self.config.get("hub", {})
191
+
192
+
193
+ class S2SDataConfig(S2TDataConfig):
194
+ """Wrapper class for data config YAML"""
195
+
196
+ @property
197
+ def vocab_filename(self):
198
+ """fairseq vocabulary file under data root"""
199
+ return self.config.get("vocab_filename", None)
200
+
201
+ @property
202
+ def pre_tokenizer(self) -> Dict:
203
+ return None
204
+
205
+ @property
206
+ def bpe_tokenizer(self) -> Dict:
207
+ return None
208
+
209
+ @property
210
+ def input_transformed_channels(self):
211
+ """The number of channels in the audio after feature transforms"""
212
+ # TODO: move this into individual transforms
213
+ # TODO: deprecate transforms
214
+ _cur = self.config.get("transforms", {})
215
+ ft_transforms = self.config.get("feature_transforms", {})
216
+ if _cur and ft_transforms:
217
+ _cur.update(ft_transforms)
218
+ else:
219
+ _cur = self.config.get("feature_transforms", {})
220
+ cur = _cur.get("_train", [])
221
+
222
+ _channels = self.input_channels
223
+ if "delta_deltas" in cur:
224
+ _channels *= 3
225
+
226
+ return _channels
227
+
228
+ @property
229
+ def output_sample_rate(self):
230
+ """The audio sample rate of output target speech"""
231
+ return self.config.get("output_sample_rate", 22050)
232
+
233
+ @property
234
+ def target_speaker_embed(self):
235
+ """Target speaker embedding file (one line per target audio sample)"""
236
+ return self.config.get("target_speaker_embed", None)
237
+
238
+ @property
239
+ def prepend_tgt_lang_tag_as_bos(self) -> bool:
240
+ """Prepend target lang ID token as the target BOS."""
241
+ return self.config.get("prepend_tgt_lang_tag_as_bos", False)
242
+
243
+
244
+ class MultitaskConfig(object):
245
+ """Wrapper class for data config YAML"""
246
+
247
+ def __init__(self, yaml_path: Path):
248
+ config = get_config_from_yaml(yaml_path)
249
+ self.config = {}
250
+ for k, v in config.items():
251
+ self.config[k] = SingleTaskConfig(k, v)
252
+
253
+ def get_all_tasks(self):
254
+ return self.config
255
+
256
+ def get_single_task(self, name):
257
+ assert name in self.config, f"multitask '{name}' does not exist!"
258
+ return self.config[name]
259
+
260
+ @property
261
+ def first_pass_decoder_task_index(self):
262
+ """Return the task index of the first-pass text decoder.
263
+ If there are multiple 'is_first_pass_decoder: True' in the config file,
264
+ the last task is used for the first-pass decoder.
265
+ If there is no 'is_first_pass_decoder: True' in the config file,
266
+ the last task whose task_name includes 'target' and decoder_type is not ctc.
267
+ """
268
+ idx = -1
269
+ for i, (k, v) in enumerate(self.config.items()):
270
+ if v.is_first_pass_decoder:
271
+ idx = i
272
+ if idx < 0:
273
+ for i, (k, v) in enumerate(self.config.items()):
274
+ if k.startswith("target") and v.decoder_type == "transformer":
275
+ idx = i
276
+ return idx
277
+
278
+
279
+ class SingleTaskConfig(object):
280
+ def __init__(self, name, config):
281
+ self.task_name = name
282
+ self.config = config
283
+ dict_path = config.get("dict", "")
284
+ self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None
285
+
286
+ @property
287
+ def data(self):
288
+ return self.config.get("data", "")
289
+
290
+ @property
291
+ def decoder_type(self):
292
+ return self.config.get("decoder_type", "transformer")
293
+
294
+ @property
295
+ def decoder_args(self):
296
+ """Decoder arch related args"""
297
+ args = self.config.get("decoder_args", {})
298
+ return Namespace(**args)
299
+
300
+ @property
301
+ def criterion_cfg(self):
302
+ """cfg for the multitask criterion"""
303
+ if self.decoder_type == "ctc":
304
+ from fairseq.criterions.ctc import CtcCriterionConfig
305
+
306
+ cfg = CtcCriterionConfig
307
+ cfg.zero_infinity = self.config.get("zero_infinity", True)
308
+ else:
309
+ from fairseq.criterions.label_smoothed_cross_entropy import (
310
+ LabelSmoothedCrossEntropyCriterionConfig,
311
+ )
312
+
313
+ cfg = LabelSmoothedCrossEntropyCriterionConfig
314
+ cfg.label_smoothing = self.config.get("label_smoothing", 0.2)
315
+ return cfg
316
+
317
+ @property
318
+ def input_from(self):
319
+ """Condition on encoder/decoder of the main model"""
320
+ return "decoder" if "decoder_layer" in self.config else "encoder"
321
+
322
+ @property
323
+ def input_layer(self):
324
+ if self.input_from == "decoder":
325
+ return self.config["decoder_layer"] - 1
326
+ else:
327
+ # default using the output from the last encoder layer (-1)
328
+ return self.config.get("encoder_layer", 0) - 1
329
+
330
+ @property
331
+ def loss_weight_schedule(self):
332
+ return (
333
+ "decay"
334
+ if "loss_weight_max" in self.config
335
+ and "loss_weight_decay_steps" in self.config
336
+ else "fixed"
337
+ )
338
+
339
+ def get_loss_weight(self, num_updates):
340
+ if self.loss_weight_schedule == "fixed":
341
+ weight = self.config.get("loss_weight", 1.0)
342
+ else: # "decay"
343
+ assert (
344
+ self.config.get("loss_weight_decay_steps", 0) > 0
345
+ ), "loss_weight_decay_steps must be greater than 0 for a decay schedule"
346
+ loss_weight_min = self.config.get("loss_weight_min", 0.0001)
347
+ loss_weight_decay_stepsize = (
348
+ self.config["loss_weight_max"] - loss_weight_min
349
+ ) / self.config["loss_weight_decay_steps"]
350
+ weight = max(
351
+ self.config["loss_weight_max"]
352
+ - loss_weight_decay_stepsize * num_updates,
353
+ loss_weight_min,
354
+ )
355
+ return weight
356
+
357
+ @property
358
+ def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
359
+ """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
360
+ return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
361
+
362
+ @property
363
+ def eos_token(self):
364
+ """EOS token during generation"""
365
+ return self.config.get("eos_token", "<eos>")
366
+
367
+ @property
368
+ def rdrop_alpha(self):
369
+ return self.config.get("rdrop_alpha", 0.0)
370
+
371
+ @property
372
+ def is_first_pass_decoder(self):
373
+ flag = self.config.get("is_first_pass_decoder", False)
374
+ if flag:
375
+ if self.decoder_type == "ctc":
376
+ raise ValueError(
377
+ "First-pass decoder in the multi-decoder model must not be CTC."
378
+ )
379
+ if "target" not in self.task_name:
380
+ raise Warning(
381
+ 'The name of the first-pass decoder does not include "target".'
382
+ )
383
+ return flag
384
+
385
+ @property
386
+ def get_lang_tag_mapping(self):
387
+ return self.config.get("lang_tag_mapping", {})
modules/voice_conversion/fairseq/data/audio/dataset_transforms/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fairseq.data.audio import (
3
+ AudioTransform,
4
+ CompositeAudioTransform,
5
+ import_transforms,
6
+ register_audio_transform,
7
+ )
8
+
9
+
10
+ class AudioDatasetTransform(AudioTransform):
11
+ pass
12
+
13
+
14
+ AUDIO_DATASET_TRANSFORM_REGISTRY = {}
15
+ AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set()
16
+
17
+
18
+ def get_audio_dataset_transform(name):
19
+ return AUDIO_DATASET_TRANSFORM_REGISTRY[name]
20
+
21
+
22
+ def register_audio_dataset_transform(name):
23
+ return register_audio_transform(
24
+ name,
25
+ AudioDatasetTransform,
26
+ AUDIO_DATASET_TRANSFORM_REGISTRY,
27
+ AUDIO_DATASET_TRANSFORM_CLASS_NAMES,
28
+ )
29
+
30
+
31
+ import_transforms(os.path.dirname(__file__), "dataset")
32
+
33
+
34
+ class CompositeAudioDatasetTransform(CompositeAudioTransform):
35
+ @classmethod
36
+ def from_config_dict(cls, config=None):
37
+ return super()._from_config_dict(
38
+ cls,
39
+ "dataset",
40
+ get_audio_dataset_transform,
41
+ CompositeAudioDatasetTransform,
42
+ config,
43
+ return_empty=True,
44
+ )
45
+
46
+ def get_transform(self, cls):
47
+ for t in self.transforms:
48
+ if isinstance(t, cls):
49
+ return t
50
+ return None
51
+
52
+ def has_transform(self, cls):
53
+ return self.get_transform(cls) is not None
modules/voice_conversion/fairseq/data/audio/dataset_transforms/concataugment.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+
4
+ from fairseq.data.audio.dataset_transforms import (
5
+ AudioDatasetTransform,
6
+ register_audio_dataset_transform,
7
+ )
8
+
9
+ _DEFAULTS = {"rate": 0.25, "max_tokens": 3000, "attempts": 5}
10
+
11
+
12
+ @register_audio_dataset_transform("concataugment")
13
+ class ConcatAugment(AudioDatasetTransform):
14
+ @classmethod
15
+ def from_config_dict(cls, config=None):
16
+ _config = {} if config is None else config
17
+ return ConcatAugment(
18
+ _config.get("rate", _DEFAULTS["rate"]),
19
+ _config.get("max_tokens", _DEFAULTS["max_tokens"]),
20
+ _config.get("attempts", _DEFAULTS["attempts"]),
21
+ )
22
+
23
+ def __init__(
24
+ self,
25
+ rate=_DEFAULTS["rate"],
26
+ max_tokens=_DEFAULTS["max_tokens"],
27
+ attempts=_DEFAULTS["attempts"],
28
+ ):
29
+ self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
30
+
31
+ def __repr__(self):
32
+ return (
33
+ self.__class__.__name__
34
+ + "("
35
+ + ", ".join(
36
+ [
37
+ f"rate={self.rate}",
38
+ f"max_tokens={self.max_tokens}",
39
+ f"attempts={self.attempts}",
40
+ ]
41
+ )
42
+ + ")"
43
+ )
44
+
45
+ def find_indices(self, index: int, n_frames: List[int], n_samples: int):
46
+ # skip conditions: application rate, max_tokens limit exceeded
47
+ if np.random.random() > self.rate:
48
+ return [index]
49
+ if self.max_tokens and n_frames[index] > self.max_tokens:
50
+ return [index]
51
+
52
+ # pick second sample to concatenate
53
+ for _ in range(self.attempts):
54
+ index2 = np.random.randint(0, n_samples)
55
+ if index2 != index and (
56
+ not self.max_tokens
57
+ or n_frames[index] + n_frames[index2] < self.max_tokens
58
+ ):
59
+ return [index, index2]
60
+
61
+ return [index]
modules/voice_conversion/fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from fairseq.data.audio import rand_uniform
5
+ from fairseq.data.audio.dataset_transforms import (
6
+ AudioDatasetTransform,
7
+ register_audio_dataset_transform,
8
+ )
9
+ from fairseq.data.audio.waveform_transforms.noiseaugment import (
10
+ NoiseAugmentTransform,
11
+ )
12
+
13
+ _DEFAULTS = {
14
+ "rate": 0.25,
15
+ "mixing_noise_rate": 0.1,
16
+ "noise_path": "",
17
+ "noise_snr_min": -5,
18
+ "noise_snr_max": 5,
19
+ "utterance_snr_min": -5,
20
+ "utterance_snr_max": 5,
21
+ }
22
+
23
+
24
+ @register_audio_dataset_transform("noisyoverlapaugment")
25
+ class NoisyOverlapAugment(AudioDatasetTransform):
26
+ @classmethod
27
+ def from_config_dict(cls, config=None):
28
+ _config = {} if config is None else config
29
+ return NoisyOverlapAugment(
30
+ _config.get("rate", _DEFAULTS["rate"]),
31
+ _config.get("mixing_noise_rate", _DEFAULTS["mixing_noise_rate"]),
32
+ _config.get("noise_path", _DEFAULTS["noise_path"]),
33
+ _config.get("noise_snr_min", _DEFAULTS["noise_snr_min"]),
34
+ _config.get("noise_snr_max", _DEFAULTS["noise_snr_max"]),
35
+ _config.get("utterance_snr_min", _DEFAULTS["utterance_snr_min"]),
36
+ _config.get("utterance_snr_max", _DEFAULTS["utterance_snr_max"]),
37
+ )
38
+
39
+ def __init__(
40
+ self,
41
+ rate=_DEFAULTS["rate"],
42
+ mixing_noise_rate=_DEFAULTS["mixing_noise_rate"],
43
+ noise_path=_DEFAULTS["noise_path"],
44
+ noise_snr_min=_DEFAULTS["noise_snr_min"],
45
+ noise_snr_max=_DEFAULTS["noise_snr_max"],
46
+ utterance_snr_min=_DEFAULTS["utterance_snr_min"],
47
+ utterance_snr_max=_DEFAULTS["utterance_snr_max"],
48
+ ):
49
+ self.rate = rate
50
+ self.mixing_noise_rate = mixing_noise_rate
51
+ self.noise_shaper = NoiseAugmentTransform(noise_path)
52
+ self.noise_snr_min = noise_snr_min
53
+ self.noise_snr_max = noise_snr_max
54
+ self.utterance_snr_min = utterance_snr_min
55
+ self.utterance_snr_max = utterance_snr_max
56
+
57
+ def __repr__(self):
58
+ return (
59
+ self.__class__.__name__
60
+ + "("
61
+ + ", ".join(
62
+ [
63
+ f"rate={self.rate}",
64
+ f"mixing_noise_rate={self.mixing_noise_rate}",
65
+ f"noise_snr_min={self.noise_snr_min}",
66
+ f"noise_snr_max={self.noise_snr_max}",
67
+ f"utterance_snr_min={self.utterance_snr_min}",
68
+ f"utterance_snr_max={self.utterance_snr_max}",
69
+ ]
70
+ )
71
+ + ")"
72
+ )
73
+
74
+ def __call__(self, sources):
75
+ for i, source in enumerate(sources):
76
+ if np.random.random() > self.rate:
77
+ continue
78
+
79
+ pri = source.numpy()
80
+
81
+ if np.random.random() > self.mixing_noise_rate:
82
+ sec = sources[np.random.randint(0, len(sources))].numpy()
83
+ snr = rand_uniform(self.utterance_snr_min, self.utterance_snr_max)
84
+ else:
85
+ sec = self.noise_shaper.pick_sample(source.shape)
86
+ snr = rand_uniform(self.noise_snr_min, self.noise_snr_max)
87
+
88
+ L1 = pri.shape[-1]
89
+ L2 = sec.shape[-1]
90
+ l = np.random.randint(0, min(round(L1 / 2), L2)) # mix len
91
+ s_source = np.random.randint(0, L1 - l)
92
+ s_sec = np.random.randint(0, L2 - l)
93
+
94
+ get_power = lambda x: np.mean(x**2)
95
+ if get_power(sec) == 0:
96
+ continue
97
+
98
+ scl = np.sqrt(get_power(pri) / (np.power(10, snr / 10) * get_power(sec)))
99
+
100
+ pri[s_source : s_source + l] = np.add(
101
+ pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
102
+ )
103
+ sources[i] = torch.from_numpy(pri).float()
104
+
105
+ return sources
modules/voice_conversion/fairseq/data/audio/feature_transforms/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fairseq.data.audio import (
3
+ AudioTransform,
4
+ CompositeAudioTransform,
5
+ import_transforms,
6
+ register_audio_transform,
7
+ )
8
+
9
+
10
+ class AudioFeatureTransform(AudioTransform):
11
+ pass
12
+
13
+
14
+ AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
15
+ AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
16
+
17
+
18
+ def get_audio_feature_transform(name):
19
+ return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
20
+
21
+
22
+ def register_audio_feature_transform(name):
23
+ return register_audio_transform(
24
+ name,
25
+ AudioFeatureTransform,
26
+ AUDIO_FEATURE_TRANSFORM_REGISTRY,
27
+ AUDIO_FEATURE_TRANSFORM_CLASS_NAMES,
28
+ )
29
+
30
+
31
+ import_transforms(os.path.dirname(__file__), "feature")
32
+
33
+
34
+ class CompositeAudioFeatureTransform(CompositeAudioTransform):
35
+ @classmethod
36
+ def from_config_dict(cls, config=None):
37
+ return super()._from_config_dict(
38
+ cls,
39
+ "feature",
40
+ get_audio_feature_transform,
41
+ CompositeAudioFeatureTransform,
42
+ config,
43
+ )
modules/voice_conversion/fairseq/data/audio/feature_transforms/delta_deltas.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from fairseq.data.audio.feature_transforms import (
4
+ AudioFeatureTransform,
5
+ register_audio_feature_transform,
6
+ )
7
+
8
+
9
+ @register_audio_feature_transform("delta_deltas")
10
+ class DeltaDeltas(AudioFeatureTransform):
11
+ """Expand delta-deltas features from spectrum."""
12
+
13
+ @classmethod
14
+ def from_config_dict(cls, config=None):
15
+ _config = {} if config is None else config
16
+ return DeltaDeltas(_config.get("win_length", 5))
17
+
18
+ def __init__(self, win_length=5):
19
+ self.win_length = win_length
20
+
21
+ def __repr__(self):
22
+ return self.__class__.__name__
23
+
24
+ def __call__(self, spectrogram):
25
+ from torchaudio.functional import compute_deltas
26
+
27
+ assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
28
+ # spectrogram is T x F, while compute_deltas takes (…, F, T)
29
+ spectrogram = torch.from_numpy(spectrogram).transpose(0, 1)
30
+ delta = compute_deltas(spectrogram)
31
+ delta_delta = compute_deltas(delta)
32
+
33
+ out_feat = np.concatenate(
34
+ [spectrogram, delta.numpy(), delta_delta.numpy()], axis=0
35
+ )
36
+ out_feat = np.transpose(out_feat)
37
+ return out_feat
modules/voice_conversion/fairseq/data/audio/feature_transforms/global_cmvn.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from fairseq.data.audio.feature_transforms import (
3
+ AudioFeatureTransform,
4
+ register_audio_feature_transform,
5
+ )
6
+
7
+
8
+ @register_audio_feature_transform("global_cmvn")
9
+ class GlobalCMVN(AudioFeatureTransform):
10
+ """Global CMVN (cepstral mean and variance normalization). The global mean
11
+ and variance need to be pre-computed and stored in NumPy format (.npz)."""
12
+
13
+ @classmethod
14
+ def from_config_dict(cls, config=None):
15
+ _config = {} if config is None else config
16
+ return GlobalCMVN(_config.get("stats_npz_path"))
17
+
18
+ def __init__(self, stats_npz_path):
19
+ self.stats_npz_path = stats_npz_path
20
+ stats = np.load(stats_npz_path)
21
+ self.mean, self.std = stats["mean"], stats["std"]
22
+
23
+ def __repr__(self):
24
+ return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")'
25
+
26
+ def __call__(self, x):
27
+ x = np.subtract(x, self.mean)
28
+ x = np.divide(x, self.std)
29
+ return x
modules/voice_conversion/fairseq/data/audio/feature_transforms/specaugment.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numbers
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ from fairseq.data.audio.feature_transforms import (
7
+ AudioFeatureTransform,
8
+ register_audio_feature_transform,
9
+ )
10
+
11
+
12
+ @register_audio_feature_transform("specaugment")
13
+ class SpecAugmentTransform(AudioFeatureTransform):
14
+ """SpecAugment (https://arxiv.org/abs/1904.08779)"""
15
+
16
+ @classmethod
17
+ def from_config_dict(cls, config=None):
18
+ _config = {} if config is None else config
19
+ return SpecAugmentTransform(
20
+ _config.get("time_warp_W", 0),
21
+ _config.get("freq_mask_N", 0),
22
+ _config.get("freq_mask_F", 0),
23
+ _config.get("time_mask_N", 0),
24
+ _config.get("time_mask_T", 0),
25
+ _config.get("time_mask_p", 0.0),
26
+ _config.get("mask_value", None),
27
+ )
28
+
29
+ def __init__(
30
+ self,
31
+ time_warp_w: int = 0,
32
+ freq_mask_n: int = 0,
33
+ freq_mask_f: int = 0,
34
+ time_mask_n: int = 0,
35
+ time_mask_t: int = 0,
36
+ time_mask_p: float = 0.0,
37
+ mask_value: Optional[float] = 0.0,
38
+ ):
39
+ # Sanity checks
40
+ assert mask_value is None or isinstance(
41
+ mask_value, numbers.Number
42
+ ), f"mask_value (type: {type(mask_value)}) must be None or a number"
43
+ if freq_mask_n > 0:
44
+ assert freq_mask_f > 0, (
45
+ f"freq_mask_F ({freq_mask_f}) "
46
+ f"must be larger than 0 when doing freq masking."
47
+ )
48
+ if time_mask_n > 0:
49
+ assert time_mask_t > 0, (
50
+ f"time_mask_T ({time_mask_t}) must be larger than 0 when "
51
+ f"doing time masking."
52
+ )
53
+
54
+ self.time_warp_w = time_warp_w
55
+ self.freq_mask_n = freq_mask_n
56
+ self.freq_mask_f = freq_mask_f
57
+ self.time_mask_n = time_mask_n
58
+ self.time_mask_t = time_mask_t
59
+ self.time_mask_p = time_mask_p
60
+ self.mask_value = mask_value
61
+
62
+ def __repr__(self):
63
+ return (
64
+ self.__class__.__name__
65
+ + "("
66
+ + ", ".join(
67
+ [
68
+ f"time_warp_w={self.time_warp_w}",
69
+ f"freq_mask_n={self.freq_mask_n}",
70
+ f"freq_mask_f={self.freq_mask_f}",
71
+ f"time_mask_n={self.time_mask_n}",
72
+ f"time_mask_t={self.time_mask_t}",
73
+ f"time_mask_p={self.time_mask_p}",
74
+ ]
75
+ )
76
+ + ")"
77
+ )
78
+
79
+ def __call__(self, spectrogram):
80
+ assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
81
+
82
+ distorted = spectrogram.copy() # make a copy of input spectrogram.
83
+ num_frames = spectrogram.shape[0] # or 'tau' in the paper.
84
+ num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
85
+ mask_value = self.mask_value
86
+
87
+ if mask_value is None: # if no value was specified, use local mean.
88
+ mask_value = spectrogram.mean()
89
+
90
+ if num_frames == 0:
91
+ return spectrogram
92
+
93
+ if num_freqs < self.freq_mask_f:
94
+ return spectrogram
95
+
96
+ if self.time_warp_w > 0:
97
+ if 2 * self.time_warp_w < num_frames:
98
+ import cv2
99
+
100
+ w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
101
+ w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w)
102
+ upper, lower = distorted[:w0, :], distorted[w0:, :]
103
+ upper = cv2.resize(
104
+ upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
105
+ )
106
+ lower = cv2.resize(
107
+ lower,
108
+ dsize=(num_freqs, num_frames - w0 - w),
109
+ interpolation=cv2.INTER_LINEAR,
110
+ )
111
+ distorted = np.concatenate((upper, lower), axis=0)
112
+
113
+ for _i in range(self.freq_mask_n):
114
+ f = np.random.randint(0, self.freq_mask_f)
115
+ f0 = np.random.randint(0, num_freqs - f)
116
+ if f != 0:
117
+ distorted[:, f0 : f0 + f] = mask_value
118
+
119
+ max_time_mask_t = min(
120
+ self.time_mask_t, math.floor(num_frames * self.time_mask_p)
121
+ )
122
+ if max_time_mask_t < 1:
123
+ return distorted
124
+
125
+ for _i in range(self.time_mask_n):
126
+ t = np.random.randint(0, max_time_mask_t)
127
+ t0 = np.random.randint(0, num_frames - t)
128
+ if t != 0:
129
+ distorted[t0 : t0 + t, :] = mask_value
130
+
131
+ return distorted
modules/voice_conversion/fairseq/data/audio/feature_transforms/utterance_cmvn.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from fairseq.data.audio.feature_transforms import (
4
+ AudioFeatureTransform,
5
+ register_audio_feature_transform,
6
+ )
7
+
8
+
9
+ @register_audio_feature_transform("utterance_cmvn")
10
+ class UtteranceCMVN(AudioFeatureTransform):
11
+ """Utterance-level CMVN (cepstral mean and variance normalization)"""
12
+
13
+ @classmethod
14
+ def from_config_dict(cls, config=None):
15
+ _config = {} if config is None else config
16
+ return UtteranceCMVN(
17
+ _config.get("norm_means", True),
18
+ _config.get("norm_vars", True),
19
+ )
20
+
21
+ def __init__(self, norm_means=True, norm_vars=True):
22
+ self.norm_means, self.norm_vars = norm_means, norm_vars
23
+
24
+ def __repr__(self):
25
+ return (
26
+ self.__class__.__name__
27
+ + f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
28
+ )
29
+
30
+ def __call__(self, x):
31
+ mean = x.mean(axis=0)
32
+ square_sums = (x**2).sum(axis=0)
33
+
34
+ if self.norm_means:
35
+ x = np.subtract(x, mean)
36
+ if self.norm_vars:
37
+ var = square_sums / x.shape[0] - mean**2
38
+ std = np.sqrt(np.maximum(var, 1e-10))
39
+ x = np.divide(x, std)
40
+
41
+ return x
modules/voice_conversion/fairseq/data/audio/frm_text_to_speech_dataset.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.abs
7
+
8
+ import csv
9
+ import logging
10
+ import os.path as op
11
+ from typing import List, Optional
12
+
13
+ import numpy as np
14
+ import torch
15
+ from fairseq.data import Dictionary
16
+ from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
17
+ from fairseq.data.audio.text_to_speech_dataset import (
18
+ TextToSpeechDataset,
19
+ TextToSpeechDatasetCreator,
20
+ )
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class FrmTextToSpeechDataset(TextToSpeechDataset):
26
+ def __init__(
27
+ self,
28
+ split: str,
29
+ is_train_split: bool,
30
+ data_cfg: S2TDataConfig,
31
+ audio_paths: List[str],
32
+ n_frames: List[int],
33
+ src_texts: Optional[List[str]] = None,
34
+ tgt_texts: Optional[List[str]] = None,
35
+ speakers: Optional[List[str]] = None,
36
+ src_langs: Optional[List[str]] = None,
37
+ tgt_langs: Optional[List[str]] = None,
38
+ ids: Optional[List[str]] = None,
39
+ tgt_dict: Optional[Dictionary] = None,
40
+ pre_tokenizer=None,
41
+ bpe_tokenizer=None,
42
+ n_frames_per_step=1,
43
+ speaker_to_id=None,
44
+ do_chunk=False,
45
+ chunk_bound=-1,
46
+ chunk_init=50,
47
+ chunk_incr=5,
48
+ add_eos=True,
49
+ dedup=True,
50
+ ref_fpu=-1,
51
+ ):
52
+ # It assumes texts are encoded at a fixed frame-rate
53
+ super().__init__(
54
+ split=split,
55
+ is_train_split=is_train_split,
56
+ data_cfg=data_cfg,
57
+ audio_paths=audio_paths,
58
+ n_frames=n_frames,
59
+ src_texts=src_texts,
60
+ tgt_texts=tgt_texts,
61
+ speakers=speakers,
62
+ src_langs=src_langs,
63
+ tgt_langs=tgt_langs,
64
+ ids=ids,
65
+ tgt_dict=tgt_dict,
66
+ pre_tokenizer=pre_tokenizer,
67
+ bpe_tokenizer=bpe_tokenizer,
68
+ n_frames_per_step=n_frames_per_step,
69
+ speaker_to_id=speaker_to_id,
70
+ )
71
+
72
+ self.do_chunk = do_chunk
73
+ self.chunk_bound = chunk_bound
74
+ self.chunk_init = chunk_init
75
+ self.chunk_incr = chunk_incr
76
+ self.add_eos = add_eos
77
+ self.dedup = dedup
78
+ self.ref_fpu = ref_fpu
79
+
80
+ self.chunk_size = -1
81
+
82
+ if do_chunk:
83
+ assert self.chunk_incr >= 0
84
+ assert self.pre_tokenizer is None
85
+
86
+ def __getitem__(self, index):
87
+ index, source, target, speaker_id, _, _, _ = super().__getitem__(index)
88
+ if target[-1].item() == self.tgt_dict.eos_index:
89
+ target = target[:-1]
90
+
91
+ fpu = source.size(0) / target.size(0) # frame-per-unit
92
+ fps = self.n_frames_per_step
93
+ assert (
94
+ self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
95
+ ), f"{fpu*fps} != {self.ref_fpu}"
96
+
97
+ # only chunk training split
98
+ if self.is_train_split and self.do_chunk and self.chunk_size > 0:
99
+ lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)]
100
+ text = target[int(self.data_cfg.prepend_tgt_lang_tag) :]
101
+ size = len(text)
102
+ chunk_size = min(self.chunk_size, size)
103
+ chunk_start = np.random.randint(size - chunk_size + 1)
104
+ text = text[chunk_start : chunk_start + chunk_size]
105
+ target = torch.cat((lang, text), 0)
106
+
107
+ f_size = int(np.floor(chunk_size * fpu))
108
+ f_start = int(np.floor(chunk_start * fpu))
109
+ assert f_size > 0
110
+ source = source[f_start : f_start + f_size, :]
111
+
112
+ if self.dedup:
113
+ target = torch.unique_consecutive(target)
114
+
115
+ if self.add_eos:
116
+ eos_idx = self.tgt_dict.eos_index
117
+ target = torch.cat((target, torch.LongTensor([eos_idx])), 0)
118
+
119
+ return index, source, target, speaker_id
120
+
121
+ def set_epoch(self, epoch):
122
+ if self.is_train_split and self.do_chunk:
123
+ old = self.chunk_size
124
+ self.chunk_size = self.chunk_init + epoch * self.chunk_incr
125
+ if self.chunk_bound > 0:
126
+ self.chunk_size = min(self.chunk_size, self.chunk_bound)
127
+ logger.info(
128
+ (
129
+ f"{self.split}: setting chunk size "
130
+ f"from {old} to {self.chunk_size}"
131
+ )
132
+ )
133
+
134
+
135
+ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
136
+ # inherit for key names
137
+ @classmethod
138
+ def from_tsv(
139
+ cls,
140
+ root: str,
141
+ data_cfg: S2TDataConfig,
142
+ split: str,
143
+ tgt_dict,
144
+ pre_tokenizer,
145
+ bpe_tokenizer,
146
+ is_train_split: bool,
147
+ n_frames_per_step: int,
148
+ speaker_to_id,
149
+ do_chunk: bool = False,
150
+ chunk_bound: int = -1,
151
+ chunk_init: int = 50,
152
+ chunk_incr: int = 5,
153
+ add_eos: bool = True,
154
+ dedup: bool = True,
155
+ ref_fpu: float = -1,
156
+ ) -> FrmTextToSpeechDataset:
157
+ tsv_path = op.join(root, f"{split}.tsv")
158
+ if not op.isfile(tsv_path):
159
+ raise FileNotFoundError(f"Dataset not found: {tsv_path}")
160
+ with open(tsv_path) as f:
161
+ reader = csv.DictReader(
162
+ f,
163
+ delimiter="\t",
164
+ quotechar=None,
165
+ doublequote=False,
166
+ lineterminator="\n",
167
+ quoting=csv.QUOTE_NONE,
168
+ )
169
+ s = [dict(e) for e in reader]
170
+ assert len(s) > 0
171
+
172
+ ids = [ss[cls.KEY_ID] for ss in s]
173
+ audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
174
+ n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
175
+ tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
176
+ src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
177
+ speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]
178
+ src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]
179
+ tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]
180
+
181
+ return FrmTextToSpeechDataset(
182
+ split=split,
183
+ is_train_split=is_train_split,
184
+ data_cfg=data_cfg,
185
+ audio_paths=audio_paths,
186
+ n_frames=n_frames,
187
+ src_texts=src_texts,
188
+ tgt_texts=tgt_texts,
189
+ speakers=speakers,
190
+ src_langs=src_langs,
191
+ tgt_langs=tgt_langs,
192
+ ids=ids,
193
+ tgt_dict=tgt_dict,
194
+ pre_tokenizer=pre_tokenizer,
195
+ bpe_tokenizer=bpe_tokenizer,
196
+ n_frames_per_step=n_frames_per_step,
197
+ speaker_to_id=speaker_to_id,
198
+ do_chunk=do_chunk,
199
+ chunk_bound=chunk_bound,
200
+ chunk_init=chunk_init,
201
+ chunk_incr=chunk_incr,
202
+ add_eos=add_eos,
203
+ dedup=dedup,
204
+ ref_fpu=ref_fpu,
205
+ )
modules/voice_conversion/fairseq/data/audio/hubert_dataset.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import logging
8
+ import os
9
+ import sys
10
+ from typing import Any, List, Optional, Union
11
+
12
+ import numpy as np
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from fairseq.data import data_utils
17
+ from fairseq.data.fairseq_dataset import FairseqDataset
18
+ from fairseq.data.audio.audio_utils import (
19
+ parse_path,
20
+ read_from_stored_zip,
21
+ )
22
+ import io
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def load_audio(manifest_path, max_keep, min_keep):
28
+ n_long, n_short = 0, 0
29
+ names, inds, sizes = [], [], []
30
+ with open(manifest_path) as f:
31
+ root = f.readline().strip()
32
+ for ind, line in enumerate(f):
33
+ items = line.strip().split("\t")
34
+ assert len(items) == 2, line
35
+ sz = int(items[1])
36
+ if min_keep is not None and sz < min_keep:
37
+ n_short += 1
38
+ elif max_keep is not None and sz > max_keep:
39
+ n_long += 1
40
+ else:
41
+ names.append(items[0])
42
+ inds.append(ind)
43
+ sizes.append(sz)
44
+ tot = ind + 1
45
+ logger.info(
46
+ (
47
+ f"max_keep={max_keep}, min_keep={min_keep}, "
48
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
49
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
50
+ )
51
+ )
52
+ return root, names, inds, tot, sizes
53
+
54
+
55
+ def load_label(label_path, inds, tot):
56
+ with open(label_path) as f:
57
+ labels = [line.rstrip() for line in f]
58
+ assert (
59
+ len(labels) == tot
60
+ ), f"number of labels does not match ({len(labels)} != {tot})"
61
+ labels = [labels[i] for i in inds]
62
+ return labels
63
+
64
+
65
+ def load_label_offset(label_path, inds, tot):
66
+ with open(label_path) as f:
67
+ code_lengths = [len(line.encode("utf-8")) for line in f]
68
+ assert (
69
+ len(code_lengths) == tot
70
+ ), f"number of labels does not match ({len(code_lengths)} != {tot})"
71
+ offsets = list(itertools.accumulate([0] + code_lengths))
72
+ offsets = [(offsets[i], offsets[i + 1]) for i in inds]
73
+ return offsets
74
+
75
+
76
+ def verify_label_lengths(
77
+ audio_sizes,
78
+ audio_rate,
79
+ label_path,
80
+ label_rate,
81
+ inds,
82
+ tot,
83
+ tol=0.1, # tolerance in seconds
84
+ ):
85
+ if label_rate < 0:
86
+ logger.info(f"{label_path} is sequence label. skipped")
87
+ return
88
+
89
+ with open(label_path) as f:
90
+ lengths = [len(line.rstrip().split()) for line in f]
91
+ assert len(lengths) == tot
92
+ lengths = [lengths[i] for i in inds]
93
+ num_invalid = 0
94
+ for i, ind in enumerate(inds):
95
+ dur_from_audio = audio_sizes[i] / audio_rate
96
+ dur_from_label = lengths[i] / label_rate
97
+ if abs(dur_from_audio - dur_from_label) > tol:
98
+ logger.warning(
99
+ (
100
+ f"audio and label duration differ too much "
101
+ f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
102
+ f"in line {ind+1} of {label_path}. Check if `label_rate` "
103
+ f"is correctly set (currently {label_rate}). "
104
+ f"num. of samples = {audio_sizes[i]}; "
105
+ f"label length = {lengths[i]}"
106
+ )
107
+ )
108
+ num_invalid += 1
109
+ if num_invalid > 0:
110
+ logger.warning(
111
+ f"total {num_invalid} (audio, label) pairs with mismatched lengths"
112
+ )
113
+
114
+
115
+ class HubertDataset(FairseqDataset):
116
+ def __init__(
117
+ self,
118
+ manifest_path: str,
119
+ sample_rate: float,
120
+ label_paths: List[str],
121
+ label_rates: Union[List[float], float], # -1 for sequence labels
122
+ pad_list: List[str],
123
+ eos_list: List[str],
124
+ label_processors: Optional[List[Any]] = None,
125
+ max_keep_sample_size: Optional[int] = None,
126
+ min_keep_sample_size: Optional[int] = None,
127
+ max_sample_size: Optional[int] = None,
128
+ shuffle: bool = True,
129
+ pad_audio: bool = False,
130
+ normalize: bool = False,
131
+ store_labels: bool = True,
132
+ random_crop: bool = False,
133
+ single_target: bool = False,
134
+ ):
135
+ self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
136
+ manifest_path, max_keep_sample_size, min_keep_sample_size
137
+ )
138
+ self.sample_rate = sample_rate
139
+ self.shuffle = shuffle
140
+ self.random_crop = random_crop
141
+
142
+ self.num_labels = len(label_paths)
143
+ self.pad_list = pad_list
144
+ self.eos_list = eos_list
145
+ self.label_processors = label_processors
146
+ self.single_target = single_target
147
+ self.label_rates = (
148
+ [label_rates for _ in range(len(label_paths))]
149
+ if isinstance(label_rates, float)
150
+ else label_rates
151
+ )
152
+ self.store_labels = store_labels
153
+ if store_labels:
154
+ self.label_list = [load_label(p, inds, tot) for p in label_paths]
155
+ else:
156
+ self.label_paths = label_paths
157
+ self.label_offsets_list = [
158
+ load_label_offset(p, inds, tot) for p in label_paths
159
+ ]
160
+ assert label_processors is None or len(label_processors) == self.num_labels
161
+ for label_path, label_rate in zip(label_paths, self.label_rates):
162
+ verify_label_lengths(
163
+ self.sizes, sample_rate, label_path, label_rate, inds, tot
164
+ )
165
+
166
+ self.max_sample_size = (
167
+ max_sample_size if max_sample_size is not None else sys.maxsize
168
+ )
169
+ self.pad_audio = pad_audio
170
+ self.normalize = normalize
171
+ logger.info(
172
+ f"pad_audio={pad_audio}, random_crop={random_crop}, "
173
+ f"normalize={normalize}, max_sample_size={self.max_sample_size}"
174
+ )
175
+
176
+ def get_audio(self, index):
177
+ import soundfile as sf
178
+
179
+ wav_path = os.path.join(self.audio_root, self.audio_names[index])
180
+ _path, slice_ptr = parse_path(wav_path)
181
+ if len(slice_ptr) == 0:
182
+ wav, cur_sample_rate = sf.read(_path)
183
+ else:
184
+ assert _path.endswith(".zip")
185
+ data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
186
+ f = io.BytesIO(data)
187
+ wav, cur_sample_rate = sf.read(f)
188
+ wav = torch.from_numpy(wav).float()
189
+ wav = self.postprocess(wav, cur_sample_rate)
190
+ return wav
191
+
192
+ def get_label(self, index, label_idx):
193
+ if self.store_labels:
194
+ label = self.label_list[label_idx][index]
195
+ else:
196
+ with open(self.label_paths[label_idx]) as f:
197
+ offset_s, offset_e = self.label_offsets_list[label_idx][index]
198
+ f.seek(offset_s)
199
+ label = f.read(offset_e - offset_s)
200
+
201
+ if self.label_processors is not None:
202
+ label = self.label_processors[label_idx](label)
203
+ return label
204
+
205
+ def get_labels(self, index):
206
+ return [self.get_label(index, i) for i in range(self.num_labels)]
207
+
208
+ def __getitem__(self, index):
209
+ wav = self.get_audio(index)
210
+ labels = self.get_labels(index)
211
+ return {"id": index, "source": wav, "label_list": labels}
212
+
213
+ def __len__(self):
214
+ return len(self.sizes)
215
+
216
+ def crop_to_max_size(self, wav, target_size):
217
+ size = len(wav)
218
+ diff = size - target_size
219
+ if diff <= 0:
220
+ return wav, 0
221
+
222
+ start, end = 0, target_size
223
+ if self.random_crop:
224
+ start = np.random.randint(0, diff + 1)
225
+ end = size - diff + start
226
+ return wav[start:end], start
227
+
228
+ def collater(self, samples):
229
+ # target = max(sizes) -> random_crop not used
230
+ # target = max_sample_size -> random_crop used for long
231
+ samples = [s for s in samples if s["source"] is not None]
232
+ if len(samples) == 0:
233
+ return {}
234
+
235
+ audios = [s["source"] for s in samples]
236
+ audio_sizes = [len(s) for s in audios]
237
+ if self.pad_audio:
238
+ audio_size = min(max(audio_sizes), self.max_sample_size)
239
+ else:
240
+ audio_size = min(min(audio_sizes), self.max_sample_size)
241
+ collated_audios, padding_mask, audio_starts = self.collater_audio(
242
+ audios, audio_size
243
+ )
244
+
245
+ targets_by_label = [
246
+ [s["label_list"][i] for s in samples] for i in range(self.num_labels)
247
+ ]
248
+ targets_list, lengths_list, ntokens_list = self.collater_label(
249
+ targets_by_label, audio_size, audio_starts
250
+ )
251
+
252
+ net_input = {"source": collated_audios, "padding_mask": padding_mask}
253
+ batch = {
254
+ "id": torch.LongTensor([s["id"] for s in samples]),
255
+ "net_input": net_input,
256
+ }
257
+
258
+ if self.single_target:
259
+ batch["target_lengths"] = lengths_list[0]
260
+ batch["ntokens"] = ntokens_list[0]
261
+ batch["target"] = targets_list[0]
262
+ else:
263
+ batch["target_lengths_list"] = lengths_list
264
+ batch["ntokens_list"] = ntokens_list
265
+ batch["target_list"] = targets_list
266
+ return batch
267
+
268
+ def collater_audio(self, audios, audio_size):
269
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
270
+ padding_mask = (
271
+ torch.BoolTensor(collated_audios.shape).fill_(False)
272
+ # if self.pad_audio else None
273
+ )
274
+ audio_starts = [0 for _ in audios]
275
+ for i, audio in enumerate(audios):
276
+ diff = len(audio) - audio_size
277
+ if diff == 0:
278
+ collated_audios[i] = audio
279
+ elif diff < 0:
280
+ assert self.pad_audio
281
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
282
+ padding_mask[i, diff:] = True
283
+ else:
284
+ collated_audios[i], audio_starts[i] = self.crop_to_max_size(
285
+ audio, audio_size
286
+ )
287
+ return collated_audios, padding_mask, audio_starts
288
+
289
+ def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
290
+ assert label_rate > 0
291
+ s2f = label_rate / self.sample_rate
292
+ frm_starts = [int(round(s * s2f)) for s in audio_starts]
293
+ frm_size = int(round(audio_size * s2f))
294
+ if not self.pad_audio:
295
+ rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
296
+ frm_size = min(frm_size, *rem_size)
297
+ targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
298
+ logger.debug(f"audio_starts={audio_starts}")
299
+ logger.debug(f"frame_starts={frm_starts}")
300
+ logger.debug(f"frame_size={frm_size}")
301
+
302
+ lengths = torch.LongTensor([len(t) for t in targets])
303
+ ntokens = lengths.sum().item()
304
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
305
+ return targets, lengths, ntokens
306
+
307
+ def collater_seq_label(self, targets, pad):
308
+ lengths = torch.LongTensor([len(t) for t in targets])
309
+ ntokens = lengths.sum().item()
310
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
311
+ return targets, lengths, ntokens
312
+
313
+ def collater_label(self, targets_by_label, audio_size, audio_starts):
314
+ targets_list, lengths_list, ntokens_list = [], [], []
315
+ itr = zip(targets_by_label, self.label_rates, self.pad_list)
316
+ for targets, label_rate, pad in itr:
317
+ if label_rate == -1.0:
318
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
319
+ else:
320
+ targets, lengths, ntokens = self.collater_frm_label(
321
+ targets, audio_size, audio_starts, label_rate, pad
322
+ )
323
+ targets_list.append(targets)
324
+ lengths_list.append(lengths)
325
+ ntokens_list.append(ntokens)
326
+ return targets_list, lengths_list, ntokens_list
327
+
328
+ def num_tokens(self, index):
329
+ return self.size(index)
330
+
331
+ def size(self, index):
332
+ if self.pad_audio:
333
+ return self.sizes[index]
334
+ return min(self.sizes[index], self.max_sample_size)
335
+
336
+ def ordered_indices(self):
337
+ if self.shuffle:
338
+ order = [np.random.permutation(len(self))]
339
+ else:
340
+ order = [np.arange(len(self))]
341
+
342
+ order.append(self.sizes)
343
+ return np.lexsort(order)[::-1]
344
+
345
+ def postprocess(self, wav, cur_sample_rate):
346
+ if wav.dim() == 2:
347
+ wav = wav.mean(-1)
348
+ assert wav.dim() == 1, wav.dim()
349
+
350
+ if cur_sample_rate != self.sample_rate:
351
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
352
+
353
+ if self.normalize:
354
+ with torch.no_grad():
355
+ wav = F.layer_norm(wav, wav.shape)
356
+ return wav
modules/voice_conversion/fairseq/data/audio/multi_modality_dataset.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.
7
+
8
+ import logging
9
+ import math
10
+ from typing import List, Optional, NamedTuple
11
+
12
+ import numpy as np
13
+ from fairseq.data.resampling_dataset import ResamplingDataset
14
+ import torch
15
+ from fairseq.data import (
16
+ ConcatDataset,
17
+ LanguagePairDataset,
18
+ FileAudioDataset,
19
+ data_utils,
20
+ )
21
+ from fairseq.data import FairseqDataset
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class ModalityDatasetItem(NamedTuple):
27
+ datasetname: str
28
+ dataset: any
29
+ max_positions: List[int]
30
+ max_tokens: Optional[int] = None
31
+ max_sentences: Optional[int] = None
32
+
33
+
34
+ def resampling_dataset_present(ds):
35
+ if isinstance(ds, ResamplingDataset):
36
+ return True
37
+ if isinstance(ds, ConcatDataset):
38
+ return any(resampling_dataset_present(d) for d in ds.datasets)
39
+ if hasattr(ds, "dataset"):
40
+ return resampling_dataset_present(ds.dataset)
41
+ return False
42
+
43
+
44
+ # MultiModalityDataset: it concate multiple datasets with different modalities.
45
+ # Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
46
+ # 2) it adds mode to indicate what type of the data samples come from.
47
+ # It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples
48
+ # from the same type of dataset
49
+ # If only one dataset is used, it will perform like the original dataset with mode added
50
+ class MultiModalityDataset(ConcatDataset):
51
+ def __init__(self, datasets: List[ModalityDatasetItem]):
52
+ id_to_mode = []
53
+ dsets = []
54
+ max_tokens = []
55
+ max_sentences = []
56
+ max_positions = []
57
+ for dset in datasets:
58
+ id_to_mode.append(dset.datasetname)
59
+ dsets.append(dset.dataset)
60
+ max_tokens.append(dset.max_tokens)
61
+ max_positions.append(dset.max_positions)
62
+ max_sentences.append(dset.max_sentences)
63
+ weights = [1.0 for s in dsets]
64
+ super().__init__(dsets, weights)
65
+ self.max_tokens = max_tokens
66
+ self.max_positions = max_positions
67
+ self.max_sentences = max_sentences
68
+ self.id_to_mode = id_to_mode
69
+ self.raw_sub_batch_samplers = []
70
+ self._cur_epoch = 0
71
+
72
+ def set_epoch(self, epoch):
73
+ super().set_epoch(epoch)
74
+ self._cur_epoch = epoch
75
+
76
+ def __getitem__(self, idx):
77
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
78
+ sample = self.datasets[dataset_idx][sample_idx]
79
+ return (dataset_idx, sample)
80
+
81
+ def collater(self, samples):
82
+ if len(samples) == 0:
83
+ return {}
84
+ dataset_idx = samples[0][0]
85
+ # make sure all samples in samples are from same dataset
86
+ assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0
87
+ samples = self.datasets[dataset_idx].collater([x[1] for x in samples])
88
+ # add mode
89
+ samples["net_input"]["mode"] = self.id_to_mode[dataset_idx]
90
+
91
+ return samples
92
+
93
+ def size(self, index: int):
94
+ if len(self.datasets) == 1:
95
+ return self.datasets[0].size(index)
96
+ return super().size(index)
97
+
98
+ @property
99
+ def sizes(self):
100
+ if len(self.datasets) == 1:
101
+ return self.datasets[0].sizes
102
+ return super().sizes
103
+
104
+ def ordered_indices(self):
105
+ """
106
+ Returns indices sorted by length. So less padding is needed.
107
+ """
108
+ if len(self.datasets) == 1:
109
+ return self.datasets[0].ordered_indices()
110
+ indices_group = []
111
+ for d_idx, ds in enumerate(self.datasets):
112
+ sample_num = self.cumulative_sizes[d_idx]
113
+ if d_idx > 0:
114
+ sample_num = sample_num - self.cumulative_sizes[d_idx - 1]
115
+ assert sample_num == len(ds)
116
+ indices_group.append(ds.ordered_indices())
117
+ return indices_group
118
+
119
+ def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
120
+ with data_utils.numpy_seed(seed):
121
+ indices = self.ordered_indices()
122
+ for i, ds in enumerate(self.datasets):
123
+ # If we have ResamplingDataset, the same id can correpond to a different
124
+ # sample in the next epoch, so we need to rebuild this at every epoch
125
+ if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present(
126
+ ds
127
+ ):
128
+ logger.info(f"dataset {i} is valid and it is not re-sampled")
129
+ continue
130
+ indices[i] = ds.filter_indices_by_size(
131
+ indices[i],
132
+ self.max_positions[i],
133
+ )[0]
134
+ sub_batch_sampler = ds.batch_by_size(
135
+ indices[i],
136
+ max_tokens=self.max_tokens[i],
137
+ max_sentences=self.max_sentences[i],
138
+ required_batch_size_multiple=required_batch_size_multiple,
139
+ )
140
+ if i < len(self.raw_sub_batch_samplers):
141
+ self.raw_sub_batch_samplers[i] = sub_batch_sampler
142
+ else:
143
+ self.raw_sub_batch_samplers.append(sub_batch_sampler)
144
+
145
+ def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed):
146
+ self.get_raw_batch_samplers(required_batch_size_multiple, seed)
147
+ batch_samplers = []
148
+ for i, _ in enumerate(self.datasets):
149
+ if i > 0:
150
+ sub_batch_sampler = [
151
+ [y + self.cumulative_sizes[i - 1] for y in x]
152
+ for x in self.raw_sub_batch_samplers[i]
153
+ ]
154
+ else:
155
+ sub_batch_sampler = list(self.raw_sub_batch_samplers[i])
156
+ smp_r = mult_ratios[i]
157
+ if smp_r != 1:
158
+ is_increase = "increased" if smp_r > 1 else "decreased"
159
+ logger.info(
160
+ "number of batch for the dataset {} is {} from {} to {}".format(
161
+ self.id_to_mode[i],
162
+ is_increase,
163
+ len(sub_batch_sampler),
164
+ int(len(sub_batch_sampler) * smp_r),
165
+ )
166
+ )
167
+ mul_samplers = []
168
+ for _ in range(math.floor(smp_r)):
169
+ mul_samplers = mul_samplers + sub_batch_sampler
170
+ if math.floor(smp_r) != smp_r:
171
+ with data_utils.numpy_seed(seed + self._cur_epoch):
172
+ np.random.shuffle(sub_batch_sampler)
173
+ smp_num = int(
174
+ (smp_r - math.floor(smp_r)) * len(sub_batch_sampler)
175
+ )
176
+ mul_samplers = mul_samplers + sub_batch_sampler[:smp_num]
177
+ sub_batch_sampler = mul_samplers
178
+ else:
179
+ logger.info(
180
+ "dataset {} batch number is {} ".format(
181
+ self.id_to_mode[i], len(sub_batch_sampler)
182
+ )
183
+ )
184
+ batch_samplers.append(sub_batch_sampler)
185
+
186
+ return batch_samplers
187
+
188
+
189
+ class LangPairMaskDataset(FairseqDataset):
190
+ def __init__(
191
+ self,
192
+ dataset: LanguagePairDataset,
193
+ src_eos: int,
194
+ src_bos: Optional[int] = None,
195
+ noise_id: Optional[int] = -1,
196
+ mask_ratio: Optional[float] = 0,
197
+ mask_type: Optional[str] = "random",
198
+ ):
199
+ self.dataset = dataset
200
+ self.src_eos = src_eos
201
+ self.src_bos = src_bos
202
+ self.noise_id = noise_id
203
+ self.mask_ratio = mask_ratio
204
+ self.mask_type = mask_type
205
+ assert mask_type in ("random", "tail")
206
+
207
+ @property
208
+ def src_sizes(self):
209
+ return self.dataset.src_sizes
210
+
211
+ @property
212
+ def tgt_sizes(self):
213
+ return self.dataset.tgt_sizes
214
+
215
+ @property
216
+ def sizes(self):
217
+ # dataset.sizes can be a dynamically computed sizes:
218
+ return self.dataset.sizes
219
+
220
+ def get_batch_shapes(self):
221
+ if hasattr(self.dataset, "get_batch_shapes"):
222
+ return self.dataset.get_batch_shapes()
223
+ return self.dataset.buckets
224
+
225
+ def num_tokens_vec(self, indices):
226
+ return self.dataset.num_tokens_vec(indices)
227
+
228
+ def __len__(self):
229
+ return len(self.dataset)
230
+
231
+ def num_tokens(self, index):
232
+ return self.dataset.num_tokens(index)
233
+
234
+ def size(self, index):
235
+ return self.dataset.size(index)
236
+
237
+ def ordered_indices(self):
238
+ return self.dataset.ordered_indices()
239
+
240
+ @property
241
+ def supports_prefetch(self):
242
+ return getattr(self.dataset, "supports_prefetch", False)
243
+
244
+ def prefetch(self, indices):
245
+ return self.dataset.prefetch(indices)
246
+
247
+ def mask_src_tokens(self, sample):
248
+ src_item = sample["source"]
249
+ mask = None
250
+ if self.mask_type == "random":
251
+ mask = torch.rand(len(src_item)).le(self.mask_ratio)
252
+ else:
253
+ mask = torch.ones(len(src_item))
254
+ mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0
255
+ mask = mask.eq(1)
256
+ if src_item[0] == self.src_bos:
257
+ mask[0] = False
258
+ if src_item[-1] == self.src_eos:
259
+ mask[-1] = False
260
+ mask_src_item = src_item.masked_fill(mask, self.noise_id)
261
+ smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]}
262
+ return smp
263
+
264
+ def __getitem__(self, index):
265
+ sample = self.dataset[index]
266
+ if self.mask_ratio > 0:
267
+ sample = self.mask_src_tokens(sample)
268
+ return sample
269
+
270
+ def collater(self, samples, pad_to_length=None):
271
+ return self.dataset.collater(samples, pad_to_length)
272
+
273
+
274
+ class FileAudioDatasetWrapper(FileAudioDataset):
275
+ def collater(self, samples):
276
+ samples = super().collater(samples)
277
+ if len(samples) == 0:
278
+ return {}
279
+ samples["net_input"]["src_tokens"] = samples["net_input"]["source"]
280
+ samples["net_input"]["prev_output_tokens"] = None
281
+ del samples["net_input"]["source"]
282
+ samples["net_input"]["src_lengths"] = None
283
+ samples["net_input"]["alignment"] = None
284
+ return samples
modules/voice_conversion/fairseq/data/audio/raw_audio_dataset.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import logging
8
+ import os
9
+ import sys
10
+ import io
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from .. import FairseqDataset
17
+ from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes
18
+ from fairseq.data.audio.audio_utils import (
19
+ parse_path,
20
+ read_from_stored_zip,
21
+ is_sf_audio_data,
22
+ )
23
+ from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class RawAudioDataset(FairseqDataset):
30
+ def __init__(
31
+ self,
32
+ sample_rate,
33
+ max_sample_size=None,
34
+ min_sample_size=0,
35
+ shuffle=True,
36
+ pad=False,
37
+ normalize=False,
38
+ compute_mask_indices=False,
39
+ **mask_compute_kwargs,
40
+ ):
41
+ super().__init__()
42
+
43
+ self.sample_rate = sample_rate
44
+ self.sizes = []
45
+ self.max_sample_size = (
46
+ max_sample_size if max_sample_size is not None else sys.maxsize
47
+ )
48
+ self.min_sample_size = min_sample_size
49
+ self.pad = pad
50
+ self.shuffle = shuffle
51
+ self.normalize = normalize
52
+ self.compute_mask_indices = compute_mask_indices
53
+ if self.compute_mask_indices:
54
+ self.mask_compute_kwargs = mask_compute_kwargs
55
+ self._features_size_map = {}
56
+ self._C = mask_compute_kwargs["encoder_embed_dim"]
57
+ self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"])
58
+
59
+ def __getitem__(self, index):
60
+ raise NotImplementedError()
61
+
62
+ def __len__(self):
63
+ return len(self.sizes)
64
+
65
+ def postprocess(self, feats, curr_sample_rate):
66
+ if feats.dim() == 2:
67
+ feats = feats.mean(-1)
68
+
69
+ if curr_sample_rate != self.sample_rate:
70
+ raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}")
71
+
72
+ assert feats.dim() == 1, feats.dim()
73
+
74
+ if self.normalize:
75
+ with torch.no_grad():
76
+ feats = F.layer_norm(feats, feats.shape)
77
+ return feats
78
+
79
+ def crop_to_max_size(self, wav, target_size):
80
+ size = len(wav)
81
+ diff = size - target_size
82
+ if diff <= 0:
83
+ return wav
84
+
85
+ start = np.random.randint(0, diff + 1)
86
+ end = size - diff + start
87
+ return wav[start:end]
88
+
89
+ def _compute_mask_indices(self, dims, padding_mask):
90
+ B, T, C = dims
91
+ mask_indices, mask_channel_indices = None, None
92
+ if self.mask_compute_kwargs["mask_prob"] > 0:
93
+ mask_indices = compute_mask_indices(
94
+ (B, T),
95
+ padding_mask,
96
+ self.mask_compute_kwargs["mask_prob"],
97
+ self.mask_compute_kwargs["mask_length"],
98
+ self.mask_compute_kwargs["mask_selection"],
99
+ self.mask_compute_kwargs["mask_other"],
100
+ min_masks=2,
101
+ no_overlap=self.mask_compute_kwargs["no_mask_overlap"],
102
+ min_space=self.mask_compute_kwargs["mask_min_space"],
103
+ )
104
+ mask_indices = torch.from_numpy(mask_indices)
105
+ if self.mask_compute_kwargs["mask_channel_prob"] > 0:
106
+ mask_channel_indices = compute_mask_indices(
107
+ (B, C),
108
+ None,
109
+ self.mask_compute_kwargs["mask_channel_prob"],
110
+ self.mask_compute_kwargs["mask_channel_length"],
111
+ self.mask_compute_kwargs["mask_channel_selection"],
112
+ self.mask_compute_kwargs["mask_channel_other"],
113
+ no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"],
114
+ min_space=self.mask_compute_kwargs["mask_channel_min_space"],
115
+ )
116
+ mask_channel_indices = (
117
+ torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1)
118
+ )
119
+
120
+ return mask_indices, mask_channel_indices
121
+
122
+ @staticmethod
123
+ def _bucket_tensor(tensor, num_pad, value):
124
+ return F.pad(tensor, (0, num_pad), value=value)
125
+
126
+ def collater(self, samples):
127
+ samples = [s for s in samples if s["source"] is not None]
128
+ if len(samples) == 0:
129
+ return {}
130
+
131
+ sources = [s["source"] for s in samples]
132
+ sizes = [len(s) for s in sources]
133
+
134
+ if self.pad:
135
+ target_size = min(max(sizes), self.max_sample_size)
136
+ else:
137
+ target_size = min(min(sizes), self.max_sample_size)
138
+
139
+ collated_sources = sources[0].new_zeros(len(sources), target_size)
140
+ padding_mask = (
141
+ torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
142
+ )
143
+ for i, (source, size) in enumerate(zip(sources, sizes)):
144
+ diff = size - target_size
145
+ if diff == 0:
146
+ collated_sources[i] = source
147
+ elif diff < 0:
148
+ assert self.pad
149
+ collated_sources[i] = torch.cat(
150
+ [source, source.new_full((-diff,), 0.0)]
151
+ )
152
+ padding_mask[i, diff:] = True
153
+ else:
154
+ collated_sources[i] = self.crop_to_max_size(source, target_size)
155
+
156
+ input = {"source": collated_sources}
157
+ out = {"id": torch.LongTensor([s["id"] for s in samples])}
158
+ if self.pad:
159
+ input["padding_mask"] = padding_mask
160
+
161
+ if hasattr(self, "num_buckets") and self.num_buckets > 0:
162
+ assert self.pad, "Cannot bucket without padding first."
163
+ bucket = max(self._bucketed_sizes[s["id"]] for s in samples)
164
+ num_pad = bucket - collated_sources.size(-1)
165
+ if num_pad:
166
+ input["source"] = self._bucket_tensor(collated_sources, num_pad, 0)
167
+ input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True)
168
+
169
+ if self.compute_mask_indices:
170
+ B = input["source"].size(0)
171
+ T = self._get_mask_indices_dims(input["source"].size(-1))
172
+ padding_mask_reshaped = input["padding_mask"].clone()
173
+ extra = padding_mask_reshaped.size(1) % T
174
+ if extra > 0:
175
+ padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
176
+ padding_mask_reshaped = padding_mask_reshaped.view(
177
+ padding_mask_reshaped.size(0), T, -1
178
+ )
179
+ padding_mask_reshaped = padding_mask_reshaped.all(-1)
180
+ input["padding_count"] = padding_mask_reshaped.sum(-1).max().item()
181
+ mask_indices, mask_channel_indices = self._compute_mask_indices(
182
+ (B, T, self._C),
183
+ padding_mask_reshaped,
184
+ )
185
+ input["mask_indices"] = mask_indices
186
+ input["mask_channel_indices"] = mask_channel_indices
187
+ out["sample_size"] = mask_indices.sum().item()
188
+
189
+ out["net_input"] = input
190
+ return out
191
+
192
+ def _get_mask_indices_dims(self, size, padding=0, dilation=1):
193
+ if size not in self._features_size_map:
194
+ L_in = size
195
+ for (_, kernel_size, stride) in self._conv_feature_layers:
196
+ L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
197
+ L_out = 1 + L_out // stride
198
+ L_in = L_out
199
+ self._features_size_map[size] = L_out
200
+ return self._features_size_map[size]
201
+
202
+ def num_tokens(self, index):
203
+ return self.size(index)
204
+
205
+ def size(self, index):
206
+ """Return an example's size as a float or tuple. This value is used when
207
+ filtering a dataset with ``--max-positions``."""
208
+ if self.pad:
209
+ return self.sizes[index]
210
+ return min(self.sizes[index], self.max_sample_size)
211
+
212
+ def ordered_indices(self):
213
+ """Return an ordered list of indices. Batches will be constructed based
214
+ on this order."""
215
+
216
+ if self.shuffle:
217
+ order = [np.random.permutation(len(self))]
218
+ order.append(
219
+ np.minimum(
220
+ np.array(self.sizes),
221
+ self.max_sample_size,
222
+ )
223
+ )
224
+ return np.lexsort(order)[::-1]
225
+ else:
226
+ return np.arange(len(self))
227
+
228
+ def set_bucket_info(self, num_buckets):
229
+ self.num_buckets = num_buckets
230
+ if self.num_buckets > 0:
231
+ self._collated_sizes = np.minimum(
232
+ np.array(self.sizes),
233
+ self.max_sample_size,
234
+ )
235
+ self.buckets = get_buckets(
236
+ self._collated_sizes,
237
+ self.num_buckets,
238
+ )
239
+ self._bucketed_sizes = get_bucketed_sizes(
240
+ self._collated_sizes, self.buckets
241
+ )
242
+ logger.info(
243
+ f"{len(self.buckets)} bucket(s) for the audio dataset: "
244
+ f"{self.buckets}"
245
+ )
246
+
247
+
248
+ class FileAudioDataset(RawAudioDataset):
249
+ def __init__(
250
+ self,
251
+ manifest_path,
252
+ sample_rate,
253
+ max_sample_size=None,
254
+ min_sample_size=0,
255
+ shuffle=True,
256
+ pad=False,
257
+ normalize=False,
258
+ num_buckets=0,
259
+ compute_mask_indices=False,
260
+ text_compression_level=TextCompressionLevel.none,
261
+ **mask_compute_kwargs,
262
+ ):
263
+ super().__init__(
264
+ sample_rate=sample_rate,
265
+ max_sample_size=max_sample_size,
266
+ min_sample_size=min_sample_size,
267
+ shuffle=shuffle,
268
+ pad=pad,
269
+ normalize=normalize,
270
+ compute_mask_indices=compute_mask_indices,
271
+ **mask_compute_kwargs,
272
+ )
273
+
274
+ self.text_compressor = TextCompressor(level=text_compression_level)
275
+
276
+ skipped = 0
277
+ self.fnames = []
278
+ sizes = []
279
+ self.skipped_indices = set()
280
+
281
+ with open(manifest_path, "r") as f:
282
+ self.root_dir = f.readline().strip()
283
+ for i, line in enumerate(f):
284
+ items = line.strip().split("\t")
285
+ assert len(items) == 2, line
286
+ sz = int(items[1])
287
+ if min_sample_size is not None and sz < min_sample_size:
288
+ skipped += 1
289
+ self.skipped_indices.add(i)
290
+ continue
291
+ self.fnames.append(self.text_compressor.compress(items[0]))
292
+ sizes.append(sz)
293
+ logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
294
+
295
+ self.sizes = np.array(sizes, dtype=np.int64)
296
+
297
+ try:
298
+ import pyarrow
299
+
300
+ self.fnames = pyarrow.array(self.fnames)
301
+ except:
302
+ logger.debug(
303
+ "Could not create a pyarrow array. Please install pyarrow for better performance"
304
+ )
305
+ pass
306
+
307
+ self.set_bucket_info(num_buckets)
308
+
309
+ def __getitem__(self, index):
310
+ import soundfile as sf
311
+
312
+ fn = self.fnames[index]
313
+ fn = fn if isinstance(self.fnames, list) else fn.as_py()
314
+ fn = self.text_compressor.decompress(fn)
315
+ path_or_fp = os.path.join(self.root_dir, fn)
316
+ _path, slice_ptr = parse_path(path_or_fp)
317
+ if len(slice_ptr) == 2:
318
+ byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
319
+ assert is_sf_audio_data(byte_data)
320
+ path_or_fp = io.BytesIO(byte_data)
321
+
322
+ wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32")
323
+
324
+ feats = torch.from_numpy(wav).float()
325
+ feats = self.postprocess(feats, curr_sample_rate)
326
+ return {"id": index, "source": feats}
327
+
328
+
329
+ class BinarizedAudioDataset(RawAudioDataset):
330
+ def __init__(
331
+ self,
332
+ data_dir,
333
+ split,
334
+ sample_rate,
335
+ max_sample_size=None,
336
+ min_sample_size=0,
337
+ shuffle=True,
338
+ pad=False,
339
+ normalize=False,
340
+ num_buckets=0,
341
+ compute_mask_indices=False,
342
+ **mask_compute_kwargs,
343
+ ):
344
+ super().__init__(
345
+ sample_rate=sample_rate,
346
+ max_sample_size=max_sample_size,
347
+ min_sample_size=min_sample_size,
348
+ shuffle=shuffle,
349
+ pad=pad,
350
+ normalize=normalize,
351
+ compute_mask_indices=compute_mask_indices,
352
+ **mask_compute_kwargs,
353
+ )
354
+
355
+ from fairseq.data import data_utils, Dictionary
356
+
357
+ self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt"))
358
+
359
+ root_path = os.path.join(data_dir, f"{split}.root")
360
+ if os.path.exists(root_path):
361
+ with open(root_path, "r") as f:
362
+ self.root_dir = next(f).strip()
363
+ else:
364
+ self.root_dir = None
365
+
366
+ fnames_path = os.path.join(data_dir, split)
367
+ self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict)
368
+ lengths_path = os.path.join(data_dir, f"{split}.lengths")
369
+
370
+ with open(lengths_path, "r") as f:
371
+ for line in f:
372
+ sz = int(line.rstrip())
373
+ assert (
374
+ sz >= min_sample_size
375
+ ), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}"
376
+ self.sizes.append(sz)
377
+
378
+ self.sizes = np.array(self.sizes, dtype=np.int64)
379
+
380
+ self.set_bucket_info(num_buckets)
381
+ logger.info(f"loaded {len(self.fnames)} samples")
382
+
383
+ def __getitem__(self, index):
384
+ import soundfile as sf
385
+
386
+ fname = self.fnames_dict.string(self.fnames[index], separator="")
387
+ if self.root_dir:
388
+ fname = os.path.join(self.root_dir, fname)
389
+
390
+ wav, curr_sample_rate = sf.read(fname)
391
+ feats = torch.from_numpy(wav).float()
392
+ feats = self.postprocess(feats, curr_sample_rate)
393
+ return {"id": index, "source": feats}
modules/voice_conversion/fairseq/data/audio/speech_to_speech_dataset.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple
10
+
11
+ import torch
12
+
13
+ from fairseq.data import ConcatDataset, Dictionary
14
+ from fairseq.data import data_utils as fairseq_data_utils
15
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
16
+ from fairseq.data.audio.data_cfg import S2SDataConfig
17
+ from fairseq.data.audio.speech_to_text_dataset import (
18
+ SpeechToTextDataset,
19
+ SpeechToTextDatasetCreator,
20
+ TextTargetMultitaskData,
21
+ _collate_frames,
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @dataclass
28
+ class SpeechToSpeechDatasetItem(object):
29
+ index: int
30
+ source: torch.Tensor
31
+ target: Optional[torch.Tensor] = None
32
+ target_speaker: Optional[torch.Tensor] = None
33
+ tgt_lang_tag: Optional[int] = None
34
+
35
+
36
+ class SpeechToSpeechDataset(SpeechToTextDataset):
37
+ def __init__(
38
+ self,
39
+ split: str,
40
+ is_train_split: bool,
41
+ data_cfg: S2SDataConfig,
42
+ src_audio_paths: List[str],
43
+ src_n_frames: List[int],
44
+ tgt_audio_paths: List[str],
45
+ tgt_n_frames: List[int],
46
+ src_langs: Optional[List[str]] = None,
47
+ tgt_langs: Optional[List[str]] = None,
48
+ ids: Optional[List[str]] = None,
49
+ target_is_code: bool = False,
50
+ tgt_dict: Dictionary = None,
51
+ n_frames_per_step: int = 1,
52
+ ):
53
+ tgt_texts = tgt_audio_paths if target_is_code else None
54
+ super().__init__(
55
+ split=split,
56
+ is_train_split=is_train_split,
57
+ cfg=data_cfg,
58
+ audio_paths=src_audio_paths,
59
+ n_frames=src_n_frames,
60
+ ids=ids,
61
+ tgt_dict=tgt_dict,
62
+ tgt_texts=tgt_texts,
63
+ src_langs=src_langs,
64
+ tgt_langs=tgt_langs,
65
+ n_frames_per_step=n_frames_per_step,
66
+ )
67
+
68
+ self.tgt_audio_paths = tgt_audio_paths
69
+ self.tgt_lens = [t // self.n_frames_per_step for t in tgt_n_frames]
70
+
71
+ assert not target_is_code or tgt_dict is not None
72
+ self.target_is_code = target_is_code
73
+
74
+ assert len(tgt_audio_paths) == self.n_samples
75
+ assert len(tgt_n_frames) == self.n_samples
76
+
77
+ self.tgt_speakers = None
78
+ if self.cfg.target_speaker_embed:
79
+ samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
80
+ self.cfg.target_speaker_embed, split
81
+ )
82
+ spk_emb_dict = {s["id"]: s["speaker_embed"] for s in samples}
83
+ self.tgt_speakers = [spk_emb_dict[id] for id in self.ids]
84
+ assert len(self.tgt_speakers) == self.n_samples
85
+
86
+ logger.info(self.__repr__())
87
+
88
+ def pack_units(self, input: torch.Tensor) -> torch.Tensor:
89
+ if self.n_frames_per_step <= 1:
90
+ return input
91
+
92
+ offset = 4
93
+ vocab_size = (
94
+ len(self.tgt_dict) - offset
95
+ ) # remove offset from <bos>, <pad>, <eos>, <unk>, which is specific to fairseq dictionary
96
+
97
+ assert input.dim() == 1
98
+ stacked_input = (
99
+ input[:-1].view(-1, self.n_frames_per_step) - offset
100
+ ) # remove <eos>
101
+ scale = [
102
+ pow(vocab_size, self.n_frames_per_step - 1 - i)
103
+ for i in range(self.n_frames_per_step)
104
+ ]
105
+ scale = torch.LongTensor(scale).squeeze(0)
106
+ res = input.new((len(input) - 1) // self.n_frames_per_step + 1).fill_(input[-1])
107
+ res[:-1] = (stacked_input * scale).sum(dim=1) + offset
108
+
109
+ return res
110
+
111
+ def __getitem__(self, index: int) -> SpeechToSpeechDatasetItem:
112
+ source = self._get_source_audio(index)
113
+
114
+ tgt_lang_tag = None
115
+ if self.cfg.prepend_tgt_lang_tag_as_bos:
116
+ # prepend_tgt_lang_tag_as_bos: put tgt_lang_tag as bos of target
117
+ tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
118
+
119
+ if not self.target_is_code:
120
+ target = get_features_or_waveform(self.tgt_audio_paths[index])
121
+ target = torch.from_numpy(target).float()
122
+ target = self.pack_frames(target)
123
+ else:
124
+ target = self.tgt_dict.encode_line(
125
+ self.tgt_audio_paths[index],
126
+ add_if_not_exist=False,
127
+ append_eos=True,
128
+ ).long()
129
+ if self.n_frames_per_step > 1:
130
+ n_tgt_frame = target.size(0) - 1 # exclude <eos>
131
+ keep_n_tgt_frame = n_tgt_frame - n_tgt_frame % self.n_frames_per_step
132
+ target = torch.cat(
133
+ (
134
+ target[:keep_n_tgt_frame],
135
+ target.new_full((1,), self.tgt_dict.eos()),
136
+ ),
137
+ dim=0,
138
+ )
139
+
140
+ if self.tgt_speakers:
141
+ tgt_spk = get_features_or_waveform(self.tgt_speakers[index])
142
+ tgt_spk = torch.from_numpy(tgt_spk).float()
143
+ else:
144
+ tgt_spk = torch.FloatTensor([])
145
+
146
+ return SpeechToSpeechDatasetItem(
147
+ index=index,
148
+ source=source,
149
+ target=target,
150
+ target_speaker=tgt_spk,
151
+ tgt_lang_tag=tgt_lang_tag,
152
+ )
153
+
154
+ def _collate_target(self, samples: List[SpeechToSpeechDatasetItem]) -> torch.Tensor:
155
+ if self.target_is_code:
156
+ target = fairseq_data_utils.collate_tokens(
157
+ [x.target for x in samples],
158
+ self.tgt_dict.pad(),
159
+ self.tgt_dict.eos(),
160
+ left_pad=False,
161
+ move_eos_to_beginning=False,
162
+ )
163
+ # convert stacked units to a single id
164
+ pack_targets = [self.pack_units(x.target) for x in samples]
165
+ prev_output_tokens = fairseq_data_utils.collate_tokens(
166
+ pack_targets,
167
+ self.tgt_dict.pad(),
168
+ self.tgt_dict.eos(),
169
+ left_pad=False,
170
+ move_eos_to_beginning=True,
171
+ )
172
+ target_lengths = torch.tensor(
173
+ [x.size(0) for x in pack_targets], dtype=torch.long
174
+ )
175
+ else:
176
+ target = _collate_frames([x.target for x in samples], is_audio_input=False)
177
+ bsz, _, d = target.size()
178
+ prev_output_tokens = torch.cat(
179
+ (target.new_full((bsz, 1, d), 0.0), target[:, :-1, :]), dim=1
180
+ )
181
+ target_lengths = torch.tensor(
182
+ [x.target.size(0) for x in samples], dtype=torch.long
183
+ )
184
+
185
+ return target, prev_output_tokens, target_lengths
186
+
187
+ def collater(
188
+ self, samples: List[SpeechToSpeechDatasetItem], return_order: bool = False
189
+ ) -> Dict:
190
+ if len(samples) == 0:
191
+ return {}
192
+ indices = torch.tensor([x.index for x in samples], dtype=torch.long)
193
+ frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input)
194
+ # sort samples by descending number of frames
195
+ n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long)
196
+ n_frames, order = n_frames.sort(descending=True)
197
+ indices = indices.index_select(0, order)
198
+ frames = frames.index_select(0, order)
199
+
200
+ target, prev_output_tokens, target_lengths = self._collate_target(samples)
201
+ target = target.index_select(0, order)
202
+ target_lengths = target_lengths.index_select(0, order)
203
+ prev_output_tokens = prev_output_tokens.index_select(0, order)
204
+ ntokens = sum(x.target.size(0) for x in samples)
205
+
206
+ tgt_speakers = None
207
+ if self.cfg.target_speaker_embed:
208
+ tgt_speakers = _collate_frames(
209
+ [x.target_speaker for x in samples], is_audio_input=True
210
+ ).index_select(0, order)
211
+
212
+ net_input = {
213
+ "src_tokens": frames,
214
+ "src_lengths": n_frames,
215
+ "prev_output_tokens": prev_output_tokens,
216
+ "tgt_speaker": tgt_speakers, # TODO: unify "speaker" and "tgt_speaker"
217
+ }
218
+ if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
219
+ for i in range(len(samples)):
220
+ net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
221
+ out = {
222
+ "id": indices,
223
+ "net_input": net_input,
224
+ "speaker": tgt_speakers, # to support Tacotron2 loss for speech-to-spectrogram model
225
+ "target": target,
226
+ "target_lengths": target_lengths,
227
+ "ntokens": ntokens,
228
+ "nsentences": len(samples),
229
+ }
230
+ if return_order:
231
+ out["order"] = order
232
+ return out
233
+
234
+
235
+ class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset):
236
+ def __init__(self, **kwargs):
237
+ super().__init__(**kwargs)
238
+ self.multitask_data = {}
239
+
240
+ def add_multitask_dataset(self, task_name, task_data):
241
+ self.multitask_data[task_name] = task_data
242
+
243
+ def __getitem__(
244
+ self, index: int
245
+ ) -> Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]:
246
+ s2s_data = super().__getitem__(index)
247
+
248
+ multitask_target = {}
249
+ sample_id = self.ids[index]
250
+ tgt_lang = self.tgt_langs[index]
251
+ for task_name, task_dataset in self.multitask_data.items():
252
+ multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
253
+
254
+ return s2s_data, multitask_target
255
+
256
+ def collater(
257
+ self, samples: List[Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]]
258
+ ) -> Dict:
259
+ if len(samples) == 0:
260
+ return {}
261
+
262
+ out = super().collater([s for s, _ in samples], return_order=True)
263
+ order = out["order"]
264
+ del out["order"]
265
+
266
+ for task_name, task_dataset in self.multitask_data.items():
267
+ if "multitask" not in out:
268
+ out["multitask"] = {}
269
+ d = [s[task_name] for _, s in samples]
270
+ task_target = task_dataset.collater(d)
271
+ out["multitask"][task_name] = {
272
+ "target": task_target["target"].index_select(0, order),
273
+ "target_lengths": task_target["target_lengths"].index_select(0, order),
274
+ "ntokens": task_target["ntokens"],
275
+ }
276
+ out["multitask"][task_name]["net_input"] = {
277
+ "prev_output_tokens": task_target["prev_output_tokens"].index_select(
278
+ 0, order
279
+ ),
280
+ }
281
+
282
+ return out
283
+
284
+
285
+ class SpeechToSpeechDatasetCreator(object):
286
+ # mandatory columns
287
+ KEY_ID, KEY_SRC_AUDIO, KEY_SRC_N_FRAMES = "id", "src_audio", "src_n_frames"
288
+ KEY_TGT_AUDIO, KEY_TGT_N_FRAMES = "tgt_audio", "tgt_n_frames"
289
+ # optional columns
290
+ KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
291
+ # default values
292
+ DEFAULT_LANG = ""
293
+
294
+ @classmethod
295
+ def _from_list(
296
+ cls,
297
+ split_name: str,
298
+ is_train_split,
299
+ samples: List[Dict],
300
+ data_cfg: S2SDataConfig,
301
+ target_is_code: bool = False,
302
+ tgt_dict: Dictionary = None,
303
+ n_frames_per_step: int = 1,
304
+ multitask: Optional[Dict] = None,
305
+ ) -> SpeechToSpeechDataset:
306
+ audio_root = Path(data_cfg.audio_root)
307
+ ids = [s[cls.KEY_ID] for s in samples]
308
+ src_audio_paths = [
309
+ (audio_root / s[cls.KEY_SRC_AUDIO]).as_posix() for s in samples
310
+ ]
311
+ tgt_audio_paths = [
312
+ s[cls.KEY_TGT_AUDIO]
313
+ if target_is_code
314
+ else (audio_root / s[cls.KEY_TGT_AUDIO]).as_posix()
315
+ for s in samples
316
+ ]
317
+ src_n_frames = [int(s[cls.KEY_SRC_N_FRAMES]) for s in samples]
318
+ tgt_n_frames = [int(s[cls.KEY_TGT_N_FRAMES]) for s in samples]
319
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
320
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
321
+
322
+ has_multitask = multitask is not None and len(multitask.keys()) > 0
323
+ dataset_cls = (
324
+ SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset
325
+ )
326
+
327
+ ds = dataset_cls(
328
+ split=split_name,
329
+ is_train_split=is_train_split,
330
+ data_cfg=data_cfg,
331
+ src_audio_paths=src_audio_paths,
332
+ src_n_frames=src_n_frames,
333
+ tgt_audio_paths=tgt_audio_paths,
334
+ tgt_n_frames=tgt_n_frames,
335
+ src_langs=src_langs,
336
+ tgt_langs=tgt_langs,
337
+ ids=ids,
338
+ target_is_code=target_is_code,
339
+ tgt_dict=tgt_dict,
340
+ n_frames_per_step=n_frames_per_step,
341
+ )
342
+
343
+ if has_multitask:
344
+ for task_name, task_obj in multitask.items():
345
+ task_data = TextTargetMultitaskData(
346
+ task_obj.args, split_name, task_obj.target_dictionary
347
+ )
348
+ ds.add_multitask_dataset(task_name, task_data)
349
+ return ds
350
+
351
+ @classmethod
352
+ def from_tsv(
353
+ cls,
354
+ root: str,
355
+ data_cfg: S2SDataConfig,
356
+ splits: str,
357
+ is_train_split: bool,
358
+ epoch: int,
359
+ seed: int,
360
+ target_is_code: bool = False,
361
+ tgt_dict: Dictionary = None,
362
+ n_frames_per_step: int = 1,
363
+ multitask: Optional[Dict] = None,
364
+ ) -> SpeechToSpeechDataset:
365
+ datasets = []
366
+ for split in splits.split(","):
367
+ samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
368
+ ds = cls._from_list(
369
+ split_name=split,
370
+ is_train_split=is_train_split,
371
+ samples=samples,
372
+ data_cfg=data_cfg,
373
+ target_is_code=target_is_code,
374
+ tgt_dict=tgt_dict,
375
+ n_frames_per_step=n_frames_per_step,
376
+ multitask=multitask,
377
+ )
378
+ datasets.append(ds)
379
+ return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
modules/voice_conversion/fairseq/data/audio/speech_to_text_dataset.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import csv
7
+ import logging
8
+ import re
9
+ from argparse import Namespace
10
+ from collections import defaultdict
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+ from typing import Dict, List, Optional, Tuple, Union
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+ from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset
20
+ from fairseq.data import data_utils as fairseq_data_utils
21
+ from fairseq.data import encoders
22
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
23
+ from fairseq.data.audio.data_cfg import S2TDataConfig
24
+ from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform
25
+ from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment
26
+ from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import (
27
+ NoisyOverlapAugment,
28
+ )
29
+ from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
30
+ from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def _collate_frames(
36
+ frames: List[torch.Tensor], is_audio_input: bool = False
37
+ ) -> torch.Tensor:
38
+ """
39
+ Convert a list of 2D frames into a padded 3D tensor
40
+ Args:
41
+ frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
42
+ length of i-th frame and f_dim is static dimension of features
43
+ Returns:
44
+ 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
45
+ """
46
+ max_len = max(frame.size(0) for frame in frames)
47
+ if is_audio_input:
48
+ out = frames[0].new_zeros((len(frames), max_len))
49
+ else:
50
+ out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
51
+ for i, v in enumerate(frames):
52
+ out[i, : v.size(0)] = v
53
+ return out
54
+
55
+
56
+ def _is_int_or_np_int(n):
57
+ return isinstance(n, int) or (
58
+ isinstance(n, np.generic) and isinstance(n.item(), int)
59
+ )
60
+
61
+
62
+ @dataclass
63
+ class SpeechToTextDatasetItem(object):
64
+ index: int
65
+ source: torch.Tensor
66
+ target: Optional[torch.Tensor] = None
67
+ speaker_id: Optional[int] = None
68
+
69
+
70
+ class SpeechToTextDataset(FairseqDataset):
71
+ LANG_TAG_TEMPLATE = "<lang:{}>"
72
+
73
+ def __init__(
74
+ self,
75
+ split: str,
76
+ is_train_split: bool,
77
+ cfg: S2TDataConfig,
78
+ audio_paths: List[str],
79
+ n_frames: List[int],
80
+ src_texts: Optional[List[str]] = None,
81
+ tgt_texts: Optional[List[str]] = None,
82
+ speakers: Optional[List[str]] = None,
83
+ src_langs: Optional[List[str]] = None,
84
+ tgt_langs: Optional[List[str]] = None,
85
+ ids: Optional[List[str]] = None,
86
+ tgt_dict: Optional[Dictionary] = None,
87
+ pre_tokenizer=None,
88
+ bpe_tokenizer=None,
89
+ n_frames_per_step=1,
90
+ speaker_to_id=None,
91
+ append_eos=True,
92
+ ):
93
+ self.split, self.is_train_split = split, is_train_split
94
+ self.cfg = cfg
95
+ self.audio_paths, self.n_frames = audio_paths, n_frames
96
+ self.n_samples = len(audio_paths)
97
+ assert len(n_frames) == self.n_samples > 0
98
+ assert src_texts is None or len(src_texts) == self.n_samples
99
+ assert tgt_texts is None or len(tgt_texts) == self.n_samples
100
+ assert speakers is None or len(speakers) == self.n_samples
101
+ assert src_langs is None or len(src_langs) == self.n_samples
102
+ assert tgt_langs is None or len(tgt_langs) == self.n_samples
103
+ assert ids is None or len(ids) == self.n_samples
104
+ assert (tgt_dict is None and tgt_texts is None) or (
105
+ tgt_dict is not None and tgt_texts is not None
106
+ )
107
+ self.src_texts, self.tgt_texts = src_texts, tgt_texts
108
+ self.src_langs, self.tgt_langs = src_langs, tgt_langs
109
+ self.speakers = speakers
110
+ self.tgt_dict = tgt_dict
111
+ self.check_tgt_lang_tag()
112
+ self.ids = ids
113
+ self.shuffle = cfg.shuffle if is_train_split else False
114
+
115
+ self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
116
+ self.cfg.get_feature_transforms(split, is_train_split)
117
+ )
118
+ self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict(
119
+ self.cfg.get_waveform_transforms(split, is_train_split)
120
+ )
121
+ # TODO: add these to data_cfg.py
122
+ self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict(
123
+ self.cfg.get_dataset_transforms(split, is_train_split)
124
+ )
125
+
126
+ # check proper usage of transforms
127
+ if self.feature_transforms and self.cfg.use_audio_input:
128
+ logger.warning(
129
+ "Feature transforms will not be applied. To use feature transforms, "
130
+ "set use_audio_input as False in config."
131
+ )
132
+
133
+ self.pre_tokenizer = pre_tokenizer
134
+ self.bpe_tokenizer = bpe_tokenizer
135
+ self.n_frames_per_step = n_frames_per_step
136
+ self.speaker_to_id = speaker_to_id
137
+
138
+ self.tgt_lens = self.get_tgt_lens_and_check_oov()
139
+ self.append_eos = append_eos
140
+
141
+ logger.info(self.__repr__())
142
+
143
+ def get_tgt_lens_and_check_oov(self):
144
+ if self.tgt_texts is None:
145
+ return [0 for _ in range(self.n_samples)]
146
+ tgt_lens = []
147
+ n_tokens, n_oov_tokens = 0, 0
148
+ for i in range(self.n_samples):
149
+ tokenized = self.get_tokenized_tgt_text(i).split(" ")
150
+ oov_tokens = [
151
+ t
152
+ for t in tokenized
153
+ if self.tgt_dict.index(t) == self.tgt_dict.unk_index
154
+ ]
155
+ n_tokens += len(tokenized)
156
+ n_oov_tokens += len(oov_tokens)
157
+ tgt_lens.append(len(tokenized))
158
+ logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV")
159
+ return tgt_lens
160
+
161
+ def __repr__(self):
162
+ return (
163
+ self.__class__.__name__
164
+ + f'(split="{self.split}", n_samples={self.n_samples:_}, '
165
+ f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, "
166
+ f"n_frames_per_step={self.n_frames_per_step}, "
167
+ f"shuffle={self.shuffle}, "
168
+ f"feature_transforms={self.feature_transforms}, "
169
+ f"waveform_transforms={self.waveform_transforms}, "
170
+ f"dataset_transforms={self.dataset_transforms})"
171
+ )
172
+
173
+ @classmethod
174
+ def is_lang_tag(cls, token):
175
+ pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
176
+ return re.match(pattern, token)
177
+
178
+ def check_tgt_lang_tag(self):
179
+ if self.cfg.prepend_tgt_lang_tag:
180
+ assert self.tgt_langs is not None and self.tgt_dict is not None
181
+ tgt_lang_tags = [
182
+ self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
183
+ ]
184
+ assert all(t in self.tgt_dict for t in tgt_lang_tags)
185
+
186
+ @classmethod
187
+ def tokenize(cls, tokenizer, text: str):
188
+ return text if tokenizer is None else tokenizer.encode(text)
189
+
190
+ def get_tokenized_tgt_text(self, index: Union[int, List[int]]):
191
+ if _is_int_or_np_int(index):
192
+ text = self.tgt_texts[index]
193
+ else:
194
+ text = " ".join([self.tgt_texts[i] for i in index])
195
+
196
+ text = self.tokenize(self.pre_tokenizer, text)
197
+ text = self.tokenize(self.bpe_tokenizer, text)
198
+ return text
199
+
200
+ def pack_frames(self, feature: torch.Tensor):
201
+ if self.n_frames_per_step == 1:
202
+ return feature
203
+ n_packed_frames = feature.shape[0] // self.n_frames_per_step
204
+ feature = feature[: self.n_frames_per_step * n_packed_frames]
205
+ return feature.reshape(n_packed_frames, -1)
206
+
207
+ @classmethod
208
+ def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary):
209
+ lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang))
210
+ assert lang_tag_idx != dictionary.unk()
211
+ return lang_tag_idx
212
+
213
+ def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor:
214
+ """
215
+ Gives source audio for given index with any relevant transforms
216
+ applied. For ConcatAug, source audios for given indices are
217
+ concatenated in given order.
218
+ Args:
219
+ index (int or List[int]): index—or in the case of ConcatAug,
220
+ indices—to pull the source audio for
221
+ Returns:
222
+ source audios concatenated for given indices with
223
+ relevant transforms appplied
224
+ """
225
+ if _is_int_or_np_int(index):
226
+ source = get_features_or_waveform(
227
+ self.audio_paths[index],
228
+ need_waveform=self.cfg.use_audio_input,
229
+ use_sample_rate=self.cfg.use_sample_rate,
230
+ waveform_transforms=self.waveform_transforms,
231
+ )
232
+ else:
233
+ source = np.concatenate(
234
+ [
235
+ get_features_or_waveform(
236
+ self.audio_paths[i],
237
+ need_waveform=self.cfg.use_audio_input,
238
+ use_sample_rate=self.cfg.use_sample_rate,
239
+ waveform_transforms=self.waveform_transforms,
240
+ )
241
+ for i in index
242
+ ]
243
+ )
244
+ if self.cfg.use_audio_input:
245
+ source = torch.from_numpy(source).float()
246
+ if self.cfg.standardize_audio:
247
+ with torch.no_grad():
248
+ source = F.layer_norm(source, source.shape)
249
+ else:
250
+ if self.feature_transforms is not None:
251
+ source = self.feature_transforms(source)
252
+ source = torch.from_numpy(source).float()
253
+ return source
254
+
255
+ def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
256
+ has_concat = self.dataset_transforms.has_transform(ConcatAugment)
257
+ if has_concat:
258
+ concat = self.dataset_transforms.get_transform(ConcatAugment)
259
+ indices = concat.find_indices(index, self.n_frames, self.n_samples)
260
+
261
+ source = self._get_source_audio(indices if has_concat else index)
262
+ source = self.pack_frames(source)
263
+
264
+ target = None
265
+ if self.tgt_texts is not None:
266
+ tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
267
+ target = self.tgt_dict.encode_line(
268
+ tokenized, add_if_not_exist=False, append_eos=self.append_eos
269
+ ).long()
270
+ if self.cfg.prepend_tgt_lang_tag:
271
+ lang_tag_idx = self.get_lang_tag_idx(
272
+ self.tgt_langs[index], self.tgt_dict
273
+ )
274
+ target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
275
+
276
+ if self.cfg.prepend_bos_and_append_tgt_lang_tag:
277
+ bos = torch.LongTensor([self.tgt_dict.bos()])
278
+ lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
279
+ assert lang_tag_idx != self.tgt_dict.unk()
280
+ lang_tag_idx = torch.LongTensor([lang_tag_idx])
281
+ target = torch.cat((bos, target, lang_tag_idx), 0)
282
+
283
+ speaker_id = None
284
+ if self.speaker_to_id is not None:
285
+ speaker_id = self.speaker_to_id[self.speakers[index]]
286
+ return SpeechToTextDatasetItem(
287
+ index=index, source=source, target=target, speaker_id=speaker_id
288
+ )
289
+
290
+ def __len__(self):
291
+ return self.n_samples
292
+
293
+ def collater(
294
+ self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
295
+ ) -> Dict:
296
+ if len(samples) == 0:
297
+ return {}
298
+ indices = torch.tensor([x.index for x in samples], dtype=torch.long)
299
+
300
+ sources = [x.source for x in samples]
301
+ has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
302
+ if has_NOAug and self.cfg.use_audio_input:
303
+ NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
304
+ sources = NOAug(sources)
305
+
306
+ frames = _collate_frames(sources, self.cfg.use_audio_input)
307
+ # sort samples by descending number of frames
308
+ n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
309
+ n_frames, order = n_frames.sort(descending=True)
310
+ indices = indices.index_select(0, order)
311
+ frames = frames.index_select(0, order)
312
+
313
+ target, target_lengths = None, None
314
+ prev_output_tokens = None
315
+ ntokens = None
316
+ if self.tgt_texts is not None:
317
+ target = fairseq_data_utils.collate_tokens(
318
+ [x.target for x in samples],
319
+ self.tgt_dict.pad(),
320
+ self.tgt_dict.eos(),
321
+ left_pad=False,
322
+ move_eos_to_beginning=False,
323
+ )
324
+ target = target.index_select(0, order)
325
+ target_lengths = torch.tensor(
326
+ [x.target.size(0) for x in samples], dtype=torch.long
327
+ ).index_select(0, order)
328
+ prev_output_tokens = fairseq_data_utils.collate_tokens(
329
+ [x.target for x in samples],
330
+ self.tgt_dict.pad(),
331
+ eos_idx=None,
332
+ left_pad=False,
333
+ move_eos_to_beginning=True,
334
+ )
335
+ prev_output_tokens = prev_output_tokens.index_select(0, order)
336
+ ntokens = sum(x.target.size(0) for x in samples)
337
+
338
+ speaker = None
339
+ if self.speaker_to_id is not None:
340
+ speaker = (
341
+ torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
342
+ .index_select(0, order)
343
+ .view(-1, 1)
344
+ )
345
+
346
+ net_input = {
347
+ "src_tokens": frames,
348
+ "src_lengths": n_frames,
349
+ "prev_output_tokens": prev_output_tokens,
350
+ }
351
+ out = {
352
+ "id": indices,
353
+ "net_input": net_input,
354
+ "speaker": speaker,
355
+ "target": target,
356
+ "target_lengths": target_lengths,
357
+ "ntokens": ntokens,
358
+ "nsentences": len(samples),
359
+ }
360
+ if return_order:
361
+ out["order"] = order
362
+ return out
363
+
364
+ def num_tokens(self, index):
365
+ return self.n_frames[index]
366
+
367
+ def size(self, index):
368
+ return self.n_frames[index], self.tgt_lens[index]
369
+
370
+ @property
371
+ def sizes(self):
372
+ return np.array(self.n_frames)
373
+
374
+ @property
375
+ def can_reuse_epoch_itr_across_epochs(self):
376
+ return True
377
+
378
+ def ordered_indices(self):
379
+ if self.shuffle:
380
+ order = [np.random.permutation(len(self))]
381
+ else:
382
+ order = [np.arange(len(self))]
383
+ # first by descending order of # of frames then by original/random order
384
+ order.append([-n for n in self.n_frames])
385
+ return np.lexsort(order)
386
+
387
+ def prefetch(self, indices):
388
+ raise False
389
+
390
+
391
+ class TextTargetMultitaskData(object):
392
+ # mandatory columns
393
+ KEY_ID, KEY_TEXT = "id", "tgt_text"
394
+ LANG_TAG_TEMPLATE = "<lang:{}>"
395
+
396
+ def __init__(self, args, split, tgt_dict):
397
+ samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split)
398
+ self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples}
399
+ self.dict = tgt_dict
400
+ self.append_eos = args.decoder_type != "ctc"
401
+ self.pre_tokenizer = self.build_tokenizer(args)
402
+ self.bpe_tokenizer = self.build_bpe(args)
403
+ self.prepend_bos_and_append_tgt_lang_tag = (
404
+ args.prepend_bos_and_append_tgt_lang_tag
405
+ )
406
+ self.eos_token = args.eos_token
407
+ self.lang_tag_mapping = args.get_lang_tag_mapping
408
+
409
+ @classmethod
410
+ def is_lang_tag(cls, token):
411
+ pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
412
+ return re.match(pattern, token)
413
+
414
+ @classmethod
415
+ def tokenize(cls, tokenizer, text: str):
416
+ return text if tokenizer is None else tokenizer.encode(text)
417
+
418
+ def get_tokenized_tgt_text(self, index: int):
419
+ text = self.tokenize(self.pre_tokenizer, self.data[index])
420
+ text = self.tokenize(self.bpe_tokenizer, text)
421
+ return text
422
+
423
+ def get_lang_tag_idx(self, lang: str, dictionary: Dictionary):
424
+ lang_tag = self.LANG_TAG_TEMPLATE.format(lang)
425
+ lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag)
426
+ lang_tag_idx = dictionary.index(lang_tag)
427
+ assert lang_tag_idx != dictionary.unk(), (lang, lang_tag)
428
+ return lang_tag_idx
429
+
430
+ def build_tokenizer(self, args):
431
+ pre_tokenizer = args.config.get("pre_tokenizer")
432
+ if pre_tokenizer is not None:
433
+ logger.info(f"pre-tokenizer: {pre_tokenizer}")
434
+ return encoders.build_tokenizer(Namespace(**pre_tokenizer))
435
+ else:
436
+ return None
437
+
438
+ def build_bpe(self, args):
439
+ bpe_tokenizer = args.config.get("bpe_tokenizer")
440
+ if bpe_tokenizer is not None:
441
+ logger.info(f"tokenizer: {bpe_tokenizer}")
442
+ return encoders.build_bpe(Namespace(**bpe_tokenizer))
443
+ else:
444
+ return None
445
+
446
+ def get(self, sample_id, tgt_lang=None):
447
+ if sample_id in self.data:
448
+ tokenized = self.get_tokenized_tgt_text(sample_id)
449
+ target = self.dict.encode_line(
450
+ tokenized,
451
+ add_if_not_exist=False,
452
+ append_eos=self.append_eos,
453
+ )
454
+ if self.prepend_bos_and_append_tgt_lang_tag:
455
+ bos = torch.LongTensor([self.dict.bos()])
456
+ lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict)
457
+ assert lang_tag_idx != self.dict.unk()
458
+ lang_tag_idx = torch.LongTensor([lang_tag_idx])
459
+ target = torch.cat((bos, target, lang_tag_idx), 0)
460
+ return target
461
+ else:
462
+ logger.warning(f"no target for {sample_id}")
463
+ return torch.IntTensor([])
464
+
465
+ def collater(self, samples: List[torch.Tensor]) -> torch.Tensor:
466
+ out = fairseq_data_utils.collate_tokens(
467
+ samples,
468
+ self.dict.pad(),
469
+ eos_idx=None,
470
+ left_pad=False,
471
+ move_eos_to_beginning=False,
472
+ ).long()
473
+
474
+ prev_out = fairseq_data_utils.collate_tokens(
475
+ samples,
476
+ self.dict.pad(),
477
+ eos_idx=None,
478
+ left_pad=False,
479
+ move_eos_to_beginning=True,
480
+ ).long()
481
+
482
+ target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long)
483
+ ntokens = sum(t.size(0) for t in samples)
484
+
485
+ output = {
486
+ "prev_output_tokens": prev_out,
487
+ "target": out,
488
+ "target_lengths": target_lengths,
489
+ "ntokens": ntokens,
490
+ }
491
+
492
+ return output
493
+
494
+
495
+ class SpeechToTextMultitaskDataset(SpeechToTextDataset):
496
+ def __init__(self, **kwargs):
497
+ super().__init__(**kwargs)
498
+ self.multitask_data = {}
499
+
500
+ def add_multitask_dataset(self, task_name, task_data):
501
+ self.multitask_data[task_name] = task_data
502
+
503
+ def __getitem__(
504
+ self, index: int
505
+ ) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]:
506
+ s2t_data = super().__getitem__(index)
507
+
508
+ multitask_target = {}
509
+ sample_id = self.ids[index]
510
+ tgt_lang = self.tgt_langs[index]
511
+ for task_name, task_dataset in self.multitask_data.items():
512
+ multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
513
+
514
+ return s2t_data, multitask_target
515
+
516
+ def collater(
517
+ self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]]
518
+ ) -> Dict:
519
+ if len(samples) == 0:
520
+ return {}
521
+
522
+ out = super().collater([s for s, _ in samples], return_order=True)
523
+ order = out["order"]
524
+ del out["order"]
525
+
526
+ for task_name, task_dataset in self.multitask_data.items():
527
+ if "multitask" not in out:
528
+ out["multitask"] = {}
529
+ d = [s[task_name] for _, s in samples]
530
+ task_target = task_dataset.collater(d)
531
+ out["multitask"][task_name] = {
532
+ "target": task_target["target"].index_select(0, order),
533
+ "target_lengths": task_target["target_lengths"].index_select(0, order),
534
+ "ntokens": task_target["ntokens"],
535
+ }
536
+ out["multitask"][task_name]["net_input"] = {
537
+ "prev_output_tokens": task_target["prev_output_tokens"].index_select(
538
+ 0, order
539
+ ),
540
+ }
541
+
542
+ return out
543
+
544
+
545
+ class SpeechToTextDatasetCreator(object):
546
+ # mandatory columns
547
+ KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
548
+ KEY_TGT_TEXT = "tgt_text"
549
+ # optional columns
550
+ KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
551
+ KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
552
+ # default values
553
+ DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
554
+
555
+ @classmethod
556
+ def _from_list(
557
+ cls,
558
+ split_name: str,
559
+ is_train_split,
560
+ samples: List[Dict],
561
+ cfg: S2TDataConfig,
562
+ tgt_dict,
563
+ pre_tokenizer,
564
+ bpe_tokenizer,
565
+ n_frames_per_step,
566
+ speaker_to_id,
567
+ multitask: Optional[Dict] = None,
568
+ ) -> SpeechToTextDataset:
569
+ audio_root = Path(cfg.audio_root)
570
+ ids = [s[cls.KEY_ID] for s in samples]
571
+ audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
572
+ n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
573
+ tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
574
+ src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
575
+ speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
576
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
577
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
578
+
579
+ has_multitask = multitask is not None and len(multitask.keys()) > 0
580
+ dataset_cls = (
581
+ SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset
582
+ )
583
+
584
+ ds = dataset_cls(
585
+ split=split_name,
586
+ is_train_split=is_train_split,
587
+ cfg=cfg,
588
+ audio_paths=audio_paths,
589
+ n_frames=n_frames,
590
+ src_texts=src_texts,
591
+ tgt_texts=tgt_texts,
592
+ speakers=speakers,
593
+ src_langs=src_langs,
594
+ tgt_langs=tgt_langs,
595
+ ids=ids,
596
+ tgt_dict=tgt_dict,
597
+ pre_tokenizer=pre_tokenizer,
598
+ bpe_tokenizer=bpe_tokenizer,
599
+ n_frames_per_step=n_frames_per_step,
600
+ speaker_to_id=speaker_to_id,
601
+ )
602
+
603
+ if has_multitask:
604
+ for task_name, task_obj in multitask.items():
605
+ task_data = TextTargetMultitaskData(
606
+ task_obj.args, split_name, task_obj.target_dictionary
607
+ )
608
+ ds.add_multitask_dataset(task_name, task_data)
609
+ return ds
610
+
611
+ @classmethod
612
+ def get_size_ratios(
613
+ cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0
614
+ ) -> List[float]:
615
+ """Size ratios for temperature-based sampling
616
+ (https://arxiv.org/abs/1907.05019)"""
617
+
618
+ id_to_lp, lp_to_sz = {}, defaultdict(int)
619
+ for ds in datasets:
620
+ lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)}
621
+ assert len(lang_pairs) == 1
622
+ lang_pair = list(lang_pairs)[0]
623
+ id_to_lp[ds.split] = lang_pair
624
+ lp_to_sz[lang_pair] += sum(ds.n_frames)
625
+
626
+ sz_sum = sum(v for v in lp_to_sz.values())
627
+ lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()}
628
+ lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()}
629
+ prob_sum = sum(v for v in lp_to_tgt_prob.values())
630
+ lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()}
631
+ lp_to_sz_ratio = {
632
+ k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items()
633
+ }
634
+ size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets]
635
+
636
+ p_formatted = {
637
+ k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz
638
+ }
639
+ logger.info(f"sampling probability balancing: {p_formatted}")
640
+ sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)}
641
+ logger.info(f"balanced sampling size ratio: {sr_formatted}")
642
+ return size_ratio
643
+
644
+ @classmethod
645
+ def _load_samples_from_tsv(cls, root: str, split: str):
646
+ tsv_path = Path(root) / f"{split}.tsv"
647
+ if not tsv_path.is_file():
648
+ raise FileNotFoundError(f"Dataset not found: {tsv_path}")
649
+ with open(tsv_path) as f:
650
+ reader = csv.DictReader(
651
+ f,
652
+ delimiter="\t",
653
+ quotechar=None,
654
+ doublequote=False,
655
+ lineterminator="\n",
656
+ quoting=csv.QUOTE_NONE,
657
+ )
658
+ samples = [dict(e) for e in reader]
659
+ if len(samples) == 0:
660
+ raise ValueError(f"Empty manifest: {tsv_path}")
661
+ return samples
662
+
663
+ @classmethod
664
+ def _from_tsv(
665
+ cls,
666
+ root: str,
667
+ cfg: S2TDataConfig,
668
+ split: str,
669
+ tgt_dict,
670
+ is_train_split: bool,
671
+ pre_tokenizer,
672
+ bpe_tokenizer,
673
+ n_frames_per_step,
674
+ speaker_to_id,
675
+ multitask: Optional[Dict] = None,
676
+ ) -> SpeechToTextDataset:
677
+ samples = cls._load_samples_from_tsv(root, split)
678
+ return cls._from_list(
679
+ split,
680
+ is_train_split,
681
+ samples,
682
+ cfg,
683
+ tgt_dict,
684
+ pre_tokenizer,
685
+ bpe_tokenizer,
686
+ n_frames_per_step,
687
+ speaker_to_id,
688
+ multitask,
689
+ )
690
+
691
+ @classmethod
692
+ def from_tsv(
693
+ cls,
694
+ root: str,
695
+ cfg: S2TDataConfig,
696
+ splits: str,
697
+ tgt_dict,
698
+ pre_tokenizer,
699
+ bpe_tokenizer,
700
+ is_train_split: bool,
701
+ epoch: int,
702
+ seed: int,
703
+ n_frames_per_step: int = 1,
704
+ speaker_to_id=None,
705
+ multitask: Optional[Dict] = None,
706
+ ) -> SpeechToTextDataset:
707
+ datasets = [
708
+ cls._from_tsv(
709
+ root=root,
710
+ cfg=cfg,
711
+ split=split,
712
+ tgt_dict=tgt_dict,
713
+ is_train_split=is_train_split,
714
+ pre_tokenizer=pre_tokenizer,
715
+ bpe_tokenizer=bpe_tokenizer,
716
+ n_frames_per_step=n_frames_per_step,
717
+ speaker_to_id=speaker_to_id,
718
+ multitask=multitask,
719
+ )
720
+ for split in splits.split(",")
721
+ ]
722
+
723
+ if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
724
+ # temperature-based sampling
725
+ size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
726
+ datasets = [
727
+ ResamplingDataset(
728
+ d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
729
+ )
730
+ for r, d in zip(size_ratios, datasets)
731
+ ]
732
+
733
+ return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
modules/voice_conversion/fairseq/data/audio/speech_to_text_joint_dataset.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Dict, List, NamedTuple, Optional
9
+
10
+ import torch
11
+
12
+ from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
13
+ from fairseq.data import data_utils as fairseq_data_utils
14
+ from fairseq.data.audio.speech_to_text_dataset import (
15
+ S2TDataConfig,
16
+ SpeechToTextDataset,
17
+ SpeechToTextDatasetCreator,
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class S2TJointDataConfig(S2TDataConfig):
24
+ """Wrapper class for data config YAML"""
25
+
26
+ @property
27
+ def src_vocab_filename(self):
28
+ """fairseq vocabulary file under data root"""
29
+ return self.config.get("src_vocab_filename", "src_dict.txt")
30
+
31
+ @property
32
+ def src_pre_tokenizer(self) -> Dict:
33
+ """Pre-tokenizer to apply before subword tokenization. Returning
34
+ a dictionary with `tokenizer` providing the tokenizer name and
35
+ the other items providing the tokenizer-specific arguments.
36
+ Tokenizers are defined in `fairseq.data.encoders.*`"""
37
+ return self.config.get("src_pre_tokenizer", {"tokenizer": None})
38
+
39
+ @property
40
+ def src_bpe_tokenizer(self) -> Dict:
41
+ """Subword tokenizer to apply on source text after pre-tokenization.
42
+ Returning a dictionary with `bpe` providing the tokenizer name and
43
+ the other items providing the tokenizer-specific arguments.
44
+ Tokenizers are defined in `fairseq.data.encoders.*`"""
45
+ return self.config.get("src_bpe_tokenizer", {"bpe": None})
46
+
47
+ @property
48
+ def prepend_tgt_lang_tag_no_change(self) -> bool:
49
+ """Prepend target lang ID token as the prev_output_tokens BOS (e.g. for
50
+ to-many multilingual setting). No change needed during inference.
51
+ This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos.
52
+ """
53
+ value = self.config.get("prepend_tgt_lang_tag_no_change", None)
54
+ if value is None:
55
+ return self.config.get("prepend_tgt_lang_tag_as_bos", False)
56
+ return value
57
+
58
+ @property
59
+ def sampling_text_alpha(self):
60
+ """Hyper-parameter alpha = 1/T for temperature-based resampling. (text
61
+ input only) (alpha = 1 for no resampling)"""
62
+ return self.config.get("sampling_text_alpha", 1.0)
63
+
64
+
65
+ class SpeechToTextJointDatasetItem(NamedTuple):
66
+ index: int
67
+ source: torch.Tensor
68
+ target: Optional[torch.Tensor] = None
69
+ src_txt_tokens: Optional[torch.Tensor] = None
70
+ tgt_lang_tag: Optional[int] = None
71
+ src_lang_tag: Optional[int] = None
72
+ tgt_alignment: Optional[torch.Tensor] = None
73
+
74
+
75
+ # use_src_lang_id:
76
+ # 0: don't use src_lang_id
77
+ # 1: attach src_lang_id to the src_txt_tokens as eos
78
+ class SpeechToTextJointDataset(SpeechToTextDataset):
79
+ def __init__(
80
+ self,
81
+ split: str,
82
+ is_train_split: bool,
83
+ cfg: S2TJointDataConfig,
84
+ audio_paths: List[str],
85
+ n_frames: List[int],
86
+ src_texts: Optional[List[str]] = None,
87
+ tgt_texts: Optional[List[str]] = None,
88
+ speakers: Optional[List[str]] = None,
89
+ src_langs: Optional[List[str]] = None,
90
+ tgt_langs: Optional[List[str]] = None,
91
+ ids: Optional[List[str]] = None,
92
+ tgt_dict: Optional[Dictionary] = None,
93
+ src_dict: Optional[Dictionary] = None,
94
+ pre_tokenizer=None,
95
+ bpe_tokenizer=None,
96
+ src_pre_tokenizer=None,
97
+ src_bpe_tokenizer=None,
98
+ append_eos: Optional[bool] = True,
99
+ alignment: Optional[List[str]] = None,
100
+ use_src_lang_id: Optional[int] = 0,
101
+ ):
102
+ super().__init__(
103
+ split,
104
+ is_train_split,
105
+ cfg,
106
+ audio_paths,
107
+ n_frames,
108
+ src_texts=src_texts,
109
+ tgt_texts=tgt_texts,
110
+ speakers=speakers,
111
+ src_langs=src_langs,
112
+ tgt_langs=tgt_langs,
113
+ ids=ids,
114
+ tgt_dict=tgt_dict,
115
+ pre_tokenizer=pre_tokenizer,
116
+ bpe_tokenizer=bpe_tokenizer,
117
+ append_eos=append_eos,
118
+ )
119
+
120
+ self.src_dict = src_dict
121
+ self.src_pre_tokenizer = src_pre_tokenizer
122
+ self.src_bpe_tokenizer = src_bpe_tokenizer
123
+ self.alignment = None
124
+ self.use_src_lang_id = use_src_lang_id
125
+ if alignment is not None:
126
+ self.alignment = [
127
+ [float(s) for s in sample.split()] for sample in alignment
128
+ ]
129
+
130
+ def get_tokenized_src_text(self, index: int):
131
+ text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index])
132
+ text = self.tokenize(self.src_bpe_tokenizer, text)
133
+ return text
134
+
135
+ def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem:
136
+ s2t_dataset_item = super().__getitem__(index)
137
+ src_tokens = None
138
+ src_lang_tag = None
139
+ if self.src_texts is not None and self.src_dict is not None:
140
+ src_tokens = self.get_tokenized_src_text(index)
141
+ src_tokens = self.src_dict.encode_line(
142
+ src_tokens, add_if_not_exist=False, append_eos=True
143
+ ).long()
144
+ if self.use_src_lang_id > 0:
145
+ src_lang_tag = self.get_lang_tag_idx(
146
+ self.src_langs[index], self.src_dict
147
+ )
148
+ tgt_lang_tag = None
149
+ if self.cfg.prepend_tgt_lang_tag_no_change:
150
+ # prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead
151
+ tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
152
+ ali = None
153
+ if self.alignment is not None:
154
+ ali = torch.Tensor(self.alignment[index]).float()
155
+
156
+ return SpeechToTextJointDatasetItem(
157
+ index=index,
158
+ source=s2t_dataset_item.source,
159
+ target=s2t_dataset_item.target,
160
+ src_txt_tokens=src_tokens,
161
+ tgt_lang_tag=tgt_lang_tag,
162
+ src_lang_tag=src_lang_tag,
163
+ tgt_alignment=ali,
164
+ )
165
+
166
+ def __len__(self):
167
+ return self.n_samples
168
+
169
+ def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict:
170
+ s2t_out = super().collater(samples, return_order=True)
171
+ if s2t_out == {}:
172
+ return s2t_out
173
+ net_input, order = s2t_out["net_input"], s2t_out["order"]
174
+
175
+ if self.src_texts is not None and self.src_dict is not None:
176
+ src_txt_tokens = fairseq_data_utils.collate_tokens(
177
+ [x.src_txt_tokens for x in samples],
178
+ self.src_dict.pad(),
179
+ self.src_dict.eos(),
180
+ left_pad=False,
181
+ move_eos_to_beginning=False,
182
+ )
183
+ src_txt_lengths = torch.tensor(
184
+ [x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long
185
+ )
186
+ if self.use_src_lang_id > 0:
187
+ src_lang_idxs = torch.tensor(
188
+ [s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype
189
+ )
190
+ if self.use_src_lang_id == 1: # replace eos with lang_id
191
+ eos_idx = src_txt_lengths - 1
192
+ src_txt_tokens.scatter_(
193
+ 1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1)
194
+ )
195
+ else:
196
+ raise NotImplementedError("Implementation is required")
197
+
198
+ src_txt_tokens = src_txt_tokens.index_select(0, order)
199
+ src_txt_lengths = src_txt_lengths.index_select(0, order)
200
+ net_input["src_txt_tokens"] = src_txt_tokens
201
+ net_input["src_txt_lengths"] = src_txt_lengths
202
+
203
+ net_input["alignment"] = None
204
+ if self.alignment is not None:
205
+ max_len = max([s.tgt_alignment.size(0) for s in samples])
206
+ alignment = torch.ones(len(samples), max_len).float()
207
+ for i, s in enumerate(samples):
208
+ cur_len = s.tgt_alignment.size(0)
209
+ alignment[i][:cur_len].copy_(s.tgt_alignment)
210
+ net_input["alignment"] = alignment.index_select(0, order)
211
+
212
+ if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
213
+ for i in range(len(samples)):
214
+ net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
215
+
216
+ out = {
217
+ "id": s2t_out["id"],
218
+ "net_input": net_input,
219
+ "target": s2t_out["target"],
220
+ "target_lengths": s2t_out["target_lengths"],
221
+ "ntokens": s2t_out["ntokens"],
222
+ "nsentences": len(samples),
223
+ }
224
+ return out
225
+
226
+
227
+ class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator):
228
+ KEY_ALIGN = "align"
229
+
230
+ @classmethod
231
+ def _from_list(
232
+ cls,
233
+ split_name: str,
234
+ is_train_split,
235
+ samples: List[Dict],
236
+ cfg: S2TJointDataConfig,
237
+ tgt_dict,
238
+ src_dict,
239
+ pre_tokenizer,
240
+ bpe_tokenizer,
241
+ src_pre_tokenizer,
242
+ src_bpe_tokenizer,
243
+ append_eos,
244
+ use_src_lang_id,
245
+ ) -> SpeechToTextJointDataset:
246
+ audio_root = Path(cfg.audio_root)
247
+ ids = [s[cls.KEY_ID] for s in samples]
248
+ audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
249
+ n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
250
+ tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
251
+ src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
252
+ speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
253
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
254
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
255
+ tgt_alignment = None
256
+ if cls.KEY_ALIGN in samples[0].keys():
257
+ tgt_alignment = [s[cls.KEY_ALIGN] for s in samples]
258
+ return SpeechToTextJointDataset(
259
+ split_name,
260
+ is_train_split,
261
+ cfg,
262
+ audio_paths,
263
+ n_frames,
264
+ src_texts=src_texts,
265
+ tgt_texts=tgt_texts,
266
+ speakers=speakers,
267
+ src_langs=src_langs,
268
+ tgt_langs=tgt_langs,
269
+ ids=ids,
270
+ tgt_dict=tgt_dict,
271
+ src_dict=src_dict,
272
+ pre_tokenizer=pre_tokenizer,
273
+ bpe_tokenizer=bpe_tokenizer,
274
+ src_pre_tokenizer=src_pre_tokenizer,
275
+ src_bpe_tokenizer=src_bpe_tokenizer,
276
+ append_eos=append_eos,
277
+ alignment=tgt_alignment,
278
+ use_src_lang_id=use_src_lang_id,
279
+ )
280
+
281
+ @classmethod
282
+ def _from_tsv(
283
+ cls,
284
+ root: str,
285
+ cfg: S2TJointDataConfig,
286
+ split: str,
287
+ tgt_dict,
288
+ src_dict,
289
+ is_train_split: bool,
290
+ pre_tokenizer,
291
+ bpe_tokenizer,
292
+ src_pre_tokenizer,
293
+ src_bpe_tokenizer,
294
+ append_eos: bool,
295
+ use_src_lang_id: int,
296
+ ) -> SpeechToTextJointDataset:
297
+ samples = cls._load_samples_from_tsv(root, split)
298
+ return cls._from_list(
299
+ split,
300
+ is_train_split,
301
+ samples,
302
+ cfg,
303
+ tgt_dict,
304
+ src_dict,
305
+ pre_tokenizer,
306
+ bpe_tokenizer,
307
+ src_pre_tokenizer,
308
+ src_bpe_tokenizer,
309
+ append_eos,
310
+ use_src_lang_id,
311
+ )
312
+
313
+ @classmethod
314
+ def from_tsv(
315
+ cls,
316
+ root: str,
317
+ cfg: S2TJointDataConfig,
318
+ splits: str,
319
+ tgt_dict,
320
+ src_dict,
321
+ pre_tokenizer,
322
+ bpe_tokenizer,
323
+ src_pre_tokenizer,
324
+ src_bpe_tokenizer,
325
+ is_train_split: bool,
326
+ epoch: int,
327
+ seed: int,
328
+ append_eos: Optional[bool] = True,
329
+ use_src_lang_id: Optional[int] = 0,
330
+ ) -> SpeechToTextJointDataset:
331
+ datasets = [
332
+ cls._from_tsv(
333
+ root,
334
+ cfg,
335
+ split,
336
+ tgt_dict,
337
+ src_dict,
338
+ is_train_split,
339
+ pre_tokenizer,
340
+ bpe_tokenizer,
341
+ src_pre_tokenizer,
342
+ src_bpe_tokenizer,
343
+ append_eos=append_eos,
344
+ use_src_lang_id=use_src_lang_id,
345
+ )
346
+ for split in splits.split(",")
347
+ ]
348
+
349
+ if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
350
+ # temperature-based sampling
351
+ size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
352
+ datasets = [
353
+ ResamplingDataset(
354
+ d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
355
+ )
356
+ for r, d in zip(size_ratios, datasets)
357
+ ]
358
+
359
+ return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
modules/voice_conversion/fairseq/data/audio/text_to_speech_dataset.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.abs
7
+
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ from fairseq.data import Dictionary
16
+ from fairseq.data import data_utils as fairseq_data_utils
17
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
18
+ from fairseq.data.audio.speech_to_text_dataset import (
19
+ S2TDataConfig,
20
+ SpeechToTextDataset,
21
+ SpeechToTextDatasetCreator,
22
+ _collate_frames,
23
+ )
24
+
25
+
26
+ @dataclass
27
+ class TextToSpeechDatasetItem(object):
28
+ index: int
29
+ source: torch.Tensor
30
+ target: Optional[torch.Tensor] = None
31
+ speaker_id: Optional[int] = None
32
+ duration: Optional[torch.Tensor] = None
33
+ pitch: Optional[torch.Tensor] = None
34
+ energy: Optional[torch.Tensor] = None
35
+
36
+
37
+ class TextToSpeechDataset(SpeechToTextDataset):
38
+ def __init__(
39
+ self,
40
+ split: str,
41
+ is_train_split: bool,
42
+ cfg: S2TDataConfig,
43
+ audio_paths: List[str],
44
+ n_frames: List[int],
45
+ src_texts: Optional[List[str]] = None,
46
+ tgt_texts: Optional[List[str]] = None,
47
+ speakers: Optional[List[str]] = None,
48
+ src_langs: Optional[List[str]] = None,
49
+ tgt_langs: Optional[List[str]] = None,
50
+ ids: Optional[List[str]] = None,
51
+ tgt_dict: Optional[Dictionary] = None,
52
+ pre_tokenizer=None,
53
+ bpe_tokenizer=None,
54
+ n_frames_per_step=1,
55
+ speaker_to_id=None,
56
+ durations: Optional[List[List[int]]] = None,
57
+ pitches: Optional[List[str]] = None,
58
+ energies: Optional[List[str]] = None,
59
+ ):
60
+ super(TextToSpeechDataset, self).__init__(
61
+ split,
62
+ is_train_split,
63
+ cfg,
64
+ audio_paths,
65
+ n_frames,
66
+ src_texts=src_texts,
67
+ tgt_texts=tgt_texts,
68
+ speakers=speakers,
69
+ src_langs=src_langs,
70
+ tgt_langs=tgt_langs,
71
+ ids=ids,
72
+ tgt_dict=tgt_dict,
73
+ pre_tokenizer=pre_tokenizer,
74
+ bpe_tokenizer=bpe_tokenizer,
75
+ n_frames_per_step=n_frames_per_step,
76
+ speaker_to_id=speaker_to_id,
77
+ )
78
+ self.durations = durations
79
+ self.pitches = pitches
80
+ self.energies = energies
81
+
82
+ def __getitem__(self, index: int) -> TextToSpeechDatasetItem:
83
+ s2t_item = super().__getitem__(index)
84
+
85
+ duration, pitch, energy = None, None, None
86
+ if self.durations is not None:
87
+ duration = torch.tensor(
88
+ self.durations[index] + [0], dtype=torch.long # pad 0 for EOS
89
+ )
90
+ if self.pitches is not None:
91
+ pitch = get_features_or_waveform(self.pitches[index])
92
+ pitch = torch.from_numpy(
93
+ np.concatenate((pitch, [0])) # pad 0 for EOS
94
+ ).float()
95
+ if self.energies is not None:
96
+ energy = get_features_or_waveform(self.energies[index])
97
+ energy = torch.from_numpy(
98
+ np.concatenate((energy, [0])) # pad 0 for EOS
99
+ ).float()
100
+ return TextToSpeechDatasetItem(
101
+ index=index,
102
+ source=s2t_item.source,
103
+ target=s2t_item.target,
104
+ speaker_id=s2t_item.speaker_id,
105
+ duration=duration,
106
+ pitch=pitch,
107
+ energy=energy,
108
+ )
109
+
110
+ def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
111
+ if len(samples) == 0:
112
+ return {}
113
+
114
+ src_lengths, order = torch.tensor(
115
+ [s.target.shape[0] for s in samples], dtype=torch.long
116
+ ).sort(descending=True)
117
+ id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
118
+ 0, order
119
+ )
120
+ feat = _collate_frames(
121
+ [s.source for s in samples], self.cfg.use_audio_input
122
+ ).index_select(0, order)
123
+ target_lengths = torch.tensor(
124
+ [s.source.shape[0] for s in samples], dtype=torch.long
125
+ ).index_select(0, order)
126
+
127
+ src_tokens = fairseq_data_utils.collate_tokens(
128
+ [s.target for s in samples],
129
+ self.tgt_dict.pad(),
130
+ self.tgt_dict.eos(),
131
+ left_pad=False,
132
+ move_eos_to_beginning=False,
133
+ ).index_select(0, order)
134
+
135
+ speaker = None
136
+ if self.speaker_to_id is not None:
137
+ speaker = (
138
+ torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
139
+ .index_select(0, order)
140
+ .view(-1, 1)
141
+ )
142
+
143
+ bsz, _, d = feat.size()
144
+ prev_output_tokens = torch.cat(
145
+ (feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1
146
+ )
147
+
148
+ durations, pitches, energies = None, None, None
149
+ if self.durations is not None:
150
+ durations = fairseq_data_utils.collate_tokens(
151
+ [s.duration for s in samples], 0
152
+ ).index_select(0, order)
153
+ assert src_tokens.shape[1] == durations.shape[1]
154
+ if self.pitches is not None:
155
+ pitches = _collate_frames([s.pitch for s in samples], True)
156
+ pitches = pitches.index_select(0, order)
157
+ assert src_tokens.shape[1] == pitches.shape[1]
158
+ if self.energies is not None:
159
+ energies = _collate_frames([s.energy for s in samples], True)
160
+ energies = energies.index_select(0, order)
161
+ assert src_tokens.shape[1] == energies.shape[1]
162
+ src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
163
+
164
+ return {
165
+ "id": id_,
166
+ "net_input": {
167
+ "src_tokens": src_tokens,
168
+ "src_lengths": src_lengths,
169
+ "prev_output_tokens": prev_output_tokens,
170
+ },
171
+ "speaker": speaker,
172
+ "target": feat,
173
+ "durations": durations,
174
+ "pitches": pitches,
175
+ "energies": energies,
176
+ "target_lengths": target_lengths,
177
+ "ntokens": sum(target_lengths).item(),
178
+ "nsentences": len(samples),
179
+ "src_texts": src_texts,
180
+ }
181
+
182
+
183
+ class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
184
+ KEY_DURATION = "duration"
185
+ KEY_PITCH = "pitch"
186
+ KEY_ENERGY = "energy"
187
+
188
+ @classmethod
189
+ def _from_list(
190
+ cls,
191
+ split_name: str,
192
+ is_train_split,
193
+ samples: List[Dict],
194
+ cfg: S2TDataConfig,
195
+ tgt_dict,
196
+ pre_tokenizer,
197
+ bpe_tokenizer,
198
+ n_frames_per_step,
199
+ speaker_to_id,
200
+ multitask=None,
201
+ ) -> TextToSpeechDataset:
202
+ audio_root = Path(cfg.audio_root)
203
+ ids = [s[cls.KEY_ID] for s in samples]
204
+ audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
205
+ n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
206
+ tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
207
+ src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
208
+ speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
209
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
210
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
211
+
212
+ durations = [s.get(cls.KEY_DURATION, None) for s in samples]
213
+ durations = [
214
+ None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
215
+ ]
216
+ durations = None if any(dd is None for dd in durations) else durations
217
+
218
+ pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
219
+ pitches = [
220
+ None if pp is None else (audio_root / pp).as_posix() for pp in pitches
221
+ ]
222
+ pitches = None if any(pp is None for pp in pitches) else pitches
223
+
224
+ energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
225
+ energies = [
226
+ None if ee is None else (audio_root / ee).as_posix() for ee in energies
227
+ ]
228
+ energies = None if any(ee is None for ee in energies) else energies
229
+
230
+ return TextToSpeechDataset(
231
+ split_name,
232
+ is_train_split,
233
+ cfg,
234
+ audio_paths,
235
+ n_frames,
236
+ src_texts,
237
+ tgt_texts,
238
+ speakers,
239
+ src_langs,
240
+ tgt_langs,
241
+ ids,
242
+ tgt_dict,
243
+ pre_tokenizer,
244
+ bpe_tokenizer,
245
+ n_frames_per_step,
246
+ speaker_to_id,
247
+ durations,
248
+ pitches,
249
+ energies,
250
+ )
modules/voice_conversion/fairseq/data/audio/waveform_transforms/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fairseq.data.audio import (
3
+ AudioTransform,
4
+ CompositeAudioTransform,
5
+ import_transforms,
6
+ register_audio_transform,
7
+ )
8
+
9
+
10
+ class AudioWaveformTransform(AudioTransform):
11
+ pass
12
+
13
+
14
+ AUDIO_WAVEFORM_TRANSFORM_REGISTRY = {}
15
+ AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES = set()
16
+
17
+
18
+ def get_audio_waveform_transform(name):
19
+ return AUDIO_WAVEFORM_TRANSFORM_REGISTRY[name]
20
+
21
+
22
+ def register_audio_waveform_transform(name):
23
+ return register_audio_transform(
24
+ name,
25
+ AudioWaveformTransform,
26
+ AUDIO_WAVEFORM_TRANSFORM_REGISTRY,
27
+ AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES,
28
+ )
29
+
30
+
31
+ import_transforms(os.path.dirname(__file__), "waveform")
32
+
33
+
34
+ class CompositeAudioWaveformTransform(CompositeAudioTransform):
35
+ @classmethod
36
+ def from_config_dict(cls, config=None):
37
+ return super()._from_config_dict(
38
+ cls,
39
+ "waveform",
40
+ get_audio_waveform_transform,
41
+ CompositeAudioWaveformTransform,
42
+ config,
43
+ )
44
+
45
+ def __call__(self, x, sample_rate):
46
+ for t in self.transforms:
47
+ x, sample_rate = t(x, sample_rate)
48
+ return x, sample_rate
modules/voice_conversion/fairseq/data/audio/waveform_transforms/noiseaugment.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ from math import ceil
4
+
5
+ from fairseq.data.audio import rand_uniform
6
+ from fairseq.data.audio.waveform_transforms import (
7
+ AudioWaveformTransform,
8
+ register_audio_waveform_transform,
9
+ )
10
+
11
+ SNR_MIN = 5.0
12
+ SNR_MAX = 15.0
13
+ RATE = 0.25
14
+
15
+ NOISE_RATE = 1.0
16
+ NOISE_LEN_MEAN = 0.2
17
+ NOISE_LEN_STD = 0.05
18
+
19
+
20
+ class NoiseAugmentTransform(AudioWaveformTransform):
21
+ @classmethod
22
+ def from_config_dict(cls, config=None):
23
+ _config = {} if config is None else config
24
+ return cls(
25
+ _config.get("samples_path", None),
26
+ _config.get("snr_min", SNR_MIN),
27
+ _config.get("snr_max", SNR_MAX),
28
+ _config.get("rate", RATE),
29
+ )
30
+
31
+ def __init__(
32
+ self,
33
+ samples_path: str,
34
+ snr_min: float = SNR_MIN,
35
+ snr_max: float = SNR_MAX,
36
+ rate: float = RATE,
37
+ ):
38
+ # Sanity checks
39
+ assert (
40
+ samples_path
41
+ ), "need to provide path to audio samples for noise augmentation"
42
+ assert snr_max >= snr_min, f"empty signal-to-noise range ({snr_min}, {snr_max})"
43
+ assert rate >= 0 and rate <= 1, "rate should be a float between 0 to 1"
44
+
45
+ self.paths = list(Path(samples_path).glob("**/*.wav")) # load music
46
+ self.n_samples = len(self.paths)
47
+ assert self.n_samples > 0, f"no audio files found in {samples_path}"
48
+
49
+ self.snr_min = snr_min
50
+ self.snr_max = snr_max
51
+ self.rate = rate
52
+
53
+ def __repr__(self):
54
+ return (
55
+ self.__class__.__name__
56
+ + "("
57
+ + ", ".join(
58
+ [
59
+ f"n_samples={self.n_samples}",
60
+ f"snr={self.snr_min}-{self.snr_max}dB",
61
+ f"rate={self.rate}",
62
+ ]
63
+ )
64
+ + ")"
65
+ )
66
+
67
+ def pick_sample(self, goal_shape, always_2d=False, use_sample_rate=None):
68
+ from fairseq.data.audio.audio_utils import get_waveform
69
+
70
+ path = self.paths[np.random.randint(0, self.n_samples)]
71
+ sample = get_waveform(
72
+ path, always_2d=always_2d, output_sample_rate=use_sample_rate
73
+ )[0]
74
+
75
+ # Check dimensions match, else silently skip adding noise to sample
76
+ # NOTE: SHOULD THIS QUIT WITH AN ERROR?
77
+ is_2d = len(goal_shape) == 2
78
+ if len(goal_shape) != sample.ndim or (
79
+ is_2d and goal_shape[0] != sample.shape[0]
80
+ ):
81
+ return np.zeros(goal_shape)
82
+
83
+ # Cut/repeat sample to size
84
+ len_dim = len(goal_shape) - 1
85
+ n_repeat = ceil(goal_shape[len_dim] / sample.shape[len_dim])
86
+ repeated = np.tile(sample, [1, n_repeat] if is_2d else n_repeat)
87
+ start = np.random.randint(0, repeated.shape[len_dim] - goal_shape[len_dim] + 1)
88
+ return (
89
+ repeated[:, start : start + goal_shape[len_dim]]
90
+ if is_2d
91
+ else repeated[start : start + goal_shape[len_dim]]
92
+ )
93
+
94
+ def _mix(self, source, noise, snr):
95
+ get_power = lambda x: np.mean(x**2)
96
+ if get_power(noise):
97
+ scl = np.sqrt(
98
+ get_power(source) / (np.power(10, snr / 10) * get_power(noise))
99
+ )
100
+ else:
101
+ scl = 0
102
+ return 1 * source + scl * noise
103
+
104
+ def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
105
+ return self.pick_sample(goal_shape, always_2d, use_sample_rate)
106
+
107
+ def __call__(self, source, sample_rate):
108
+ if np.random.random() > self.rate:
109
+ return source, sample_rate
110
+
111
+ noise = self._get_noise(
112
+ source.shape, always_2d=True, use_sample_rate=sample_rate
113
+ )
114
+
115
+ return (
116
+ self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)),
117
+ sample_rate,
118
+ )
119
+
120
+
121
+ @register_audio_waveform_transform("musicaugment")
122
+ class MusicAugmentTransform(NoiseAugmentTransform):
123
+ pass
124
+
125
+
126
+ @register_audio_waveform_transform("backgroundnoiseaugment")
127
+ class BackgroundNoiseAugmentTransform(NoiseAugmentTransform):
128
+ pass
129
+
130
+
131
+ @register_audio_waveform_transform("babbleaugment")
132
+ class BabbleAugmentTransform(NoiseAugmentTransform):
133
+ def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
134
+ for i in range(np.random.randint(3, 8)):
135
+ speech = self.pick_sample(goal_shape, always_2d, use_sample_rate)
136
+ if i == 0:
137
+ agg_noise = speech
138
+ else: # SNR scaled by i (how many noise signals already in agg_noise)
139
+ agg_noise = self._mix(agg_noise, speech, i)
140
+ return agg_noise
141
+
142
+
143
+ @register_audio_waveform_transform("sporadicnoiseaugment")
144
+ class SporadicNoiseAugmentTransform(NoiseAugmentTransform):
145
+ @classmethod
146
+ def from_config_dict(cls, config=None):
147
+ _config = {} if config is None else config
148
+ return cls(
149
+ _config.get("samples_path", None),
150
+ _config.get("snr_min", SNR_MIN),
151
+ _config.get("snr_max", SNR_MAX),
152
+ _config.get("rate", RATE),
153
+ _config.get("noise_rate", NOISE_RATE),
154
+ _config.get("noise_len_mean", NOISE_LEN_MEAN),
155
+ _config.get("noise_len_std", NOISE_LEN_STD),
156
+ )
157
+
158
+ def __init__(
159
+ self,
160
+ samples_path: str,
161
+ snr_min: float = SNR_MIN,
162
+ snr_max: float = SNR_MAX,
163
+ rate: float = RATE,
164
+ noise_rate: float = NOISE_RATE, # noises per second
165
+ noise_len_mean: float = NOISE_LEN_MEAN, # length of noises in seconds
166
+ noise_len_std: float = NOISE_LEN_STD,
167
+ ):
168
+ super().__init__(samples_path, snr_min, snr_max, rate)
169
+ self.noise_rate = noise_rate
170
+ self.noise_len_mean = noise_len_mean
171
+ self.noise_len_std = noise_len_std
172
+
173
+ def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
174
+ agg_noise = np.zeros(goal_shape)
175
+ len_dim = len(goal_shape) - 1
176
+ is_2d = len(goal_shape) == 2
177
+
178
+ n_noises = round(self.noise_rate * goal_shape[len_dim] / use_sample_rate)
179
+ start_pointers = [
180
+ round(rand_uniform(0, goal_shape[len_dim])) for _ in range(n_noises)
181
+ ]
182
+
183
+ for start_pointer in start_pointers:
184
+ noise_shape = list(goal_shape)
185
+ len_seconds = np.random.normal(self.noise_len_mean, self.noise_len_std)
186
+ noise_shape[len_dim] = round(max(0, len_seconds) * use_sample_rate)
187
+ end_pointer = start_pointer + noise_shape[len_dim]
188
+ if end_pointer >= goal_shape[len_dim]:
189
+ continue
190
+
191
+ noise = self.pick_sample(noise_shape, always_2d, use_sample_rate)
192
+ if is_2d:
193
+ agg_noise[:, start_pointer:end_pointer] = (
194
+ agg_noise[:, start_pointer:end_pointer] + noise
195
+ )
196
+ else:
197
+ agg_noise[start_pointer:end_pointer] = (
198
+ agg_noise[start_pointer:end_pointer] + noise
199
+ )
200
+
201
+ return agg_noise
modules/voice_conversion/fairseq/data/backtranslation_dataset.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from fairseq import utils
8
+
9
+ from . import FairseqDataset
10
+
11
+
12
+ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
13
+ """Backtranslate a list of samples.
14
+
15
+ Given an input (*samples*) of the form:
16
+
17
+ [{'id': 1, 'source': 'hallo welt'}]
18
+
19
+ this will return:
20
+
21
+ [{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
22
+
23
+ Args:
24
+ samples (List[dict]): samples to backtranslate. Individual samples are
25
+ expected to have a 'source' key, which will become the 'target'
26
+ after backtranslation.
27
+ collate_fn (callable): function to collate samples into a mini-batch
28
+ generate_fn (callable): function to generate backtranslations
29
+ cuda (bool): use GPU for generation (default: ``True``)
30
+
31
+ Returns:
32
+ List[dict]: an updated list of samples with a backtranslated source
33
+ """
34
+ collated_samples = collate_fn(samples)
35
+ s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
36
+ generated_sources = generate_fn(s)
37
+
38
+ id_to_src = {sample["id"]: sample["source"] for sample in samples}
39
+
40
+ # Go through each tgt sentence in batch and its corresponding best
41
+ # generated hypothesis and create a backtranslation data pair
42
+ # {id: id, source: generated backtranslation, target: original tgt}
43
+ return [
44
+ {
45
+ "id": id.item(),
46
+ "target": id_to_src[id.item()],
47
+ "source": hypos[0]["tokens"].cpu(),
48
+ }
49
+ for id, hypos in zip(collated_samples["id"], generated_sources)
50
+ ]
51
+
52
+
53
+ class BacktranslationDataset(FairseqDataset):
54
+ """
55
+ Sets up a backtranslation dataset which takes a tgt batch, generates
56
+ a src using a tgt-src backtranslation function (*backtranslation_fn*),
57
+ and returns the corresponding `{generated src, input tgt}` batch.
58
+
59
+ Args:
60
+ tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
61
+ backtranslated. Only the source side of this dataset will be used.
62
+ After backtranslation, the source sentences in this dataset will be
63
+ returned as the targets.
64
+ src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
65
+ sentences.
66
+ tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
67
+ sentences to be backtranslated.
68
+ backtranslation_fn (callable, optional): function to call to generate
69
+ backtranslations. This is typically the `generate` method of a
70
+ :class:`~fairseq.sequence_generator.SequenceGenerator` object.
71
+ Pass in None when it is not available at initialization time, and
72
+ use set_backtranslation_fn function to set it when available.
73
+ output_collater (callable, optional): function to call on the
74
+ backtranslated samples to create the final batch
75
+ (default: ``tgt_dataset.collater``).
76
+ cuda: use GPU for generation
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ tgt_dataset,
82
+ src_dict,
83
+ tgt_dict=None,
84
+ backtranslation_fn=None,
85
+ output_collater=None,
86
+ cuda=True,
87
+ **kwargs
88
+ ):
89
+ self.tgt_dataset = tgt_dataset
90
+ self.backtranslation_fn = backtranslation_fn
91
+ self.output_collater = (
92
+ output_collater if output_collater is not None else tgt_dataset.collater
93
+ )
94
+ self.cuda = cuda if torch.cuda.is_available() else False
95
+ self.src_dict = src_dict
96
+ self.tgt_dict = tgt_dict
97
+
98
+ def __getitem__(self, index):
99
+ """
100
+ Returns a single sample from *tgt_dataset*. Note that backtranslation is
101
+ not applied in this step; use :func:`collater` instead to backtranslate
102
+ a batch of samples.
103
+ """
104
+ return self.tgt_dataset[index]
105
+
106
+ def __len__(self):
107
+ return len(self.tgt_dataset)
108
+
109
+ def set_backtranslation_fn(self, backtranslation_fn):
110
+ self.backtranslation_fn = backtranslation_fn
111
+
112
+ def collater(self, samples):
113
+ """Merge and backtranslate a list of samples to form a mini-batch.
114
+
115
+ Using the samples from *tgt_dataset*, load a collated target sample to
116
+ feed to the backtranslation model. Then take the backtranslation with
117
+ the best score as the source and the original input as the target.
118
+
119
+ Note: we expect *tgt_dataset* to provide a function `collater()` that
120
+ will collate samples into the format expected by *backtranslation_fn*.
121
+ After backtranslation, we will feed the new list of samples (i.e., the
122
+ `(backtranslated source, original source)` pairs) to *output_collater*
123
+ and return the result.
124
+
125
+ Args:
126
+ samples (List[dict]): samples to backtranslate and collate
127
+
128
+ Returns:
129
+ dict: a mini-batch with keys coming from *output_collater*
130
+ """
131
+ if samples[0].get("is_dummy", False):
132
+ return samples
133
+ samples = backtranslate_samples(
134
+ samples=samples,
135
+ collate_fn=self.tgt_dataset.collater,
136
+ generate_fn=(lambda net_input: self.backtranslation_fn(net_input)),
137
+ cuda=self.cuda,
138
+ )
139
+ return self.output_collater(samples)
140
+
141
+ def num_tokens(self, index):
142
+ """Just use the tgt dataset num_tokens"""
143
+ return self.tgt_dataset.num_tokens(index)
144
+
145
+ def ordered_indices(self):
146
+ """Just use the tgt dataset ordered_indices"""
147
+ return self.tgt_dataset.ordered_indices()
148
+
149
+ def size(self, index):
150
+ """Return an example's size as a float or tuple. This value is used
151
+ when filtering a dataset with ``--max-positions``.
152
+
153
+ Note: we use *tgt_dataset* to approximate the length of the source
154
+ sentence, since we do not know the actual length until after
155
+ backtranslation.
156
+ """
157
+ tgt_size = self.tgt_dataset.size(index)[0]
158
+ return (tgt_size, tgt_size)
159
+
160
+ @property
161
+ def supports_prefetch(self):
162
+ return getattr(self.tgt_dataset, "supports_prefetch", False)
163
+
164
+ def prefetch(self, indices):
165
+ return self.tgt_dataset.prefetch(indices)
modules/voice_conversion/fairseq/data/base_wrapper_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch.utils.data.dataloader import default_collate
7
+
8
+ from . import FairseqDataset
9
+
10
+
11
+ class BaseWrapperDataset(FairseqDataset):
12
+ def __init__(self, dataset):
13
+ super().__init__()
14
+ self.dataset = dataset
15
+
16
+ def __getitem__(self, index):
17
+ return self.dataset[index]
18
+
19
+ def __len__(self):
20
+ return len(self.dataset)
21
+
22
+ def collater(self, samples):
23
+ if hasattr(self.dataset, "collater"):
24
+ return self.dataset.collater(samples)
25
+ else:
26
+ return default_collate(samples)
27
+
28
+ @property
29
+ def sizes(self):
30
+ return self.dataset.sizes
31
+
32
+ def num_tokens(self, index):
33
+ return self.dataset.num_tokens(index)
34
+
35
+ def size(self, index):
36
+ return self.dataset.size(index)
37
+
38
+ def ordered_indices(self):
39
+ return self.dataset.ordered_indices()
40
+
41
+ @property
42
+ def supports_prefetch(self):
43
+ return getattr(self.dataset, "supports_prefetch", False)
44
+
45
+ def attr(self, attr: str, index: int):
46
+ return self.dataset.attr(attr, index)
47
+
48
+ def prefetch(self, indices):
49
+ self.dataset.prefetch(indices)
50
+
51
+ def get_batch_shapes(self):
52
+ return self.dataset.get_batch_shapes()
53
+
54
+ def batch_by_size(
55
+ self,
56
+ indices,
57
+ max_tokens=None,
58
+ max_sentences=None,
59
+ required_batch_size_multiple=1,
60
+ ):
61
+ return self.dataset.batch_by_size(
62
+ indices,
63
+ max_tokens=max_tokens,
64
+ max_sentences=max_sentences,
65
+ required_batch_size_multiple=required_batch_size_multiple,
66
+ )
67
+
68
+ def filter_indices_by_size(self, indices, max_sizes):
69
+ return self.dataset.filter_indices_by_size(indices, max_sizes)
70
+
71
+ @property
72
+ def can_reuse_epoch_itr_across_epochs(self):
73
+ return self.dataset.can_reuse_epoch_itr_across_epochs
74
+
75
+ def set_epoch(self, epoch):
76
+ super().set_epoch(epoch)
77
+ if hasattr(self.dataset, "set_epoch"):
78
+ self.dataset.set_epoch(epoch)
modules/voice_conversion/fairseq/data/bucket_pad_length_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+ from fairseq.data import BaseWrapperDataset
9
+ from fairseq.data.data_utils import get_buckets, get_bucketed_sizes
10
+
11
+
12
+ class BucketPadLengthDataset(BaseWrapperDataset):
13
+ """
14
+ Bucket and pad item lengths to the nearest bucket size. This can be used to
15
+ reduce the number of unique batch shapes, which is important on TPUs since
16
+ each new batch shape requires a recompilation.
17
+
18
+ Args:
19
+ dataset (FairseqDatset): dataset to bucket
20
+ sizes (List[int]): all item sizes
21
+ num_buckets (int): number of buckets to create
22
+ pad_idx (int): padding symbol
23
+ left_pad (bool): if True, pad on the left; otherwise right pad
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dataset,
29
+ sizes,
30
+ num_buckets,
31
+ pad_idx,
32
+ left_pad,
33
+ tensor_key=None,
34
+ ):
35
+ super().__init__(dataset)
36
+ self.pad_idx = pad_idx
37
+ self.left_pad = left_pad
38
+
39
+ assert num_buckets > 0
40
+ self.buckets = get_buckets(sizes, num_buckets)
41
+ self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
42
+ self._tensor_key = tensor_key
43
+
44
+ def _set_tensor(self, item, val):
45
+ if self._tensor_key is None:
46
+ return val
47
+ item[self._tensor_key] = val
48
+ return item
49
+
50
+ def _get_tensor(self, item):
51
+ if self._tensor_key is None:
52
+ return item
53
+ return item[self._tensor_key]
54
+
55
+ def _pad(self, tensor, bucket_size, dim=-1):
56
+ num_pad = bucket_size - tensor.size(dim)
57
+ return F.pad(
58
+ tensor,
59
+ (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
60
+ value=self.pad_idx,
61
+ )
62
+
63
+ def __getitem__(self, index):
64
+ item = self.dataset[index]
65
+ bucket_size = self._bucketed_sizes[index]
66
+ tensor = self._get_tensor(item)
67
+ padded = self._pad(tensor, bucket_size)
68
+ return self._set_tensor(item, padded)
69
+
70
+ @property
71
+ def sizes(self):
72
+ return self._bucketed_sizes
73
+
74
+ def num_tokens(self, index):
75
+ return self._bucketed_sizes[index]
76
+
77
+ def size(self, index):
78
+ return self._bucketed_sizes[index]
modules/voice_conversion/fairseq/data/codedataset.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ import random
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.utils.data
16
+
17
+ from . import data_utils
18
+ from fairseq.data.fairseq_dataset import FairseqDataset
19
+
20
+ F0_FRAME_SPACE = 0.005 # sec
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class ExpressiveCodeDataConfig(object):
27
+ def __init__(self, json_path):
28
+ with open(json_path, "r") as f:
29
+ self.config = json.load(f)
30
+ self._manifests = self.config["manifests"]
31
+
32
+ @property
33
+ def manifests(self):
34
+ return self._manifests
35
+
36
+ @property
37
+ def n_units(self):
38
+ return self.config["n_units"]
39
+
40
+ @property
41
+ def sampling_rate(self):
42
+ return self.config["sampling_rate"]
43
+
44
+ @property
45
+ def code_hop_size(self):
46
+ return self.config["code_hop_size"]
47
+
48
+ @property
49
+ def f0_stats(self):
50
+ """pre-computed f0 statistics path"""
51
+ return self.config.get("f0_stats", None)
52
+
53
+ @property
54
+ def f0_vq_type(self):
55
+ """naive or precomp"""
56
+ return self.config["f0_vq_type"]
57
+
58
+ @property
59
+ def f0_vq_name(self):
60
+ return self.config["f0_vq_name"]
61
+
62
+ def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std):
63
+ key = "log" if log else "linear"
64
+ if norm_mean and norm_std:
65
+ key += "_mean_std_norm"
66
+ elif norm_mean:
67
+ key += "_mean_norm"
68
+ else:
69
+ key += "_none_norm"
70
+ return self.config["f0_vq_naive_quantizer"][key]
71
+
72
+ @property
73
+ def f0_vq_n_units(self):
74
+ return self.config["f0_vq_n_units"]
75
+
76
+ @property
77
+ def multispkr(self):
78
+ """how to parse speaker label from audio path"""
79
+ return self.config.get("multispkr", None)
80
+
81
+
82
+ def get_f0(audio, rate=16000):
83
+ try:
84
+ import amfm_decompy.basic_tools as basic
85
+ import amfm_decompy.pYAAPT as pYAAPT
86
+ from librosa.util import normalize
87
+ except ImportError:
88
+ raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)."
89
+
90
+ assert audio.ndim == 1
91
+ frame_length = 20.0 # ms
92
+ to_pad = int(frame_length / 1000 * rate) // 2
93
+
94
+ audio = normalize(audio) * 0.95
95
+ audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0)
96
+ audio = basic.SignalObj(audio, rate)
97
+ pitch = pYAAPT.yaapt(
98
+ audio,
99
+ frame_length=frame_length,
100
+ frame_space=F0_FRAME_SPACE * 1000,
101
+ nccf_thresh1=0.25,
102
+ tda_frame_length=25.0,
103
+ )
104
+ f0 = pitch.samp_values
105
+ return f0
106
+
107
+
108
+ def interpolate_f0(f0):
109
+ try:
110
+ from scipy.interpolate import interp1d
111
+ except ImportError:
112
+ raise "Please install scipy (`pip install scipy`)"
113
+
114
+ orig_t = np.arange(f0.shape[0])
115
+ f0_interp = f0[:]
116
+ ii = f0_interp != 0
117
+ if ii.sum() > 1:
118
+ f0_interp = interp1d(
119
+ orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0
120
+ )(orig_t)
121
+ f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device)
122
+ return f0_interp
123
+
124
+
125
+ def naive_quantize(x, edges):
126
+ bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1)
127
+ return bin_idx
128
+
129
+
130
+ def load_wav(full_path):
131
+ try:
132
+ import soundfile as sf
133
+ except ImportError:
134
+ raise "Please install soundfile (`pip install SoundFile`)"
135
+ data, sampling_rate = sf.read(full_path)
136
+ return data, sampling_rate
137
+
138
+
139
+ def parse_code(code_str, dictionary, append_eos):
140
+ code, duration = torch.unique_consecutive(
141
+ torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True
142
+ )
143
+ code = " ".join(map(str, code.tolist()))
144
+ code = dictionary.encode_line(code, append_eos).short()
145
+
146
+ if append_eos:
147
+ duration = torch.cat((duration, duration.new_zeros((1,))), dim=0) # eos
148
+ duration = duration.short()
149
+ return code, duration
150
+
151
+
152
+ def parse_manifest(manifest, dictionary):
153
+ audio_files = []
154
+ codes = []
155
+ durations = []
156
+ speakers = []
157
+
158
+ with open(manifest) as info:
159
+ for line in info.readlines():
160
+ sample = eval(line.strip())
161
+ if "cpc_km100" in sample:
162
+ k = "cpc_km100"
163
+ elif "hubert_km100" in sample:
164
+ k = "hubert_km100"
165
+ elif "phone" in sample:
166
+ k = "phone"
167
+ else:
168
+ assert False, "unknown format"
169
+ code = sample[k]
170
+ code, duration = parse_code(code, dictionary, append_eos=True)
171
+
172
+ codes.append(code)
173
+ durations.append(duration)
174
+ audio_files.append(sample["audio"])
175
+ speakers.append(sample.get("speaker", None))
176
+
177
+ return audio_files, codes, durations, speakers
178
+
179
+
180
+ def parse_speaker(path, method):
181
+ if type(path) == str:
182
+ path = Path(path)
183
+
184
+ if method == "parent_name":
185
+ return path.parent.name
186
+ elif method == "parent_parent_name":
187
+ return path.parent.parent.name
188
+ elif method == "_":
189
+ return path.name.split("_")[0]
190
+ elif method == "single":
191
+ return "A"
192
+ elif callable(method):
193
+ return method(path)
194
+ else:
195
+ raise NotImplementedError()
196
+
197
+
198
+ def get_f0_by_filename(filename, tgt_sampling_rate):
199
+ audio, sampling_rate = load_wav(filename)
200
+ if sampling_rate != tgt_sampling_rate:
201
+ raise ValueError(
202
+ "{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate)
203
+ )
204
+
205
+ # compute un-interpolated f0, and use Ann's interp in __getitem__ if set
206
+ f0 = get_f0(audio, rate=tgt_sampling_rate)
207
+ f0 = torch.from_numpy(f0.astype(np.float32))
208
+ return f0
209
+
210
+
211
+ def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1):
212
+ code_len = durations.sum()
213
+ targ_len = int(f0_code_ratio * code_len)
214
+ diff = f0.size(0) - targ_len
215
+ assert abs(diff) <= tol, (
216
+ f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|"
217
+ f" > {tol} (dur=\n{durations})"
218
+ )
219
+ if diff > 0:
220
+ f0 = f0[:targ_len]
221
+ elif diff < 0:
222
+ f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0)
223
+
224
+ f0_offset = 0.0
225
+ seg_f0s = []
226
+ for dur in durations:
227
+ f0_dur = dur.item() * f0_code_ratio
228
+ seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)]
229
+ seg_f0 = seg_f0[seg_f0 != 0]
230
+ if len(seg_f0) == 0:
231
+ seg_f0 = torch.tensor(0).type(seg_f0.type())
232
+ else:
233
+ seg_f0 = seg_f0.mean()
234
+ seg_f0s.append(seg_f0)
235
+ f0_offset += f0_dur
236
+
237
+ assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}"
238
+ return torch.tensor(seg_f0s)
239
+
240
+
241
+ class Paddings(object):
242
+ def __init__(self, code_val, dur_val=0, f0_val=-2.0):
243
+ self.code = code_val
244
+ self.dur = dur_val
245
+ self.f0 = f0_val
246
+
247
+
248
+ class Shifts(object):
249
+ def __init__(self, shifts_str, pads):
250
+ self._shifts = list(map(int, shifts_str.split(",")))
251
+ assert len(self._shifts) == 2, self._shifts
252
+ assert all(s >= 0 for s in self._shifts)
253
+ self.extra_length = max(s for s in self._shifts)
254
+ self.pads = pads
255
+
256
+ @property
257
+ def dur(self):
258
+ return self._shifts[0]
259
+
260
+ @property
261
+ def f0(self):
262
+ return self._shifts[1]
263
+
264
+ @staticmethod
265
+ def shift_one(seq, left_pad_num, right_pad_num, pad):
266
+ assert seq.ndim == 1
267
+ bos = seq.new_full((left_pad_num,), pad)
268
+ eos = seq.new_full((right_pad_num,), pad)
269
+ seq = torch.cat([bos, seq, eos])
270
+ mask = torch.ones_like(seq).bool()
271
+ mask[left_pad_num : len(seq) - right_pad_num] = 0
272
+ return seq, mask
273
+
274
+ def __call__(self, code, dur, f0):
275
+ if self.extra_length == 0:
276
+ code_mask = torch.zeros_like(code).bool()
277
+ dur_mask = torch.zeros_like(dur).bool()
278
+ f0_mask = torch.zeros_like(f0).bool()
279
+ return code, code_mask, dur, dur_mask, f0, f0_mask
280
+
281
+ code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code)
282
+ dur, dur_mask = self.shift_one(
283
+ dur, self.dur, self.extra_length - self.dur, self.pads.dur
284
+ )
285
+ f0, f0_mask = self.shift_one(
286
+ f0, self.f0, self.extra_length - self.f0, self.pads.f0
287
+ )
288
+ return code, code_mask, dur, dur_mask, f0, f0_mask
289
+
290
+
291
+ class CodeDataset(FairseqDataset):
292
+ def __init__(
293
+ self,
294
+ manifest,
295
+ dictionary,
296
+ dur_dictionary,
297
+ f0_dictionary,
298
+ config,
299
+ discrete_dur,
300
+ discrete_f0,
301
+ log_f0,
302
+ normalize_f0_mean,
303
+ normalize_f0_std,
304
+ interpolate_f0,
305
+ return_filename=False,
306
+ strip_filename=True,
307
+ shifts="0,0",
308
+ return_continuous_f0=False,
309
+ ):
310
+ random.seed(1234)
311
+ self.dictionary = dictionary
312
+ self.dur_dictionary = dur_dictionary
313
+ self.f0_dictionary = f0_dictionary
314
+ self.config = config
315
+
316
+ # duration config
317
+ self.discrete_dur = discrete_dur
318
+
319
+ # pitch config
320
+ self.discrete_f0 = discrete_f0
321
+ self.log_f0 = log_f0
322
+ self.normalize_f0_mean = normalize_f0_mean
323
+ self.normalize_f0_std = normalize_f0_std
324
+ self.interpolate_f0 = interpolate_f0
325
+
326
+ self.return_filename = return_filename
327
+ self.strip_filename = strip_filename
328
+ self.f0_code_ratio = config.code_hop_size / (
329
+ config.sampling_rate * F0_FRAME_SPACE
330
+ )
331
+
332
+ # use lazy loading to avoid sharing file handlers across workers
333
+ self.manifest = manifest
334
+ self._codes = None
335
+ self._durs = None
336
+ self._f0s = None
337
+ with open(f"{manifest}.leng.txt", "r") as f:
338
+ lengs = [int(line.rstrip()) for line in f]
339
+ edges = np.cumsum([0] + lengs)
340
+ self.starts, self.ends = edges[:-1], edges[1:]
341
+ with open(f"{manifest}.path.txt", "r") as f:
342
+ self.file_names = [line.rstrip() for line in f]
343
+ logger.info(f"num entries: {len(self.starts)}")
344
+
345
+ if os.path.exists(f"{manifest}.f0_stat.pt"):
346
+ self.f0_stats = torch.load(f"{manifest}.f0_stat.pt")
347
+ elif config.f0_stats:
348
+ self.f0_stats = torch.load(config.f0_stats)
349
+
350
+ self.multispkr = config.multispkr
351
+ if config.multispkr:
352
+ with open(f"{manifest}.speaker.txt", "r") as f:
353
+ self.spkrs = [line.rstrip() for line in f]
354
+ self.id_to_spkr = sorted(self.spkrs)
355
+ self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)}
356
+
357
+ self.pads = Paddings(
358
+ dictionary.pad(),
359
+ 0, # use 0 for duration padding
360
+ f0_dictionary.pad() if discrete_f0 else -5.0,
361
+ )
362
+ self.shifts = Shifts(shifts, pads=self.pads)
363
+ self.return_continuous_f0 = return_continuous_f0
364
+
365
+ def get_data_handlers(self):
366
+ logging.info(f"loading data for {self.manifest}")
367
+ self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r")
368
+ self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r")
369
+
370
+ if self.discrete_f0:
371
+ if self.config.f0_vq_type == "precomp":
372
+ self._f0s = np.load(
373
+ f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r"
374
+ )
375
+ elif self.config.f0_vq_type == "naive":
376
+ self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
377
+ quantizers_path = self.config.get_f0_vq_naive_quantizer(
378
+ self.log_f0, self.normalize_f0_mean, self.normalize_f0_std
379
+ )
380
+ quantizers = torch.load(quantizers_path)
381
+ n_units = self.config.f0_vq_n_units
382
+ self._f0_quantizer = torch.from_numpy(quantizers[n_units])
383
+ else:
384
+ raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported")
385
+ else:
386
+ self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
387
+
388
+ def preprocess_f0(self, f0, stats):
389
+ """
390
+ 1. interpolate
391
+ 2. log transform (keep unvoiced frame 0)
392
+ """
393
+ # TODO: change this to be dependent on config for naive quantizer
394
+ f0 = f0.clone()
395
+ if self.interpolate_f0:
396
+ f0 = interpolate_f0(f0)
397
+
398
+ mask = f0 != 0 # only process voiced frames
399
+ if self.log_f0:
400
+ f0[mask] = f0[mask].log()
401
+ if self.normalize_f0_mean:
402
+ mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"]
403
+ f0[mask] = f0[mask] - mean
404
+ if self.normalize_f0_std:
405
+ std = stats["logf0_std"] if self.log_f0 else stats["f0_std"]
406
+ f0[mask] = f0[mask] / std
407
+ return f0
408
+
409
+ def _get_raw_item(self, index):
410
+ start, end = self.starts[index], self.ends[index]
411
+ if self._codes is None:
412
+ self.get_data_handlers()
413
+ code = torch.from_numpy(np.array(self._codes[start:end])).long()
414
+ dur = torch.from_numpy(np.array(self._durs[start:end]))
415
+ f0 = torch.from_numpy(np.array(self._f0s[start:end]))
416
+ return code, dur, f0
417
+
418
+ def __getitem__(self, index):
419
+ code, dur, f0 = self._get_raw_item(index)
420
+ code = torch.cat([code.new([self.dictionary.bos()]), code])
421
+
422
+ # use 0 for eos and bos
423
+ dur = torch.cat([dur.new([0]), dur])
424
+ if self.discrete_dur:
425
+ dur = self.dur_dictionary.encode_line(
426
+ " ".join(map(str, dur.tolist())), append_eos=False
427
+ ).long()
428
+ else:
429
+ dur = dur.float()
430
+
431
+ # TODO: find a more elegant approach
432
+ raw_f0 = None
433
+ if self.discrete_f0:
434
+ if self.config.f0_vq_type == "precomp":
435
+ f0 = self.f0_dictionary.encode_line(
436
+ " ".join(map(str, f0.tolist())), append_eos=False
437
+ ).long()
438
+ else:
439
+ f0 = f0.float()
440
+ f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
441
+ if self.return_continuous_f0:
442
+ raw_f0 = f0
443
+ raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0])
444
+ f0 = naive_quantize(f0, self._f0_quantizer)
445
+ f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0])
446
+ else:
447
+ f0 = f0.float()
448
+ if self.multispkr:
449
+ f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
450
+ else:
451
+ f0 = self.preprocess_f0(f0, self.f0_stats)
452
+ f0 = torch.cat([f0.new([0]), f0])
453
+
454
+ if raw_f0 is not None:
455
+ *_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0)
456
+ else:
457
+ raw_f0_mask = None
458
+
459
+ code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0)
460
+ if raw_f0_mask is not None:
461
+ assert (raw_f0_mask == f0_mask).all()
462
+
463
+ # is a padded frame if either input or output is padded
464
+ feats = {
465
+ "source": code[:-1],
466
+ "target": code[1:],
467
+ "mask": code_mask[1:].logical_or(code_mask[:-1]),
468
+ "dur_source": dur[:-1],
469
+ "dur_target": dur[1:],
470
+ "dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]),
471
+ "f0_source": f0[:-1],
472
+ "f0_target": f0[1:],
473
+ "f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]),
474
+ }
475
+
476
+ if raw_f0 is not None:
477
+ feats["raw_f0"] = raw_f0[1:]
478
+
479
+ if self.return_filename:
480
+ fname = self.file_names[index]
481
+ feats["filename"] = (
482
+ fname if not self.strip_filename else Path(fname).with_suffix("").name
483
+ )
484
+ return feats
485
+
486
+ def __len__(self):
487
+ return len(self.starts)
488
+
489
+ def size(self, index):
490
+ return self.ends[index] - self.starts[index] + self.shifts.extra_length
491
+
492
+ def num_tokens(self, index):
493
+ return self.size(index)
494
+
495
+ def collater(self, samples):
496
+ pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos()
497
+ if len(samples) == 0:
498
+ return {}
499
+
500
+ src_tokens = data_utils.collate_tokens(
501
+ [s["source"] for s in samples], pad_idx, eos_idx, left_pad=False
502
+ )
503
+
504
+ tgt_tokens = data_utils.collate_tokens(
505
+ [s["target"] for s in samples],
506
+ pad_idx=pad_idx,
507
+ eos_idx=pad_idx, # appending padding, eos is there already
508
+ left_pad=False,
509
+ )
510
+
511
+ src_durs, tgt_durs = [
512
+ data_utils.collate_tokens(
513
+ [s[k] for s in samples],
514
+ pad_idx=self.pads.dur,
515
+ eos_idx=self.pads.dur,
516
+ left_pad=False,
517
+ )
518
+ for k in ["dur_source", "dur_target"]
519
+ ]
520
+
521
+ src_f0s, tgt_f0s = [
522
+ data_utils.collate_tokens(
523
+ [s[k] for s in samples],
524
+ pad_idx=self.pads.f0,
525
+ eos_idx=self.pads.f0,
526
+ left_pad=False,
527
+ )
528
+ for k in ["f0_source", "f0_target"]
529
+ ]
530
+
531
+ mask, dur_mask, f0_mask = [
532
+ data_utils.collate_tokens(
533
+ [s[k] for s in samples],
534
+ pad_idx=1,
535
+ eos_idx=1,
536
+ left_pad=False,
537
+ )
538
+ for k in ["mask", "dur_mask", "f0_mask"]
539
+ ]
540
+
541
+ src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
542
+ n_tokens = sum(len(s["source"]) for s in samples)
543
+
544
+ result = {
545
+ "nsentences": len(samples),
546
+ "ntokens": n_tokens,
547
+ "net_input": {
548
+ "src_tokens": src_tokens,
549
+ "src_lengths": src_lengths,
550
+ "dur_src": src_durs,
551
+ "f0_src": src_f0s,
552
+ },
553
+ "target": tgt_tokens,
554
+ "dur_target": tgt_durs,
555
+ "f0_target": tgt_f0s,
556
+ "mask": mask,
557
+ "dur_mask": dur_mask,
558
+ "f0_mask": f0_mask,
559
+ }
560
+
561
+ if "filename" in samples[0]:
562
+ result["filename"] = [s["filename"] for s in samples]
563
+
564
+ # TODO: remove this hack into the inference dataset
565
+ if "prefix" in samples[0]:
566
+ result["prefix"] = [s["prefix"] for s in samples]
567
+
568
+ if "raw_f0" in samples[0]:
569
+ raw_f0s = data_utils.collate_tokens(
570
+ [s["raw_f0"] for s in samples],
571
+ pad_idx=self.pads.f0,
572
+ eos_idx=self.pads.f0,
573
+ left_pad=False,
574
+ )
575
+ result["raw_f0"] = raw_f0s
576
+ return result