chocolatedesue commited on
Commit
223aff6
1 Parent(s): 6641bc1
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # create time : 2023.1.09
2
+ # onnxruntime==1.13.1 not supported py=3.11
3
+ # FROM python:3.11.1-slim-bullseye as compile-image
4
+ FROM python:3.9.15-slim-bullseye as compile-image
5
+
6
+ ENV POETRY_VERSION=1.3.1
7
+
8
+ RUN export DEBIAN_FRONTEND=noninteractive && \
9
+ apt-get update && \
10
+ apt-get install cmake build-essential -y --no-install-recommends && \
11
+ pip install poetry==$POETRY_VERSION
12
+
13
+
14
+
15
+ COPY ./pyproject.toml ./app/init_jptalk.py ./poetry.lock ./
16
+ RUN poetry export -f requirements.txt -o requirements.txt --without dev --without test --without-hashes && \
17
+ python -m venv /opt/venv && \
18
+ /opt/venv/bin/pip install --no-cache-dir -U pip && \
19
+ /opt/venv/bin/pip install --no-cache-dir -r requirements.txt && \
20
+ /opt/venv/bin/python3 init_jptalk.py
21
+
22
+ # FROM python:3.11.1-slim-bullseye as final
23
+ FROM python:3.9.15-slim-bullseye as final
24
+ EXPOSE 7860
25
+ COPY --from=compile-image /opt/venv /opt/venv
26
+ # COPY ./app/init_jptalk.py /app/init_jptalk.py
27
+ ENV TZ=Asia/Shanghai PATH="/opt/venv/bin:$PATH"
28
+
29
+ COPY ./app /app
30
+ WORKDIR /
31
+ CMD ["python", "-m","app.main"]
LICENSE ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 2, June 1991
3
+
4
+ Copyright (C) 1989, 1991 Free Software Foundation, Inc.,
5
+ 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
6
+ Everyone is permitted to copy and distribute verbatim copies
7
+ of this license document, but changing it is not allowed.
8
+
9
+ Preamble
10
+
11
+ The licenses for most software are designed to take away your
12
+ freedom to share and change it. By contrast, the GNU General Public
13
+ License is intended to guarantee your freedom to share and change free
14
+ software--to make sure the software is free for all its users. This
15
+ General Public License applies to most of the Free Software
16
+ Foundation's software and to any other program whose authors commit to
17
+ using it. (Some other Free Software Foundation software is covered by
18
+ the GNU Lesser General Public License instead.) You can apply it to
19
+ your programs, too.
20
+
21
+ When we speak of free software, we are referring to freedom, not
22
+ price. Our General Public Licenses are designed to make sure that you
23
+ have the freedom to distribute copies of free software (and charge for
24
+ this service if you wish), that you receive source code or can get it
25
+ if you want it, that you can change the software or use pieces of it
26
+ in new free programs; and that you know you can do these things.
27
+
28
+ To protect your rights, we need to make restrictions that forbid
29
+ anyone to deny you these rights or to ask you to surrender the rights.
30
+ These restrictions translate to certain responsibilities for you if you
31
+ distribute copies of the software, or if you modify it.
32
+
33
+ For example, if you distribute copies of such a program, whether
34
+ gratis or for a fee, you must give the recipients all the rights that
35
+ you have. You must make sure that they, too, receive or can get the
36
+ source code. And you must show them these terms so they know their
37
+ rights.
38
+
39
+ We protect your rights with two steps: (1) copyright the software, and
40
+ (2) offer you this license which gives you legal permission to copy,
41
+ distribute and/or modify the software.
42
+
43
+ Also, for each author's protection and ours, we want to make certain
44
+ that everyone understands that there is no warranty for this free
45
+ software. If the software is modified by someone else and passed on, we
46
+ want its recipients to know that what they have is not the original, so
47
+ that any problems introduced by others will not reflect on the original
48
+ authors' reputations.
49
+
50
+ Finally, any free program is threatened constantly by software
51
+ patents. We wish to avoid the danger that redistributors of a free
52
+ program will individually obtain patent licenses, in effect making the
53
+ program proprietary. To prevent this, we have made it clear that any
54
+ patent must be licensed for everyone's free use or not licensed at all.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ GNU GENERAL PUBLIC LICENSE
60
+ TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
61
+
62
+ 0. This License applies to any program or other work which contains
63
+ a notice placed by the copyright holder saying it may be distributed
64
+ under the terms of this General Public License. The "Program", below,
65
+ refers to any such program or work, and a "work based on the Program"
66
+ means either the Program or any derivative work under copyright law:
67
+ that is to say, a work containing the Program or a portion of it,
68
+ either verbatim or with modifications and/or translated into another
69
+ language. (Hereinafter, translation is included without limitation in
70
+ the term "modification".) Each licensee is addressed as "you".
71
+
72
+ Activities other than copying, distribution and modification are not
73
+ covered by this License; they are outside its scope. The act of
74
+ running the Program is not restricted, and the output from the Program
75
+ is covered only if its contents constitute a work based on the
76
+ Program (independent of having been made by running the Program).
77
+ Whether that is true depends on what the Program does.
78
+
79
+ 1. You may copy and distribute verbatim copies of the Program's
80
+ source code as you receive it, in any medium, provided that you
81
+ conspicuously and appropriately publish on each copy an appropriate
82
+ copyright notice and disclaimer of warranty; keep intact all the
83
+ notices that refer to this License and to the absence of any warranty;
84
+ and give any other recipients of the Program a copy of this License
85
+ along with the Program.
86
+
87
+ You may charge a fee for the physical act of transferring a copy, and
88
+ you may at your option offer warranty protection in exchange for a fee.
89
+
90
+ 2. You may modify your copy or copies of the Program or any portion
91
+ of it, thus forming a work based on the Program, and copy and
92
+ distribute such modifications or work under the terms of Section 1
93
+ above, provided that you also meet all of these conditions:
94
+
95
+ a) You must cause the modified files to carry prominent notices
96
+ stating that you changed the files and the date of any change.
97
+
98
+ b) You must cause any work that you distribute or publish, that in
99
+ whole or in part contains or is derived from the Program or any
100
+ part thereof, to be licensed as a whole at no charge to all third
101
+ parties under the terms of this License.
102
+
103
+ c) If the modified program normally reads commands interactively
104
+ when run, you must cause it, when started running for such
105
+ interactive use in the most ordinary way, to print or display an
106
+ announcement including an appropriate copyright notice and a
107
+ notice that there is no warranty (or else, saying that you provide
108
+ a warranty) and that users may redistribute the program under
109
+ these conditions, and telling the user how to view a copy of this
110
+ License. (Exception: if the Program itself is interactive but
111
+ does not normally print such an announcement, your work based on
112
+ the Program is not required to print an announcement.)
113
+
114
+ These requirements apply to the modified work as a whole. If
115
+ identifiable sections of that work are not derived from the Program,
116
+ and can be reasonably considered independent and separate works in
117
+ themselves, then this License, and its terms, do not apply to those
118
+ sections when you distribute them as separate works. But when you
119
+ distribute the same sections as part of a whole which is a work based
120
+ on the Program, the distribution of the whole must be on the terms of
121
+ this License, whose permissions for other licensees extend to the
122
+ entire whole, and thus to each and every part regardless of who wrote it.
123
+
124
+ Thus, it is not the intent of this section to claim rights or contest
125
+ your rights to work written entirely by you; rather, the intent is to
126
+ exercise the right to control the distribution of derivative or
127
+ collective works based on the Program.
128
+
129
+ In addition, mere aggregation of another work not based on the Program
130
+ with the Program (or with a work based on the Program) on a volume of
131
+ a storage or distribution medium does not bring the other work under
132
+ the scope of this License.
133
+
134
+ 3. You may copy and distribute the Program (or a work based on it,
135
+ under Section 2) in object code or executable form under the terms of
136
+ Sections 1 and 2 above provided that you also do one of the following:
137
+
138
+ a) Accompany it with the complete corresponding machine-readable
139
+ source code, which must be distributed under the terms of Sections
140
+ 1 and 2 above on a medium customarily used for software interchange; or,
141
+
142
+ b) Accompany it with a written offer, valid for at least three
143
+ years, to give any third party, for a charge no more than your
144
+ cost of physically performing source distribution, a complete
145
+ machine-readable copy of the corresponding source code, to be
146
+ distributed under the terms of Sections 1 and 2 above on a medium
147
+ customarily used for software interchange; or,
148
+
149
+ c) Accompany it with the information you received as to the offer
150
+ to distribute corresponding source code. (This alternative is
151
+ allowed only for noncommercial distribution and only if you
152
+ received the program in object code or executable form with such
153
+ an offer, in accord with Subsection b above.)
154
+
155
+ The source code for a work means the preferred form of the work for
156
+ making modifications to it. For an executable work, complete source
157
+ code means all the source code for all modules it contains, plus any
158
+ associated interface definition files, plus the scripts used to
159
+ control compilation and installation of the executable. However, as a
160
+ special exception, the source code distributed need not include
161
+ anything that is normally distributed (in either source or binary
162
+ form) with the major components (compiler, kernel, and so on) of the
163
+ operating system on which the executable runs, unless that component
164
+ itself accompanies the executable.
165
+
166
+ If distribution of executable or object code is made by offering
167
+ access to copy from a designated place, then offering equivalent
168
+ access to copy the source code from the same place counts as
169
+ distribution of the source code, even though third parties are not
170
+ compelled to copy the source along with the object code.
171
+
172
+ 4. You may not copy, modify, sublicense, or distribute the Program
173
+ except as expressly provided under this License. Any attempt
174
+ otherwise to copy, modify, sublicense or distribute the Program is
175
+ void, and will automatically terminate your rights under this License.
176
+ However, parties who have received copies, or rights, from you under
177
+ this License will not have their licenses terminated so long as such
178
+ parties remain in full compliance.
179
+
180
+ 5. You are not required to accept this License, since you have not
181
+ signed it. However, nothing else grants you permission to modify or
182
+ distribute the Program or its derivative works. These actions are
183
+ prohibited by law if you do not accept this License. Therefore, by
184
+ modifying or distributing the Program (or any work based on the
185
+ Program), you indicate your acceptance of this License to do so, and
186
+ all its terms and conditions for copying, distributing or modifying
187
+ the Program or works based on it.
188
+
189
+ 6. Each time you redistribute the Program (or any work based on the
190
+ Program), the recipient automatically receives a license from the
191
+ original licensor to copy, distribute or modify the Program subject to
192
+ these terms and conditions. You may not impose any further
193
+ restrictions on the recipients' exercise of the rights granted herein.
194
+ You are not responsible for enforcing compliance by third parties to
195
+ this License.
196
+
197
+ 7. If, as a consequence of a court judgment or allegation of patent
198
+ infringement or for any other reason (not limited to patent issues),
199
+ conditions are imposed on you (whether by court order, agreement or
200
+ otherwise) that contradict the conditions of this License, they do not
201
+ excuse you from the conditions of this License. If you cannot
202
+ distribute so as to satisfy simultaneously your obligations under this
203
+ License and any other pertinent obligations, then as a consequence you
204
+ may not distribute the Program at all. For example, if a patent
205
+ license would not permit royalty-free redistribution of the Program by
206
+ all those who receive copies directly or indirectly through you, then
207
+ the only way you could satisfy both it and this License would be to
208
+ refrain entirely from distribution of the Program.
209
+
210
+ If any portion of this section is held invalid or unenforceable under
211
+ any particular circumstance, the balance of the section is intended to
212
+ apply and the section as a whole is intended to apply in other
213
+ circumstances.
214
+
215
+ It is not the purpose of this section to induce you to infringe any
216
+ patents or other property right claims or to contest validity of any
217
+ such claims; this section has the sole purpose of protecting the
218
+ integrity of the free software distribution system, which is
219
+ implemented by public license practices. Many people have made
220
+ generous contributions to the wide range of software distributed
221
+ through that system in reliance on consistent application of that
222
+ system; it is up to the author/donor to decide if he or she is willing
223
+ to distribute software through any other system and a licensee cannot
224
+ impose that choice.
225
+
226
+ This section is intended to make thoroughly clear what is believed to
227
+ be a consequence of the rest of this License.
228
+
229
+ 8. If the distribution and/or use of the Program is restricted in
230
+ certain countries either by patents or by copyrighted interfaces, the
231
+ original copyright holder who places the Program under this License
232
+ may add an explicit geographical distribution limitation excluding
233
+ those countries, so that distribution is permitted only in or among
234
+ countries not thus excluded. In such case, this License incorporates
235
+ the limitation as if written in the body of this License.
236
+
237
+ 9. The Free Software Foundation may publish revised and/or new versions
238
+ of the General Public License from time to time. Such new versions will
239
+ be similar in spirit to the present version, but may differ in detail to
240
+ address new problems or concerns.
241
+
242
+ Each version is given a distinguishing version number. If the Program
243
+ specifies a version number of this License which applies to it and "any
244
+ later version", you have the option of following the terms and conditions
245
+ either of that version or of any later version published by the Free
246
+ Software Foundation. If the Program does not specify a version number of
247
+ this License, you may choose any version ever published by the Free Software
248
+ Foundation.
249
+
250
+ 10. If you wish to incorporate parts of the Program into other free
251
+ programs whose distribution conditions are different, write to the author
252
+ to ask for permission. For software which is copyrighted by the Free
253
+ Software Foundation, write to the Free Software Foundation; we sometimes
254
+ make exceptions for this. Our decision will be guided by the two goals
255
+ of preserving the free status of all derivatives of our free software and
256
+ of promoting the sharing and reuse of software generally.
257
+
258
+ NO WARRANTY
259
+
260
+ 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
261
+ FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
262
+ OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
263
+ PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
264
+ OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
265
+ MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
266
+ TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
267
+ PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
268
+ REPAIR OR CORRECTION.
269
+
270
+ 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
271
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR
272
+ REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
273
+ INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
274
+ OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED
275
+ TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY
276
+ YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
277
+ PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
278
+ POSSIBILITY OF SUCH DAMAGES.
279
+
280
+ END OF TERMS AND CONDITIONS
281
+
282
+ How to Apply These Terms to Your New Programs
283
+
284
+ If you develop a new program, and you want it to be of the greatest
285
+ possible use to the public, the best way to achieve this is to make it
286
+ free software which everyone can redistribute and change under these terms.
287
+
288
+ To do so, attach the following notices to the program. It is safest
289
+ to attach them to the start of each source file to most effectively
290
+ convey the exclusion of warranty; and each file should have at least
291
+ the "copyright" line and a pointer to where the full notice is found.
292
+
293
+ <one line to give the program's name and a brief idea of what it does.>
294
+ Copyright (C) <year> <name of author>
295
+
296
+ This program is free software; you can redistribute it and/or modify
297
+ it under the terms of the GNU General Public License as published by
298
+ the Free Software Foundation; either version 2 of the License, or
299
+ (at your option) any later version.
300
+
301
+ This program is distributed in the hope that it will be useful,
302
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
303
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
304
+ GNU General Public License for more details.
305
+
306
+ You should have received a copy of the GNU General Public License along
307
+ with this program; if not, write to the Free Software Foundation, Inc.,
308
+ 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
309
+
310
+ Also add information on how to contact you by electronic and paper mail.
311
+
312
+ If the program is interactive, make it output a short notice like this
313
+ when it starts in an interactive mode:
314
+
315
+ Gnomovision version 69, Copyright (C) year name of author
316
+ Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
317
+ This is free software, and you are welcome to redistribute it
318
+ under certain conditions; type `show c' for details.
319
+
320
+ The hypothetical commands `show w' and `show c' should show the appropriate
321
+ parts of the General Public License. Of course, the commands you use may
322
+ be called something other than `show w' and `show c'; they could even be
323
+ mouse-clicks or menu items--whatever suits your program.
324
+
325
+ You should also get your employer (if you work as a programmer) or your
326
+ school, if any, to sign a "copyright disclaimer" for the program, if
327
+ necessary. Here is a sample; alter the names:
328
+
329
+ Yoyodyne, Inc., hereby disclaims all copyright interest in the program
330
+ `Gnomovision' (which makes passes at compilers) written by James Hacker.
331
+
332
+ <signature of Ty Coon>, 1 April 1989
333
+ Ty Coon, President of Vice
334
+
335
+ This General Public License does not permit incorporating your program into
336
+ proprietary programs. If your program is a subroutine library, you may
337
+ consider it more useful to permit linking proprietary applications with the
338
+ library. If this is what you want to do, use the GNU Lesser General
339
+ Public License instead of this License.
README.md CHANGED
@@ -1,11 +1,60 @@
1
- ---
2
- title: 19 Onnx
3
- emoji: 👀
4
- colorFrom: purple
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- license: gpl-2.0
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## onnx inference server in docker container
2
+
3
+ ### Copy the demo web from [link](https://huggingface.co/spaces/skytnt/moe-japanese-tts/tree/main)
4
+ > Thanks a lot to [@CjangCjengh](https://github.com/CjangCjengh)
5
+ > Thanks a lot to [wetts](https://github.com/wenet-e2e/wetts)
6
+
7
+ ***Only used for entertainment.
8
+ Don't used for bussiness***
9
+
10
+ ### quick start
11
+ > To use other model and config<br> please use -v /path/to/dir:/app/.model to mount your model and config
12
+
13
+ ```shell
14
+ export name=vits_onnx
15
+ docker stop $name
16
+ docker rm $name
17
+ docker run -d \
18
+ --name $name \
19
+ -p 7860:7860 \
20
+ ccdesue/vits_demo:onnx
21
+ # -v /path/to/dir:/app/.model
22
+ ```
23
+
24
+
25
+
26
+
27
+ ### dir structure
28
+ ```
29
+
30
+ ├── app # gradio code
31
+ ├── build.sh
32
+ ├── Dockerfile
33
+ ├── export # some util for export model
34
+ ├── LICENSE
35
+ ├── poetry.lock
36
+ ├── __pycache__
37
+ ├── pyproject.toml
38
+ ├── README.md
39
+ ├── setup.sh
40
+ └── util # some posibile util
41
+
42
+ ```
43
+
44
+ ### Helpful info
45
+ 1. please read the source code to better understand
46
+ 2. refer to the demo config.json to tail to your own model config
47
+ 3. refer the dockerfile
48
+
49
+ ### limitation
50
+ 1. only test on japanese_cleaners and japanese_cleaners2 in config.json with [raw vits](https://github.com/jaywalnut310/vits)
51
+
52
+
53
+ ### Reference
54
+ 1. [vits_export_discussion](https://github.com/MasayaKawamura/MB-iSTFT-VITS/issues/8)
55
+ 2. [other_vits_onnx](https://github.com/NaruseMioShirakana/VitsOnnx)
56
+ 3. [wetts](https://github.com/wenet-e2e/wetts)
57
+ 4. [android_vits](https://github.com/weirdseed/Vits-Android-ncnn)
58
+
59
+ ### license
60
+ GPLv2
app/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class HParams():
3
+ def __init__(self, **kwargs):
4
+ for k, v in kwargs.items():
5
+ if type(v) == dict:
6
+ v = HParams(**v)
7
+ self[k] = v
8
+
9
+ def keys(self):
10
+ return self.__dict__.keys()
11
+
12
+ def items(self):
13
+ return self.__dict__.items()
14
+
15
+ def values(self):
16
+ return self.__dict__.values()
17
+
18
+ def __len__(self):
19
+ return len(self.__dict__)
20
+
21
+ def __getitem__(self, key):
22
+ return getattr(self, key)
23
+
24
+ def __setitem__(self, key, value):
25
+ return setattr(self, key, value)
26
+
27
+ def __contains__(self, key):
28
+ return key in self.__dict__
29
+
30
+ def __repr__(self):
31
+ return self.__dict__.__repr__()
app/config.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from loguru import logger
5
+ # from app import CONFIG_URL, MODEL_URL
6
+ from app.util import get_hparams_from_file, get_paths, time_it
7
+ import requests
8
+ from tqdm.auto import tqdm
9
+ import re
10
+ from re import Pattern
11
+ import onnxruntime as ort
12
+ import threading
13
+
14
+
15
+ MODEL_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJmckZWcGdCR0xxLWJmU28/root/content"
16
+ CONFIG_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJhNEJ3enhhUHpqNE5EZWc/root/content"
17
+
18
+
19
+
20
+ class Config:
21
+ hps: dict = None
22
+ pattern: Pattern = None
23
+ # symbol_to_id:dict = None
24
+ speaker_choices: list = None
25
+ ort_sess: ort.InferenceSession = None
26
+ model_is_ok: bool = False
27
+
28
+ @classmethod
29
+ def init(cls):
30
+
31
+ # logger.add(
32
+ # "vits_infer.log", rotation="10 MB", encoding="utf-8", enqueue=True, retention="30 days"
33
+ # )
34
+
35
+ brackets = ['(', '[', '『', '「', '【', ")", "】", "]", "』", "」", ")"]
36
+ cls.pattern = re.compile('|'.join(map(re.escape, brackets)))
37
+
38
+ dir_path = Path(__file__).parent.absolute() / ".model"
39
+ dir_path.mkdir(
40
+ parents=True, exist_ok=True
41
+ )
42
+ model_path, config_path = get_paths(dir_path)
43
+
44
+ if not model_path or not config_path:
45
+ model_path = dir_path / "model.onnx"
46
+ config_path = dir_path / "config.json"
47
+ logger.warning(
48
+ "unable to find model or config, try to download default model and config"
49
+ )
50
+ cfg = requests.get(CONFIG_URL, timeout=5).content
51
+ with open(str(config_path), 'wb') as f:
52
+ f.write(cfg)
53
+ cls.setup_config(str(config_path))
54
+ t = threading.Thread(target=cls.pdownload,
55
+ args=(MODEL_URL, str(model_path)))
56
+ t.start()
57
+ # cls.pdownload(MODEL_URL, str(model_path))
58
+
59
+ else:
60
+ cls.setup_config(str(config_path))
61
+ cls.setup_model(str(model_path))
62
+
63
+ @classmethod
64
+ @logger.catch
65
+ @time_it
66
+ def setup_model(cls, model_path: str):
67
+ import numpy as np
68
+ cls.ort_sess = ort.InferenceSession(model_path)
69
+ # init the model
70
+ seq = np.random.randint(low=0, high=len(
71
+ cls.hps.symbols), size=(1, 10), dtype=np.int64)
72
+
73
+ # seq_len = torch.IntTensor([seq.size(1)]).long()
74
+ seq_len = np.array([seq.shape[1]], dtype=np.int64)
75
+
76
+ # noise(可用于控制感情等变化程度) lenth(可用于控制整体语速) noisew(控制音素发音长度变化程度)
77
+ # 参考 https://github.com/gbxh/genshinTTS
78
+ # scales = torch.FloatTensor([0.667, 1.0, 0.8])
79
+ scales = np.array([0.667, 1.0, 0.8], dtype=np.float32)
80
+ # make triton dynamic shape happy
81
+ # scales = scales.unsqueeze(0)
82
+ scales.resize(1, 3)
83
+ # sid = torch.IntTensor([0]).long()
84
+ sid = np.array([0], dtype=np.int64)
85
+ # sid = torch.LongTensor([0])
86
+ ort_inputs = {
87
+ 'input': seq,
88
+ 'input_lengths': seq_len,
89
+ 'scales': scales,
90
+ 'sid': sid
91
+ }
92
+ cls.ort_sess.run(None, ort_inputs)
93
+
94
+ cls.model_is_ok = True
95
+
96
+ logger.info(
97
+ f"model init done with model path {model_path}"
98
+ )
99
+
100
+ @classmethod
101
+ def setup_config(cls, config_path: str):
102
+ cls.hps = get_hparams_from_file(config_path)
103
+ cls.speaker_choices = list(
104
+ map(lambda x: str(x[0])+":"+x[1], enumerate(cls.hps.speakers)))
105
+
106
+ logger.info(
107
+ f"config init done with config path {config_path}"
108
+ )
109
+
110
+ @classmethod
111
+ def pdownload(cls, url: str, save_path: str, chunk_size: int = 8192):
112
+ # copy from https://github.com/tqdm/tqdm/blob/master/examples/tqdm_requests.py
113
+ file_size = int(requests.head(url).headers["Content-Length"])
114
+ response = requests.get(url, stream=True)
115
+ with tqdm(total=file_size, unit='B', unit_scale=True, unit_divisor=1024, miniters=1,
116
+ desc="model download") as pbar:
117
+
118
+ with open(save_path, 'wb') as f:
119
+ for chunk in response.iter_content(chunk_size=chunk_size):
120
+ if chunk:
121
+ f.write(chunk)
122
+ pbar.update(chunk_size)
123
+ cls.setup_model(save_path)
app/init_jptalk.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import pyopenjtalk
4
+ import gradio as gr
5
+
6
+
7
+
8
+ pyopenjtalk. _lazy_init()
9
+ # pyopenjtalk._extract_dic()
app/main.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from multiprocessing import Process
2
+ import numpy as np
3
+ from .util import find_path_by_suffix, time_it
4
+ from loguru import logger
5
+ from .util import intersperse
6
+ from .config import Config
7
+ from .text import text_to_sequence
8
+ import gradio as gr
9
+ # import sys
10
+ # sys.path.append('..')
11
+
12
+
13
+ def text_to_seq(text: str):
14
+ text = Config.pattern.sub(' ', text).strip()
15
+ text_norm = text_to_sequence(
16
+ text, Config.hps.symbols, Config.hps.data.text_cleaners)
17
+ if Config.hps.data.add_blank:
18
+ text_norm = intersperse(text_norm, 0)
19
+ return text_norm
20
+
21
+
22
+ @time_it
23
+ @logger.catch
24
+ def tts_fn(text, speaker_id, speed=1.0):
25
+
26
+ if len(text) > 300:
27
+ return "Error: Text is too long, please down it to 300 characters", None
28
+
29
+ if not Config.model_is_ok:
30
+ return "Error: model not loaded, please wait for a while or look the log", None
31
+
32
+ seq = text_to_seq(text)
33
+ x = np.array([seq], dtype=np.int64)
34
+ x_len = np.array([x.shape[1]], dtype=np.int64)
35
+ sid = np.array([speaker_id], dtype=np.int64)
36
+ speed = 1/speed
37
+ scales = np.array([0.667, speed, 0.8], dtype=np.float32)
38
+ scales.resize(1, 3)
39
+ ort_inputs = {
40
+ 'input': x,
41
+ 'input_lengths': x_len,
42
+ 'scales': scales,
43
+ 'sid': sid
44
+ }
45
+ audio = np.squeeze(Config.ort_sess.run(None, ort_inputs))
46
+ audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
47
+ audio = np.clip(audio, -32767.0, 32767.0)
48
+
49
+ return "success", (Config.hps.data.sampling_rate, audio.astype(np.int16))
50
+
51
+
52
+ def set_gradio_view():
53
+ app = gr.Blocks()
54
+
55
+ with app:
56
+ gr.Markdown(
57
+ "a demo of web service of vits, thanks to @CjangCjengh, copy from [link](https://huggingface.co/spaces/skytnt/moe-japanese-tts)")
58
+ with gr.Tabs():
59
+ with gr.TabItem("TTS"):
60
+ with gr.Column():
61
+ tts_input1 = gr.TextArea(
62
+ label="TTS_text", value="わたしの趣味はたくさんあります。でも、一番好きな事は写真をとることです。")
63
+ tts_input2 = gr.Dropdown(
64
+ label="Speaker", choices=Config.speaker_choices, type="index", value=Config.speaker_choices[0])
65
+ tts_input3 = gr.Slider(
66
+ label="Speed", value=1, minimum=0.2, maximum=3, step=0.1)
67
+
68
+ tts_submit = gr.Button("Generate", variant="primary")
69
+ tts_output1 = gr.Textbox(label="Output Message")
70
+ tts_output2 = gr.Audio(label="Output Audio")
71
+
72
+ inputs = [
73
+ tts_input1, tts_input2, tts_input3
74
+ ]
75
+ outputs = [
76
+ tts_output1, tts_output2]
77
+
78
+ tts_submit.click(tts_fn, inputs=inputs, outputs=outputs)
79
+
80
+ app.queue(concurrency_count=3)
81
+ gr.close_all()
82
+ app.launch(server_name='0.0.0.0', show_api=False,
83
+ share=False, server_port=7860)
84
+
85
+
86
+ def main():
87
+ # p = Process(target=Config.init)
88
+ # p.start()
89
+ Config.init()
90
+ set_gradio_view()
91
+
92
+
93
+ if __name__ == '__main__':
94
+ main()
app/text/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017 Keith Ito
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
app/text/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ from loguru import logger
3
+ # from app.config import Config
4
+ from . import cleaners
5
+
6
+ _symbol_to_id = None
7
+
8
+ def text_to_sequence(text, symbols, cleaner_names):
9
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
10
+ Args:
11
+ text: string to convert to a sequence
12
+ symbols: list of symbols in the text
13
+ cleaner_names: names of the cleaner functions to run the text through
14
+ Returns:
15
+ List of integers corresponding to the symbols in the text
16
+
17
+
18
+ ATTENTION: unable to access Config variabel , don't know why
19
+ '''
20
+
21
+ global _symbol_to_id
22
+
23
+
24
+ if not _symbol_to_id:
25
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
26
+
27
+
28
+
29
+ clean_text = _clean_text(text, cleaner_names)
30
+
31
+ sequence = [
32
+ _symbol_to_id[symbol] for symbol in clean_text if symbol in _symbol_to_id.keys()
33
+ ]
34
+
35
+ # for symbol in clean_text:
36
+ # if symbol not in _symbol_to_id.keys():
37
+ # continue
38
+ # symbol_id = _symbol_to_id[symbol]
39
+ # sequence += [symbol_id]
40
+ return sequence
41
+
42
+
43
+ def _clean_text(text, cleaner_names):
44
+ for name in cleaner_names:
45
+ cleaner = getattr(cleaners, name)
46
+ if not cleaner:
47
+ raise Exception('Unknown cleaner: %s' % name)
48
+ text = cleaner(text)
49
+ return text
app/text/cleaners.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from unidecode import unidecode
3
+ import pyopenjtalk
4
+
5
+ pyopenjtalk._lazy_init()
6
+
7
+ # Regular expression matching Japanese without punctuation marks:
8
+ _japanese_characters = re.compile(
9
+ r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
10
+
11
+ # Regular expression matching non-Japanese characters or punctuation marks:
12
+ _japanese_marks = re.compile(
13
+ r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
14
+
15
+
16
+ def japanese_cleaners(text):
17
+ '''Pipeline for notating accent in Japanese text.'''
18
+ '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
19
+ sentences = re.split(_japanese_marks, text)
20
+ marks = re.findall(_japanese_marks, text)
21
+ text = ''
22
+ for i, sentence in enumerate(sentences):
23
+ if re.match(_japanese_characters, sentence):
24
+ if text != '':
25
+ text += ' '
26
+ labels = pyopenjtalk.extract_fullcontext(sentence)
27
+ for n, label in enumerate(labels):
28
+ phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
29
+ if phoneme not in ['sil', 'pau']:
30
+ text += phoneme.replace('ch', 'ʧ').replace('sh', 'ʃ').replace('cl', 'Q')
31
+ else:
32
+ continue
33
+ n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
34
+ a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
35
+ a2 = int(re.search(r"\+(\d+)\+", label).group(1))
36
+ a3 = int(re.search(r"\+(\d+)/", label).group(1))
37
+ if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
38
+ a2_next = -1
39
+ else:
40
+ a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
41
+ # Accent phrase boundary
42
+ if a3 == 1 and a2_next == 1:
43
+ text += ' '
44
+ # Falling
45
+ elif a1 == 0 and a2_next == a2 + 1 and a2 != n_moras:
46
+ text += '↓'
47
+ # Rising
48
+ elif a2 == 1 and a2_next == 2:
49
+ text += '↑'
50
+ if i < len(marks):
51
+ text += unidecode(marks[i]).replace(' ', '')
52
+ if re.match('[A-Za-z]', text[-1]):
53
+ text += '.'
54
+ return text
55
+
56
+
57
+ def japanese_cleaners2(text):
58
+ return japanese_cleaners(text).replace('ts','ʦ').replace('...','…')
app/text/symbols.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Defines the set of symbols used in text input to the model.
3
+ '''
4
+
5
+ '''# japanese_cleaners
6
+ _pad = '_'
7
+ _punctuation = ',.!?-'
8
+ _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
9
+ '''
10
+ # jp_cleaners
11
+ _pad = '_'
12
+ _punctuation = ',.!?-'
13
+ _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
14
+
15
+
16
+
17
+ # japanese_cleaners2
18
+ # _pad = '_'
19
+ # _punctuation = ',.!?-~…'
20
+ # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
21
+
22
+
23
+ '''# korean_cleaners
24
+ _pad = '_'
25
+ _punctuation = ',.!?…~'
26
+ _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
27
+ '''
28
+
29
+ '''# chinese_cleaners
30
+ _pad = '_'
31
+ _punctuation = ',。!?—…'
32
+ _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
33
+ '''
34
+
35
+ # Export all symbols:
36
+ symbols = [_pad] + list(_punctuation) + list(_letters)
37
+
38
+ # Special symbol ids
39
+ SPACE_ID = symbols.index(" ")
app/util.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import pathlib
4
+ # import tqdm
5
+
6
+ from typing import Optional
7
+ import os
8
+ import threading
9
+
10
+ from loguru import logger
11
+ # from app.common import HParams
12
+ # from __ini import HParams
13
+ from pathlib import Path
14
+ import requests
15
+
16
+ from app import HParams
17
+
18
+
19
+ def find_path_by_suffix(dir_path: Path, suffix: Path):
20
+ assert dir_path.is_dir()
21
+
22
+ for path in dir_path.glob(f"*.{suffix}"):
23
+ return path
24
+
25
+ return None
26
+
27
+
28
+ def get_hparams_from_file(config_path):
29
+ with open(config_path, "r") as f:
30
+ data = f.read()
31
+ config = json.loads(data)
32
+
33
+ hparams = HParams(**config)
34
+ return hparams
35
+
36
+
37
+ def intersperse(lst, item):
38
+ result = [item] * (len(lst) * 2 + 1)
39
+ result[1::2] = lst
40
+ return result
41
+
42
+
43
+ def time_it(func: callable):
44
+ import time
45
+
46
+ def wrapper(*args, **kwargs):
47
+ # start = time.time()
48
+ start = time.perf_counter()
49
+ res = func(*args, **kwargs)
50
+ # end = time.time()
51
+ end = time.perf_counter()
52
+ # print(f"func {func.__name__} cost {end-start} seconds")
53
+ logger.info(f"func {func.__name__} cost {end-start} seconds")
54
+ return res
55
+ return wrapper
56
+
57
+
58
+
59
+
60
+
61
+ # def download_defaults(model_path: pathlib.Path, config_path: pathlib.Path):
62
+
63
+ # config = requests.get(config_url, timeout=10).content
64
+ # with open(str(config_path), 'wb') as f:
65
+ # f.write(config)
66
+
67
+ # t = threading.Thread(target=pdownload, args=(model_url, str(model_path)))
68
+ # t.start()
69
+
70
+
71
+ def get_paths(dir_path: Path):
72
+
73
+ model_path: Path = find_path_by_suffix(dir_path, "onnx")
74
+ config_path: Path = find_path_by_suffix(dir_path, "json")
75
+ # if not model_path or not config_path:
76
+ # model_path = dir_path / "model.onnx"
77
+ # config_path = dir_path / "config.json"
78
+ # logger.warning(
79
+ # "unable to find model or config, try to download default model and config"
80
+ # )
81
+ # download_defaults(model_path, config_path)
82
+
83
+ # model_path = str(model_path)
84
+ # config_path = str(config_path)
85
+ # logger.info(f"model path: {model_path} config path: {config_path}")
86
+ return model_path, config_path
build.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ docker image rm ccdesue/vits_demo:onnx -f --no-prune
2
+ docker build -t ccdesue/vits_demo:onnx .
3
+
4
+ function run()
5
+ {
6
+ export name=vits_onnx
7
+ docker stop $name
8
+ docker rm $name
9
+ docker run -d \
10
+ --name $name \
11
+ -p 7860:7860 \
12
+ ccdesue/vits_demo:onnx
13
+ }
14
+
15
+ # docker run --rm -it -p 7860:7860/tcp ccdesue/vits_demo:onnx bash
16
+
17
+ function push(){
18
+
19
+ docker push ccdesue/vits_demo:onnx
20
+ }
export/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
export/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ > Thanks a lot to [wetts](https://github.com/wenet-e2e/wetts)
2
+ > 欢迎 pr
3
+ ## 修改说明
4
+ 1. 将原仓库的配置文件修改成[@CjangCjengh](https://github.com/CjangCjengh)用的部署文件 详细参考config.json
5
+ 2. 为导出代码添加注释, tensor修改为np.array
6
+ 3. 有问题请认真阅读源码
7
+
8
+
export/export.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ function get_data(){
3
+ mkdir -p model
4
+ cd model
5
+ model_url='https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJiTzdqanlEQXNyWDV4bDA/root/content'
6
+ config_url='https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJhNEJ3enhhUHpqNE5EZWc/root/content'
7
+
8
+ wget -O model.pth $model_url
9
+ wget -O config.json $config_url
10
+ cd ..
11
+ }
12
+
13
+
14
+ # mkdir -p model
15
+
16
+ python vits/export_onnx.py --checkpoint model/model.pth --cfg model/config.json \
17
+ --onnx_model model/model.onnx
18
+
19
+ # https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBb1E3R21YN3hRWkpnYjQtT1VocVdjUFc4VWM5bVE/root/content
export/infer.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ python vits/inference_onnx.py --onnx_model model/model.onnx \
2
+ --cfg model/config.json --test_file test.txt
export/requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ flake8==3.8.2
3
+ flake8-bugbear
4
+ flake8-comprehensions
5
+ flake8-executable
6
+ flake8-pyi==20.5.0
7
+ # mccabe
8
+ pycodestyle==2.6.0
9
+ pyflakes==2.2.0
10
+ # tensorboard
11
+ sklearn
12
+ WeTextProcessing
13
+ monotonic_align
14
+ matplotlib
15
+ librosa
16
+ scipy
17
+ transformers
18
+ # Cython
19
+ pyopenjtalk
20
+ unidecode
21
+ # pip3 install torch --extra-index-url https://download.pytorch.org/whl/cpu
export/setup.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # update pip
2
+
3
+ function setup_py(){
4
+ conda create -n dl python=3.9 -y
5
+ conda init bash
6
+ bash
7
+ }
8
+
9
+ conda activate dl
10
+ export DEBIAN_FRONTEND=noninteractive && \
11
+ sudo apt-get update && \
12
+ sudo apt-get install cmake build-essential -y --no-install-recommends
13
+
14
+
15
+ pip install --upgrade pip
16
+ pip3 install torch --extra-index-url https://download.pytorch.org/whl/cpu
17
+ pip3 install onnxruntime Cython
18
+ pip install -r requirements.txt
export/test.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ test.wav|わたしの趣味はたくさんあります。でも、一番好きな事は写真をとることです。
export/vits/attentions.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import commons
8
+ from modules import LayerNorm
9
+
10
+
11
+ class Encoder(nn.Module):
12
+ def __init__(self,
13
+ hidden_channels,
14
+ filter_channels,
15
+ n_heads,
16
+ n_layers,
17
+ kernel_size=1,
18
+ p_dropout=0.,
19
+ window_size=4,
20
+ **kwargs):
21
+ super().__init__()
22
+ self.hidden_channels = hidden_channels
23
+ self.filter_channels = filter_channels
24
+ self.n_heads = n_heads
25
+ self.n_layers = n_layers
26
+ self.kernel_size = kernel_size
27
+ self.p_dropout = p_dropout
28
+ self.window_size = window_size
29
+
30
+ self.drop = nn.Dropout(p_dropout)
31
+ self.attn_layers = nn.ModuleList()
32
+ self.norm_layers_1 = nn.ModuleList()
33
+ self.ffn_layers = nn.ModuleList()
34
+ self.norm_layers_2 = nn.ModuleList()
35
+ for i in range(self.n_layers):
36
+ self.attn_layers.append(
37
+ MultiHeadAttention(hidden_channels,
38
+ hidden_channels,
39
+ n_heads,
40
+ p_dropout=p_dropout,
41
+ window_size=window_size))
42
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
43
+ self.ffn_layers.append(
44
+ FFN(hidden_channels,
45
+ hidden_channels,
46
+ filter_channels,
47
+ kernel_size,
48
+ p_dropout=p_dropout))
49
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
50
+
51
+ def forward(self, x, x_mask):
52
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
53
+ x = x * x_mask
54
+ for i in range(self.n_layers):
55
+ y = self.attn_layers[i](x, x, attn_mask)
56
+ y = self.drop(y)
57
+ x = self.norm_layers_1[i](x + y)
58
+
59
+ y = self.ffn_layers[i](x, x_mask)
60
+ y = self.drop(y)
61
+ x = self.norm_layers_2[i](x + y)
62
+ x = x * x_mask
63
+ return x
64
+
65
+
66
+ class Decoder(nn.Module):
67
+ def __init__(self,
68
+ hidden_channels,
69
+ filter_channels,
70
+ n_heads,
71
+ n_layers,
72
+ kernel_size=1,
73
+ p_dropout=0.,
74
+ proximal_bias=False,
75
+ proximal_init=True,
76
+ **kwargs):
77
+ super().__init__()
78
+ self.hidden_channels = hidden_channels
79
+ self.filter_channels = filter_channels
80
+ self.n_heads = n_heads
81
+ self.n_layers = n_layers
82
+ self.kernel_size = kernel_size
83
+ self.p_dropout = p_dropout
84
+ self.proximal_bias = proximal_bias
85
+ self.proximal_init = proximal_init
86
+
87
+ self.drop = nn.Dropout(p_dropout)
88
+ self.self_attn_layers = nn.ModuleList()
89
+ self.norm_layers_0 = nn.ModuleList()
90
+ self.encdec_attn_layers = nn.ModuleList()
91
+ self.norm_layers_1 = nn.ModuleList()
92
+ self.ffn_layers = nn.ModuleList()
93
+ self.norm_layers_2 = nn.ModuleList()
94
+ for i in range(self.n_layers):
95
+ self.self_attn_layers.append(
96
+ MultiHeadAttention(hidden_channels,
97
+ hidden_channels,
98
+ n_heads,
99
+ p_dropout=p_dropout,
100
+ proximal_bias=proximal_bias,
101
+ proximal_init=proximal_init))
102
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
103
+ self.encdec_attn_layers.append(
104
+ MultiHeadAttention(hidden_channels,
105
+ hidden_channels,
106
+ n_heads,
107
+ p_dropout=p_dropout))
108
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
109
+ self.ffn_layers.append(
110
+ FFN(hidden_channels,
111
+ hidden_channels,
112
+ filter_channels,
113
+ kernel_size,
114
+ p_dropout=p_dropout,
115
+ causal=True))
116
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
117
+
118
+ def forward(self, x, x_mask, h, h_mask):
119
+ """
120
+ x: decoder input
121
+ h: encoder output
122
+ """
123
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
124
+ device=x.device, dtype=x.dtype)
125
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
126
+ x = x * x_mask
127
+ for i in range(self.n_layers):
128
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
129
+ y = self.drop(y)
130
+ x = self.norm_layers_0[i](x + y)
131
+
132
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
133
+ y = self.drop(y)
134
+ x = self.norm_layers_1[i](x + y)
135
+
136
+ y = self.ffn_layers[i](x, x_mask)
137
+ y = self.drop(y)
138
+ x = self.norm_layers_2[i](x + y)
139
+ x = x * x_mask
140
+ return x
141
+
142
+
143
+ class MultiHeadAttention(nn.Module):
144
+ def __init__(self,
145
+ channels,
146
+ out_channels,
147
+ n_heads,
148
+ p_dropout=0.,
149
+ window_size=None,
150
+ heads_share=True,
151
+ block_length=None,
152
+ proximal_bias=False,
153
+ proximal_init=False):
154
+ super().__init__()
155
+ assert channels % n_heads == 0
156
+
157
+ self.channels = channels
158
+ self.out_channels = out_channels
159
+ self.n_heads = n_heads
160
+ self.p_dropout = p_dropout
161
+ self.window_size = window_size
162
+ self.heads_share = heads_share
163
+ self.block_length = block_length
164
+ self.proximal_bias = proximal_bias
165
+ self.proximal_init = proximal_init
166
+ self.attn = None
167
+
168
+ self.k_channels = channels // n_heads
169
+ self.conv_q = nn.Conv1d(channels, channels, 1)
170
+ self.conv_k = nn.Conv1d(channels, channels, 1)
171
+ self.conv_v = nn.Conv1d(channels, channels, 1)
172
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
173
+ self.drop = nn.Dropout(p_dropout)
174
+
175
+ if window_size is not None:
176
+ n_heads_rel = 1 if heads_share else n_heads
177
+ rel_stddev = self.k_channels**-0.5
178
+ self.emb_rel_k = nn.Parameter(
179
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
180
+ * rel_stddev)
181
+ self.emb_rel_v = nn.Parameter(
182
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
183
+ * rel_stddev)
184
+
185
+ nn.init.xavier_uniform_(self.conv_q.weight)
186
+ nn.init.xavier_uniform_(self.conv_k.weight)
187
+ nn.init.xavier_uniform_(self.conv_v.weight)
188
+ if proximal_init:
189
+ with torch.no_grad():
190
+ self.conv_k.weight.copy_(self.conv_q.weight)
191
+ self.conv_k.bias.copy_(self.conv_q.bias)
192
+
193
+ def forward(self, x, c, attn_mask=None):
194
+ q = self.conv_q(x)
195
+ k = self.conv_k(c)
196
+ v = self.conv_v(c)
197
+
198
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
199
+
200
+ x = self.conv_o(x)
201
+ return x
202
+
203
+ def attention(self, query, key, value, mask=None):
204
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
205
+ b, d, t_s, t_t = (*key.size(), query.size(2))
206
+ query = query.view(b, self.n_heads, self.k_channels,
207
+ t_t).transpose(2, 3)
208
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
209
+ value = value.view(b, self.n_heads, self.k_channels,
210
+ t_s).transpose(2, 3)
211
+
212
+ scores = torch.matmul(query / math.sqrt(self.k_channels),
213
+ key.transpose(-2, -1))
214
+ if self.window_size is not None:
215
+ msg = "Relative attention is only available for self-attention."
216
+ assert t_s == t_t, msg
217
+ key_relative_embeddings = self._get_relative_embeddings(
218
+ self.emb_rel_k, t_s)
219
+ rel_logits = self._matmul_with_relative_keys(
220
+ query / math.sqrt(self.k_channels), key_relative_embeddings)
221
+ scores_local = self._relative_position_to_absolute_position(
222
+ rel_logits)
223
+ scores = scores + scores_local
224
+ if self.proximal_bias:
225
+ msg = "Proximal bias is only available for self-attention."
226
+ assert t_s == t_t, msg
227
+ scores = scores + self._attention_bias_proximal(t_s).to(
228
+ device=scores.device, dtype=scores.dtype)
229
+ if mask is not None:
230
+ scores = scores.masked_fill(mask == 0, -1e4)
231
+ if self.block_length is not None:
232
+ msg = "Local attention is only available for self-attention."
233
+ assert t_s == t_t, msg
234
+ block_mask = torch.ones_like(scores).triu(
235
+ -self.block_length).tril(self.block_length)
236
+ scores = scores.masked_fill(block_mask == 0, -1e4)
237
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
238
+ p_attn = self.drop(p_attn)
239
+ output = torch.matmul(p_attn, value)
240
+ if self.window_size is not None:
241
+ relative_weights = self._absolute_position_to_relative_position(
242
+ p_attn)
243
+ value_relative_embeddings = self._get_relative_embeddings(
244
+ self.emb_rel_v, t_s)
245
+ output = output + self._matmul_with_relative_values(
246
+ relative_weights, value_relative_embeddings)
247
+ output = output.transpose(2, 3).contiguous().view(
248
+ b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
249
+ return output, p_attn
250
+
251
+ def _matmul_with_relative_values(self, x, y):
252
+ """
253
+ x: [b, h, l, m]
254
+ y: [h or 1, m, d]
255
+ ret: [b, h, l, d]
256
+ """
257
+ ret = torch.matmul(x, y.unsqueeze(0))
258
+ return ret
259
+
260
+ def _matmul_with_relative_keys(self, x, y):
261
+ """
262
+ x: [b, h, l, d]
263
+ y: [h or 1, m, d]
264
+ ret: [b, h, l, m]
265
+ """
266
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
267
+ return ret
268
+
269
+ def _get_relative_embeddings(self, relative_embeddings, length):
270
+ max_relative_position = 2 * self.window_size + 1
271
+ # Pad first before slice to avoid using cond ops.
272
+ pad_length = max(length - (self.window_size + 1), 0)
273
+ slice_start_position = max((self.window_size + 1) - length, 0)
274
+ slice_end_position = slice_start_position + 2 * length - 1
275
+ if pad_length > 0:
276
+ padded_relative_embeddings = F.pad(
277
+ relative_embeddings,
278
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length],
279
+ [0, 0]]))
280
+ else:
281
+ padded_relative_embeddings = relative_embeddings
282
+ used_relative_embeddings = padded_relative_embeddings[:,
283
+ slice_start_position:
284
+ slice_end_position]
285
+ return used_relative_embeddings
286
+
287
+ def _relative_position_to_absolute_position(self, x):
288
+ """
289
+ x: [b, h, l, 2*l-1]
290
+ ret: [b, h, l, l]
291
+ """
292
+ batch, heads, length, _ = x.size()
293
+ # Concat columns of pad to shift from relative to absolute indexing.
294
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0,
295
+ 1]]))
296
+
297
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
298
+ x_flat = x.view([batch, heads, length * 2 * length])
299
+ x_flat = F.pad(
300
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0,
301
+ length - 1]]))
302
+
303
+ # Reshape and slice out the padded elements.
304
+ x_final = x_flat.view([batch, heads, length + 1,
305
+ 2 * length - 1])[:, :, :length, length - 1:]
306
+ return x_final
307
+
308
+ def _absolute_position_to_relative_position(self, x):
309
+ """
310
+ x: [b, h, l, l]
311
+ ret: [b, h, l, 2*l-1]
312
+ """
313
+ batch, heads, length, _ = x.size()
314
+ # padd along column
315
+ x = F.pad(
316
+ x,
317
+ commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0,
318
+ length - 1]]))
319
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
320
+ # add 0's in the beginning that will skew the elements after reshape
321
+ x_flat = F.pad(
322
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
323
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
324
+ return x_final
325
+
326
+ def _attention_bias_proximal(self, length):
327
+ """Bias for self-attention to encourage attention to close positions.
328
+ Args:
329
+ length: an integer scalar.
330
+ Returns:
331
+ a Tensor with shape [1, 1, length, length]
332
+ """
333
+ r = torch.arange(length, dtype=torch.float32)
334
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
335
+ return torch.unsqueeze(
336
+ torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
337
+
338
+
339
+ class FFN(nn.Module):
340
+ def __init__(self,
341
+ in_channels,
342
+ out_channels,
343
+ filter_channels,
344
+ kernel_size,
345
+ p_dropout=0.,
346
+ activation=None,
347
+ causal=False):
348
+ super().__init__()
349
+ self.in_channels = in_channels
350
+ self.out_channels = out_channels
351
+ self.filter_channels = filter_channels
352
+ self.kernel_size = kernel_size
353
+ self.p_dropout = p_dropout
354
+ self.activation = activation
355
+ self.causal = causal
356
+
357
+ if causal:
358
+ self.padding = self._causal_padding
359
+ else:
360
+ self.padding = self._same_padding
361
+
362
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
363
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
364
+ self.drop = nn.Dropout(p_dropout)
365
+
366
+ def forward(self, x, x_mask):
367
+ x = self.conv_1(self.padding(x * x_mask))
368
+ if self.activation == "gelu":
369
+ x = x * torch.sigmoid(1.702 * x)
370
+ else:
371
+ x = torch.relu(x)
372
+ x = self.drop(x)
373
+ x = self.conv_2(self.padding(x * x_mask))
374
+ return x * x_mask
375
+
376
+ def _causal_padding(self, x):
377
+ if self.kernel_size == 1:
378
+ return x
379
+ pad_l = self.kernel_size - 1
380
+ pad_r = 0
381
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
382
+ x = F.pad(x, commons.convert_pad_shape(padding))
383
+ return x
384
+
385
+ def _same_padding(self, x):
386
+ if self.kernel_size == 1:
387
+ return x
388
+ pad_l = (self.kernel_size - 1) // 2
389
+ pad_r = self.kernel_size // 2
390
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
391
+ x = F.pad(x, commons.convert_pad_shape(padding))
392
+ return x
export/vits/commons.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def init_weights(m, mean=0.0, std=0.01):
8
+ classname = m.__class__.__name__
9
+ if classname.find("Conv") != -1:
10
+ m.weight.data.normal_(mean, std)
11
+
12
+
13
+ def get_padding(kernel_size, dilation=1):
14
+ return int((kernel_size * dilation - dilation) / 2)
15
+
16
+
17
+ def convert_pad_shape(pad_shape):
18
+ pad_shape = [item for sublist in reversed(pad_shape) for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += 0.5 * (torch.exp(2. * logs_p) +
32
+ ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
33
+ return kl
34
+
35
+
36
+ def rand_gumbel(shape):
37
+ """Sample from the Gumbel distribution, protect from overflows."""
38
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
39
+ return -torch.log(-torch.log(uniform_samples))
40
+
41
+
42
+ def rand_gumbel_like(x):
43
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
44
+ return g
45
+
46
+
47
+ def slice_segments(x, ids_str, segment_size=4):
48
+ ret = torch.zeros_like(x[:, :, :segment_size])
49
+ for i in range(x.size(0)):
50
+ idx_str = ids_str[i]
51
+ idx_end = idx_str + segment_size
52
+ ret[i] = x[i, :, idx_str:idx_end]
53
+ return ret
54
+
55
+
56
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
57
+ b, d, t = x.size()
58
+ if x_lengths is None:
59
+ x_lengths = t
60
+ ids_str_max = x_lengths - segment_size + 1
61
+ ids_str = (torch.rand([b]).to(device=x.device) *
62
+ ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def get_timing_signal_1d(length,
68
+ channels,
69
+ min_timescale=1.0,
70
+ max_timescale=1.0e4):
71
+ position = torch.arange(length, dtype=torch.float)
72
+ num_timescales = channels // 2
73
+ log_timescale_increment = (
74
+ math.log(float(max_timescale) / float(min_timescale)) /
75
+ (num_timescales - 1))
76
+ inv_timescales = min_timescale * torch.exp(
77
+ torch.arange(num_timescales, dtype=torch.float) *
78
+ -log_timescale_increment)
79
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
80
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
81
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
82
+ signal = signal.view(1, channels, length)
83
+ return signal
84
+
85
+
86
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
87
+ b, channels, length = x.size()
88
+ signal = get_timing_signal_1d(length, channels, min_timescale,
89
+ max_timescale)
90
+ return x + signal.to(dtype=x.dtype, device=x.device)
91
+
92
+
93
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
94
+ b, channels, length = x.size()
95
+ signal = get_timing_signal_1d(length, channels, min_timescale,
96
+ max_timescale)
97
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
98
+
99
+
100
+ def subsequent_mask(length):
101
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
102
+ return mask
103
+
104
+
105
+ @torch.jit.script
106
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
107
+ n_channels_int = n_channels[0]
108
+ in_act = input_a + input_b
109
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
110
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
111
+ acts = t_act * s_act
112
+ return acts
113
+
114
+
115
+ def shift_1d(x):
116
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
117
+ return x
118
+
119
+
120
+ def sequence_mask(length, max_length=None):
121
+ if max_length is None:
122
+ max_length = length.max()
123
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
124
+ return x.unsqueeze(0) < length.unsqueeze(1)
125
+
126
+
127
+ def generate_path(duration, mask):
128
+ """
129
+ duration: [b, 1, t_x]
130
+ mask: [b, 1, t_y, t_x]
131
+ """
132
+ device = duration.device
133
+
134
+ b, _, t_y, t_x = mask.shape
135
+ cum_duration = torch.cumsum(duration, -1)
136
+
137
+ cum_duration_flat = cum_duration.view(b * t_x)
138
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
139
+ path = path.view(b, t_x, t_y)
140
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]
141
+ ]))[:, :-1]
142
+ path = path.unsqueeze(1).transpose(2, 3) * mask
143
+ return path
144
+
145
+
146
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
147
+ if isinstance(parameters, torch.Tensor):
148
+ parameters = [parameters]
149
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
150
+ norm_type = float(norm_type)
151
+ if clip_value is not None:
152
+ clip_value = float(clip_value)
153
+
154
+ total_norm = 0
155
+ for p in parameters:
156
+ param_norm = p.grad.data.norm(norm_type)
157
+ total_norm += param_norm.item()**norm_type
158
+ if clip_value is not None:
159
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
160
+ total_norm = total_norm**(1. / norm_type)
161
+ return total_norm
export/vits/data_utils.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import torch
5
+ import torchaudio
6
+ import torch.utils.data
7
+
8
+ import commons
9
+ from mel_processing import spectrogram_torch
10
+ from utils import load_filepaths_and_text
11
+
12
+
13
+ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
14
+ """
15
+ 1) loads audio, speaker_id, text pairs
16
+ 2) normalizes text and converts them to sequences of integers
17
+ 3) computes spectrograms from audio files.
18
+ """
19
+ def __init__(self, audiopaths_sid_text, hparams):
20
+ self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text)
21
+ # self.text_cleaners = hparams.text_cleaners
22
+ self.max_wav_value = hparams.max_wav_value
23
+ self.sampling_rate = hparams.sampling_rate
24
+ self.filter_length = hparams.filter_length
25
+ self.hop_length = hparams.hop_length
26
+ self.win_length = hparams.win_length
27
+ self.sampling_rate = hparams.sampling_rate
28
+ self.src_sampling_rate = getattr(hparams, "src_sampling_rate",
29
+ self.sampling_rate)
30
+
31
+ self.cleaned_text = getattr(hparams, "cleaned_text", False)
32
+
33
+ self.add_blank = hparams.add_blank
34
+ self.min_text_len = getattr(hparams, "min_text_len", 1)
35
+ self.max_text_len = getattr(hparams, "max_text_len", 190)
36
+
37
+ phone_file = getattr(hparams, "phone_table", None)
38
+ self.phone_dict = None
39
+ if phone_file is not None:
40
+ self.phone_dict = {}
41
+ with open(phone_file) as fin:
42
+ for line in fin:
43
+ arr = line.strip().split()
44
+ self.phone_dict[arr[0]] = int(arr[1])
45
+
46
+ speaker_file = getattr(hparams, "speaker_table", None)
47
+ self.speaker_dict = None
48
+ if speaker_file is not None:
49
+ self.speaker_dict = {}
50
+ with open(speaker_file) as fin:
51
+ for line in fin:
52
+ arr = line.strip().split()
53
+ self.speaker_dict[arr[0]] = int(arr[1])
54
+
55
+ random.seed(1234)
56
+ random.shuffle(self.audiopaths_sid_text)
57
+ self._filter()
58
+
59
+ def _filter(self):
60
+ """
61
+ Filter text & store spec lengths
62
+ """
63
+ # Store spectrogram lengths for Bucketing
64
+ # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
65
+ # spec_length = wav_length // hop_length
66
+
67
+ audiopaths_sid_text_new = []
68
+ lengths = []
69
+ for item in self.audiopaths_sid_text:
70
+ audiopath = item[0]
71
+ # filename|text or filename|speaker|text
72
+ text = item[1] if len(item) == 2 else item[2]
73
+ if self.min_text_len <= len(text) and len(
74
+ text) <= self.max_text_len:
75
+ audiopaths_sid_text_new.append(item)
76
+ lengths.append(
77
+ int(
78
+ os.path.getsize(audiopath) * self.sampling_rate /
79
+ self.src_sampling_rate) // (2 * self.hop_length))
80
+ self.audiopaths_sid_text = audiopaths_sid_text_new
81
+ self.lengths = lengths
82
+
83
+ def get_audio_text_speaker_pair(self, audiopath_sid_text):
84
+ audiopath = audiopath_sid_text[0]
85
+ if len(audiopath_sid_text) == 2: # filename|text
86
+ sid = 0
87
+ text = audiopath_sid_text[1]
88
+ else: # filename|speaker|text
89
+ sid = self.speaker_dict[audiopath_sid_text[1]]
90
+ text = audiopath_sid_text[2]
91
+ text = self.get_text(text)
92
+ spec, wav = self.get_audio(audiopath)
93
+ sid = self.get_sid(sid)
94
+ return (text, spec, wav, sid)
95
+
96
+ def get_audio(self, filename):
97
+ audio, sampling_rate = torchaudio.load(filename, normalize=False)
98
+ if sampling_rate != self.sampling_rate:
99
+ audio = audio.to(torch.float)
100
+ audio = torchaudio.transforms.Resample(sampling_rate,
101
+ self.sampling_rate)(audio)
102
+ audio = audio.to(torch.int16)
103
+ audio = audio[0] # Get the first channel
104
+ audio_norm = audio / self.max_wav_value
105
+ audio_norm = audio_norm.unsqueeze(0)
106
+ spec = spectrogram_torch(audio_norm,
107
+ self.filter_length,
108
+ self.sampling_rate,
109
+ self.hop_length,
110
+ self.win_length,
111
+ center=False)
112
+ spec = torch.squeeze(spec, 0)
113
+ return spec, audio_norm
114
+
115
+ def get_text(self, text):
116
+ text_norm = [self.phone_dict[phone] for phone in text.split()]
117
+ if self.add_blank:
118
+ text_norm = commons.intersperse(text_norm, 0)
119
+ text_norm = torch.LongTensor(text_norm)
120
+ return text_norm
121
+
122
+ def get_sid(self, sid):
123
+ sid = torch.LongTensor([int(sid)])
124
+ return sid
125
+
126
+ def __getitem__(self, index):
127
+ return self.get_audio_text_speaker_pair(
128
+ self.audiopaths_sid_text[index])
129
+
130
+ def __len__(self):
131
+ return len(self.audiopaths_sid_text)
132
+
133
+
134
+ class TextAudioSpeakerCollate():
135
+ """ Zero-pads model inputs and targets
136
+ """
137
+ def __init__(self, return_ids=False):
138
+ self.return_ids = return_ids
139
+
140
+ def __call__(self, batch):
141
+ """Collate's training batch from normalized text, audio and speaker identities
142
+ PARAMS
143
+ ------
144
+ batch: [text_normalized, spec_normalized, wav_normalized, sid]
145
+ """
146
+ # Right zero-pad all one-hot text sequences to max input length
147
+ _, ids_sorted_decreasing = torch.sort(torch.LongTensor(
148
+ [x[1].size(1) for x in batch]),
149
+ dim=0,
150
+ descending=True)
151
+
152
+ max_text_len = max([len(x[0]) for x in batch])
153
+ max_spec_len = max([x[1].size(1) for x in batch])
154
+ max_wav_len = max([x[2].size(1) for x in batch])
155
+
156
+ text_lengths = torch.LongTensor(len(batch))
157
+ spec_lengths = torch.LongTensor(len(batch))
158
+ wav_lengths = torch.LongTensor(len(batch))
159
+ sid = torch.LongTensor(len(batch))
160
+
161
+ text_padded = torch.LongTensor(len(batch), max_text_len)
162
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0),
163
+ max_spec_len)
164
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
165
+ text_padded.zero_()
166
+ spec_padded.zero_()
167
+ wav_padded.zero_()
168
+ for i in range(len(ids_sorted_decreasing)):
169
+ row = batch[ids_sorted_decreasing[i]]
170
+
171
+ text = row[0]
172
+ text_padded[i, :text.size(0)] = text
173
+ text_lengths[i] = text.size(0)
174
+
175
+ spec = row[1]
176
+ spec_padded[i, :, :spec.size(1)] = spec
177
+ spec_lengths[i] = spec.size(1)
178
+
179
+ wav = row[2]
180
+ wav_padded[i, :, :wav.size(1)] = wav
181
+ wav_lengths[i] = wav.size(1)
182
+
183
+ sid[i] = row[3]
184
+
185
+ if self.return_ids:
186
+ return (text_padded, text_lengths, spec_padded, spec_lengths,
187
+ wav_padded, wav_lengths, sid, ids_sorted_decreasing)
188
+ return (text_padded, text_lengths, spec_padded, spec_lengths,
189
+ wav_padded, wav_lengths, sid)
190
+
191
+
192
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler
193
+ ):
194
+ """
195
+ Maintain similar input lengths in a batch.
196
+ Length groups are specified by boundaries.
197
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either
198
+ {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
199
+
200
+ It removes samples which are not included in the boundaries.
201
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1
202
+ or length(x) > b3 are discarded.
203
+ """
204
+ def __init__(self,
205
+ dataset,
206
+ batch_size,
207
+ boundaries,
208
+ num_replicas=None,
209
+ rank=None,
210
+ shuffle=True):
211
+ super().__init__(dataset,
212
+ num_replicas=num_replicas,
213
+ rank=rank,
214
+ shuffle=shuffle)
215
+ self.lengths = dataset.lengths
216
+ self.batch_size = batch_size
217
+ self.boundaries = boundaries
218
+
219
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
220
+ self.total_size = sum(self.num_samples_per_bucket)
221
+ self.num_samples = self.total_size // self.num_replicas
222
+
223
+ def _create_buckets(self):
224
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
225
+ for i in range(len(self.lengths)):
226
+ length = self.lengths[i]
227
+ idx_bucket = self._bisect(length)
228
+ if idx_bucket != -1:
229
+ buckets[idx_bucket].append(i)
230
+
231
+ for i in range(len(buckets) - 1, 0, -1):
232
+ if len(buckets[i]) == 0:
233
+ buckets.pop(i)
234
+ self.boundaries.pop(i + 1)
235
+
236
+ num_samples_per_bucket = []
237
+ for i in range(len(buckets)):
238
+ len_bucket = len(buckets[i])
239
+ total_batch_size = self.num_replicas * self.batch_size
240
+ rem = (total_batch_size -
241
+ (len_bucket % total_batch_size)) % total_batch_size
242
+ num_samples_per_bucket.append(len_bucket + rem)
243
+ return buckets, num_samples_per_bucket
244
+
245
+ def __iter__(self):
246
+ # deterministically shuffle based on epoch
247
+ g = torch.Generator()
248
+ g.manual_seed(self.epoch)
249
+
250
+ indices = []
251
+ if self.shuffle:
252
+ for bucket in self.buckets:
253
+ indices.append(
254
+ torch.randperm(len(bucket), generator=g).tolist())
255
+ else:
256
+ for bucket in self.buckets:
257
+ indices.append(list(range(len(bucket))))
258
+
259
+ batches = []
260
+ for i in range(len(self.buckets)):
261
+ bucket = self.buckets[i]
262
+ len_bucket = len(bucket)
263
+ ids_bucket = indices[i]
264
+ num_samples_bucket = self.num_samples_per_bucket[i]
265
+
266
+ # add extra samples to make it evenly divisible
267
+ rem = num_samples_bucket - len_bucket
268
+ ids_bucket = ids_bucket + ids_bucket * (
269
+ rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
270
+
271
+ # subsample
272
+ ids_bucket = ids_bucket[self.rank::self.num_replicas]
273
+
274
+ # batching
275
+ for j in range(len(ids_bucket) // self.batch_size):
276
+ batch = [
277
+ bucket[idx]
278
+ for idx in ids_bucket[j * self.batch_size:(j + 1) *
279
+ self.batch_size]
280
+ ]
281
+ batches.append(batch)
282
+
283
+ if self.shuffle:
284
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
285
+ batches = [batches[i] for i in batch_ids]
286
+ self.batches = batches
287
+
288
+ assert len(self.batches) * self.batch_size == self.num_samples
289
+ return iter(self.batches)
290
+
291
+ def _bisect(self, x, lo=0, hi=None):
292
+ if hi is None:
293
+ hi = len(self.boundaries) - 1
294
+
295
+ if hi > lo:
296
+ mid = (hi + lo) // 2
297
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
298
+ return mid
299
+ elif x <= self.boundaries[mid]:
300
+ return self._bisect(x, lo, mid)
301
+ else:
302
+ return self._bisect(x, mid + 1, hi)
303
+ else:
304
+ return -1
305
+
306
+ def __len__(self):
307
+ return self.num_samples // self.batch_size
export/vits/export_onnx.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Yongqiang Li (yongqiangli@alumni.hust.edu.cn)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import json
17
+ import os
18
+ import sys
19
+
20
+ import torch
21
+
22
+ from models import SynthesizerTrn
23
+ import utils
24
+
25
+ try:
26
+ import onnxruntime as ort
27
+ except ImportError:
28
+ print('Please install onnxruntime!')
29
+ sys.exit(1)
30
+
31
+
32
+ def to_numpy(tensor):
33
+ return tensor.detach().cpu().numpy() if tensor.requires_grad \
34
+ else tensor.detach().numpy()
35
+
36
+
37
+ def get_args():
38
+ parser = argparse.ArgumentParser(description='export onnx model')
39
+ parser.add_argument('--checkpoint', required=True, help='checkpoint')
40
+ parser.add_argument('--cfg', required=True, help='config file')
41
+ parser.add_argument('--onnx_model', required=True, help='onnx model name')
42
+ # parser.add_argument('--phone_table',
43
+ # required=True,
44
+ # help='input phone dict')
45
+ # parser.add_argument('--speaker_table', default=None, help='speaker table')
46
+ # parser.add_argument("--speaker_num", required=True,
47
+ # type=int, help="speaker num")
48
+ parser.add_argument(
49
+ '--providers',
50
+ required=False,
51
+ default='CPUExecutionProvider',
52
+ choices=['CUDAExecutionProvider', 'CPUExecutionProvider'],
53
+ help='the model to send request to')
54
+ args = parser.parse_args()
55
+ return args
56
+
57
+
58
+ def get_data_from_cfg(cfg_path: str):
59
+ assert os.path.isfile(cfg_path)
60
+ with open(cfg_path, 'r') as f:
61
+ data = json.load(f)
62
+ symbols = data["symbols"]
63
+ speaker_num = data["data"]["n_speakers"]
64
+ return len(symbols), speaker_num
65
+
66
+
67
+ def main():
68
+ args = get_args()
69
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
70
+
71
+ hps = utils.get_hparams_from_file(args.cfg)
72
+ # with open(args.phone_table) as p_f:
73
+ # phone_num = len(p_f.readlines()) + 1
74
+ # num_speakers = 1
75
+ # if args.speaker_table is not None:
76
+ # num_speakers = len(open(args.speaker_table).readlines()) + 1
77
+ phone_num, num_speakers = get_data_from_cfg(args.cfg)
78
+ net_g = SynthesizerTrn(phone_num,
79
+ hps.data.filter_length // 2 + 1,
80
+ hps.train.segment_size // hps.data.hop_length,
81
+ n_speakers=num_speakers,
82
+ **hps.model)
83
+ utils.load_checkpoint(args.checkpoint, net_g, None)
84
+ net_g.forward = net_g.export_forward
85
+ net_g.eval()
86
+
87
+ seq = torch.randint(low=0, high=phone_num, size=(1, 10), dtype=torch.long)
88
+ seq_len = torch.IntTensor([seq.size(1)]).long()
89
+
90
+ # noise(可用于控制感情等变化程度) lenth(可用于控制整体语速) noisew(控制音素发音长度变化程度)
91
+ # 参考 https://github.com/gbxh/genshinTTS
92
+ scales = torch.FloatTensor([0.667, 1.0, 0.8])
93
+ # make triton dynamic shape happy
94
+ scales = scales.unsqueeze(0)
95
+ sid = torch.IntTensor([0]).long()
96
+
97
+ dummy_input = (seq, seq_len, scales, sid)
98
+ torch.onnx.export(model=net_g,
99
+ args=dummy_input,
100
+ f=args.onnx_model,
101
+ input_names=['input', 'input_lengths', 'scales', 'sid'],
102
+ output_names=['output'],
103
+ dynamic_axes={
104
+ 'input': {
105
+ 0: 'batch',
106
+ 1: 'phonemes'
107
+ },
108
+ 'input_lengths': {
109
+ 0: 'batch'
110
+ },
111
+ 'scales': {
112
+ 0: 'batch'
113
+ },
114
+ 'sid': {
115
+ 0: 'batch'
116
+ },
117
+ 'output': {
118
+ 0: 'batch',
119
+ 1: 'audio',
120
+ 2: 'audio_length'
121
+ }
122
+ },
123
+ opset_version=13,
124
+ verbose=False)
125
+
126
+ # Verify onnx precision
127
+ torch_output = net_g(seq, seq_len, scales, sid)
128
+ providers = [args.providers]
129
+ ort_sess = ort.InferenceSession(args.onnx_model, providers=providers)
130
+ ort_inputs = {
131
+ 'input': to_numpy(seq),
132
+ 'input_lengths': to_numpy(seq_len),
133
+ 'scales': to_numpy(scales),
134
+ 'sid': to_numpy(sid),
135
+ }
136
+ onnx_output = ort_sess.run(None, ort_inputs)
137
+
138
+
139
+ if __name__ == '__main__':
140
+ main()
export/vits/inference.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Yongqiang Li (yongqiangli@alumni.hust.edu.cn)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+
17
+ import numpy as np
18
+ from scipy.io import wavfile
19
+ import torch
20
+
21
+ import commons
22
+ from models import SynthesizerTrn
23
+ import utils
24
+
25
+
26
+ def get_args():
27
+ parser = argparse.ArgumentParser(description='inference')
28
+ parser.add_argument('--checkpoint', required=True, help='checkpoint')
29
+ parser.add_argument('--cfg', required=True, help='config file')
30
+ parser.add_argument('--outdir', required=True, help='ouput directory')
31
+ parser.add_argument('--phone_table',
32
+ required=True,
33
+ help='input phone dict')
34
+ parser.add_argument('--speaker_table', default=None, help='speaker table')
35
+ parser.add_argument('--test_file', required=True, help='test file')
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def main():
41
+ args = get_args()
42
+ print(args)
43
+ phone_dict = {}
44
+ with open(args.phone_table) as p_f:
45
+ for line in p_f:
46
+ phone_id = line.strip().split()
47
+ phone_dict[phone_id[0]] = int(phone_id[1])
48
+ speaker_dict = {}
49
+ if args.speaker_table is not None:
50
+ with open(args.speaker_table) as p_f:
51
+ for line in p_f:
52
+ arr = line.strip().split()
53
+ assert len(arr) == 2
54
+ speaker_dict[arr[0]] = int(arr[1])
55
+ hps = utils.get_hparams_from_file(args.cfg)
56
+
57
+ net_g = SynthesizerTrn(
58
+ len(phone_dict) + 1,
59
+ hps.data.filter_length // 2 + 1,
60
+ hps.train.segment_size // hps.data.hop_length,
61
+ n_speakers=len(speaker_dict) + 1, # 0 is kept for unknown speaker
62
+ **hps.model).cuda()
63
+ net_g.eval()
64
+ utils.load_checkpoint(args.checkpoint, net_g, None)
65
+
66
+ with open(args.test_file) as fin:
67
+ for line in fin:
68
+ arr = line.strip().split("|")
69
+ audio_path = arr[0]
70
+ if len(arr) == 2:
71
+ sid = 0
72
+ text = arr[1]
73
+ else:
74
+ sid = speaker_dict[arr[1]]
75
+ text = arr[2]
76
+ seq = [phone_dict[symbol] for symbol in text.split()]
77
+ if hps.data.add_blank:
78
+ seq = commons.intersperse(seq, 0)
79
+ seq = torch.LongTensor(seq)
80
+ with torch.no_grad():
81
+ x = seq.cuda().unsqueeze(0)
82
+ x_length = torch.LongTensor([seq.size(0)]).cuda()
83
+ sid = torch.LongTensor([sid]).cuda()
84
+ audio = net_g.infer(
85
+ x,
86
+ x_length,
87
+ sid=sid,
88
+ noise_scale=.667,
89
+ noise_scale_w=0.8,
90
+ length_scale=1)[0][0, 0].data.cpu().float().numpy()
91
+ audio *= 32767 / max(0.01, np.max(np.abs(audio))) * 0.6
92
+ audio = np.clip(audio, -32767.0, 32767.0)
93
+ wavfile.write(args.outdir + "/" + audio_path.split("/")[-1],
94
+ hps.data.sampling_rate, audio.astype(np.int16))
95
+
96
+
97
+ if __name__ == '__main__':
98
+ main()
export/vits/inference_onnx.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Yongqiang Li (yongqiangli@alumni.hust.edu.cn)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ from text import text_to_sequence
17
+ import numpy as np
18
+ from scipy.io import wavfile
19
+ import torch
20
+ import json
21
+ import commons
22
+ import utils
23
+ import sys
24
+ import pathlib
25
+
26
+ try:
27
+ import onnxruntime as ort
28
+ except ImportError:
29
+ print('Please install onnxruntime!')
30
+ sys.exit(1)
31
+
32
+
33
+ def to_numpy(tensor: torch.Tensor):
34
+ return tensor.detach().cpu().numpy() if tensor.requires_grad \
35
+ else tensor.detach().numpy()
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser(description='inference')
40
+ parser.add_argument('--onnx_model', required=True, help='onnx model')
41
+ parser.add_argument('--cfg', required=True, help='config file')
42
+ parser.add_argument('--outdir', default="onnx_output",
43
+ help='ouput directory')
44
+ # parser.add_argument('--phone_table',
45
+ # required=True,
46
+ # help='input phone dict')
47
+ # parser.add_argument('--speaker_table', default=None, help='speaker table')
48
+ parser.add_argument('--test_file', required=True, help='test file')
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def get_symbols_from_json(path):
54
+ import os
55
+ assert os.path.isfile(path)
56
+ with open(path, 'r') as f:
57
+ data = json.load(f)
58
+ return data['symbols']
59
+
60
+
61
+ def main():
62
+ args = get_args()
63
+ print(args)
64
+ if not pathlib.Path(args.outdir).exists():
65
+ pathlib.Path(args.outdir).mkdir(exist_ok=True, parents=True)
66
+ # phones =
67
+ symbols = get_symbols_from_json(args.cfg)
68
+ phone_dict = {
69
+ symbol: i for i, symbol in enumerate(symbols)
70
+ }
71
+
72
+ # speaker_dict = {}
73
+ # if args.speaker_table is not None:
74
+ # with open(args.speaker_table) as p_f:
75
+ # for line in p_f:
76
+ # arr = line.strip().split()
77
+ # assert len(arr) == 2
78
+ # speaker_dict[arr[0]] = int(arr[1])
79
+ hps = utils.get_hparams_from_file(args.cfg)
80
+
81
+ ort_sess = ort.InferenceSession(args.onnx_model)
82
+
83
+ with open(args.test_file) as fin:
84
+ for line in fin:
85
+ arr = line.strip().split("|")
86
+ audio_path = arr[0]
87
+
88
+ # TODO: 控制说话人编号
89
+ sid = 8
90
+ text = arr[1]
91
+ # else:
92
+ # sid = speaker_dict[arr[1]]
93
+ # text = arr[2]
94
+ seq = text_to_sequence(text, symbols=hps.symbols, cleaner_names=["japanese_cleaners2"]
95
+ )
96
+ if hps.data.add_blank:
97
+ seq = commons.intersperse(seq, 0)
98
+
99
+ # if hps.data.add_blank:
100
+ # seq = commons.intersperse(seq, 0)
101
+ with torch.no_grad():
102
+ # x = torch.LongTensor([seq])
103
+ # x_len = torch.IntTensor([x.size(1)]).long()
104
+ # sid = torch.LongTensor([sid]).long()
105
+ # scales = torch.FloatTensor([0.667, 1.0, 1])
106
+ # # make triton dynamic shape happy
107
+ # scales = scales.unsqueeze(0)
108
+
109
+ # use numpy to replace torch
110
+ x = np.array([seq], dtype=np.int64)
111
+ x_len = np.array([x.shape[1]], dtype=np.int64)
112
+ sid = np.array([sid], dtype=np.int64)
113
+ # noise(可用于控制感情等变化程度) lenth(可用于控制整体语速) noisew(控制音素发音长度变化程度)
114
+ # 参考 https://github.com/gbxh/genshinTTS
115
+ scales = np.array([0.667, 0.8, 1], dtype=np.float32)
116
+ # scales = scales[np.newaxis, :]
117
+ # scales.reshape(1, -1)
118
+ scales.resize(1, 3)
119
+
120
+ ort_inputs = {
121
+ 'input': x,
122
+ 'input_lengths': x_len,
123
+ 'scales': scales,
124
+ 'sid': sid
125
+ }
126
+
127
+ # ort_inputs = {
128
+ # 'input': to_numpy(x),
129
+ # 'input_lengths': to_numpy(x_len),
130
+ # 'scales': to_numpy(scales),
131
+ # 'sid': to_numpy(sid)
132
+ # }
133
+ import time
134
+ # start_time = time.time()
135
+ start_time = time.perf_counter()
136
+ audio = np.squeeze(ort_sess.run(None, ort_inputs))
137
+ audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
138
+ audio = np.clip(audio, -32767.0, 32767.0)
139
+ end_time = time.perf_counter()
140
+ # end_time = time.time()
141
+ print("infer time cost: ", end_time - start_time, "s")
142
+
143
+ wavfile.write(args.outdir + "/" + audio_path.split("/")[-1],
144
+ hps.data.sampling_rate, audio.astype(np.int16))
145
+
146
+
147
+ if __name__ == '__main__':
148
+ main()
export/vits/losses.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def feature_loss(fmap_r, fmap_g):
5
+ loss = 0
6
+ for dr, dg in zip(fmap_r, fmap_g):
7
+ for rl, gl in zip(dr, dg):
8
+ rl = rl.float().detach()
9
+ gl = gl.float()
10
+ loss += torch.mean(torch.abs(rl - gl))
11
+
12
+ return loss * 2
13
+
14
+
15
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
16
+ loss = 0
17
+ r_losses = []
18
+ g_losses = []
19
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
20
+ dr = dr.float()
21
+ dg = dg.float()
22
+ r_loss = torch.mean((1 - dr)**2)
23
+ g_loss = torch.mean(dg**2)
24
+ loss += (r_loss + g_loss)
25
+ r_losses.append(r_loss.item())
26
+ g_losses.append(g_loss.item())
27
+
28
+ return loss, r_losses, g_losses
29
+
30
+
31
+ def generator_loss(disc_outputs):
32
+ loss = 0
33
+ gen_losses = []
34
+ for dg in disc_outputs:
35
+ dg = dg.float()
36
+ l = torch.mean((1 - dg)**2)
37
+ gen_losses.append(l)
38
+ loss += l
39
+
40
+ return loss, gen_losses
41
+
42
+
43
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
44
+ """
45
+ z_p, logs_q: [b, h, t_t]
46
+ m_p, logs_p: [b, h, t_t]
47
+ """
48
+ z_p = z_p.float()
49
+ logs_q = logs_q.float()
50
+ m_p = m_p.float()
51
+ logs_p = logs_p.float()
52
+ z_mask = z_mask.float()
53
+
54
+ kl = logs_p - logs_q - 0.5
55
+ kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
56
+ kl = torch.sum(kl * z_mask)
57
+ l = kl / torch.sum(z_mask)
58
+ return l
export/vits/mel_processing.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+
6
+ MAX_WAV_VALUE = 32768.0
7
+
8
+
9
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
10
+ """
11
+ PARAMS
12
+ ------
13
+ C: compression factor
14
+ """
15
+ return torch.log(torch.clamp(x, min=clip_val) * C)
16
+
17
+
18
+ def dynamic_range_decompression_torch(x, C=1):
19
+ """
20
+ PARAMS
21
+ ------
22
+ C: compression factor used to compress
23
+ """
24
+ return torch.exp(x) / C
25
+
26
+
27
+ def spectral_normalize_torch(magnitudes):
28
+ output = dynamic_range_compression_torch(magnitudes)
29
+ return output
30
+
31
+
32
+ def spectral_de_normalize_torch(magnitudes):
33
+ output = dynamic_range_decompression_torch(magnitudes)
34
+ return output
35
+
36
+
37
+ mel_basis = {}
38
+ hann_window = {}
39
+
40
+
41
+ def spectrogram_torch(y,
42
+ n_fft,
43
+ sampling_rate,
44
+ hop_size,
45
+ win_size,
46
+ center=False):
47
+ if torch.min(y) < -1.:
48
+ print('min value is ', torch.min(y))
49
+ if torch.max(y) > 1.:
50
+ print('max value is ', torch.max(y))
51
+
52
+ global hann_window
53
+ dtype_device = str(y.dtype) + '_' + str(y.device)
54
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
55
+ if wnsize_dtype_device not in hann_window:
56
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
57
+ dtype=y.dtype, device=y.device)
58
+
59
+ y = F.pad(y.unsqueeze(1),
60
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
61
+ mode='reflect')
62
+ y = y.squeeze(1)
63
+
64
+ spec = torch.stft(y,
65
+ n_fft,
66
+ hop_length=hop_size,
67
+ win_length=win_size,
68
+ window=hann_window[wnsize_dtype_device],
69
+ center=center,
70
+ pad_mode='reflect',
71
+ normalized=False,
72
+ onesided=True)
73
+
74
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
75
+ return spec
76
+
77
+
78
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
79
+ global mel_basis
80
+ dtype_device = str(spec.dtype) + '_' + str(spec.device)
81
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
82
+ if fmax_dtype_device not in mel_basis:
83
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
84
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
85
+ dtype=spec.dtype, device=spec.device)
86
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
87
+ spec = spectral_normalize_torch(spec)
88
+ return spec
89
+
90
+
91
+ def mel_spectrogram_torch(y,
92
+ n_fft,
93
+ num_mels,
94
+ sampling_rate,
95
+ hop_size,
96
+ win_size,
97
+ fmin,
98
+ fmax,
99
+ center=False):
100
+ if torch.min(y) < -1.:
101
+ print('min value is ', torch.min(y))
102
+ if torch.max(y) > 1.:
103
+ print('max value is ', torch.max(y))
104
+
105
+ global mel_basis, hann_window
106
+ dtype_device = str(y.dtype) + '_' + str(y.device)
107
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
108
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
109
+ if fmax_dtype_device not in mel_basis:
110
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
111
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
112
+ dtype=y.dtype, device=y.device)
113
+ if wnsize_dtype_device not in hann_window:
114
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
115
+ dtype=y.dtype, device=y.device)
116
+
117
+ y = F.pad(y.unsqueeze(1),
118
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
119
+ mode='reflect')
120
+ y = y.squeeze(1)
121
+
122
+ spec = torch.stft(y,
123
+ n_fft,
124
+ hop_length=hop_size,
125
+ win_length=win_size,
126
+ window=hann_window[wnsize_dtype_device],
127
+ center=center,
128
+ pad_mode='reflect',
129
+ normalized=False,
130
+ onesided=True)
131
+
132
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
133
+
134
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
135
+ spec = spectral_normalize_torch(spec)
136
+
137
+ return spec
export/vits/models.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
7
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
8
+ import monotonic_align
9
+
10
+ import commons
11
+ import modules
12
+ import attentions
13
+ from commons import init_weights, get_padding
14
+
15
+
16
+ class StochasticDurationPredictor(nn.Module):
17
+ def __init__(self,
18
+ in_channels,
19
+ filter_channels,
20
+ kernel_size,
21
+ p_dropout,
22
+ n_flows=4,
23
+ gin_channels=0):
24
+ super().__init__()
25
+ filter_channels = in_channels # it needs to be removed from future version.
26
+ self.in_channels = in_channels
27
+ self.filter_channels = filter_channels
28
+ self.kernel_size = kernel_size
29
+ self.p_dropout = p_dropout
30
+ self.n_flows = n_flows
31
+ self.gin_channels = gin_channels
32
+
33
+ self.log_flow = modules.Log()
34
+ self.flows = nn.ModuleList()
35
+ self.flows.append(modules.ElementwiseAffine(2))
36
+ for i in range(n_flows):
37
+ self.flows.append(
38
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
39
+ self.flows.append(modules.Flip())
40
+
41
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
42
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
43
+ self.post_convs = modules.DDSConv(filter_channels,
44
+ kernel_size,
45
+ n_layers=3,
46
+ p_dropout=p_dropout)
47
+ self.post_flows = nn.ModuleList()
48
+ self.post_flows.append(modules.ElementwiseAffine(2))
49
+ for i in range(4):
50
+ self.post_flows.append(
51
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
52
+ self.post_flows.append(modules.Flip())
53
+
54
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
55
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
56
+ self.convs = modules.DDSConv(filter_channels,
57
+ kernel_size,
58
+ n_layers=3,
59
+ p_dropout=p_dropout)
60
+ if gin_channels != 0:
61
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
62
+
63
+ def forward(self,
64
+ x,
65
+ x_mask,
66
+ w=None,
67
+ g=None,
68
+ reverse=False,
69
+ noise_scale=1.0):
70
+ x = torch.detach(x)
71
+ x = self.pre(x)
72
+ if g is not None:
73
+ g = torch.detach(g)
74
+ x = x + self.cond(g)
75
+ x = self.convs(x, x_mask)
76
+ x = self.proj(x) * x_mask
77
+
78
+ if not reverse:
79
+ flows = self.flows
80
+ assert w is not None
81
+
82
+ logdet_tot_q = 0
83
+ h_w = self.post_pre(w)
84
+ h_w = self.post_convs(h_w, x_mask)
85
+ h_w = self.post_proj(h_w) * x_mask
86
+ e_q = torch.randn(w.size(0), 2, w.size(2)).to(
87
+ device=x.device, dtype=x.dtype) * x_mask
88
+ z_q = e_q
89
+ for flow in self.post_flows:
90
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
91
+ logdet_tot_q += logdet_q
92
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
93
+ u = torch.sigmoid(z_u) * x_mask
94
+ z0 = (w - u) * x_mask
95
+ logdet_tot_q += torch.sum(
96
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
97
+ logq = torch.sum(
98
+ -0.5 * (math.log(2 * math.pi) +
99
+ (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
100
+
101
+ logdet_tot = 0
102
+ z0, logdet = self.log_flow(z0, x_mask)
103
+ logdet_tot += logdet
104
+ z = torch.cat([z0, z1], 1)
105
+ for flow in flows:
106
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
107
+ logdet_tot = logdet_tot + logdet
108
+ nll = torch.sum(0.5 * (math.log(2 * math.pi) +
109
+ (z**2)) * x_mask, [1, 2]) - logdet_tot
110
+ return nll + logq # [b]
111
+ else:
112
+ flows = list(reversed(self.flows))
113
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
114
+ z = torch.randn(x.size(0), 2, x.size(2)).to(
115
+ device=x.device, dtype=x.dtype) * noise_scale
116
+ for flow in flows:
117
+ z = flow(z, x_mask, g=x, reverse=reverse)
118
+ z0, z1 = torch.split(z, [1, 1], 1)
119
+ logw = z0
120
+ return logw
121
+
122
+
123
+ class DurationPredictor(nn.Module):
124
+ def __init__(self,
125
+ in_channels,
126
+ filter_channels,
127
+ kernel_size,
128
+ p_dropout,
129
+ gin_channels=0):
130
+ super().__init__()
131
+
132
+ self.in_channels = in_channels
133
+ self.filter_channels = filter_channels
134
+ self.kernel_size = kernel_size
135
+ self.p_dropout = p_dropout
136
+ self.gin_channels = gin_channels
137
+
138
+ self.drop = nn.Dropout(p_dropout)
139
+ self.conv_1 = nn.Conv1d(in_channels,
140
+ filter_channels,
141
+ kernel_size,
142
+ padding=kernel_size // 2)
143
+ self.norm_1 = modules.LayerNorm(filter_channels)
144
+ self.conv_2 = nn.Conv1d(filter_channels,
145
+ filter_channels,
146
+ kernel_size,
147
+ padding=kernel_size // 2)
148
+ self.norm_2 = modules.LayerNorm(filter_channels)
149
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
150
+
151
+ if gin_channels != 0:
152
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
153
+
154
+ def forward(self, x, x_mask, g=None):
155
+ x = torch.detach(x)
156
+ if g is not None:
157
+ g = torch.detach(g)
158
+ x = x + self.cond(g)
159
+ x = self.conv_1(x * x_mask)
160
+ x = torch.relu(x)
161
+ x = self.norm_1(x)
162
+ x = self.drop(x)
163
+ x = self.conv_2(x * x_mask)
164
+ x = torch.relu(x)
165
+ x = self.norm_2(x)
166
+ x = self.drop(x)
167
+ x = self.proj(x * x_mask)
168
+ return x * x_mask
169
+
170
+
171
+ class TextEncoder(nn.Module):
172
+ def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels,
173
+ n_heads, n_layers, kernel_size, p_dropout):
174
+ super().__init__()
175
+ self.n_vocab = n_vocab
176
+ self.out_channels = out_channels
177
+ self.hidden_channels = hidden_channels
178
+ self.filter_channels = filter_channels
179
+ self.n_heads = n_heads
180
+ self.n_layers = n_layers
181
+ self.kernel_size = kernel_size
182
+ self.p_dropout = p_dropout
183
+
184
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
185
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
186
+
187
+ self.encoder = attentions.Encoder(hidden_channels, filter_channels,
188
+ n_heads, n_layers, kernel_size,
189
+ p_dropout)
190
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
191
+
192
+ def forward(self, x, x_lengths):
193
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
194
+ x = torch.transpose(x, 1, -1) # [b, h, t]
195
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)),
196
+ 1).to(x.dtype)
197
+
198
+ x = self.encoder(x * x_mask, x_mask)
199
+ stats = self.proj(x) * x_mask
200
+
201
+ m, logs = torch.split(stats, self.out_channels, dim=1)
202
+ return x, m, logs, x_mask
203
+
204
+
205
+ class ResidualCouplingBlock(nn.Module):
206
+ def __init__(self,
207
+ channels,
208
+ hidden_channels,
209
+ kernel_size,
210
+ dilation_rate,
211
+ n_layers,
212
+ n_flows=4,
213
+ gin_channels=0):
214
+ super().__init__()
215
+ self.channels = channels
216
+ self.hidden_channels = hidden_channels
217
+ self.kernel_size = kernel_size
218
+ self.dilation_rate = dilation_rate
219
+ self.n_layers = n_layers
220
+ self.n_flows = n_flows
221
+ self.gin_channels = gin_channels
222
+
223
+ self.flows = nn.ModuleList()
224
+ for i in range(n_flows):
225
+ self.flows.append(
226
+ modules.ResidualCouplingLayer(channels,
227
+ hidden_channels,
228
+ kernel_size,
229
+ dilation_rate,
230
+ n_layers,
231
+ gin_channels=gin_channels,
232
+ mean_only=True))
233
+ self.flows.append(modules.Flip())
234
+
235
+ def forward(self, x, x_mask, g=None, reverse=False):
236
+ if not reverse:
237
+ for flow in self.flows:
238
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
239
+ else:
240
+ for flow in reversed(self.flows):
241
+ x = flow(x, x_mask, g=g, reverse=reverse)
242
+ return x
243
+
244
+
245
+ class PosteriorEncoder(nn.Module):
246
+ def __init__(self,
247
+ in_channels,
248
+ out_channels,
249
+ hidden_channels,
250
+ kernel_size,
251
+ dilation_rate,
252
+ n_layers,
253
+ gin_channels=0):
254
+ super().__init__()
255
+ self.in_channels = in_channels
256
+ self.out_channels = out_channels
257
+ self.hidden_channels = hidden_channels
258
+ self.kernel_size = kernel_size
259
+ self.dilation_rate = dilation_rate
260
+ self.n_layers = n_layers
261
+ self.gin_channels = gin_channels
262
+
263
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
264
+ self.enc = modules.WN(hidden_channels,
265
+ kernel_size,
266
+ dilation_rate,
267
+ n_layers,
268
+ gin_channels=gin_channels)
269
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
270
+
271
+ def forward(self, x, x_lengths, g=None):
272
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)),
273
+ 1).to(x.dtype)
274
+ x = self.pre(x) * x_mask
275
+ x = self.enc(x, x_mask, g=g)
276
+ stats = self.proj(x) * x_mask
277
+ m, logs = torch.split(stats, self.out_channels, dim=1)
278
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
279
+ return z, m, logs, x_mask
280
+
281
+
282
+ class Generator(torch.nn.Module):
283
+ def __init__(self,
284
+ initial_channel,
285
+ resblock,
286
+ resblock_kernel_sizes,
287
+ resblock_dilation_sizes,
288
+ upsample_rates,
289
+ upsample_initial_channel,
290
+ upsample_kernel_sizes,
291
+ gin_channels=0):
292
+ super(Generator, self).__init__()
293
+ self.num_kernels = len(resblock_kernel_sizes)
294
+ self.num_upsamples = len(upsample_rates)
295
+ self.conv_pre = Conv1d(initial_channel,
296
+ upsample_initial_channel,
297
+ 7,
298
+ 1,
299
+ padding=3)
300
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
301
+
302
+ self.ups = nn.ModuleList()
303
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
304
+ self.ups.append(
305
+ weight_norm(
306
+ ConvTranspose1d(upsample_initial_channel // (2**i),
307
+ upsample_initial_channel // (2**(i + 1)),
308
+ k,
309
+ u,
310
+ padding=(k - u) // 2)))
311
+
312
+ self.resblocks = nn.ModuleList()
313
+ for i in range(len(self.ups)):
314
+ ch = upsample_initial_channel // (2**(i + 1))
315
+ for j, (k, d) in enumerate(
316
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)):
317
+ self.resblocks.append(resblock(ch, k, d))
318
+
319
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
320
+ self.ups.apply(init_weights)
321
+
322
+ if gin_channels != 0:
323
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
324
+
325
+ def forward(self, x, g=None):
326
+ x = self.conv_pre(x)
327
+ if g is not None:
328
+ x = x + self.cond(g)
329
+
330
+ for i in range(self.num_upsamples):
331
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
332
+ x = self.ups[i](x)
333
+ xs = None
334
+ for j in range(self.num_kernels):
335
+ if xs is None:
336
+ xs = self.resblocks[i * self.num_kernels + j](x)
337
+ else:
338
+ xs += self.resblocks[i * self.num_kernels + j](x)
339
+ x = xs / self.num_kernels
340
+ x = F.leaky_relu(x)
341
+ x = self.conv_post(x)
342
+ x = torch.tanh(x)
343
+
344
+ return x
345
+
346
+ def remove_weight_norm(self):
347
+ print('Removing weight norm...')
348
+ for l in self.ups:
349
+ remove_weight_norm(l)
350
+ for l in self.resblocks:
351
+ l.remove_weight_norm()
352
+
353
+
354
+ class DiscriminatorP(torch.nn.Module):
355
+ def __init__(self,
356
+ period,
357
+ kernel_size=5,
358
+ stride=3,
359
+ use_spectral_norm=False):
360
+ super(DiscriminatorP, self).__init__()
361
+ self.period = period
362
+ self.use_spectral_norm = use_spectral_norm
363
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
364
+ self.convs = nn.ModuleList([
365
+ norm_f(
366
+ Conv2d(1,
367
+ 32, (kernel_size, 1), (stride, 1),
368
+ padding=(get_padding(kernel_size, 1), 0))),
369
+ norm_f(
370
+ Conv2d(32,
371
+ 128, (kernel_size, 1), (stride, 1),
372
+ padding=(get_padding(kernel_size, 1), 0))),
373
+ norm_f(
374
+ Conv2d(128,
375
+ 512, (kernel_size, 1), (stride, 1),
376
+ padding=(get_padding(kernel_size, 1), 0))),
377
+ norm_f(
378
+ Conv2d(512,
379
+ 1024, (kernel_size, 1), (stride, 1),
380
+ padding=(get_padding(kernel_size, 1), 0))),
381
+ norm_f(
382
+ Conv2d(1024,
383
+ 1024, (kernel_size, 1),
384
+ 1,
385
+ padding=(get_padding(kernel_size, 1), 0))),
386
+ ])
387
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
388
+
389
+ def forward(self, x):
390
+ fmap = []
391
+
392
+ # 1d to 2d
393
+ b, c, t = x.shape
394
+ if t % self.period != 0: # pad first
395
+ n_pad = self.period - (t % self.period)
396
+ x = F.pad(x, (0, n_pad), "reflect")
397
+ t = t + n_pad
398
+ x = x.view(b, c, t // self.period, self.period)
399
+
400
+ for l in self.convs:
401
+ x = l(x)
402
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
403
+ fmap.append(x)
404
+ x = self.conv_post(x)
405
+ fmap.append(x)
406
+ x = torch.flatten(x, 1, -1)
407
+
408
+ return x, fmap
409
+
410
+
411
+ class DiscriminatorS(torch.nn.Module):
412
+ def __init__(self, use_spectral_norm=False):
413
+ super(DiscriminatorS, self).__init__()
414
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
415
+ self.convs = nn.ModuleList([
416
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
417
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
418
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
419
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
420
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
421
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
422
+ ])
423
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
424
+
425
+ def forward(self, x):
426
+ fmap = []
427
+
428
+ for l in self.convs:
429
+ x = l(x)
430
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
431
+ fmap.append(x)
432
+ x = self.conv_post(x)
433
+ fmap.append(x)
434
+ x = torch.flatten(x, 1, -1)
435
+
436
+ return x, fmap
437
+
438
+
439
+ class MultiPeriodDiscriminator(torch.nn.Module):
440
+ def __init__(self, use_spectral_norm=False):
441
+ super(MultiPeriodDiscriminator, self).__init__()
442
+ periods = [2, 3, 5, 7, 11]
443
+
444
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
445
+ discs = discs + [
446
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm)
447
+ for i in periods
448
+ ]
449
+ self.discriminators = nn.ModuleList(discs)
450
+
451
+ def forward(self, y, y_hat):
452
+ y_d_rs = []
453
+ y_d_gs = []
454
+ fmap_rs = []
455
+ fmap_gs = []
456
+ for i, d in enumerate(self.discriminators):
457
+ y_d_r, fmap_r = d(y)
458
+ y_d_g, fmap_g = d(y_hat)
459
+ y_d_rs.append(y_d_r)
460
+ y_d_gs.append(y_d_g)
461
+ fmap_rs.append(fmap_r)
462
+ fmap_gs.append(fmap_g)
463
+
464
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
465
+
466
+
467
+ class SynthesizerTrn(nn.Module):
468
+ """
469
+ Synthesizer for Training
470
+ """
471
+ def __init__(self,
472
+ n_vocab,
473
+ spec_channels,
474
+ segment_size,
475
+ inter_channels,
476
+ hidden_channels,
477
+ filter_channels,
478
+ n_heads,
479
+ n_layers,
480
+ kernel_size,
481
+ p_dropout,
482
+ resblock,
483
+ resblock_kernel_sizes,
484
+ resblock_dilation_sizes,
485
+ upsample_rates,
486
+ upsample_initial_channel,
487
+ upsample_kernel_sizes,
488
+ n_speakers=0,
489
+ gin_channels=0,
490
+ use_sdp=True,
491
+ **kwargs):
492
+
493
+ super().__init__()
494
+ self.n_vocab = n_vocab
495
+ self.spec_channels = spec_channels
496
+ self.inter_channels = inter_channels
497
+ self.hidden_channels = hidden_channels
498
+ self.filter_channels = filter_channels
499
+ self.n_heads = n_heads
500
+ self.n_layers = n_layers
501
+ self.kernel_size = kernel_size
502
+ self.p_dropout = p_dropout
503
+ self.resblock = resblock
504
+ self.resblock_kernel_sizes = resblock_kernel_sizes
505
+ self.resblock_dilation_sizes = resblock_dilation_sizes
506
+ self.upsample_rates = upsample_rates
507
+ self.upsample_initial_channel = upsample_initial_channel
508
+ self.upsample_kernel_sizes = upsample_kernel_sizes
509
+ self.segment_size = segment_size
510
+ self.n_speakers = n_speakers
511
+ self.gin_channels = gin_channels
512
+ if self.n_speakers != 0:
513
+ message = "gin_channels must be none zero for multiple speakers"
514
+ assert gin_channels != 0, message
515
+
516
+ self.use_sdp = use_sdp
517
+
518
+ self.enc_p = TextEncoder(n_vocab, inter_channels, hidden_channels,
519
+ filter_channels, n_heads, n_layers,
520
+ kernel_size, p_dropout)
521
+ self.dec = Generator(inter_channels,
522
+ resblock,
523
+ resblock_kernel_sizes,
524
+ resblock_dilation_sizes,
525
+ upsample_rates,
526
+ upsample_initial_channel,
527
+ upsample_kernel_sizes,
528
+ gin_channels=gin_channels)
529
+ self.enc_q = PosteriorEncoder(spec_channels,
530
+ inter_channels,
531
+ hidden_channels,
532
+ 5,
533
+ 1,
534
+ 16,
535
+ gin_channels=gin_channels)
536
+ self.flow = ResidualCouplingBlock(inter_channels,
537
+ hidden_channels,
538
+ 5,
539
+ 1,
540
+ 4,
541
+ gin_channels=gin_channels)
542
+
543
+ if use_sdp:
544
+ self.dp = StochasticDurationPredictor(hidden_channels,
545
+ 192,
546
+ 3,
547
+ 0.5,
548
+ 4,
549
+ gin_channels=gin_channels)
550
+ else:
551
+ self.dp = DurationPredictor(hidden_channels,
552
+ 256,
553
+ 3,
554
+ 0.5,
555
+ gin_channels=gin_channels)
556
+
557
+ if n_speakers > 1:
558
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
559
+
560
+ def forward(self, x, x_lengths, y, y_lengths, sid=None):
561
+
562
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
563
+ if self.n_speakers > 0:
564
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
565
+ else:
566
+ g = None
567
+
568
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
569
+ z_p = self.flow(z, y_mask, g=g)
570
+
571
+ with torch.no_grad():
572
+ # negative cross-entropy
573
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
574
+ neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1],
575
+ keepdim=True) # [b, 1, t_s]
576
+ neg_cent2 = torch.matmul(
577
+ -0.5 * (z_p**2).transpose(1, 2),
578
+ s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
579
+ neg_cent3 = torch.matmul(
580
+ z_p.transpose(1, 2),
581
+ (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
582
+ neg_cent4 = torch.sum(-0.5 * (m_p**2) * s_p_sq_r, [1],
583
+ keepdim=True) # [b, 1, t_s]
584
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
585
+
586
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(
587
+ y_mask, -1)
588
+ attn = monotonic_align.maximum_path(
589
+ neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
590
+
591
+ w = attn.sum(2)
592
+ if self.use_sdp:
593
+ l_length = self.dp(x, x_mask, w, g=g)
594
+ l_length = l_length / torch.sum(x_mask)
595
+ else:
596
+ logw_ = torch.log(w + 1e-6) * x_mask
597
+ logw = self.dp(x, x_mask, g=g)
598
+ l_length = torch.sum(
599
+ (logw - logw_)**2, [1, 2]) / torch.sum(x_mask) # for averaging
600
+
601
+ # expand prior
602
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1,
603
+ 2)).transpose(1, 2)
604
+ logs_p = torch.matmul(attn.squeeze(1),
605
+ logs_p.transpose(1, 2)).transpose(1, 2)
606
+
607
+ z_slice, ids_slice = commons.rand_slice_segments(
608
+ z, y_lengths, self.segment_size)
609
+ o = self.dec(z_slice, g=g)
610
+ return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p,
611
+ logs_p, m_q,
612
+ logs_q)
613
+
614
+ def infer(self,
615
+ x,
616
+ x_lengths,
617
+ sid=None,
618
+ noise_scale=1,
619
+ length_scale=1,
620
+ noise_scale_w=1.,
621
+ max_len=None):
622
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
623
+ if self.n_speakers > 0:
624
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
625
+ else:
626
+ g = None
627
+
628
+ if self.use_sdp:
629
+ logw = self.dp(x,
630
+ x_mask,
631
+ g=g,
632
+ reverse=True,
633
+ noise_scale=noise_scale_w)
634
+ else:
635
+ logw = self.dp(x, x_mask, g=g)
636
+ w = torch.exp(logw) * x_mask * length_scale
637
+ w_ceil = torch.ceil(w)
638
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
639
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None),
640
+ 1).to(x_mask.dtype)
641
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
642
+ attn = commons.generate_path(w_ceil, attn_mask)
643
+
644
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
645
+ 1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
646
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(
647
+ 1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
648
+
649
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
650
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
651
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
652
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
653
+
654
+ def export_forward(self, x, x_lengths, scales, sid):
655
+ # shape of scales: Bx3, make triton happy
656
+ audio, *_ = self.infer(x,
657
+ x_lengths,
658
+ sid,
659
+ noise_scale=scales[0][0],
660
+ length_scale=scales[0][1],
661
+ noise_scale_w=scales[0][2])
662
+ return audio
663
+
664
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
665
+ assert self.n_speakers > 0, "n_speakers have to be larger than 0."
666
+ g_src = self.emb_g(sid_src).unsqueeze(-1)
667
+ g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
668
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
669
+ z_p = self.flow(z, y_mask, g=g_src)
670
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
671
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt)
672
+ return o_hat, y_mask, (z, z_p, z_hat)
export/vits/modules.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.nn import Conv1d
7
+ from torch.nn.utils import weight_norm, remove_weight_norm
8
+
9
+ import commons
10
+ from commons import init_weights, get_padding
11
+ from transforms import piecewise_rational_quadratic_transform
12
+
13
+ LRELU_SLOPE = 0.1
14
+
15
+
16
+ class LayerNorm(nn.Module):
17
+ def __init__(self, channels, eps=1e-5):
18
+ super().__init__()
19
+ self.channels = channels
20
+ self.eps = eps
21
+
22
+ self.gamma = nn.Parameter(torch.ones(channels))
23
+ self.beta = nn.Parameter(torch.zeros(channels))
24
+
25
+ def forward(self, x):
26
+ x = x.transpose(1, -1)
27
+ x = F.layer_norm(x, (self.channels, ), self.gamma, self.beta, self.eps)
28
+ return x.transpose(1, -1)
29
+
30
+
31
+ class ConvReluNorm(nn.Module):
32
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
33
+ n_layers, p_dropout):
34
+ super().__init__()
35
+ self.in_channels = in_channels
36
+ self.hidden_channels = hidden_channels
37
+ self.out_channels = out_channels
38
+ self.kernel_size = kernel_size
39
+ self.n_layers = n_layers
40
+ self.p_dropout = p_dropout
41
+ assert n_layers > 1, "Number of layers should be larger than 0."
42
+
43
+ self.conv_layers = nn.ModuleList()
44
+ self.norm_layers = nn.ModuleList()
45
+ self.conv_layers.append(
46
+ nn.Conv1d(in_channels,
47
+ hidden_channels,
48
+ kernel_size,
49
+ padding=kernel_size // 2))
50
+ self.norm_layers.append(LayerNorm(hidden_channels))
51
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
52
+ for _ in range(n_layers - 1):
53
+ self.conv_layers.append(
54
+ nn.Conv1d(hidden_channels,
55
+ hidden_channels,
56
+ kernel_size,
57
+ padding=kernel_size // 2))
58
+ self.norm_layers.append(LayerNorm(hidden_channels))
59
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
60
+ self.proj.weight.data.zero_()
61
+ self.proj.bias.data.zero_()
62
+
63
+ def forward(self, x, x_mask):
64
+ x_org = x
65
+ for i in range(self.n_layers):
66
+ x = self.conv_layers[i](x * x_mask)
67
+ x = self.norm_layers[i](x)
68
+ x = self.relu_drop(x)
69
+ x = x_org + self.proj(x)
70
+ return x * x_mask
71
+
72
+
73
+ class DDSConv(nn.Module):
74
+ """
75
+ Dialted and Depth-Separable Convolution
76
+ """
77
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
78
+ super().__init__()
79
+ self.channels = channels
80
+ self.kernel_size = kernel_size
81
+ self.n_layers = n_layers
82
+ self.p_dropout = p_dropout
83
+
84
+ self.drop = nn.Dropout(p_dropout)
85
+ self.convs_sep = nn.ModuleList()
86
+ self.convs_1x1 = nn.ModuleList()
87
+ self.norms_1 = nn.ModuleList()
88
+ self.norms_2 = nn.ModuleList()
89
+ for i in range(n_layers):
90
+ dilation = kernel_size**i
91
+ padding = (kernel_size * dilation - dilation) // 2
92
+ self.convs_sep.append(
93
+ nn.Conv1d(channels,
94
+ channels,
95
+ kernel_size,
96
+ groups=channels,
97
+ dilation=dilation,
98
+ padding=padding))
99
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
100
+ self.norms_1.append(LayerNorm(channels))
101
+ self.norms_2.append(LayerNorm(channels))
102
+
103
+ def forward(self, x, x_mask, g=None):
104
+ if g is not None:
105
+ x = x + g
106
+ for i in range(self.n_layers):
107
+ y = self.convs_sep[i](x * x_mask)
108
+ y = self.norms_1[i](y)
109
+ y = F.gelu(y)
110
+ y = self.convs_1x1[i](y)
111
+ y = self.norms_2[i](y)
112
+ y = F.gelu(y)
113
+ y = self.drop(y)
114
+ x = x + y
115
+ return x * x_mask
116
+
117
+
118
+ class WN(torch.nn.Module):
119
+ def __init__(self,
120
+ hidden_channels,
121
+ kernel_size,
122
+ dilation_rate,
123
+ n_layers,
124
+ gin_channels=0,
125
+ p_dropout=0):
126
+ super(WN, self).__init__()
127
+ assert (kernel_size % 2 == 1)
128
+ self.hidden_channels = hidden_channels
129
+ self.kernel_size = kernel_size,
130
+ self.dilation_rate = dilation_rate
131
+ self.n_layers = n_layers
132
+ self.gin_channels = gin_channels
133
+ self.p_dropout = p_dropout
134
+
135
+ self.in_layers = torch.nn.ModuleList()
136
+ self.res_skip_layers = torch.nn.ModuleList()
137
+ self.drop = nn.Dropout(p_dropout)
138
+
139
+ if gin_channels != 0:
140
+ cond_layer = torch.nn.Conv1d(gin_channels,
141
+ 2 * hidden_channels * n_layers, 1)
142
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
143
+ name='weight')
144
+
145
+ for i in range(n_layers):
146
+ dilation = dilation_rate**i
147
+ padding = int((kernel_size * dilation - dilation) / 2)
148
+ in_layer = torch.nn.Conv1d(hidden_channels,
149
+ 2 * hidden_channels,
150
+ kernel_size,
151
+ dilation=dilation,
152
+ padding=padding)
153
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
154
+ self.in_layers.append(in_layer)
155
+
156
+ # last one is not necessary
157
+ if i < n_layers - 1:
158
+ res_skip_channels = 2 * hidden_channels
159
+ else:
160
+ res_skip_channels = hidden_channels
161
+
162
+ res_skip_layer = torch.nn.Conv1d(hidden_channels,
163
+ res_skip_channels, 1)
164
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer,
165
+ name='weight')
166
+ self.res_skip_layers.append(res_skip_layer)
167
+
168
+ def forward(self, x, x_mask, g=None, **kwargs):
169
+ output = torch.zeros_like(x)
170
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
171
+
172
+ if g is not None:
173
+ g = self.cond_layer(g)
174
+
175
+ for i in range(self.n_layers):
176
+ x_in = self.in_layers[i](x)
177
+ if g is not None:
178
+ cond_offset = i * 2 * self.hidden_channels
179
+ g_l = g[:,
180
+ cond_offset:cond_offset + 2 * self.hidden_channels, :]
181
+ else:
182
+ g_l = torch.zeros_like(x_in)
183
+
184
+ acts = commons.fused_add_tanh_sigmoid_multiply(
185
+ x_in, g_l, n_channels_tensor)
186
+ acts = self.drop(acts)
187
+
188
+ res_skip_acts = self.res_skip_layers[i](acts)
189
+ if i < self.n_layers - 1:
190
+ res_acts = res_skip_acts[:, :self.hidden_channels, :]
191
+ x = (x + res_acts) * x_mask
192
+ output = output + res_skip_acts[:, self.hidden_channels:, :]
193
+ else:
194
+ output = output + res_skip_acts
195
+ return output * x_mask
196
+
197
+ def remove_weight_norm(self):
198
+ if self.gin_channels != 0:
199
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
200
+ for l in self.in_layers:
201
+ torch.nn.utils.remove_weight_norm(l)
202
+ for l in self.res_skip_layers:
203
+ torch.nn.utils.remove_weight_norm(l)
204
+
205
+
206
+ class ResBlock1(torch.nn.Module):
207
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
208
+ super(ResBlock1, self).__init__()
209
+ self.convs1 = nn.ModuleList([
210
+ weight_norm(
211
+ Conv1d(channels,
212
+ channels,
213
+ kernel_size,
214
+ 1,
215
+ dilation=dilation[0],
216
+ padding=get_padding(kernel_size, dilation[0]))),
217
+ weight_norm(
218
+ Conv1d(channels,
219
+ channels,
220
+ kernel_size,
221
+ 1,
222
+ dilation=dilation[1],
223
+ padding=get_padding(kernel_size, dilation[1]))),
224
+ weight_norm(
225
+ Conv1d(channels,
226
+ channels,
227
+ kernel_size,
228
+ 1,
229
+ dilation=dilation[2],
230
+ padding=get_padding(kernel_size, dilation[2])))
231
+ ])
232
+ self.convs1.apply(init_weights)
233
+
234
+ self.convs2 = nn.ModuleList([
235
+ weight_norm(
236
+ Conv1d(channels,
237
+ channels,
238
+ kernel_size,
239
+ 1,
240
+ dilation=1,
241
+ padding=get_padding(kernel_size, 1))),
242
+ weight_norm(
243
+ Conv1d(channels,
244
+ channels,
245
+ kernel_size,
246
+ 1,
247
+ dilation=1,
248
+ padding=get_padding(kernel_size, 1))),
249
+ weight_norm(
250
+ Conv1d(channels,
251
+ channels,
252
+ kernel_size,
253
+ 1,
254
+ dilation=1,
255
+ padding=get_padding(kernel_size, 1)))
256
+ ])
257
+ self.convs2.apply(init_weights)
258
+
259
+ def forward(self, x, x_mask=None):
260
+ for c1, c2 in zip(self.convs1, self.convs2):
261
+ xt = F.leaky_relu(x, LRELU_SLOPE)
262
+ if x_mask is not None:
263
+ xt = xt * x_mask
264
+ xt = c1(xt)
265
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
266
+ if x_mask is not None:
267
+ xt = xt * x_mask
268
+ xt = c2(xt)
269
+ x = xt + x
270
+ if x_mask is not None:
271
+ x = x * x_mask
272
+ return x
273
+
274
+ def remove_weight_norm(self):
275
+ for l in self.convs1:
276
+ remove_weight_norm(l)
277
+ for l in self.convs2:
278
+ remove_weight_norm(l)
279
+
280
+
281
+ class ResBlock2(torch.nn.Module):
282
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
283
+ super(ResBlock2, self).__init__()
284
+ self.convs = nn.ModuleList([
285
+ weight_norm(
286
+ Conv1d(channels,
287
+ channels,
288
+ kernel_size,
289
+ 1,
290
+ dilation=dilation[0],
291
+ padding=get_padding(kernel_size, dilation[0]))),
292
+ weight_norm(
293
+ Conv1d(channels,
294
+ channels,
295
+ kernel_size,
296
+ 1,
297
+ dilation=dilation[1],
298
+ padding=get_padding(kernel_size, dilation[1])))
299
+ ])
300
+ self.convs.apply(init_weights)
301
+
302
+ def forward(self, x, x_mask=None):
303
+ for c in self.convs:
304
+ xt = F.leaky_relu(x, LRELU_SLOPE)
305
+ if x_mask is not None:
306
+ xt = xt * x_mask
307
+ xt = c(xt)
308
+ x = xt + x
309
+ if x_mask is not None:
310
+ x = x * x_mask
311
+ return x
312
+
313
+ def remove_weight_norm(self):
314
+ for l in self.convs:
315
+ remove_weight_norm(l)
316
+
317
+
318
+ class Log(nn.Module):
319
+ def forward(self, x, x_mask, reverse=False, **kwargs):
320
+ if not reverse:
321
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
322
+ logdet = torch.sum(-y, [1, 2])
323
+ return y, logdet
324
+ else:
325
+ x = torch.exp(x) * x_mask
326
+ return x
327
+
328
+
329
+ class Flip(nn.Module):
330
+ def forward(self, x, *args, reverse=False, **kwargs):
331
+ x = torch.flip(x, [1])
332
+ if not reverse:
333
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
334
+ return x, logdet
335
+ else:
336
+ return x
337
+
338
+
339
+ class ElementwiseAffine(nn.Module):
340
+ def __init__(self, channels):
341
+ super().__init__()
342
+ self.channels = channels
343
+ self.m = nn.Parameter(torch.zeros(channels, 1))
344
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
345
+
346
+ def forward(self, x, x_mask, reverse=False, **kwargs):
347
+ if not reverse:
348
+ y = self.m + torch.exp(self.logs) * x
349
+ y = y * x_mask
350
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
351
+ return y, logdet
352
+ else:
353
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
354
+ return x
355
+
356
+
357
+ class ResidualCouplingLayer(nn.Module):
358
+ def __init__(self,
359
+ channels,
360
+ hidden_channels,
361
+ kernel_size,
362
+ dilation_rate,
363
+ n_layers,
364
+ p_dropout=0,
365
+ gin_channels=0,
366
+ mean_only=False):
367
+ assert channels % 2 == 0, "channels should be divisible by 2"
368
+ super().__init__()
369
+ self.channels = channels
370
+ self.hidden_channels = hidden_channels
371
+ self.kernel_size = kernel_size
372
+ self.dilation_rate = dilation_rate
373
+ self.n_layers = n_layers
374
+ self.half_channels = channels // 2
375
+ self.mean_only = mean_only
376
+
377
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
378
+ self.enc = WN(hidden_channels,
379
+ kernel_size,
380
+ dilation_rate,
381
+ n_layers,
382
+ p_dropout=p_dropout,
383
+ gin_channels=gin_channels)
384
+ self.post = nn.Conv1d(hidden_channels,
385
+ self.half_channels * (2 - mean_only), 1)
386
+ self.post.weight.data.zero_()
387
+ self.post.bias.data.zero_()
388
+
389
+ def forward(self, x, x_mask, g=None, reverse=False):
390
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
391
+ h = self.pre(x0) * x_mask
392
+ h = self.enc(h, x_mask, g=g)
393
+ stats = self.post(h) * x_mask
394
+ if not self.mean_only:
395
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
396
+ else:
397
+ m = stats
398
+ logs = torch.zeros_like(m)
399
+
400
+ if not reverse:
401
+ x1 = m + x1 * torch.exp(logs) * x_mask
402
+ x = torch.cat([x0, x1], 1)
403
+ logdet = torch.sum(logs, [1, 2])
404
+ return x, logdet
405
+ else:
406
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
407
+ x = torch.cat([x0, x1], 1)
408
+ return x
409
+
410
+
411
+ class ConvFlow(nn.Module):
412
+ def __init__(self,
413
+ in_channels,
414
+ filter_channels,
415
+ kernel_size,
416
+ n_layers,
417
+ num_bins=10,
418
+ tail_bound=5.0):
419
+ super().__init__()
420
+ self.in_channels = in_channels
421
+ self.filter_channels = filter_channels
422
+ self.kernel_size = kernel_size
423
+ self.n_layers = n_layers
424
+ self.num_bins = num_bins
425
+ self.tail_bound = tail_bound
426
+ self.half_channels = in_channels // 2
427
+
428
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
429
+ self.convs = DDSConv(filter_channels,
430
+ kernel_size,
431
+ n_layers,
432
+ p_dropout=0.)
433
+ self.proj = nn.Conv1d(filter_channels,
434
+ self.half_channels * (num_bins * 3 - 1), 1)
435
+ self.proj.weight.data.zero_()
436
+ self.proj.bias.data.zero_()
437
+
438
+ def forward(self, x, x_mask, g=None, reverse=False):
439
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
440
+ h = self.pre(x0)
441
+ h = self.convs(h, x_mask, g=g)
442
+ h = self.proj(h) * x_mask
443
+
444
+ b, c, t = x0.shape
445
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3,
446
+ 2) # [b, cx?, t] -> [b, c, t, ?]
447
+
448
+ unnormalized_widths = h[..., :self.num_bins] / math.sqrt(
449
+ self.filter_channels)
450
+ unnormalized_heights = h[...,
451
+ self.num_bins:2 * self.num_bins] / math.sqrt(
452
+ self.filter_channels)
453
+ unnormalized_derivatives = h[..., 2 * self.num_bins:]
454
+
455
+ x1, logabsdet = piecewise_rational_quadratic_transform(
456
+ x1,
457
+ unnormalized_widths,
458
+ unnormalized_heights,
459
+ unnormalized_derivatives,
460
+ inverse=reverse,
461
+ tails='linear',
462
+ tail_bound=self.tail_bound)
463
+
464
+ x = torch.cat([x0, x1], 1) * x_mask
465
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
466
+ if not reverse:
467
+ return x, logdet
468
+ else:
469
+ return x
export/vits/text/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017 Keith Ito
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
export/vits/text/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ from text import cleaners
3
+
4
+
5
+ def text_to_sequence(text, symbols, cleaner_names):
6
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
7
+ Args:
8
+ text: string to convert to a sequence
9
+ cleaner_names: names of the cleaner functions to run the text through
10
+ Returns:
11
+ List of integers corresponding to the symbols in the text
12
+ '''
13
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
14
+
15
+ sequence = []
16
+
17
+ clean_text = _clean_text(text, cleaner_names)
18
+
19
+ sequence = [
20
+ _symbol_to_id[symbol] for symbol in clean_text if symbol in _symbol_to_id.keys()
21
+ ]
22
+
23
+ # for symbol in clean_text:
24
+ # if symbol not in _symbol_to_id.keys():
25
+ # continue
26
+ # symbol_id = _symbol_to_id[symbol]
27
+ # sequence += [symbol_id]
28
+ return sequence
29
+
30
+
31
+ def _clean_text(text, cleaner_names):
32
+ for name in cleaner_names:
33
+ cleaner = getattr(cleaners, name)
34
+ if not cleaner:
35
+ raise Exception('Unknown cleaner: %s' % name)
36
+ text = cleaner(text)
37
+ return text
export/vits/text/cleaners.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from unidecode import unidecode
3
+ import pyopenjtalk
4
+
5
+ pyopenjtalk._lazy_init()
6
+
7
+ # Regular expression matching Japanese without punctuation marks:
8
+ _japanese_characters = re.compile(
9
+ r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
10
+
11
+ # Regular expression matching non-Japanese characters or punctuation marks:
12
+ _japanese_marks = re.compile(
13
+ r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
14
+
15
+
16
+ def japanese_cleaners(text):
17
+ '''Pipeline for notating accent in Japanese text.'''
18
+ '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
19
+ sentences = re.split(_japanese_marks, text)
20
+ marks = re.findall(_japanese_marks, text)
21
+ text = ''
22
+ for i, sentence in enumerate(sentences):
23
+ if re.match(_japanese_characters, sentence):
24
+ if text != '':
25
+ text += ' '
26
+ labels = pyopenjtalk.extract_fullcontext(sentence)
27
+ for n, label in enumerate(labels):
28
+ phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
29
+ if phoneme not in ['sil', 'pau']:
30
+ text += phoneme.replace('ch', 'ʧ').replace('sh', 'ʃ').replace('cl', 'Q')
31
+ else:
32
+ continue
33
+ n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
34
+ a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
35
+ a2 = int(re.search(r"\+(\d+)\+", label).group(1))
36
+ a3 = int(re.search(r"\+(\d+)/", label).group(1))
37
+ if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
38
+ a2_next = -1
39
+ else:
40
+ a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
41
+ # Accent phrase boundary
42
+ if a3 == 1 and a2_next == 1:
43
+ text += ' '
44
+ # Falling
45
+ elif a1 == 0 and a2_next == a2 + 1 and a2 != n_moras:
46
+ text += '↓'
47
+ # Rising
48
+ elif a2 == 1 and a2_next == 2:
49
+ text += '↑'
50
+ if i < len(marks):
51
+ text += unidecode(marks[i]).replace(' ', '')
52
+ if re.match('[A-Za-z]', text[-1]):
53
+ text += '.'
54
+ return text
55
+
56
+
57
+ def japanese_cleaners2(text):
58
+ return japanese_cleaners(text).replace('ts','ʦ').replace('...','…')
export/vits/text/symbols.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Defines the set of symbols used in text input to the model.
3
+ '''
4
+
5
+ '''# japanese_cleaners
6
+ _pad = '_'
7
+ _punctuation = ',.!?-'
8
+ _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
9
+ '''
10
+ # jp_cleaners
11
+ _pad = '_'
12
+ _punctuation = ',.!?-'
13
+ _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
14
+
15
+
16
+
17
+ # japanese_cleaners2
18
+ # _pad = '_'
19
+ # _punctuation = ',.!?-~…'
20
+ # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
21
+
22
+
23
+ '''# korean_cleaners
24
+ _pad = '_'
25
+ _punctuation = ',.!?…~'
26
+ _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
27
+ '''
28
+
29
+ '''# chinese_cleaners
30
+ _pad = '_'
31
+ _punctuation = ',。!?—…'
32
+ _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
33
+ '''
34
+
35
+ # Export all symbols:
36
+ symbols = [_pad] + list(_punctuation) + list(_letters)
37
+
38
+ # Special symbol ids
39
+ SPACE_ID = symbols.index(" ")
export/vits/train.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+ from torch.utils.data import DataLoader
6
+ from torch.utils.tensorboard import SummaryWriter
7
+ import torch.multiprocessing as mp
8
+ import torch.distributed as dist
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+ from torch.cuda.amp import autocast, GradScaler
11
+
12
+ import commons
13
+ import utils
14
+ from data_utils import (TextAudioSpeakerLoader, TextAudioSpeakerCollate,
15
+ DistributedBucketSampler)
16
+ from models import (
17
+ SynthesizerTrn,
18
+ MultiPeriodDiscriminator,
19
+ )
20
+ from losses import (generator_loss, discriminator_loss, feature_loss, kl_loss)
21
+ from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
22
+
23
+ torch.backends.cudnn.benchmark = True
24
+ global_step = 0
25
+
26
+
27
+ def main():
28
+ """Assume Single Node Multi GPUs Training Only"""
29
+ assert torch.cuda.is_available(), "CPU training is not allowed."
30
+
31
+ n_gpus = torch.cuda.device_count()
32
+ hps = utils.get_hparams()
33
+ mp.spawn(run, nprocs=n_gpus, args=(
34
+ n_gpus,
35
+ hps,
36
+ ))
37
+
38
+
39
+ def run(rank, n_gpus, hps):
40
+ global global_step
41
+ if rank == 0:
42
+ logger = utils.get_logger(hps.model_dir)
43
+ logger.info(hps)
44
+ utils.check_git_hash(hps.model_dir)
45
+ writer = SummaryWriter(log_dir=hps.model_dir)
46
+ writer_eval = SummaryWriter(
47
+ log_dir=os.path.join(hps.model_dir, "eval"))
48
+
49
+ dist.init_process_group(backend='nccl',
50
+ init_method='env://',
51
+ world_size=n_gpus,
52
+ rank=rank)
53
+ torch.manual_seed(hps.train.seed)
54
+ torch.cuda.set_device(rank)
55
+ train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
56
+ train_sampler = DistributedBucketSampler(
57
+ train_dataset,
58
+ hps.train.batch_size, [32, 300, 400, 500, 600, 700, 800, 900, 1000],
59
+ num_replicas=n_gpus,
60
+ rank=rank,
61
+ shuffle=True)
62
+ collate_fn = TextAudioSpeakerCollate()
63
+ train_loader = DataLoader(train_dataset,
64
+ num_workers=8,
65
+ shuffle=False,
66
+ pin_memory=True,
67
+ collate_fn=collate_fn,
68
+ batch_sampler=train_sampler)
69
+ if rank == 0:
70
+ eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files,
71
+ hps.data)
72
+ eval_loader = DataLoader(eval_dataset,
73
+ num_workers=8,
74
+ shuffle=False,
75
+ batch_size=hps.train.batch_size,
76
+ pin_memory=True,
77
+ drop_last=False,
78
+ collate_fn=collate_fn)
79
+
80
+ net_g = SynthesizerTrn(hps.data.num_phones,
81
+ hps.data.filter_length // 2 + 1,
82
+ hps.train.segment_size // hps.data.hop_length,
83
+ n_speakers=hps.data.n_speakers,
84
+ **hps.model).cuda(rank)
85
+ net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
86
+ optim_g = torch.optim.AdamW(net_g.parameters(),
87
+ hps.train.learning_rate,
88
+ betas=hps.train.betas,
89
+ eps=hps.train.eps)
90
+ optim_d = torch.optim.AdamW(net_d.parameters(),
91
+ hps.train.learning_rate,
92
+ betas=hps.train.betas,
93
+ eps=hps.train.eps)
94
+ net_g = DDP(net_g, device_ids=[rank])
95
+ net_d = DDP(net_d, device_ids=[rank])
96
+
97
+ try:
98
+ _, _, _, epoch_str = utils.load_checkpoint(
99
+ utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
100
+ optim_g)
101
+ _, _, _, epoch_str = utils.load_checkpoint(
102
+ utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,
103
+ optim_d)
104
+ global_step = (epoch_str - 1) * len(train_loader)
105
+ except Exception as e:
106
+ epoch_str = 1
107
+ global_step = 0
108
+
109
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
110
+ optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
111
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
112
+ optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
113
+
114
+ scaler = GradScaler(enabled=hps.train.fp16_run)
115
+
116
+ for epoch in range(epoch_str, hps.train.epochs + 1):
117
+ if rank == 0:
118
+ train_and_evaluate(rank, epoch, hps, [net_g, net_d],
119
+ [optim_g, optim_d], [scheduler_g, scheduler_d],
120
+ scaler, [train_loader, eval_loader], logger,
121
+ [writer, writer_eval])
122
+ else:
123
+ train_and_evaluate(rank, epoch, hps, [net_g, net_d],
124
+ [optim_g, optim_d], [scheduler_g, scheduler_d],
125
+ scaler, [train_loader, None], None, None)
126
+ scheduler_g.step()
127
+ scheduler_d.step()
128
+
129
+
130
+ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler,
131
+ loaders, logger, writers):
132
+ net_g, net_d = nets
133
+ optim_g, optim_d = optims
134
+ scheduler_g, scheduler_d = schedulers
135
+ train_loader, eval_loader = loaders
136
+ if writers is not None:
137
+ writer, writer_eval = writers
138
+
139
+ train_loader.batch_sampler.set_epoch(epoch)
140
+ global global_step
141
+
142
+ net_g.train()
143
+ net_d.train()
144
+ for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths,
145
+ speakers) in enumerate(train_loader):
146
+ x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(
147
+ rank, non_blocking=True)
148
+ spec, spec_lengths = spec.cuda(
149
+ rank, non_blocking=True), spec_lengths.cuda(rank,
150
+ non_blocking=True)
151
+ y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
152
+ rank, non_blocking=True)
153
+ speakers = speakers.cuda(rank, non_blocking=True)
154
+
155
+ with autocast(enabled=hps.train.fp16_run):
156
+ y_hat, l_length, attn, ids_slice, x_mask, z_mask, (
157
+ z, z_p, m_p, logs_p, m_q,
158
+ logs_q) = net_g(x, x_lengths, spec, spec_lengths, speakers)
159
+
160
+ mel = spec_to_mel_torch(spec, hps.data.filter_length,
161
+ hps.data.n_mel_channels,
162
+ hps.data.sampling_rate, hps.data.mel_fmin,
163
+ hps.data.mel_fmax)
164
+ y_mel = commons.slice_segments(
165
+ mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
166
+ y_hat_mel = mel_spectrogram_torch(
167
+ y_hat.squeeze(1), hps.data.filter_length,
168
+ hps.data.n_mel_channels, hps.data.sampling_rate,
169
+ hps.data.hop_length, hps.data.win_length, hps.data.mel_fmin,
170
+ hps.data.mel_fmax)
171
+
172
+ y = commons.slice_segments(y, ids_slice * hps.data.hop_length,
173
+ hps.train.segment_size) # slice
174
+
175
+ # Discriminator
176
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
177
+ with autocast(enabled=False):
178
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
179
+ y_d_hat_r, y_d_hat_g)
180
+ loss_disc_all = loss_disc
181
+ optim_d.zero_grad()
182
+ scaler.scale(loss_disc_all).backward()
183
+ scaler.unscale_(optim_d)
184
+ grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
185
+ scaler.step(optim_d)
186
+
187
+ with autocast(enabled=hps.train.fp16_run):
188
+ # Generator
189
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
190
+ with autocast(enabled=False):
191
+ loss_dur = torch.sum(l_length.float())
192
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
193
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p,
194
+ z_mask) * hps.train.c_kl
195
+
196
+ loss_fm = feature_loss(fmap_r, fmap_g)
197
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
198
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
199
+ optim_g.zero_grad()
200
+ scaler.scale(loss_gen_all).backward()
201
+ scaler.unscale_(optim_g)
202
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
203
+ scaler.step(optim_g)
204
+ scaler.update()
205
+
206
+ if rank == 0:
207
+ if global_step % hps.train.log_interval == 0:
208
+ lr = optim_g.param_groups[0]['lr']
209
+ losses = [
210
+ loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl
211
+ ]
212
+ logger.info('Train Epoch: {} [{:.0f}%]'.format(
213
+ epoch, 100. * batch_idx / len(train_loader)))
214
+ logger.info([x.item() for x in losses] + [global_step, lr])
215
+
216
+ scalar_dict = {
217
+ "loss/g/total": loss_gen_all,
218
+ "loss/d/total": loss_disc_all,
219
+ "learning_rate": lr,
220
+ "grad_norm_d": grad_norm_d,
221
+ "grad_norm_g": grad_norm_g
222
+ }
223
+ scalar_dict.update({
224
+ "loss/g/fm": loss_fm,
225
+ "loss/g/mel": loss_mel,
226
+ "loss/g/dur": loss_dur,
227
+ "loss/g/kl": loss_kl
228
+ })
229
+
230
+ scalar_dict.update({
231
+ "loss/g/{}".format(i): v
232
+ for i, v in enumerate(losses_gen)
233
+ })
234
+ scalar_dict.update({
235
+ "loss/d_r/{}".format(i): v
236
+ for i, v in enumerate(losses_disc_r)
237
+ })
238
+ scalar_dict.update({
239
+ "loss/d_g/{}".format(i): v
240
+ for i, v in enumerate(losses_disc_g)
241
+ })
242
+ image_dict = {
243
+ "slice/mel_org":
244
+ utils.plot_spectrogram_to_numpy(
245
+ y_mel[0].data.cpu().numpy()),
246
+ "slice/mel_gen":
247
+ utils.plot_spectrogram_to_numpy(
248
+ y_hat_mel[0].data.cpu().numpy()),
249
+ "all/mel":
250
+ utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
251
+ "all/attn":
252
+ utils.plot_alignment_to_numpy(attn[0,
253
+ 0].data.cpu().numpy())
254
+ }
255
+ utils.summarize(writer=writer,
256
+ global_step=global_step,
257
+ images=image_dict,
258
+ scalars=scalar_dict)
259
+
260
+ if global_step % hps.train.eval_interval == 0:
261
+ evaluate(hps, net_g, eval_loader, writer_eval)
262
+ utils.save_checkpoint(
263
+ net_g, optim_g, hps.train.learning_rate, epoch,
264
+ os.path.join(hps.model_dir,
265
+ "G_{}.pth".format(global_step)))
266
+ utils.save_checkpoint(
267
+ net_d, optim_d, hps.train.learning_rate, epoch,
268
+ os.path.join(hps.model_dir,
269
+ "D_{}.pth".format(global_step)))
270
+ global_step += 1
271
+
272
+ if rank == 0:
273
+ logger.info('====> Epoch: {}'.format(epoch))
274
+
275
+
276
+ def evaluate(hps, generator, eval_loader, writer_eval):
277
+ generator.eval()
278
+ with torch.no_grad():
279
+ for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths,
280
+ speakers) in enumerate(eval_loader):
281
+ x, x_lengths = x.cuda(0), x_lengths.cuda(0)
282
+ spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0)
283
+ y, y_lengths = y.cuda(0), y_lengths.cuda(0)
284
+ speakers = speakers.cuda(0)
285
+
286
+ # remove else
287
+ x = x[:1]
288
+ x_lengths = x_lengths[:1]
289
+ spec = spec[:1]
290
+ spec_lengths = spec_lengths[:1]
291
+ y = y[:1]
292
+ y_lengths = y_lengths[:1]
293
+ speakers = speakers[:1]
294
+ break
295
+ y_hat, attn, mask, *_ = generator.module.infer(x,
296
+ x_lengths,
297
+ speakers,
298
+ max_len=1000)
299
+ y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
300
+
301
+ mel = spec_to_mel_torch(spec, hps.data.filter_length,
302
+ hps.data.n_mel_channels,
303
+ hps.data.sampling_rate, hps.data.mel_fmin,
304
+ hps.data.mel_fmax)
305
+ y_hat_mel = mel_spectrogram_torch(
306
+ y_hat.squeeze(1).float(), hps.data.filter_length,
307
+ hps.data.n_mel_channels, hps.data.sampling_rate,
308
+ hps.data.hop_length, hps.data.win_length, hps.data.mel_fmin,
309
+ hps.data.mel_fmax)
310
+ image_dict = {
311
+ "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())
312
+ }
313
+ audio_dict = {"gen/audio": y_hat[0, :, :y_hat_lengths[0]]}
314
+ if global_step == 0:
315
+ image_dict.update(
316
+ {"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
317
+ audio_dict.update({"gt/audio": y[0, :, :y_lengths[0]]})
318
+
319
+ utils.summarize(writer=writer_eval,
320
+ global_step=global_step,
321
+ images=image_dict,
322
+ audios=audio_dict,
323
+ audio_sampling_rate=hps.data.sampling_rate)
324
+ generator.train()
325
+
326
+
327
+ if __name__ == "__main__":
328
+ main()
export/vits/transforms.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
6
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
7
+ DEFAULT_MIN_DERIVATIVE = 1e-3
8
+
9
+
10
+ def piecewise_rational_quadratic_transform(
11
+ inputs,
12
+ unnormalized_widths,
13
+ unnormalized_heights,
14
+ unnormalized_derivatives,
15
+ inverse=False,
16
+ tails=None,
17
+ tail_bound=1.,
18
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
19
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
20
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
21
+
22
+ if tails is None:
23
+ spline_fn = rational_quadratic_spline
24
+ spline_kwargs = {}
25
+ else:
26
+ spline_fn = unconstrained_rational_quadratic_spline
27
+ spline_kwargs = {'tails': tails, 'tail_bound': tail_bound}
28
+
29
+ outputs, logabsdet = spline_fn(
30
+ inputs=inputs,
31
+ unnormalized_widths=unnormalized_widths,
32
+ unnormalized_heights=unnormalized_heights,
33
+ unnormalized_derivatives=unnormalized_derivatives,
34
+ inverse=inverse,
35
+ min_bin_width=min_bin_width,
36
+ min_bin_height=min_bin_height,
37
+ min_derivative=min_derivative,
38
+ **spline_kwargs)
39
+ return outputs, logabsdet
40
+
41
+
42
+ def searchsorted(bin_locations, inputs, eps=1e-6):
43
+ bin_locations[..., bin_locations.size(-1) - 1] += eps
44
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
45
+
46
+
47
+ def unconstrained_rational_quadratic_spline(
48
+ inputs,
49
+ unnormalized_widths,
50
+ unnormalized_heights,
51
+ unnormalized_derivatives,
52
+ inverse=False,
53
+ tails='linear',
54
+ tail_bound=1.,
55
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
56
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
57
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
58
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
59
+ outside_interval_mask = ~inside_interval_mask
60
+
61
+ outputs = torch.zeros_like(inputs)
62
+ logabsdet = torch.zeros_like(inputs)
63
+
64
+ if tails == 'linear':
65
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
66
+ constant = np.log(np.exp(1 - min_derivative) - 1)
67
+ unnormalized_derivatives[..., 0] = constant
68
+ unnormalized_derivatives[..., unnormalized_derivatives.size(-1) - 1] = constant
69
+
70
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
71
+ logabsdet[outside_interval_mask] = 0
72
+ else:
73
+ raise RuntimeError('{} tails are not implemented.'.format(tails))
74
+
75
+ outputs[inside_interval_mask], logabsdet[
76
+ inside_interval_mask] = rational_quadratic_spline(
77
+ inputs=inputs[inside_interval_mask],
78
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
79
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
80
+ unnormalized_derivatives=unnormalized_derivatives[
81
+ inside_interval_mask, :],
82
+ inverse=inverse,
83
+ left=-tail_bound,
84
+ right=tail_bound,
85
+ bottom=-tail_bound,
86
+ top=tail_bound,
87
+ min_bin_width=min_bin_width,
88
+ min_bin_height=min_bin_height,
89
+ min_derivative=min_derivative)
90
+
91
+ return outputs, logabsdet
92
+
93
+
94
+ def rational_quadratic_spline(inputs,
95
+ unnormalized_widths,
96
+ unnormalized_heights,
97
+ unnormalized_derivatives,
98
+ inverse=False,
99
+ left=0.,
100
+ right=1.,
101
+ bottom=0.,
102
+ top=1.,
103
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
104
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
105
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
106
+ if torch.min(inputs) < left or torch.max(inputs) > right:
107
+ raise ValueError('Input to a transform is not within its domain')
108
+
109
+ num_bins = unnormalized_widths.shape[-1]
110
+
111
+ if min_bin_width * num_bins > 1.0:
112
+ raise ValueError('Minimal bin width too large for the number of bins')
113
+ if min_bin_height * num_bins > 1.0:
114
+ raise ValueError('Minimal bin height too large for the number of bins')
115
+
116
+ widths = F.softmax(unnormalized_widths, dim=-1)
117
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
118
+ cumwidths = torch.cumsum(widths, dim=-1)
119
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
120
+ cumwidths = (right - left) * cumwidths + left
121
+ cumwidths[..., 0] = left
122
+ cumwidths[..., cumwidths.size(-1) - 1] = right
123
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
124
+
125
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
126
+
127
+ heights = F.softmax(unnormalized_heights, dim=-1)
128
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
129
+ cumheights = torch.cumsum(heights, dim=-1)
130
+ cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
131
+ cumheights = (top - bottom) * cumheights + bottom
132
+ cumheights[..., 0] = bottom
133
+ cumheights[..., cumheights.size(-1) - 1] = top
134
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
135
+
136
+ if inverse:
137
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
138
+ else:
139
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
140
+
141
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
142
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
143
+
144
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
145
+ delta = heights / widths
146
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
147
+
148
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
149
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[...,
150
+ 0]
151
+
152
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
153
+
154
+ if inverse:
155
+ a = (
156
+ ((inputs - input_cumheights) *
157
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
158
+ + input_heights * (input_delta - input_derivatives)))
159
+ b = (
160
+ input_heights * input_derivatives - (inputs - input_cumheights) *
161
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta))
162
+ c = -input_delta * (inputs - input_cumheights)
163
+
164
+ discriminant = b.pow(2) - 4 * a * c
165
+ assert (discriminant >= 0).all()
166
+
167
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
168
+ outputs = root * input_bin_widths + input_cumwidths
169
+
170
+ theta_one_minus_theta = root * (1 - root)
171
+ denominator = input_delta + (
172
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
173
+ * theta_one_minus_theta)
174
+ derivative_numerator = input_delta.pow(2) * (
175
+ input_derivatives_plus_one * root.pow(2) +
176
+ 2 * input_delta * theta_one_minus_theta + input_derivatives *
177
+ (1 - root).pow(2))
178
+ logabsdet = torch.log(
179
+ derivative_numerator) - 2 * torch.log(denominator)
180
+
181
+ return outputs, -logabsdet
182
+ else:
183
+ theta = (inputs - input_cumwidths) / input_bin_widths
184
+ theta_one_minus_theta = theta * (1 - theta)
185
+
186
+ numerator = input_heights * (input_delta * theta.pow(2) +
187
+ input_derivatives * theta_one_minus_theta)
188
+ denominator = input_delta + (
189
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
190
+ * theta_one_minus_theta)
191
+ outputs = input_cumheights + numerator / denominator
192
+
193
+ derivative_numerator = input_delta.pow(2) * (
194
+ input_derivatives_plus_one * theta.pow(2) +
195
+ 2 * input_delta * theta_one_minus_theta + input_derivatives *
196
+ (1 - theta).pow(2))
197
+ logabsdet = torch.log(
198
+ derivative_numerator) - 2 * torch.log(denominator)
199
+
200
+ return outputs, logabsdet
export/vits/utils.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import json
4
+ import logging
5
+ import os
6
+ import subprocess
7
+ import sys
8
+
9
+ import numpy as np
10
+ from scipy.io.wavfile import read
11
+ import torch
12
+
13
+ MATPLOTLIB_FLAG = False
14
+
15
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
16
+ logger = logging
17
+
18
+
19
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
20
+ assert os.path.isfile(checkpoint_path)
21
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
22
+ iteration = checkpoint_dict['iteration']
23
+ learning_rate = checkpoint_dict['learning_rate']
24
+ if optimizer is not None:
25
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
26
+ saved_state_dict = checkpoint_dict['model']
27
+ if hasattr(model, 'module'):
28
+ state_dict = model.module.state_dict()
29
+ else:
30
+ state_dict = model.state_dict()
31
+ new_state_dict = {}
32
+ for k, v in state_dict.items():
33
+ try:
34
+ new_state_dict[k] = saved_state_dict[k]
35
+ except Exception as e:
36
+ logger.info("%s is not in the checkpoint" % k)
37
+ new_state_dict[k] = v
38
+ if hasattr(model, 'module'):
39
+ model.module.load_state_dict(new_state_dict)
40
+ else:
41
+ model.load_state_dict(new_state_dict)
42
+ logger.info("Loaded checkpoint '{}' (iteration {})".format(
43
+ checkpoint_path, iteration))
44
+ return model, optimizer, learning_rate, iteration
45
+
46
+
47
+ def save_checkpoint(model, optimizer, learning_rate, iteration,
48
+ checkpoint_path):
49
+ logger.info(
50
+ "Saving model and optimizer state at iteration {} to {}".format(
51
+ iteration, checkpoint_path))
52
+ if hasattr(model, 'module'):
53
+ state_dict = model.module.state_dict()
54
+ else:
55
+ state_dict = model.state_dict()
56
+ torch.save(
57
+ {
58
+ 'model': state_dict,
59
+ 'iteration': iteration,
60
+ 'optimizer': optimizer.state_dict(),
61
+ 'learning_rate': learning_rate
62
+ }, checkpoint_path)
63
+
64
+
65
+ def summarize(
66
+ writer,
67
+ global_step,
68
+ scalars={}, # noqa
69
+ histograms={}, # noqa
70
+ images={}, # noqa
71
+ audios={}, # noqa
72
+ audio_sampling_rate=22050):
73
+ for k, v in scalars.items():
74
+ writer.add_scalar(k, v, global_step)
75
+ for k, v in histograms.items():
76
+ writer.add_histogram(k, v, global_step)
77
+ for k, v in images.items():
78
+ writer.add_image(k, v, global_step, dataformats='HWC')
79
+ for k, v in audios.items():
80
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
81
+
82
+
83
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
84
+ f_list = glob.glob(os.path.join(dir_path, regex))
85
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
86
+ x = f_list[-1]
87
+ print(x)
88
+ return x
89
+
90
+
91
+ def plot_spectrogram_to_numpy(spectrogram):
92
+ global MATPLOTLIB_FLAG
93
+ if not MATPLOTLIB_FLAG:
94
+ import matplotlib
95
+ matplotlib.use("Agg")
96
+ MATPLOTLIB_FLAG = True
97
+ mpl_logger = logging.getLogger('matplotlib')
98
+ mpl_logger.setLevel(logging.WARNING)
99
+ import matplotlib.pylab as plt
100
+ import numpy as np
101
+
102
+ fig, ax = plt.subplots(figsize=(10, 2))
103
+ im = ax.imshow(spectrogram,
104
+ aspect="auto",
105
+ origin="lower",
106
+ interpolation='none')
107
+ plt.colorbar(im, ax=ax)
108
+ plt.xlabel("Frames")
109
+ plt.ylabel("Channels")
110
+ plt.tight_layout()
111
+
112
+ fig.canvas.draw()
113
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
114
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
115
+ plt.close()
116
+ return data
117
+
118
+
119
+ def plot_alignment_to_numpy(alignment, info=None):
120
+ global MATPLOTLIB_FLAG
121
+ if not MATPLOTLIB_FLAG:
122
+ import matplotlib
123
+ matplotlib.use("Agg")
124
+ MATPLOTLIB_FLAG = True
125
+ mpl_logger = logging.getLogger('matplotlib')
126
+ mpl_logger.setLevel(logging.WARNING)
127
+ import matplotlib.pylab as plt
128
+ import numpy as np
129
+
130
+ fig, ax = plt.subplots(figsize=(6, 4))
131
+ im = ax.imshow(alignment.transpose(),
132
+ aspect='auto',
133
+ origin='lower',
134
+ interpolation='none')
135
+ fig.colorbar(im, ax=ax)
136
+ xlabel = 'Decoder timestep'
137
+ if info is not None:
138
+ xlabel += '\n\n' + info
139
+ plt.xlabel(xlabel)
140
+ plt.ylabel('Encoder timestep')
141
+ plt.tight_layout()
142
+
143
+ fig.canvas.draw()
144
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
145
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
146
+ plt.close()
147
+ return data
148
+
149
+
150
+ def load_wav_to_torch(full_path):
151
+ sampling_rate, data = read(full_path)
152
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
153
+
154
+
155
+ def load_filepaths_and_text(filename, split="|"):
156
+ with open(filename, encoding='utf-8') as f:
157
+ filepaths_and_text = [line.strip().split(split) for line in f]
158
+ return filepaths_and_text
159
+
160
+
161
+ def get_hparams(init=True):
162
+ parser = argparse.ArgumentParser()
163
+ parser.add_argument('-c',
164
+ '--config',
165
+ type=str,
166
+ default="./configs/base.json",
167
+ help='JSON file for configuration')
168
+ parser.add_argument('-m',
169
+ '--model',
170
+ type=str,
171
+ required=True,
172
+ help='Model name')
173
+ parser.add_argument('--train_data',
174
+ type=str,
175
+ required=True,
176
+ help='train data')
177
+ parser.add_argument('--val_data', type=str, required=True, help='val data')
178
+ parser.add_argument('--phone_table',
179
+ type=str,
180
+ required=True,
181
+ help='phone table')
182
+ parser.add_argument('--speaker_table',
183
+ type=str,
184
+ default=None,
185
+ help='speaker table, required for multiple speakers')
186
+
187
+ args = parser.parse_args()
188
+ model_dir = args.model
189
+
190
+ if not os.path.exists(model_dir):
191
+ os.makedirs(model_dir)
192
+
193
+ config_path = args.config
194
+ config_save_path = os.path.join(model_dir, "config.json")
195
+ if init:
196
+ with open(config_path, "r", encoding='utf8') as f:
197
+ data = f.read()
198
+ with open(config_save_path, "w", encoding='utf8') as f:
199
+ f.write(data)
200
+ else:
201
+ with open(config_save_path, "r", encoding='utf8') as f:
202
+ data = f.read()
203
+ config = json.loads(data)
204
+ config['data']['training_files'] = args.train_data
205
+ config['data']['validation_files'] = args.val_data
206
+ config['data']['phone_table'] = args.phone_table
207
+ # 0 is kept for blank
208
+ config['data']['num_phones'] = len(open(args.phone_table).readlines()) + 1
209
+ if args.speaker_table is not None:
210
+ config['data']['speaker_table'] = args.speaker_table
211
+ # 0 is kept for unknown speaker
212
+ config['data']['n_speakers'] = len(
213
+ open(args.speaker_table).readlines()) + 1
214
+ else:
215
+ config['data']['n_speakers'] = 0
216
+
217
+ hparams = HParams(**config)
218
+ hparams.model_dir = model_dir
219
+ return hparams
220
+
221
+
222
+ def get_hparams_from_dir(model_dir):
223
+ config_save_path = os.path.join(model_dir, "config.json")
224
+ with open(config_save_path, "r") as f:
225
+ data = f.read()
226
+ config = json.loads(data)
227
+
228
+ hparams = HParams(**config)
229
+ hparams.model_dir = model_dir
230
+ return hparams
231
+
232
+
233
+ def get_hparams_from_file(config_path):
234
+ with open(config_path, "r") as f:
235
+ data = f.read()
236
+ config = json.loads(data)
237
+
238
+ hparams = HParams(**config)
239
+ return hparams
240
+
241
+
242
+ def check_git_hash(model_dir):
243
+ source_dir = os.path.dirname(os.path.realpath(__file__))
244
+ if not os.path.exists(os.path.join(source_dir, ".git")):
245
+ logger.warn('''{} is not a git repository, therefore hash value
246
+ comparison will be ignored.'''.format(source_dir))
247
+ return
248
+
249
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
250
+
251
+ path = os.path.join(model_dir, "githash")
252
+ if os.path.exists(path):
253
+ saved_hash = open(path).read()
254
+ if saved_hash != cur_hash:
255
+ logger.warn(
256
+ "git hash values are different. {}(saved) != {}(current)".
257
+ format(saved_hash[:8], cur_hash[:8]))
258
+ else:
259
+ open(path, "w").write(cur_hash)
260
+
261
+
262
+ def get_logger(model_dir, filename="train.log"):
263
+ global logger
264
+ logger = logging.getLogger(os.path.basename(model_dir))
265
+ logger.setLevel(logging.INFO)
266
+
267
+ formatter = logging.Formatter(
268
+ "%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
269
+ if not os.path.exists(model_dir):
270
+ os.makedirs(model_dir)
271
+ h = logging.FileHandler(os.path.join(model_dir, filename))
272
+ h.setLevel(logging.INFO)
273
+ h.setFormatter(formatter)
274
+ logger.addHandler(h)
275
+ return logger
276
+
277
+
278
+ class HParams():
279
+ def __init__(self, **kwargs):
280
+ for k, v in kwargs.items():
281
+ if type(v) == dict:
282
+ v = HParams(**v)
283
+ self[k] = v
284
+
285
+ def keys(self):
286
+ return self.__dict__.keys()
287
+
288
+ def items(self):
289
+ return self.__dict__.items()
290
+
291
+ def values(self):
292
+ return self.__dict__.values()
293
+
294
+ def __len__(self):
295
+ return len(self.__dict__)
296
+
297
+ def __getitem__(self, key):
298
+ return getattr(self, key)
299
+
300
+ def __setitem__(self, key, value):
301
+ return setattr(self, key, value)
302
+
303
+ def __contains__(self, key):
304
+ return key in self.__dict__
305
+
306
+ def __repr__(self):
307
+ return self.__dict__.__repr__()
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "vits-onnx"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["ccds <ccdesue@163.com>"]
6
+ readme = "README.md"
7
+ packages = [{include = "vits_onnx"}]
8
+
9
+ [tool.poetry.dependencies]
10
+ python = "^3.9"
11
+ gradio = "^3.16.1"
12
+ loguru = "^0.6.0"
13
+ onnxruntime = "^1.13.1"
14
+ pyopenjtalk = "^0.3.0"
15
+ unidecode = "^1.3.6"
16
+
17
+
18
+ [build-system]
19
+ requires = ["poetry-core"]
20
+ build-backend = "poetry.core.masonry.api"
requirements.txt ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.3 ; python_version >= "3.9" and python_version < "4.0"
2
+ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
3
+ altair==4.2.0 ; python_version >= "3.9" and python_version < "4.0"
4
+ anyio==3.6.2 ; python_version >= "3.9" and python_version < "4.0"
5
+ async-timeout==4.0.2 ; python_version >= "3.9" and python_version < "4.0"
6
+ attrs==22.2.0 ; python_version >= "3.9" and python_version < "4.0"
7
+ certifi==2022.12.7 ; python_version >= "3.9" and python_version < "4.0"
8
+ charset-normalizer==2.1.1 ; python_version >= "3.9" and python_version < "4.0"
9
+ click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
10
+ colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows"
11
+ coloredlogs==15.0.1 ; python_version >= "3.9" and python_version < "4.0"
12
+ contourpy==1.0.6 ; python_version >= "3.9" and python_version < "4.0"
13
+ cycler==0.11.0 ; python_version >= "3.9" and python_version < "4.0"
14
+ cython==0.29.33 ; python_version >= "3.9" and python_version < "4.0"
15
+ entrypoints==0.4 ; python_version >= "3.9" and python_version < "4.0"
16
+ fastapi==0.89.0 ; python_version >= "3.9" and python_version < "4.0"
17
+ ffmpy==0.3.0 ; python_version >= "3.9" and python_version < "4.0"
18
+ flatbuffers==23.1.4 ; python_version >= "3.9" and python_version < "4.0"
19
+ fonttools==4.38.0 ; python_version >= "3.9" and python_version < "4.0"
20
+ frozenlist==1.3.3 ; python_version >= "3.9" and python_version < "4.0"
21
+ fsspec==2022.11.0 ; python_version >= "3.9" and python_version < "4.0"
22
+ gradio==3.16.1 ; python_version >= "3.9" and python_version < "4.0"
23
+ h11==0.14.0 ; python_version >= "3.9" and python_version < "4.0"
24
+ httpcore==0.16.3 ; python_version >= "3.9" and python_version < "4.0"
25
+ httpx==0.23.3 ; python_version >= "3.9" and python_version < "4.0"
26
+ humanfriendly==10.0 ; python_version >= "3.9" and python_version < "4.0"
27
+ idna==3.4 ; python_version >= "3.9" and python_version < "4"
28
+ jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0"
29
+ jsonschema==4.17.3 ; python_version >= "3.9" and python_version < "4.0"
30
+ kiwisolver==1.4.4 ; python_version >= "3.9" and python_version < "4.0"
31
+ linkify-it-py==1.0.3 ; python_version >= "3.9" and python_version < "4.0"
32
+ loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
33
+ markdown-it-py==2.1.0 ; python_version >= "3.9" and python_version < "4.0"
34
+ markdown-it-py[linkify,plugins]==2.1.0 ; python_version >= "3.9" and python_version < "4.0"
35
+ markupsafe==2.1.1 ; python_version >= "3.9" and python_version < "4.0"
36
+ matplotlib==3.6.2 ; python_version >= "3.9" and python_version < "4.0"
37
+ mdit-py-plugins==0.3.3 ; python_version >= "3.9" and python_version < "4.0"
38
+ mdurl==0.1.2 ; python_version >= "3.9" and python_version < "4.0"
39
+ mpmath==1.2.1 ; python_version >= "3.9" and python_version < "4.0"
40
+ multidict==6.0.4 ; python_version >= "3.9" and python_version < "4.0"
41
+ numpy==1.24.1 ; python_version < "4.0" and python_version >= "3.9"
42
+ onnxruntime==1.13.1 ; python_version >= "3.9" and python_version < "4.0"
43
+ orjson==3.8.4 ; python_version >= "3.9" and python_version < "4.0"
44
+ packaging==23.0 ; python_version >= "3.9" and python_version < "4.0"
45
+ pandas==1.5.2 ; python_version >= "3.9" and python_version < "4.0"
46
+ pillow==9.4.0 ; python_version >= "3.9" and python_version < "4.0"
47
+ protobuf==4.21.12 ; python_version >= "3.9" and python_version < "4.0"
48
+ pycryptodome==3.16.0 ; python_version >= "3.9" and python_version < "4.0"
49
+ pydantic==1.10.4 ; python_version >= "3.9" and python_version < "4.0"
50
+ pydub==0.25.1 ; python_version >= "3.9" and python_version < "4.0"
51
+ pyopenjtalk==0.3.0 ; python_version >= "3.9" and python_version < "4.0"
52
+ pyparsing==3.0.9 ; python_version >= "3.9" and python_version < "4.0"
53
+ pyreadline3==3.4.1 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "4.0"
54
+ pyrsistent==0.19.3 ; python_version >= "3.9" and python_version < "4.0"
55
+ python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
56
+ python-multipart==0.0.5 ; python_version >= "3.9" and python_version < "4.0"
57
+ pytz==2022.7 ; python_version >= "3.9" and python_version < "4.0"
58
+ pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
59
+ requests==2.28.1 ; python_version >= "3.9" and python_version < "4"
60
+ rfc3986[idna2008]==1.5.0 ; python_version >= "3.9" and python_version < "4.0"
61
+ setuptools-scm==7.1.0 ; python_version >= "3.9" and python_version < "4.0"
62
+ setuptools==65.6.3 ; python_version >= "3.9" and python_version < "4.0"
63
+ six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
64
+ sniffio==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
65
+ starlette==0.22.0 ; python_version >= "3.9" and python_version < "4.0"
66
+ sympy==1.11.1 ; python_version >= "3.9" and python_version < "4.0"
67
+ tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
68
+ toolz==0.12.0 ; python_version >= "3.9" and python_version < "4.0"
69
+ tqdm==4.64.1 ; python_version >= "3.9" and python_version < "4.0"
70
+ typing-extensions==4.4.0 ; python_version >= "3.9" and python_version < "4.0"
71
+ uc-micro-py==1.0.1 ; python_version >= "3.9" and python_version < "4.0"
72
+ unidecode==1.3.6 ; python_version >= "3.9" and python_version < "4.0"
73
+ urllib3==1.26.13 ; python_version >= "3.9" and python_version < "4"
74
+ uvicorn==0.20.0 ; python_version >= "3.9" and python_version < "4.0"
75
+ websockets==10.4 ; python_version >= "3.9" and python_version < "4.0"
76
+ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
77
+ yarl==1.8.2 ; python_version >= "3.9" and python_version < "4.0"
setup.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sudo su
2
+ conda create -n dl python=3.9 -y
3
+ conda init bash
4
+ bash
5
+
6
+ conda activate dl
7
+ export POETRY_VERSION=1.3.1
8
+ export DEBIAN_FRONTEND=noninteractive && \
9
+ sudo apt-get update && \
10
+ sudo apt-get install cmake build-essential -y --no-install-recommends && \
11
+ pip install poetry==$POETRY_VERSION
12
+
13
+
14
+ poetry export -f requirements.txt -o requirements.txt --without dev --without test --without-hashes && \
15
+ pip install --upgrade pip && \
16
+ pip install -r requirements.txt
17
+
18
+
19
+
20
+ function run{
21
+ python -m app.main
22
+ }
util/build_docker.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ export DOCKER_BUILDKIT=1
2
+ docker build -f Dockerfile -t ccdesue/vits_demo .
util/extract_w.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from tabnanny import check
2
+ import torch
3
+ import pathlib
4
+
5
+ path = r"/workspaces/vits_web_demo/export/model/D_second.pth"
6
+ model_path = pathlib.Path(path)
7
+
8
+ assert model_path.exists(), "model path does not exist"
9
+
10
+ checkpoint = torch.load(str(model_path), map_location='cpu')
11
+
12
+ state_file = checkpoint['model']
13
+ iteration = checkpoint['iteration']
14
+
15
+ out_path = model_path.parent / pathlib.Path("19_"+str(iteration)+'_demo'+'.pth')
16
+ out_path = str(out_path)
17
+ torch.save({'model': state_file,
18
+ 'iteration': iteration,
19
+ 'optimizer': None,
20
+ 'learning_rate': None}, out_path)