rinflan commited on
Commit
5f84dff
1 Parent(s): 7d24597

Upload 17 files

Browse files
Files changed (15) hide show
  1. .gitignore +155 -0
  2. Dockerfile +16 -0
  3. LICENSE +201 -0
  4. README.en.md +151 -0
  5. flask_api.py +63 -0
  6. inference.py +425 -0
  7. inference_svs.py +237 -0
  8. inference_vst.py +217 -0
  9. poetry.lock +0 -0
  10. poetry.toml +2 -0
  11. pyproject.toml +51 -0
  12. requirements.txt +103 -0
  13. train.py +228 -0
  14. tst +1 -0
  15. 开始处理.bat +4 -0
.gitignore ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Text tool
55
+ tools/text/create_symbol_dict.py
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
133
+
134
+ data
135
+ dataset
136
+ .vscode
137
+
138
+ *.pt
139
+ *.pth
140
+ hifigan/model
141
+ output
142
+ lightning_logs
143
+ logs
144
+ wandb
145
+ *.ckpt
146
+ checkpoints
147
+ filelists
148
+ raw
149
+ results
150
+
151
+ configs/exp_*.py
152
+ exp_*.sh
153
+ .DS_Store
154
+ .vscode
155
+ exported
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.0-cudnn8-devel-ubuntu22.04 AS fish-diffusion
2
+
3
+ # Install Poetry
4
+ RUN apt-get update && apt-get install -y git curl python3 python3-pip build-essential ffmpeg libsm6 libxext6
5
+ RUN curl -sSL https://install.python-poetry.org | python3 -
6
+ ENV PATH="/root/.local/bin:${PATH}"
7
+ RUN poetry config virtualenvs.create false
8
+
9
+ # Install dependencies
10
+ WORKDIR /root
11
+
12
+ RUN pip3 install torch torchvision torchaudio
13
+ RUN git clone https://github.com/fishaudio/fish-diffusion.git && cd fish-diffusion && poetry install
14
+
15
+ WORKDIR /root/fish-diffusion
16
+ RUN python3 tools/download_nsf_hifigan.py --agree-license
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2023] [Fish Audio]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.en.md ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <img alt="LOGO" src="https://cdn.jsdelivr.net/gh/fishaudio/fish-diffusion@main/images/logo_512x512.png" width="256" height="256" />
4
+
5
+ # Fish Diffusion
6
+
7
+ <div>
8
+ <a href="https://github.com/fishaudio/fish-diffusion/actions/workflows/ci.yml">
9
+ <img alt="Build Status" src="https://img.shields.io/github/actions/workflow/status/fishaudio/fish-diffusion/ci.yml?style=flat-square&logo=GitHub">
10
+ </a>
11
+ <a href="https://hub.docker.com/r/lengyue233/fish-diffusion">
12
+ <img alt="Docker Hub" src="https://img.shields.io/docker/cloud/build/lengyue233/fish-diffusion?style=flat-square&logo=Docker&logoColor=white">
13
+ </a>
14
+ <a href="https://discord.gg/wbYSRBrW2E">
15
+ <img alt="Discord" src="https://img.shields.io/discord/1044927142900809739?color=%23738ADB&label=Discord&logo=discord&logoColor=white&style=flat-square">
16
+ </a>
17
+ </div>
18
+
19
+ </div>
20
+
21
+ ------
22
+
23
+ An easy to understand TTS / SVS / SVC training framework.
24
+
25
+ > Check our [Wiki](https://fishaudio.github.io/fish-diffusion/) to get started!
26
+
27
+ [中文文档](README.md)
28
+
29
+ ## Summary
30
+ Using Diffusion Model to solve different voice generating tasks. Compared with the original diffsvc repository, the advantages and disadvantages of this repository are as follows:
31
+ + Support multi-speaker
32
+ + The code structure of this repository is simpler and easier to understand, and all modules are decoupled
33
+ + Support [441khz Diff Singer community vocoder](https://openvpi.github.io/vocoders/)
34
+ + Support multi-machine multi-devices training, support half-precision training, save your training speed and memory
35
+
36
+ ## Preparing the environment
37
+ The following commands need to be executed in the conda environment of python 3.10
38
+
39
+ ```bash
40
+ # Install PyTorch related core dependencies, skip if installed
41
+ # Reference: https://pytorch.org/get-started/locally/
42
+ conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
43
+
44
+ # Install Poetry dependency management tool, skip if installed
45
+ # Reference: https://python-poetry.org/docs/#installation
46
+ curl -sSL https://install.python-poetry.org | python3 -
47
+
48
+ # Install the project dependencies
49
+ poetry install
50
+ ```
51
+
52
+ ## Vocoder preparation
53
+ Fish Diffusion requires the [OPENVPI 441khz NSF-HiFiGAN](https://github.com/openvpi/vocoders/releases/tag/nsf-hifigan-v1) vocoder to generate audio.
54
+
55
+ ### Automatic download
56
+ ```bash
57
+ python tools/download_nsf_hifigan.py
58
+ ```
59
+
60
+ If you are using the script to download the model, you can use the `--agree-license` parameter to agree to the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) license.
61
+
62
+ ```bash
63
+ python tools/download_nsf_hifigan.py --agree-license
64
+ ```
65
+
66
+ ### Manual download
67
+ Download and unzip `nsf_hifigan_20221211.zip` from [441khz vocoder](https://github.com/openvpi/vocoders/releases/tag/nsf-hifigan-v1)
68
+
69
+ Copy the `nsf_hifigan` folder to the `checkpoints` directory (create if not exist)
70
+
71
+ ## Dataset preparation
72
+ You only need to put the dataset into the `dataset` directory in the following file structure
73
+
74
+ ```shell
75
+ dataset
76
+ ├───train
77
+ │ ├───xxx1-xxx1.wav
78
+ │ ├───...
79
+ │ ├───Lxx-0xx8.wav
80
+ │ └───speaker0 (Subdirectory is also supported)
81
+ │ └───xxx1-xxx1.wav
82
+ └───valid
83
+ ├───xx2-0xxx2.wav
84
+ ├───...
85
+ └───xxx7-xxx007.wav
86
+ ```
87
+
88
+ ```bash
89
+ # Extract all data features, such as pitch, text features, mel features, etc.
90
+ python tools/preprocessing/extract_features.py --config configs/svc_hubert_soft.py --path dataset --clean
91
+ ```
92
+
93
+ ## Baseline training
94
+ > The project is under active development, please backup your config file
95
+ > The project is under active development, please backup your config file
96
+ > The project is under active development, please backup your config file
97
+
98
+ ```bash
99
+ # Single machine single card / multi-card training
100
+ python train.py --config configs/svc_hubert_soft.py
101
+
102
+ # Resume training
103
+ python train.py --config configs/svc_hubert_soft.py --resume [checkpoint]
104
+
105
+ # Fine-tune the pre-trained model
106
+ # Note: You should adjust the learning rate scheduler in the config file to warmup_cosine_finetune
107
+ python train.py --config configs/svc_hubert_soft.py --pretrained [checkpoint]
108
+ ```
109
+
110
+ ## Inference
111
+ ```bash
112
+ # Inference using shell, you can use --help to view more parameters
113
+ python inference.py --config [config] \
114
+ --checkpoint [checkpoint] \
115
+ --input [input audio] \
116
+ --output [output audio]
117
+
118
+
119
+ # Gradio Web Inference, other parameters will be used as gradio default parameters
120
+ python inference/gradio_inference.py --config [config] \
121
+ --checkpoint [checkpoint] \
122
+ --gradio
123
+ ```
124
+
125
+ ## Convert a DiffSVC model to Fish Diffusion
126
+ ```bash
127
+ python tools/diff_svc_converter.py --config configs/svc_hubert_soft_diff_svc.py \
128
+ --input-path [DiffSVC ckpt] \
129
+ --output-path [Fish Diffusion ckpt]
130
+ ```
131
+
132
+ ## Contributing
133
+ If you have any questions, please submit an issue or pull request.
134
+ You should run `tools/lint.sh` before submitting a pull request.
135
+
136
+ Real-time documentation can be generated by
137
+ ```bash
138
+ sphinx-autobuild docs docs/_build/html
139
+ ```
140
+
141
+ ## Credits
142
+ + [diff-svc original](https://github.com/prophesier/diff-svc)
143
+ + [diff-svc optimized](https://github.com/innnky/diff-svc/)
144
+ + [DiffSinger](https://github.com/openvpi/DiffSinger/)
145
+ + [SpeechSplit](https://github.com/auspicious3000/SpeechSplit)
146
+
147
+ ## Thanks to all contributors for their efforts
148
+
149
+ <a href="https://github.com/fishaudio/fish-diffusion/graphs/contributors" target="_blank">
150
+ <img src="https://contrib.rocks/image?repo=fishaudio/fish-diffusion" />
151
+ </a>
flask_api.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+
4
+ import librosa
5
+ import soundfile
6
+ from flask import Flask, request, send_file
7
+ from flask_cors import CORS
8
+
9
+ #from infer_tools.infer_tool import Svc
10
+ from inference_vst import SvcFish
11
+ #from utils.hparams import hparams
12
+
13
+ app = Flask(__name__)
14
+
15
+ CORS(app)
16
+
17
+ logging.getLogger('numba').setLevel(logging.WARNING)
18
+
19
+
20
+ @app.route("/voiceChangeModel", methods=["POST"])
21
+ def voice_change_model():
22
+ request_form = request.form
23
+ wave_file = request.files.get("sample", None)
24
+ # 变调信息
25
+ f_pitch_change = float(request_form.get("fPitchChange", 0))
26
+ # 获取spkid
27
+ int_speak_Id = int(request_form.get("sSpeakId", 0))
28
+ # DAW所需的采样率
29
+ daw_sample = int(float(request_form.get("sampleRate", 0)))
30
+ # http获得wav文件并转换
31
+ input_wav_path = io.BytesIO(wave_file.read())
32
+ # 模型推理
33
+ _audio, _model_sr = svc_model.infer(input_wav_path, f_pitch_change, int_speak_Id, daw_sample)
34
+ tar_audio = librosa.resample(_audio, _model_sr, daw_sample)
35
+ # 返回音频
36
+ out_wav_path = io.BytesIO()
37
+ soundfile.write(out_wav_path, tar_audio, daw_sample, format="wav")
38
+ out_wav_path.seek(0)
39
+ return send_file(out_wav_path, download_name="temp.wav", as_attachment=True)
40
+
41
+
42
+ if __name__ == '__main__':
43
+ # fish下只需传入下列参数
44
+ checkpoint_path = 'logs/DiffSVC/version_0/checkpoints/epoch=123-step=300000-valid_loss=0.17.ckpt'
45
+ config_path = 'configs/svc_cn_hubert_soft_ms.py'
46
+ # 加速倍率,None即采用配置文件的值
47
+ sampler_interval = None
48
+ # 是否提取人声,是否合成非人声,以及人声响度增益
49
+ extract_vocals = True
50
+ merge_non_vocals = False
51
+ vocals_loudness_gain = 0.0
52
+ # 最大切片时长
53
+ max_slice_duration = 30.0
54
+ # 静音阈值
55
+ silence_threshold = 60
56
+
57
+ svc_model = SvcFish(checkpoint_path, config_path, sampler_interval=sampler_interval,
58
+ extract_vocals=extract_vocals,merge_non_vocals=merge_non_vocals,
59
+ vocals_loudness_gain=vocals_loudness_gain,silence_threshold=silence_threshold,
60
+ max_slice_duration=max_slice_duration)
61
+
62
+ # 此处与vst插件对应,不建议更改
63
+ app.run(port=6842, host="0.0.0.0", debug=False, threaded=False)
inference.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from functools import partial
5
+ from typing import Union
6
+
7
+ import gradio as gr
8
+ import librosa
9
+ import numpy as np
10
+ import soundfile as sf
11
+ import torch
12
+ from fish_audio_preprocess.utils import loudness_norm, separate_audio
13
+ from loguru import logger
14
+ from mmengine import Config
15
+
16
+ from fish_diffusion.feature_extractors import FEATURE_EXTRACTORS, PITCH_EXTRACTORS
17
+ from fish_diffusion.utils.audio import get_mel_from_audio, slice_audio
18
+ from fish_diffusion.utils.inference import load_checkpoint
19
+ from fish_diffusion.utils.tensor import repeat_expand
20
+
21
+
22
+ @torch.no_grad()
23
+ def inference(
24
+ config,
25
+ checkpoint,
26
+ input_path,
27
+ output_path,
28
+ speaker_id=0,
29
+ pitch_adjust=0,
30
+ silence_threshold=30,
31
+ max_slice_duration=5,
32
+ extract_vocals=True,
33
+ merge_non_vocals=True,
34
+ vocals_loudness_gain=0.0,
35
+ sampler_interval=None,
36
+ sampler_progress=False,
37
+ device="cuda",
38
+ gradio_progress=None,
39
+ ):
40
+ """Inference
41
+
42
+ Args:
43
+ config: config
44
+ checkpoint: checkpoint path
45
+ input_path: input path
46
+ output_path: output path
47
+ speaker_id: speaker id
48
+ pitch_adjust: pitch adjust
49
+ silence_threshold: silence threshold of librosa.effects.split
50
+ max_slice_duration: maximum duration of each slice
51
+ extract_vocals: extract vocals
52
+ merge_non_vocals: merge non-vocals, only works when extract_vocals is True
53
+ vocals_loudness_gain: loudness gain of vocals (dB)
54
+ sampler_interval: sampler interval, lower value means higher quality
55
+ sampler_progress: show sampler progress
56
+ device: device
57
+ gradio_progress: gradio progress callback
58
+ """
59
+
60
+ if sampler_interval is not None:
61
+ config.model.diffusion.sampler_interval = sampler_interval
62
+
63
+ if os.path.isdir(checkpoint):
64
+ # Find the latest checkpoint
65
+ checkpoints = sorted(os.listdir(checkpoint))
66
+ logger.info(f"Found {len(checkpoints)} checkpoints, using {checkpoints[-1]}")
67
+ checkpoint = os.path.join(checkpoint, checkpoints[-1])
68
+
69
+ audio, sr = librosa.load(input_path, sr=config.sampling_rate, mono=True)
70
+
71
+ # Extract vocals
72
+
73
+ if extract_vocals:
74
+ logger.info("Extracting vocals...")
75
+
76
+ if gradio_progress is not None:
77
+ gradio_progress(0, "Extracting vocals...")
78
+
79
+ model = separate_audio.init_model("htdemucs", device=device)
80
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=model.samplerate)[None]
81
+
82
+ # To two channels
83
+ audio = np.concatenate([audio, audio], axis=0)
84
+ audio = torch.from_numpy(audio).to(device)
85
+ tracks = separate_audio.separate_audio(
86
+ model, audio, shifts=1, num_workers=0, progress=True
87
+ )
88
+ audio = separate_audio.merge_tracks(tracks, filter=["vocals"]).cpu().numpy()
89
+ non_vocals = (
90
+ separate_audio.merge_tracks(tracks, filter=["drums", "bass", "other"])
91
+ .cpu()
92
+ .numpy()
93
+ )
94
+
95
+ audio = librosa.resample(audio[0], orig_sr=model.samplerate, target_sr=sr)
96
+ non_vocals = librosa.resample(
97
+ non_vocals[0], orig_sr=model.samplerate, target_sr=sr
98
+ )
99
+
100
+ # Normalize loudness
101
+ non_vocals = loudness_norm.loudness_norm(non_vocals, sr)
102
+
103
+ # Normalize loudness
104
+ audio = loudness_norm.loudness_norm(audio, sr)
105
+
106
+ # Slice into segments
107
+ segments = list(
108
+ slice_audio(
109
+ audio, sr, max_duration=max_slice_duration, top_db=silence_threshold
110
+ )
111
+ )
112
+ logger.info(f"Sliced into {len(segments)} segments")
113
+
114
+ # Load models
115
+ text_features_extractor = FEATURE_EXTRACTORS.build(
116
+ config.preprocessing.text_features_extractor
117
+ ).to(device)
118
+ text_features_extractor.eval()
119
+
120
+ model = load_checkpoint(config, checkpoint, device=device)
121
+
122
+ pitch_extractor = PITCH_EXTRACTORS.build(config.preprocessing.pitch_extractor)
123
+ assert pitch_extractor is not None, "Pitch extractor not found"
124
+
125
+ generated_audio = np.zeros_like(audio)
126
+ audio_torch = torch.from_numpy(audio).to(device)[None]
127
+
128
+ for idx, (start, end) in enumerate(segments):
129
+ if gradio_progress is not None:
130
+ gradio_progress(idx / len(segments), "Generating audio...")
131
+
132
+ segment = audio_torch[:, start:end]
133
+ logger.info(
134
+ f"Processing segment {idx + 1}/{len(segments)}, duration: {segment.shape[-1] / sr:.2f}s"
135
+ )
136
+
137
+ # Extract mel
138
+ mel = get_mel_from_audio(segment, sr)
139
+
140
+ # Extract pitch (f0)
141
+ pitch = pitch_extractor(segment, sr, pad_to=mel.shape[-1]).float()
142
+ pitch *= 2 ** (pitch_adjust / 12)
143
+
144
+ # Extract text features
145
+ text_features = text_features_extractor(segment, sr)[0]
146
+ text_features = repeat_expand(text_features, mel.shape[-1]).T
147
+
148
+ # Predict
149
+ src_lens = torch.tensor([mel.shape[-1]]).to(device)
150
+
151
+ features = model.model.forward_features(
152
+ speakers=torch.tensor([speaker_id]).long().to(device),
153
+ contents=text_features[None].to(device),
154
+ src_lens=src_lens,
155
+ max_src_len=max(src_lens),
156
+ mel_lens=src_lens,
157
+ max_mel_len=max(src_lens),
158
+ pitches=pitch[None].to(device),
159
+ )
160
+
161
+ result = model.model.diffusion(features["features"], progress=sampler_progress)
162
+ wav = model.vocoder.spec2wav(result[0].T, f0=pitch).cpu().numpy()
163
+ max_wav_len = generated_audio.shape[-1] - start
164
+ generated_audio[start : start + wav.shape[-1]] = wav[:max_wav_len]
165
+
166
+ # Loudness normalization
167
+ generated_audio = loudness_norm.loudness_norm(generated_audio, sr)
168
+
169
+ # Loudness gain
170
+ loudness_float = 10 ** (vocals_loudness_gain / 20)
171
+ generated_audio = generated_audio * loudness_float
172
+
173
+ # Merge non-vocals
174
+ if extract_vocals and merge_non_vocals:
175
+ generated_audio = (generated_audio + non_vocals) / 2
176
+
177
+ logger.info("Done")
178
+
179
+ if output_path is not None:
180
+ sf.write(output_path, generated_audio, sr)
181
+
182
+ return generated_audio, sr
183
+
184
+
185
+ def parse_args():
186
+ parser = argparse.ArgumentParser()
187
+
188
+ parser.add_argument(
189
+ "--config",
190
+ type=str,
191
+ required=True,
192
+ help="Path to the config file",
193
+ )
194
+
195
+ parser.add_argument(
196
+ "--checkpoint",
197
+ type=str,
198
+ required=True,
199
+ help="Path to the checkpoint file",
200
+ )
201
+
202
+ parser.add_argument(
203
+ "--gradio",
204
+ action="store_true",
205
+ help="Run in gradio mode",
206
+ )
207
+
208
+ parser.add_argument(
209
+ "--gradio_share",
210
+ action="store_true",
211
+ help="Share gradio app",
212
+ )
213
+
214
+ parser.add_argument(
215
+ "--input",
216
+ type=str,
217
+ required=False,
218
+ help="Path to the input audio file",
219
+ )
220
+
221
+ parser.add_argument(
222
+ "--output",
223
+ type=str,
224
+ required=False,
225
+ help="Path to the output audio file",
226
+ )
227
+
228
+ parser.add_argument(
229
+ "--speaker_id",
230
+ type=int,
231
+ default=0,
232
+ help="Speaker id",
233
+ )
234
+
235
+ parser.add_argument(
236
+ "--speaker_mapping",
237
+ type=str,
238
+ default=None,
239
+ help="Speaker mapping file (gradio mode only)",
240
+ )
241
+
242
+ parser.add_argument(
243
+ "--pitch_adjust",
244
+ type=int,
245
+ default=0,
246
+ help="Pitch adjustment in semitones",
247
+ )
248
+
249
+ parser.add_argument(
250
+ "--extract_vocals",
251
+ action="store_true",
252
+ help="Extract vocals",
253
+ )
254
+
255
+ parser.add_argument(
256
+ "--merge_non_vocals",
257
+ action="store_true",
258
+ help="Merge non-vocals",
259
+ )
260
+
261
+ parser.add_argument(
262
+ "--vocals_loudness_gain",
263
+ type=float,
264
+ default=0,
265
+ help="Loudness gain for vocals",
266
+ )
267
+
268
+ parser.add_argument(
269
+ "--sampler_interval",
270
+ type=int,
271
+ default=None,
272
+ required=False,
273
+ help="Sampler interval, if not specified, will be taken from config",
274
+ )
275
+
276
+ parser.add_argument(
277
+ "--sampler_progress",
278
+ action="store_true",
279
+ help="Show sampler progress",
280
+ )
281
+
282
+ parser.add_argument(
283
+ "--device",
284
+ type=str,
285
+ default=None,
286
+ required=False,
287
+ help="Device to use",
288
+ )
289
+
290
+ return parser.parse_args()
291
+
292
+
293
+ def run_inference(
294
+ config_path: str,
295
+ model_path: str,
296
+ input_path: str,
297
+ speaker: Union[int, str],
298
+ pitch_adjust: int,
299
+ sampler_interval: int,
300
+ extract_vocals: bool,
301
+ device: str,
302
+ progress=gr.Progress(),
303
+ speaker_mapping: dict = None,
304
+ ):
305
+ if speaker_mapping is not None and isinstance(speaker, str):
306
+ speaker = speaker_mapping[speaker]
307
+
308
+ audio, sr = inference(
309
+ Config.fromfile(config_path),
310
+ model_path,
311
+ input_path=input_path,
312
+ output_path=None,
313
+ speaker_id=speaker,
314
+ pitch_adjust=pitch_adjust,
315
+ sampler_interval=round(sampler_interval),
316
+ extract_vocals=extract_vocals,
317
+ merge_non_vocals=False,
318
+ device=device,
319
+ gradio_progress=progress,
320
+ )
321
+
322
+ return (sr, audio)
323
+
324
+
325
+ def launch_gradio(args):
326
+ with gr.Blocks(title="Fish Diffusion") as app:
327
+ gr.Markdown("# Fish Diffusion SVC Inference")
328
+
329
+ with gr.Row():
330
+ with gr.Column():
331
+ input_audio = gr.Audio(
332
+ label="Input Audio",
333
+ type="filepath",
334
+ value=args.input,
335
+ )
336
+ output_audio = gr.Audio(label="Output Audio")
337
+
338
+ with gr.Column():
339
+ if args.speaker_mapping is not None:
340
+ speaker_mapping = json.load(open(args.speaker_mapping))
341
+
342
+ speaker = gr.Dropdown(
343
+ label="Speaker Name (Used for Multi-Speaker Models)",
344
+ choices=list(speaker_mapping.keys()),
345
+ value=list(speaker_mapping.keys())[0],
346
+ )
347
+ else:
348
+ speaker_mapping = None
349
+ speaker = gr.Number(
350
+ label="Speaker ID (Used for Multi-Speaker Models)",
351
+ value=args.speaker_id,
352
+ )
353
+
354
+ pitch_adjust = gr.Number(
355
+ label="Pitch Adjust (Semitones)", value=args.pitch_adjust
356
+ )
357
+ sampler_interval = gr.Slider(
358
+ label="Sampler Interval (⬆️ Faster Generation, ⬇️ Better Quality)",
359
+ value=args.sampler_interval or 10,
360
+ minimum=1,
361
+ maximum=100,
362
+ )
363
+ extract_vocals = gr.Checkbox(
364
+ label="Extract Vocals (For low quality audio)",
365
+ value=args.extract_vocals,
366
+ )
367
+ device = gr.Radio(
368
+ label="Device", choices=["cuda", "cpu"], value=args.device or "cuda"
369
+ )
370
+
371
+ run_btn = gr.Button(label="Run")
372
+
373
+ run_btn.click(
374
+ partial(
375
+ run_inference,
376
+ args.config,
377
+ args.checkpoint,
378
+ speaker_mapping=speaker_mapping,
379
+ ),
380
+ [
381
+ input_audio,
382
+ speaker,
383
+ pitch_adjust,
384
+ sampler_interval,
385
+ extract_vocals,
386
+ device,
387
+ ],
388
+ output_audio,
389
+ )
390
+
391
+ app.queue(concurrency_count=2).launch(share=args.gradio_share)
392
+
393
+
394
+ if __name__ == "__main__":
395
+ args = parse_args()
396
+
397
+ assert args.gradio or (
398
+ args.input is not None and args.output is not None
399
+ ), "Either --gradio or --input and --output should be specified"
400
+
401
+ if args.device is None:
402
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
403
+ else:
404
+ device = torch.device(args.device)
405
+
406
+ if args.gradio:
407
+ args.device = device
408
+ launch_gradio(args)
409
+
410
+ else:
411
+
412
+ inference(
413
+ Config.fromfile(args.config),
414
+ args.checkpoint,
415
+ args.input,
416
+ args.output,
417
+ speaker_id=args.speaker_id,
418
+ pitch_adjust=args.pitch_adjust,
419
+ extract_vocals=args.extract_vocals,
420
+ merge_non_vocals=args.merge_non_vocals,
421
+ vocals_loudness_gain=args.vocals_loudness_gain,
422
+ sampler_interval=args.sampler_interval,
423
+ sampler_progress=args.sampler_progress,
424
+ device=device,
425
+ )
inference_svs.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import math
4
+ import os
5
+
6
+ import numpy as np
7
+ import soundfile as sf
8
+ import torch
9
+ from fish_audio_preprocess.utils import loudness_norm
10
+ from loguru import logger
11
+ from mmengine import Config
12
+
13
+ from fish_diffusion.feature_extractors import FEATURE_EXTRACTORS, PITCH_EXTRACTORS
14
+ from fish_diffusion.utils.tensor import repeat_expand
15
+ from train import FishDiffusion
16
+
17
+
18
+ @torch.no_grad()
19
+ def inference(
20
+ config,
21
+ checkpoint,
22
+ input_path,
23
+ output_path,
24
+ dictionary_path="dictionaries/opencpop-strict.txt",
25
+ speaker_id=0,
26
+ sampler_interval=None,
27
+ sampler_progress=False,
28
+ device="cuda",
29
+ ):
30
+ """Inference
31
+
32
+ Args:
33
+ config: config
34
+ checkpoint: checkpoint path
35
+ input_path: input path
36
+ output_path: output path
37
+ dictionary_path: dictionary path
38
+ speaker_id: speaker id
39
+ sampler_interval: sampler interval, lower value means higher quality
40
+ sampler_progress: show sampler progress
41
+ device: device
42
+ """
43
+
44
+ if sampler_interval is not None:
45
+ config.model.diffusion.sampler_interval = sampler_interval
46
+
47
+ if os.path.isdir(checkpoint):
48
+ # Find the latest checkpoint
49
+ checkpoints = sorted(os.listdir(checkpoint))
50
+ logger.info(f"Found {len(checkpoints)} checkpoints, using {checkpoints[-1]}")
51
+ checkpoint = os.path.join(checkpoint, checkpoints[-1])
52
+
53
+ # Load models
54
+ phoneme_features_extractor = FEATURE_EXTRACTORS.build(
55
+ config.preprocessing.phoneme_features_extractor
56
+ ).to(device)
57
+ phoneme_features_extractor.eval()
58
+
59
+ model = FishDiffusion(config)
60
+ state_dict = torch.load(checkpoint, map_location="cpu")
61
+
62
+ if "state_dict" in state_dict: # Checkpoint is saved by pl
63
+ state_dict = state_dict["state_dict"]
64
+
65
+ model.load_state_dict(state_dict)
66
+ model.to(device)
67
+ model.eval()
68
+
69
+ pitch_extractor = PITCH_EXTRACTORS.build(config.preprocessing.pitch_extractor)
70
+ assert pitch_extractor is not None, "Pitch extractor not found"
71
+
72
+ # Load dictionary
73
+ phones_list = []
74
+ for i in open(dictionary_path):
75
+ _, phones = i.strip().split("\t")
76
+ for j in phones.split():
77
+ if j not in phones_list:
78
+ phones_list.append(j)
79
+
80
+ phones_list = ["<PAD>", "<EOS>", "<UNK>", "AP", "SP"] + sorted(phones_list)
81
+
82
+ # Load ds file
83
+ with open(input_path) as f:
84
+ ds = json.load(f)
85
+
86
+ generated_audio = np.zeros(
87
+ math.ceil(
88
+ (
89
+ float(ds[-1]["offset"])
90
+ + float(ds[-1]["f0_timestep"]) * len(ds[-1]["f0_seq"].split(" "))
91
+ )
92
+ * config.sampling_rate
93
+ )
94
+ )
95
+
96
+ for idx, chunk in enumerate(ds):
97
+ offset = float(chunk["offset"])
98
+
99
+ phones = np.array([phones_list.index(i) for i in chunk["ph_seq"].split(" ")])
100
+ durations = np.array([0] + [float(i) for i in chunk["ph_dur"].split(" ")])
101
+ durations = np.cumsum(durations)
102
+
103
+ f0_timestep = float(chunk["f0_timestep"])
104
+ f0_seq = torch.FloatTensor([float(i) for i in chunk["f0_seq"].split(" ")])
105
+ f0_seq *= 2 ** (6 / 12)
106
+
107
+ total_duration = f0_timestep * len(f0_seq)
108
+
109
+ logger.info(
110
+ f"Processing segment {idx + 1}/{len(ds)}, duration: {total_duration:.2f}s"
111
+ )
112
+
113
+ n_mels = round(total_duration * config.sampling_rate / 512)
114
+ f0_seq = repeat_expand(f0_seq, n_mels, mode="linear")
115
+ f0_seq = f0_seq.to(device)
116
+
117
+ # aligned is in 20ms
118
+ aligned_phones = torch.zeros(int(total_duration * 50), dtype=torch.long)
119
+ for i, phone in enumerate(phones):
120
+ start = int(durations[i] / f0_timestep / 4)
121
+ end = int(durations[i + 1] / f0_timestep / 4)
122
+ aligned_phones[start:end] = phone
123
+
124
+ # Extract text features
125
+ phoneme_features = phoneme_features_extractor.forward(
126
+ aligned_phones.to(device)
127
+ )[0]
128
+
129
+ phoneme_features = repeat_expand(phoneme_features, n_mels).T
130
+
131
+ # Predict
132
+ src_lens = torch.tensor([phoneme_features.shape[0]]).to(device)
133
+
134
+ features = model.model.forward_features(
135
+ speakers=torch.tensor([speaker_id]).long().to(device),
136
+ contents=phoneme_features[None].to(device),
137
+ src_lens=src_lens,
138
+ max_src_len=max(src_lens),
139
+ mel_lens=src_lens,
140
+ max_mel_len=max(src_lens),
141
+ pitches=f0_seq[None],
142
+ )
143
+
144
+ result = model.model.diffusion(features["features"], progress=sampler_progress)
145
+ wav = model.vocoder.spec2wav(result[0].T, f0=f0_seq).cpu().numpy()
146
+ start = round(offset * config.sampling_rate)
147
+ max_wav_len = generated_audio.shape[-1] - start
148
+ generated_audio[start : start + wav.shape[-1]] = wav[:max_wav_len]
149
+
150
+ # Loudness normalization
151
+ generated_audio = loudness_norm.loudness_norm(generated_audio, config.sampling_rate)
152
+
153
+ sf.write(output_path, generated_audio, config.sampling_rate)
154
+ logger.info("Done")
155
+
156
+
157
+ def parse_args():
158
+ parser = argparse.ArgumentParser()
159
+
160
+ parser.add_argument(
161
+ "--config",
162
+ type=str,
163
+ default="configs/svc_hubert_soft.py",
164
+ help="Path to the config file",
165
+ )
166
+
167
+ parser.add_argument(
168
+ "--checkpoint",
169
+ type=str,
170
+ required=True,
171
+ help="Path to the checkpoint file",
172
+ )
173
+
174
+ parser.add_argument(
175
+ "--input",
176
+ type=str,
177
+ required=True,
178
+ help="Path to the input audio file",
179
+ )
180
+
181
+ parser.add_argument(
182
+ "--output",
183
+ type=str,
184
+ required=True,
185
+ help="Path to the output audio file",
186
+ )
187
+
188
+ parser.add_argument(
189
+ "--speaker_id",
190
+ type=int,
191
+ default=0,
192
+ help="Speaker id",
193
+ )
194
+
195
+ parser.add_argument(
196
+ "--sampler_interval",
197
+ type=int,
198
+ default=None,
199
+ required=False,
200
+ help="Sampler interval, if not specified, will be taken from config",
201
+ )
202
+
203
+ parser.add_argument(
204
+ "--sampler_progress",
205
+ action="store_true",
206
+ help="Show sampler progress",
207
+ )
208
+
209
+ parser.add_argument(
210
+ "--device",
211
+ type=str,
212
+ default=None,
213
+ required=False,
214
+ help="Device to use",
215
+ )
216
+
217
+ return parser.parse_args()
218
+
219
+
220
+ if __name__ == "__main__":
221
+ args = parse_args()
222
+
223
+ if args.device is None:
224
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
225
+ else:
226
+ device = torch.device(args.device)
227
+
228
+ inference(
229
+ Config.fromfile(args.config),
230
+ args.checkpoint,
231
+ args.input,
232
+ args.output,
233
+ speaker_id=args.speaker_id,
234
+ sampler_interval=args.sampler_interval,
235
+ sampler_progress=args.sampler_progress,
236
+ device=device,
237
+ )
inference_vst.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from functools import partial
5
+ from typing import Union
6
+
7
+ import gradio as gr
8
+ import librosa
9
+ import numpy as np
10
+ import soundfile as sf
11
+ import torch
12
+ from fish_audio_preprocess.utils import loudness_norm, separate_audio
13
+ from loguru import logger
14
+ from mmengine import Config
15
+
16
+ from fish_diffusion.feature_extractors import FEATURE_EXTRACTORS, PITCH_EXTRACTORS
17
+ from fish_diffusion.utils.audio import get_mel_from_audio, slice_audio
18
+ from fish_diffusion.utils.inference import load_checkpoint
19
+ from fish_diffusion.utils.tensor import repeat_expand
20
+
21
+
22
+ @torch.no_grad()
23
+ def inference(
24
+ in_sample,
25
+ config_path,
26
+ checkpoint,
27
+ input_path,
28
+ output_path,
29
+ speaker_id=0,
30
+ pitch_adjust=0,
31
+ silence_threshold=60,
32
+ max_slice_duration=30.0,
33
+ extract_vocals=True,
34
+ merge_non_vocals=True,
35
+ vocals_loudness_gain=0.0,
36
+ sampler_interval=None,
37
+ sampler_progress=False,
38
+ device="cuda",
39
+ gradio_progress=None,
40
+ ):
41
+ """Inference
42
+
43
+ Args:
44
+ config: config
45
+ checkpoint: checkpoint path
46
+ input_path: input path
47
+ output_path: output path
48
+ speaker_id: speaker id
49
+ pitch_adjust: pitch adjust
50
+ silence_threshold: silence threshold of librosa.effects.split
51
+ max_slice_duration: maximum duration of each slice
52
+ extract_vocals: extract vocals
53
+ merge_non_vocals: merge non-vocals, only works when extract_vocals is True
54
+ vocals_loudness_gain: loudness gain of vocals (dB)
55
+ sampler_interval: sampler interval, lower value means higher quality
56
+ sampler_progress: show sampler progress
57
+ device: device
58
+ gradio_progress: gradio progress callback
59
+ """
60
+ config = Config.fromfile(config_path)
61
+
62
+ if sampler_interval is not None:
63
+ config.model.diffusion.sampler_interval = sampler_interval
64
+
65
+ if os.path.isdir(checkpoint):
66
+ # Find the latest checkpoint
67
+ checkpoints = sorted(os.listdir(checkpoint))
68
+ logger.info(f"Found {len(checkpoints)} checkpoints, using {checkpoints[-1]}")
69
+ checkpoint = os.path.join(checkpoint, checkpoints[-1])
70
+
71
+ audio, sr = librosa.load(input_path, config.sampling_rate, mono=True)
72
+ #sr = in_sample
73
+ #audio = sf.read(input_path)
74
+
75
+ # Extract vocals
76
+
77
+ if extract_vocals:
78
+ logger.info("Extracting vocals...")
79
+
80
+ if gradio_progress is not None:
81
+ gradio_progress(0, "Extracting vocals...")
82
+
83
+ model = separate_audio.init_model("htdemucs", device=device)
84
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=model.samplerate)[None]
85
+
86
+ # To two channels
87
+ audio = np.concatenate([audio, audio], axis=0)
88
+ audio = torch.from_numpy(audio).to(device)
89
+ tracks = separate_audio.separate_audio(
90
+ model, audio, shifts=1, num_workers=0, progress=True
91
+ )
92
+ audio = separate_audio.merge_tracks(tracks, filter=["vocals"]).cpu().numpy()
93
+ non_vocals = (
94
+ separate_audio.merge_tracks(tracks, filter=["drums", "bass", "other"])
95
+ .cpu()
96
+ .numpy()
97
+ )
98
+
99
+ audio = librosa.resample(audio[0], orig_sr=model.samplerate, target_sr=sr)
100
+ non_vocals = librosa.resample(
101
+ non_vocals[0], orig_sr=model.samplerate, target_sr=sr
102
+ )
103
+
104
+ # Normalize loudness
105
+ non_vocals = loudness_norm.loudness_norm(non_vocals, sr)
106
+
107
+ # Normalize loudness
108
+ audio = loudness_norm.loudness_norm(audio, sr)
109
+
110
+ # Slice into segments
111
+ segments = list(
112
+ slice_audio(
113
+ audio, sr, max_duration=max_slice_duration, top_db=silence_threshold
114
+ )
115
+ )
116
+ logger.info(f"Sliced into {len(segments)} segments")
117
+
118
+ # Load models
119
+ text_features_extractor = FEATURE_EXTRACTORS.build(
120
+ config.preprocessing.text_features_extractor
121
+ ).to(device)
122
+ text_features_extractor.eval()
123
+
124
+ model = load_checkpoint(config, checkpoint, device=device)
125
+
126
+ pitch_extractor = PITCH_EXTRACTORS.build(config.preprocessing.pitch_extractor)
127
+ assert pitch_extractor is not None, "Pitch extractor not found"
128
+
129
+ generated_audio = np.zeros_like(audio)
130
+ audio_torch = torch.from_numpy(audio).to(device)[None]
131
+
132
+ for idx, (start, end) in enumerate(segments):
133
+ if gradio_progress is not None:
134
+ gradio_progress(idx / len(segments), "Generating audio...")
135
+
136
+ segment = audio_torch[:, start:end]
137
+ logger.info(
138
+ f"Processing segment {idx + 1}/{len(segments)}, duration: {segment.shape[-1] / sr:.2f}s"
139
+ )
140
+
141
+ # Extract mel
142
+ mel = get_mel_from_audio(segment, sr)
143
+
144
+ # Extract pitch (f0)
145
+ pitch = pitch_extractor(segment, sr, pad_to=mel.shape[-1]).float()
146
+ pitch *= 2 ** (pitch_adjust / 12)
147
+
148
+ # Extract text features
149
+ text_features = text_features_extractor(segment, sr)[0]
150
+ text_features = repeat_expand(text_features, mel.shape[-1]).T
151
+
152
+ # Predict
153
+ src_lens = torch.tensor([mel.shape[-1]]).to(device)
154
+
155
+ features = model.model.forward_features(
156
+ speakers=torch.tensor([speaker_id]).long().to(device),
157
+ contents=text_features[None].to(device),
158
+ src_lens=src_lens,
159
+ max_src_len=max(src_lens),
160
+ mel_lens=src_lens,
161
+ max_mel_len=max(src_lens),
162
+ pitches=pitch[None].to(device),
163
+ )
164
+
165
+ result = model.model.diffusion(features["features"], progress=sampler_progress)
166
+ wav = model.vocoder.spec2wav(result[0].T, f0=pitch).cpu().numpy()
167
+ max_wav_len = generated_audio.shape[-1] - start
168
+ generated_audio[start : start + wav.shape[-1]] = wav[:max_wav_len]
169
+
170
+ # Loudness normalization
171
+ generated_audio = loudness_norm.loudness_norm(generated_audio, sr)
172
+
173
+ # Loudness gain
174
+ loudness_float = 10 ** (vocals_loudness_gain / 20)
175
+ generated_audio = generated_audio * loudness_float
176
+
177
+ # Merge non-vocals
178
+ if extract_vocals and merge_non_vocals:
179
+ generated_audio = (generated_audio + non_vocals) / 2
180
+
181
+ logger.info("Done")
182
+
183
+ if output_path is not None:
184
+ sf.write(output_path, generated_audio, sr)
185
+
186
+ return generated_audio, sr
187
+
188
+ class SvcFish:
189
+ def __init__(self, checkpoint_path, config_path, sampler_interval=None, extract_vocals=True,
190
+ merge_non_vocals=True,vocals_loudness_gain=0.0,silence_threshold=60, max_slice_duration=30.0):
191
+ self.config_path = config_path
192
+ self.checkpoint_path = checkpoint_path
193
+ self.sampler_interval = sampler_interval
194
+ self.silence_threshold = silence_threshold
195
+ self.max_slice_duration = max_slice_duration
196
+ self.extract_vocals = extract_vocals
197
+ self.merge_non_vocals = merge_non_vocals
198
+ self.vocals_loudness_gain = vocals_loudness_gain
199
+ def infer(self, input_path, pitch_adjust, speaker_id, in_sample):
200
+ return inference(
201
+ in_sample=in_sample,
202
+ config_path=self.config_path,
203
+ checkpoint=self.checkpoint_path,
204
+ input_path=input_path,
205
+ output_path=None,
206
+ speaker_id=speaker_id,
207
+ pitch_adjust=pitch_adjust,
208
+ silence_threshold=self.silence_threshold,
209
+ max_slice_duration=self.max_slice_duration,
210
+ extract_vocals=self.extract_vocals,
211
+ merge_non_vocals=self.merge_non_vocals,
212
+ vocals_loudness_gain=self.vocals_loudness_gain,
213
+ sampler_interval=self.sampler_interval,
214
+ sampler_progress=True,
215
+ device="cuda",
216
+ gradio_progress=None,
217
+ )
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
poetry.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [virtualenvs]
2
+ create = false
pyproject.toml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "fish-diffusion"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Lengyue <lengyue@lengyue.me>"]
6
+ license = "Apache"
7
+
8
+ packages = [{ include = "fish_diffusion" }]
9
+
10
+ [tool.poetry.dependencies]
11
+ python = "^3.10"
12
+ praat-parselmouth = "^0.4.3"
13
+ soundfile = "^0.11.0"
14
+ librosa = "^0.9.1"
15
+ pytorch-lightning = "^1.8.6"
16
+ numba = "^0.56.4"
17
+ fish-audio-preprocess = "^0.1.9"
18
+ wandb = "^0.13.9"
19
+ transformers = "^4.25.1"
20
+ torchcrepe = "^0.0.17"
21
+ mmengine = "^0.4.0"
22
+ loguru = "^0.6.0"
23
+ click = "^8.1.3"
24
+ tensorboard = "^2.11.2"
25
+ openai-whisper = "^20230124"
26
+ pypinyin = "^0.48.0"
27
+ TextGrid = "^1.5"
28
+ pyworld = "^0.3.2"
29
+ pykakasi = "^2.2.1"
30
+ gradio = "^3.18.0"
31
+ onnxruntime = "^1.14.0"
32
+
33
+ [tool.poetry.group.dev.dependencies]
34
+ isort = "^5.11.4"
35
+ black = "^22.12.0"
36
+
37
+ [tool.poetry.group.docs]
38
+ optional = true
39
+
40
+ [tool.poetry.group.docs.dependencies]
41
+ furo = "^2022.12.7"
42
+ sphinx-autobuild = "^2021.3.14"
43
+ myst-parser = "^0.18.1"
44
+
45
+ [build-system]
46
+ requires = ["poetry-core>=1.2.0"]
47
+ build-backend = "poetry.core.masonry.api"
48
+
49
+ [tool.isort]
50
+ profile = "black"
51
+ extend_skip = ["dataset", "logs"]
requirements.txt ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.3.0
2
+ aiohttp==3.8.3
3
+ aiosignal==1.3.1
4
+ appdirs==1.4.4
5
+ asttokens==2.1.0
6
+ async-timeout==4.0.2
7
+ attrs==22.1.0
8
+ audioread==3.0.0
9
+ backcall==0.2.0
10
+ cachetools==5.2.0
11
+ certifi==2022.9.24
12
+ cffi==1.15.1
13
+ charset-normalizer==2.1.1
14
+ contourpy==1.0.6
15
+ cycler==0.11.0
16
+ debugpy==1.6.3
17
+ decorator==5.1.1
18
+ einops==0.6.0
19
+ entrypoints==0.4
20
+ executing==1.2.0
21
+ fonttools==4.38.0
22
+ frozenlist==1.3.3
23
+ fsspec==2022.11.0
24
+ future==0.18.2
25
+ google-auth==2.14.1
26
+ google-auth-oauthlib==0.4.6
27
+ grpcio==1.50.0
28
+ h5py==3.7.0
29
+ hparams==0.3.0
30
+ idna==3.4
31
+ imageio==2.22.4
32
+ importlib-metadata==5.0.0
33
+ ipykernel==6.17.1
34
+ ipython==8.6.0
35
+ jedi==0.18.1
36
+ joblib==1.2.0
37
+ jupyter_client==7.4.7
38
+ jupyter_core==5.0.0
39
+ kiwisolver==1.4.4
40
+ librosa==0.9.1
41
+ llvmlite==0.39.1
42
+ Markdown==3.4.1
43
+ MarkupSafe==2.1.1
44
+ matplotlib==3.6.2
45
+ matplotlib-inline==0.1.6
46
+ multidict==6.0.2
47
+ nest-asyncio==1.5.6
48
+ networkx==2.8.8
49
+ numba==0.56.4
50
+ numpy==1.23.5
51
+ oauthlib==3.2.2
52
+ packaging==21.3
53
+ parso==0.8.3
54
+ pexpect==4.8.0
55
+ pickleshare==0.7.5
56
+ Pillow==9.3.0
57
+ platformdirs==2.5.4
58
+ pooch==1.6.0
59
+ praat-parselmouth==0.5.0
60
+ prompt-toolkit==3.0.32
61
+ protobuf==3.20.3
62
+ psutil==5.9.4
63
+ ptyprocess==0.7.0
64
+ pure-eval==0.2.2
65
+ pyasn1==0.4.8
66
+ pyasn1-modules==0.2.8
67
+ pycparser==2.21
68
+ pycwt==0.3.0a22
69
+ pyDeprecate==0.3.0
70
+ Pygments==2.13.0
71
+ pyloudnorm==0.2.0
72
+ pyparsing==3.0.9
73
+ python-dateutil==2.8.2
74
+ pytorch-lightning==2.0.0
75
+ PyWavelets==1.4.1
76
+ PyYAML==5.4.1
77
+ pyzmq==24.0.1
78
+ requests==2.28.1
79
+ requests-oauthlib==1.3.1
80
+ resampy==0.4.2
81
+ rsa==4.9
82
+ scikit-image==0.19.3
83
+ scikit-learn==1.1.3
84
+ scipy==1.9.3
85
+ six==1.16.0
86
+ soundfile==0.11.0
87
+ stack-data==0.6.1
88
+ tensorboard==3.0.0
89
+ tensorboard-data-server==0.6.1
90
+ tensorboard-plugin-wit==1.8.1
91
+ threadpoolctl==3.1.0
92
+ tifffile==2022.10.10
93
+ tqdm==4.64.1
94
+ traitlets==5.5.0
95
+ typeguard==2.13.3
96
+ typing_extensions==4.4.0
97
+ urllib3==1.26.12
98
+ utils==1.0.1
99
+ wcwidth==0.2.5
100
+ webrtcvad==2.0.10
101
+ Werkzeug==2.2.2
102
+ yarl==1.8.1
103
+ zipp==3.10.0
train.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+
3
+ import matplotlib.pyplot as plt
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ import wandb
7
+ from loguru import logger
8
+ from mmengine import Config
9
+ from mmengine.optim import OPTIMIZERS
10
+ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
11
+ from torch.utils.data import DataLoader
12
+
13
+ from fish_diffusion.archs.diffsinger import DiffSinger
14
+ from fish_diffusion.datasets import DATASETS
15
+ from fish_diffusion.datasets.repeat import RepeatDataset
16
+ from fish_diffusion.utils.scheduler import LR_SCHEUDLERS
17
+ from fish_diffusion.utils.viz import viz_synth_sample
18
+ from fish_diffusion.vocoders import VOCODERS
19
+
20
+
21
+ class FishDiffusion(pl.LightningModule):
22
+ def __init__(self, config):
23
+ super().__init__()
24
+ self.save_hyperparameters()
25
+
26
+ self.model = DiffSinger(config.model)
27
+ self.config = config
28
+
29
+ # 音频编码器, 将梅尔谱转换为音频
30
+ self.vocoder = VOCODERS.build(config.model.vocoder)
31
+ self.vocoder.freeze()
32
+
33
+ def configure_optimizers(self):
34
+ self.config.optimizer.params = self.parameters()
35
+ optimizer = OPTIMIZERS.build(self.config.optimizer)
36
+
37
+ self.config.scheduler.optimizer = optimizer
38
+ scheduler = LR_SCHEUDLERS.build(self.config.scheduler)
39
+
40
+ return [optimizer], dict(scheduler=scheduler, interval="step")
41
+
42
+ def _step(self, batch, batch_idx, mode):
43
+ assert batch["pitches"].shape[1] == batch["mels"].shape[1]
44
+
45
+ pitches = batch["pitches"].clone()
46
+ batch_size = batch["speakers"].shape[0]
47
+
48
+ output = self.model(
49
+ speakers=batch["speakers"],
50
+ contents=batch["contents"],
51
+ src_lens=batch["content_lens"],
52
+ max_src_len=batch["max_content_len"],
53
+ mels=batch["mels"],
54
+ mel_lens=batch["mel_lens"],
55
+ max_mel_len=batch["max_mel_len"],
56
+ pitches=batch["pitches"],
57
+ )
58
+
59
+ self.log(f"{mode}_loss", output["loss"], batch_size=batch_size, sync_dist=True)
60
+
61
+ if mode != "valid":
62
+ return output["loss"]
63
+
64
+ x = self.model.diffusion(output["features"])
65
+
66
+ for idx, (gt_mel, gt_pitch, predict_mel, predict_mel_len) in enumerate(
67
+ zip(batch["mels"], pitches, x, batch["mel_lens"])
68
+ ):
69
+ image_mels, wav_reconstruction, wav_prediction = viz_synth_sample(
70
+ gt_mel=gt_mel,
71
+ gt_pitch=gt_pitch,
72
+ predict_mel=predict_mel,
73
+ predict_mel_len=predict_mel_len,
74
+ vocoder=self.vocoder,
75
+ return_image=False,
76
+ )
77
+
78
+ wav_reconstruction = wav_reconstruction.to(torch.float32).cpu().numpy()
79
+ wav_prediction = wav_prediction.to(torch.float32).cpu().numpy()
80
+
81
+ # WanDB logger
82
+ if isinstance(self.logger, WandbLogger):
83
+ self.logger.experiment.log(
84
+ {
85
+ f"reconstruction_mel": wandb.Image(image_mels, caption="mels"),
86
+ f"wavs": [
87
+ wandb.Audio(
88
+ wav_reconstruction,
89
+ sample_rate=44100,
90
+ caption=f"reconstruction (gt)",
91
+ ),
92
+ wandb.Audio(
93
+ wav_prediction,
94
+ sample_rate=44100,
95
+ caption=f"prediction",
96
+ ),
97
+ ],
98
+ },
99
+ )
100
+
101
+ # TensorBoard logger
102
+ if isinstance(self.logger, TensorBoardLogger):
103
+ self.logger.experiment.add_figure(
104
+ f"sample-{idx}/mels",
105
+ image_mels,
106
+ global_step=self.global_step,
107
+ )
108
+ self.logger.experiment.add_audio(
109
+ f"sample-{idx}/wavs/gt",
110
+ wav_reconstruction,
111
+ self.global_step,
112
+ sample_rate=44100,
113
+ )
114
+ self.logger.experiment.add_audio(
115
+ f"sample-{idx}/wavs/prediction",
116
+ wav_prediction,
117
+ self.global_step,
118
+ sample_rate=44100,
119
+ )
120
+
121
+ if isinstance(image_mels, plt.Figure):
122
+ plt.close(image_mels)
123
+
124
+ return output["loss"]
125
+
126
+ def training_step(self, batch, batch_idx):
127
+ return self._step(batch, batch_idx, mode="train")
128
+
129
+ def validation_step(self, batch, batch_idx):
130
+ return self._step(batch, batch_idx, mode="valid")
131
+
132
+
133
+ if __name__ == "__main__":
134
+ pl.seed_everything(42, workers=True)
135
+
136
+ parser = ArgumentParser()
137
+ parser.add_argument("--config", type=str, required=True)
138
+ parser.add_argument("--resume", type=str, default=None)
139
+ parser.add_argument(
140
+ "--tensorboard",
141
+ action="store_true",
142
+ default=False,
143
+ help="Use tensorboard logger, default is wandb.",
144
+ )
145
+ parser.add_argument("--resume-id", type=str, default=None, help="Wandb run id.")
146
+ parser.add_argument("--entity", type=str, default=None, help="Wandb entity.")
147
+ parser.add_argument("--name", type=str, default=None, help="Wandb run name.")
148
+ parser.add_argument(
149
+ "--pretrained", type=str, default=None, help="Pretrained model."
150
+ )
151
+ parser.add_argument(
152
+ "--only-train-speaker-embeddings",
153
+ action="store_true",
154
+ default=False,
155
+ help="Only train speaker embeddings.",
156
+ )
157
+
158
+ args = parser.parse_args()
159
+
160
+ cfg = Config.fromfile(args.config)
161
+
162
+ model = FishDiffusion(cfg)
163
+
164
+ # We only load the state_dict of the model, not the optimizer.
165
+ if args.pretrained:
166
+ state_dict = torch.load(args.pretrained, map_location="cpu")
167
+ if "state_dict" in state_dict:
168
+ state_dict = state_dict["state_dict"]
169
+
170
+ result = model.load_state_dict(state_dict, strict=False)
171
+
172
+ missing_keys = set(result.missing_keys)
173
+ unexpected_keys = set(result.unexpected_keys)
174
+
175
+ # Make sure incorrect keys are just noise predictor keys.
176
+ unexpected_keys = unexpected_keys - set(
177
+ i.replace(".naive_noise_predictor.", ".") for i in missing_keys
178
+ )
179
+
180
+ assert len(unexpected_keys) == 0
181
+
182
+ if args.only_train_speaker_embeddings:
183
+ for name, param in model.named_parameters():
184
+ if "speaker_encoder" not in name:
185
+ param.requires_grad = False
186
+
187
+ logger.info(
188
+ "Only train speaker embeddings, all other parameters are frozen."
189
+ )
190
+
191
+ logger = (
192
+ TensorBoardLogger("logs", name=cfg.model.type)
193
+ if args.tensorboard
194
+ else WandbLogger(
195
+ project=cfg.model.type,
196
+ save_dir="logs",
197
+ log_model=True,
198
+ name=args.name,
199
+ entity=args.entity,
200
+ resume="must" if args.resume_id else False,
201
+ id=args.resume_id,
202
+ )
203
+ )
204
+
205
+ trainer = pl.Trainer(
206
+ logger=logger,
207
+ **cfg.trainer,
208
+ )
209
+
210
+ train_dataset = DATASETS.build(cfg.dataset.train)
211
+ train_loader = DataLoader(
212
+ train_dataset,
213
+ collate_fn=train_dataset.collate_fn,
214
+ **cfg.dataloader.train,
215
+ )
216
+
217
+ valid_dataset = DATASETS.build(cfg.dataset.valid)
218
+ valid_dataset = RepeatDataset(
219
+ valid_dataset, repeat=trainer.num_devices, collate_fn=valid_dataset.collate_fn
220
+ )
221
+
222
+ valid_loader = DataLoader(
223
+ valid_dataset,
224
+ collate_fn=valid_dataset.collate_fn,
225
+ **cfg.dataloader.valid,
226
+ )
227
+
228
+ trainer.fit(model, train_loader, valid_loader, ckpt_path=args.resume)
tst ADDED
@@ -0,0 +1 @@
 
 
1
+ python inference.py --config configs\svc_cn_hubert_soft_finetune_crepe.py --checkpoint checkpoints\epoch=909-step=20000-valid_loss=0.23.ckpt--gradio
开始处理.bat ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ @echo off
2
+ env310\python.exe train.py --config configs/train_my_config.py --pretrained checkpoints\hubert\cn-hubert-soft-600-singers-pretrained-v1.ckpt
3
+
4
+ pause