zhzluke96
commited on
Commit
·
bed01bd
1
Parent(s):
37195a7
update
Browse files- CHANGELOG.md +150 -7
- launch.py +49 -11
- modules/ChatTTS/ChatTTS/core.py +124 -93
- modules/ChatTTS/ChatTTS/infer/api.py +4 -0
- modules/ChatTTS/ChatTTS/model/gpt.py +123 -74
- modules/ChatTTS/ChatTTS/utils/infer_utils.py +31 -6
- modules/Enhancer/ResembleEnhance.py +1 -1
- modules/SentenceSplitter.py +32 -1
- modules/SynthesizeSegments.py +30 -17
- modules/api/api_setup.py +5 -98
- modules/api/impl/handler/AudioHandler.py +19 -1
- modules/api/impl/handler/SSMLHandler.py +5 -0
- modules/api/impl/handler/TTSHandler.py +60 -1
- modules/api/impl/model/audio_model.py +4 -0
- modules/api/impl/tts_api.py +18 -4
- modules/api/impl/xtts_v2_api.py +97 -37
- modules/api/worker.py +3 -7
- modules/devices/devices.py +7 -1
- modules/finetune/train_speaker.py +2 -2
- modules/generate_audio.py +84 -8
- modules/models.py +11 -2
- modules/models_setup.py +74 -0
- modules/normalization.py +35 -24
- modules/refiner.py +8 -0
- modules/repos_static/resemble_enhance/inference.py +5 -5
- modules/speaker.py +6 -0
- modules/synthesize_audio.py +0 -2
- modules/synthesize_stream.py +42 -0
- modules/utils/HomophonesReplacer.py +39 -0
- modules/utils/audio.py +64 -58
- modules/utils/detect_lang.py +27 -0
- modules/utils/html.py +26 -0
- modules/utils/ignore_warn.py +9 -0
- modules/utils/markdown.py +1 -0
- modules/webui/localization_runtime.py +22 -0
- modules/webui/speaker/speaker_creator.py +6 -6
- modules/webui/ssml/podcast_tab.py +1 -1
- modules/webui/ssml/ssml_tab.py +61 -14
- modules/webui/tts_tab.py +49 -2
- modules/webui/webui_utils.py +83 -31
- requirements.txt +11 -9
- webui.py +5 -6
CHANGELOG.md
CHANGED
@@ -1,22 +1,150 @@
|
|
1 |
# Changelog
|
2 |
|
3 |
-
<a name="0.
|
4 |
-
## 0.
|
5 |
|
6 |
### Added
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
- ✨ add localization [[c05035d](https://github.com/lenML/ChatTTS-Forge/commit/c05035d5cdcc5aa7efd995fe42f6a2541abe718b)]
|
9 |
- ✨ SSML 支持 enhancer [[5c2788e](https://github.com/lenML/ChatTTS-Forge/commit/5c2788e04f3debfa8bafd8a2e2371dde30f38d4d)]
|
10 |
- ✨ webui 增加 podcast 工具 tab [[b0b169d](https://github.com/lenML/ChatTTS-Forge/commit/b0b169d8b49c8e013209e59d1f8b637382d8b997)]
|
11 |
-
- ✨ 完善
|
12 |
|
13 |
### Changed
|
14 |
|
|
|
|
|
|
|
|
|
|
|
15 |
- 🍱 update banner [[dbc293e](https://github.com/lenML/ChatTTS-Forge/commit/dbc293e1a7dec35f60020dcaf783ba3b7c734bfa)]
|
16 |
- ⚡ 增强 TN [[092c1b9](https://github.com/lenML/ChatTTS-Forge/commit/092c1b94147249880198fe2ad3dfe3b209099e19)]
|
17 |
- ⚡ enhancer 支持 off_tqdm [[94d34d6](https://github.com/lenML/ChatTTS-Forge/commit/94d34d657fa3433dae9ff61775e0c364a6f77aff)]
|
18 |
- ⚡ 增加 git env [[43d9c65](https://github.com/lenML/ChatTTS-Forge/commit/43d9c65877ff68ad94716bc2e505ccc7ae8869a8)]
|
19 |
-
- ⚡ 修改webui保存文件格式 [[2da41c9](https://github.com/lenML/ChatTTS-Forge/commit/2da41c90aa81bf87403598aefaea3e0ae2e83d79)]
|
|
|
|
|
|
|
|
|
20 |
|
21 |
### Removed
|
22 |
|
@@ -24,6 +152,17 @@
|
|
24 |
|
25 |
### Fixed
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
- 🐛 fix hparams config [#22](https://github.com/lenML/ChatTTS-Forge/issues/22) [[61d9809](https://github.com/lenML/ChatTTS-Forge/commit/61d9809804ad8c141d36afde51a608734a105662)]
|
28 |
- 🐛 fix enhance 下载脚本 [[d2e14b0](https://github.com/lenML/ChatTTS-Forge/commit/d2e14b0a4905724a55b03493fa4b94b5c4383c95)]
|
29 |
- 🐛 fix 'trange' referenced [[d1a8dae](https://github.com/lenML/ChatTTS-Forge/commit/d1a8daee61e62d14cf5fd7a17fab4424e24b1c41)]
|
@@ -33,10 +172,14 @@
|
|
33 |
|
34 |
### Miscellaneous
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
- 🌐 更新翻译文案 [[f56caa7](https://github.com/lenML/ChatTTS-Forge/commit/f56caa71e9186680b93c487d9645186ae18c1dc6)]
|
37 |
-
- 📝 update [[7cacf91](https://github.com/lenML/ChatTTS-Forge/commit/7cacf913541ee5f86eaa80d8b193b94b3db2b67c)]
|
38 |
-
- 📝 update webui document [[7f2bb22](https://github.com/lenML/ChatTTS-Forge/commit/7f2bb227027cc0eff312c37758a20916c1ebade6)]
|
39 |
-
|
40 |
|
41 |
<a name="0.5.5"></a>
|
42 |
|
|
|
1 |
# Changelog
|
2 |
|
3 |
+
<a name="0.6.2-rc"></a>
|
4 |
+
## 0.6.2-rc (2024-06-23)
|
5 |
|
6 |
### Added
|
7 |
|
8 |
+
- ✨ add adjuster to webui [[01f09b4](https://github.com/lenML/ChatTTS-Forge/commit/01f09b4fad2eb8b24a16b7768403de4975d51774)]
|
9 |
+
- ✨ stream mode support adjuster [[585d2dd](https://github.com/lenML/ChatTTS-Forge/commit/585d2dd488d8f8387e0d9435fb399f090a41b9cc)]
|
10 |
+
- ✨ improve xtts_v2 api [[fec66c7](https://github.com/lenML/ChatTTS-Forge/commit/fec66c7c00939a3c7c15e007536e037ac01153fa)]
|
11 |
+
- ✨ improve normalize [[d0da37e](https://github.com/lenML/ChatTTS-Forge/commit/d0da37e43f1de4088ef638edd90723f93894b1d2)]
|
12 |
+
- ✨ improve normalize/spliter [[163b649](https://github.com/lenML/ChatTTS-Forge/commit/163b6490e4d453c37cc259ce27208f55d10a9084)]
|
13 |
+
- ✨ add loudness equalization [[bc8bda7](https://github.com/lenML/ChatTTS-Forge/commit/bc8bda74825c31985d3cc1a44366ad92af1b623a)]
|
14 |
+
- ✨ support `--use_cpu=chattts,enhancer,trainer,all` [[23023bc](https://github.com/lenML/ChatTTS-Forge/commit/23023bc610f6f74a157faa8a6c6aacf64d91d870)]
|
15 |
+
- ✨ improve normalizetion.py [[1a7c0ed](https://github.com/lenML/ChatTTS-Forge/commit/1a7c0ed3923234ceadb79f397fa7577f9e682f2d)]
|
16 |
+
- ✨ ignore_useless_warnings [[4b9a32e](https://github.com/lenML/ChatTTS-Forge/commit/4b9a32ef821d85ceaf3d62af8f871aeb5088e084)]
|
17 |
+
- ✨ enhance logger, info => debug [[73bc8e7](https://github.com/lenML/ChatTTS-Forge/commit/73bc8e72b40146debd0a59100b1cca4cc42f5029)]
|
18 |
+
- ✨ add playground.stream page [[31377b0](https://github.com/lenML/ChatTTS-Forge/commit/31377b060c182519d74a12d81e66c8e73686bcd8)]
|
19 |
+
- ✨ tts api support stream [#5](https://github.com/lenML/ChatTTS-Forge/issues/5) [[15e0b2c](https://github.com/lenML/ChatTTS-Forge/commit/15e0b2cb051ba39dcf99f60f1faa11941f6dc656)]
|
20 |
+
|
21 |
+
### Changed
|
22 |
+
|
23 |
+
- 🍱 add _p_en [[56f1fbf](https://github.com/lenML/ChatTTS-Forge/commit/56f1fbf1f3fff6f76ca8c29aa12a6ddef665cf9f)]
|
24 |
+
- 🍱 update prompt [[4f95b31](https://github.com/lenML/ChatTTS-Forge/commit/4f95b31679225e1ee144a411a9cfa9b30c598450)]
|
25 |
+
- ⚡ Reduce popping sounds [[2d0fd68](https://github.com/lenML/ChatTTS-Forge/commit/2d0fd688ad1a5cff1e6aafc0502aee26de3f1d75)]
|
26 |
+
- ⚡ improve `apply_character_map` [[ea7399f](https://github.com/lenML/ChatTTS-Forge/commit/ea7399facc5c29327a7870bd66ad6222f5731ce3)]
|
27 |
+
|
28 |
+
### Fixed
|
29 |
+
|
30 |
+
- 🐛 fix `apply_normalize` missing `sr` [[2db6d65](https://github.com/lenML/ChatTTS-Forge/commit/2db6d65ef8fbf8a3a213cbdc3d4b1143396cc165)]
|
31 |
+
- 🐛 fix sentence spliter [[5d8937c](https://github.com/lenML/ChatTTS-Forge/commit/5d8937c169d5f7784920a93834df0480dd3a67b3)]
|
32 |
+
- 🐛 fix playground url_join [[53e7cbc](https://github.com/lenML/ChatTTS-Forge/commit/53e7cbc6103bc0e3bb83767a9233c45285b77e75)]
|
33 |
+
- 🐛 fix generate_audio args [[a7a698c](https://github.com/lenML/ChatTTS-Forge/commit/a7a698c760b5bc97c90a144a4a7afb5e17414995)]
|
34 |
+
- 🐛 fix infer func [[b0de527](https://github.com/lenML/ChatTTS-Forge/commit/b0de5275342c02d332a50d0ab5ac171a7007b300)]
|
35 |
+
- 🐛 fix webui logging format [[4adc29e](https://github.com/lenML/ChatTTS-Forge/commit/4adc29e6c06fa806a8178f445399bbac8ed57911)]
|
36 |
+
- 🐛 fix webui speaker_tab missing progress [[fafe242](https://github.com/lenML/ChatTTS-Forge/commit/fafe242e69ea8019729a62e52f6c0b3c0d6a63ad)]
|
37 |
+
|
38 |
+
### Miscellaneous
|
39 |
+
|
40 |
+
- 📝 添加整合包地址 [[26122d4](https://github.com/lenML/ChatTTS-Forge/commit/26122d4cfd975206211fc37491348cf40aa39561)]
|
41 |
+
- 📝 details `.env` file and cli usage docs [[ec3d36f](https://github.com/lenML/ChatTTS-Forge/commit/ec3d36f8a67215e243e6b8225aa9144ac888313a)]
|
42 |
+
- 📝 update changelog [[22996e9](https://github.com/lenML/ChatTTS-Forge/commit/22996e9f0c42d9cad59950aecfe6b16413f2ab40)]
|
43 |
+
- Windows not yet supported for torch.compile fix [[74ac27d](https://github.com/lenML/ChatTTS-Forge/commit/74ac27d56a370f87560329043c42be27022ca0f5)]
|
44 |
+
- fix: replace mispronounced words in TTS [[de66e6b](https://github.com/lenML/ChatTTS-Forge/commit/de66e6b8f7f8b5c10e7ac54f7b2488c798e5ef81)]
|
45 |
+
- feat: support stream mode [[3da0f0c](https://github.com/lenML/ChatTTS-Forge/commit/3da0f0cb7f213dee40d00a89093166ad9e1d17a0)]
|
46 |
+
- optimize: mps audio quality by contiguous scores [[1e4d79f](https://github.com/lenML/ChatTTS-Forge/commit/1e4d79f1a81a3ac8697afff0e44f0cfd2608599a)]
|
47 |
+
- 📝 update changelog [[ab55c14](https://github.com/lenML/ChatTTS-Forge/commit/ab55c149d48edc52f1de9c6d4fe6e6ed78b3b134)]
|
48 |
+
|
49 |
+
|
50 |
+
<a name="0.6.1"></a>
|
51 |
+
|
52 |
+
## 0.6.1 (2024-06-18)
|
53 |
+
|
54 |
+
### Added
|
55 |
+
|
56 |
+
- ✨ add `--preload_models` [[73a41e0](https://github.com/lenML/ChatTTS-Forge/commit/73a41e009cd4426dfe4b0a35325da68189966390)]
|
57 |
+
- ✨ add webui progress [[778802d](https://github.com/lenML/ChatTTS-Forge/commit/778802ded12de340520f41a3e1bdb852f00bd637)]
|
58 |
+
- ✨ add merger error [[51060bc](https://github.com/lenML/ChatTTS-Forge/commit/51060bc343a6308493b7d582e21dca62eacaa7cb)]
|
59 |
+
- ✨ tts prompt => experimental [[d3e6315](https://github.com/lenML/ChatTTS-Forge/commit/d3e6315a3cb8b1fa254cefb2efe2bae7c74a50f8)]
|
60 |
+
- ✨ add 基本的 speaker finetune ui [[5f68f19](https://github.com/lenML/ChatTTS-Forge/commit/5f68f193e78f470bd2c3ca4b9fa1008cf809e753)]
|
61 |
+
- ✨ add speaker finetune [[5ce27ed](https://github.com/lenML/ChatTTS-Forge/commit/5ce27ed7e4da6c96bb3fd016b8b491768faf319d)]
|
62 |
+
- ✨ add `--ino_half` remove `--half` [[5820e57](https://github.com/lenML/ChatTTS-Forge/commit/5820e576b288df50b929fbdfd9d0d6b6f548b54e)]
|
63 |
+
- ✨ add webui podcast 默认值 [[dd786a8](https://github.com/lenML/ChatTTS-Forge/commit/dd786a83733a71d005ff7efe6312e35d652b2525)]
|
64 |
+
- ✨ add webui 分割器配置 [[589327b](https://github.com/lenML/ChatTTS-Forge/commit/589327b729188d1385838816b9807e894eb128b0)]
|
65 |
+
- ✨ add `eos` params to all api [[79c994f](https://github.com/lenML/ChatTTS-Forge/commit/79c994fadf7d60ea432b62f4000b62b67efe7259)]
|
66 |
+
|
67 |
+
### Changed
|
68 |
+
|
69 |
+
- ⬆️ Bump urllib3 from 2.2.1 to 2.2.2 [[097c15b](https://github.com/lenML/ChatTTS-Forge/commit/097c15ba56f8197a4f26adcfb77336a70e5ed806)]
|
70 |
+
- 🎨 run formatter [[8c267e1](https://github.com/lenML/ChatTTS-Forge/commit/8c267e151152fe2090528104627ec031453d4ed5)]
|
71 |
+
- ⚡ Optimize `audio_data_to_segment` [#57](https://github.com/lenML/ChatTTS-Forge/issues/57) [[d33809c](https://github.com/lenML/ChatTTS-Forge/commit/d33809c60a3ac76a01f71de4fd26b315d066c8d3)]
|
72 |
+
- ⚡ map_location="cpu" [[0f58c10](https://github.com/lenML/ChatTTS-Forge/commit/0f58c10a445efaa9829f862acb4fb94bc07f07bf)]
|
73 |
+
- ⚡ colab use default GPU [[c7938ad](https://github.com/lenML/ChatTTS-Forge/commit/c7938adb6d3615f37210b1f3cbe4671f93d58285)]
|
74 |
+
- ⚡ improve hf calling [[2dde612](https://github.com/lenML/ChatTTS-Forge/commit/2dde6127906ce6e77a970b4cd96e68f7a5417c6a)]
|
75 |
+
- 🍱 add `bob_ft10.pt` [[9eee965](https://github.com/lenML/ChatTTS-Forge/commit/9eee965425a7d6640eba22d843db4975dd3e355a)]
|
76 |
+
- ⚡ enhance SynthesizeSegments [[0bb4dd7](https://github.com/lenML/ChatTTS-Forge/commit/0bb4dd7676c38249f10bf0326174ff8b74b2abae)]
|
77 |
+
- 🍱 add `bob_ft10.pt` [[bef1b02](https://github.com/lenML/ChatTTS-Forge/commit/bef1b02435c39830612b18738bb31ac48e340fc6)]
|
78 |
+
- ♻️ refactor api [[671fcc3](https://github.com/lenML/ChatTTS-Forge/commit/671fcc38a570d0cb7de0a214d318281084c9608c)]
|
79 |
+
- ⚡ improve xtts_v2 api [[206fabc](https://github.com/lenML/ChatTTS-Forge/commit/206fabc76f1dbad261c857cb02f8c99c21e99eef)]
|
80 |
+
- ⚡ train text => just text [[e2037e0](https://github.com/lenML/ChatTTS-Forge/commit/e2037e0f97f15ff560fce14bbdc3926e3261bff9)]
|
81 |
+
- ⚡ improve TN [[a0069ed](https://github.com/lenML/ChatTTS-Forge/commit/a0069ed2d0c3122444e873fb13b9922f9ab88a79)]
|
82 |
+
|
83 |
+
### Fixed
|
84 |
+
|
85 |
+
- 🐛 fix webui speaker_editor missing `describe` [[2a2a36d](https://github.com/lenML/ChatTTS-Forge/commit/2a2a36d62d8f253fc2e17ccc558038dbcc99d1ee)]
|
86 |
+
- 💚 Dependabot alerts [[f501860](https://github.com/lenML/ChatTTS-Forge/commit/f5018607f602769d4dda7aa00573b9a06e659d91)]
|
87 |
+
- 🐛 fix `numpy<2` [#50](https://github.com/lenML/ChatTTS-Forge/issues/50) [[e4fea4f](https://github.com/lenML/ChatTTS-Forge/commit/e4fea4f80b31d962f02cd1146ce8c73bf75b6a39)]
|
88 |
+
- 🐛 fix Box() index [#49](https://github.com/lenML/ChatTTS-Forge/issues/49) add testcase [[d982e33](https://github.com/lenML/ChatTTS-Forge/commit/d982e33ed30749d7ae6570ade5ec7b560a3d1f06)]
|
89 |
+
- 🐛 fix Box() index [#49](https://github.com/lenML/ChatTTS-Forge/issues/49) [[1788318](https://github.com/lenML/ChatTTS-Forge/commit/1788318a96c014a53ee41c4db7d60fdd4b15cfca)]
|
90 |
+
- 🐛 fix `--use_cpu` [#47](https://github.com/lenML/ChatTTS-Forge/issues/47) update conftest [[4095b08](https://github.com/lenML/ChatTTS-Forge/commit/4095b085c4c6523f2579e00edfb1569d65608ca2)]
|
91 |
+
- 🐛 fix `--use_cpu` [#47](https://github.com/lenML/ChatTTS-Forge/issues/47) [[221962f](https://github.com/lenML/ChatTTS-Forge/commit/221962fd0f61d3f269918b26a814cbcd5aabd1f0)]
|
92 |
+
- 🐛 fix webui speaker args [[3b3c331](https://github.com/lenML/ChatTTS-Forge/commit/3b3c3311dd0add0e567179fc38223a3cc5e56f6e)]
|
93 |
+
- 🐛 fix speaker trainer [[52d473f](https://github.com/lenML/ChatTTS-Forge/commit/52d473f37f6a3950d4c8738c294f048f11198776)]
|
94 |
+
- 🐛 兼容 win32 [[7ffa37f](https://github.com/lenML/ChatTTS-Forge/commit/7ffa37f3d36fb9ba53ab051b2fce6229920b1208)]
|
95 |
+
- 🐛 fix google api ssml synthesize [#43](https://github.com/lenML/ChatTTS-Forge/issues/43) [[1566f88](https://github.com/lenML/ChatTTS-Forge/commit/1566f8891c22d63681d756deba70374e2b75d078)]
|
96 |
+
|
97 |
+
### Miscellaneous
|
98 |
+
|
99 |
+
- Merge pull request [#58](https://github.com/lenML/ChatTTS-Forge/issues/58) from lenML/dependabot/pip/urllib3-2.2.2 [[f259f18](https://github.com/lenML/ChatTTS-Forge/commit/f259f180af57f9a6938b14bf263d0387b6900e57)]
|
100 |
+
- 📝 update changelog [[b9da7ec](https://github.com/lenML/ChatTTS-Forge/commit/b9da7ec1afed416a825e9e4a507b8263f69bf47e)]
|
101 |
+
- 📝 update [[8439437](https://github.com/lenML/ChatTTS-Forge/commit/84394373de66b81a9f7f70ef8484254190e292ab)]
|
102 |
+
- 📝 update [[ef97206](https://github.com/lenML/ChatTTS-Forge/commit/ef972066558d0b229d6d0b3d83bb4f8e8517558f)]
|
103 |
+
- 📝 improve readme.md [[7bf3de2](https://github.com/lenML/ChatTTS-Forge/commit/7bf3de2afb41b9a29071bec18ee6306ce8e70183)]
|
104 |
+
- 📝 add bug report forms [[091cf09](https://github.com/lenML/ChatTTS-Forge/commit/091cf0958a719236c77107acf4cfb8c0ba090946)]
|
105 |
+
- 📝 update changelog [[3d519ec](https://github.com/lenML/ChatTTS-Forge/commit/3d519ec8a20098c2de62631ae586f39053dd89a5)]
|
106 |
+
- 📝 update [[66963f8](https://github.com/lenML/ChatTTS-Forge/commit/66963f8ff8f29c298de64cd4a54913b1d3e29a6a)]
|
107 |
+
- 📝 update [[b7a63b5](https://github.com/lenML/ChatTTS-Forge/commit/b7a63b59132d2c8dbb4ad2e15bd23713f00f0084)]
|
108 |
+
|
109 |
+
<a name="0.6.0"></a>
|
110 |
+
|
111 |
+
## 0.6.0 (2024-06-12)
|
112 |
+
|
113 |
+
### Added
|
114 |
+
|
115 |
+
- ✨ add XTTSv2 api [#42](https://github.com/lenML/ChatTTS-Forge/issues/42) [[d1fc63c](https://github.com/lenML/ChatTTS-Forge/commit/d1fc63cd1e847d622135c96371bbfe2868a80c19)]
|
116 |
+
- ✨ google api 支持 enhancer [[14fecdb](https://github.com/lenML/ChatTTS-Forge/commit/14fecdb8ea0f9a5d872a4c7ca862e901990076c0)]
|
117 |
+
- ✨ 修改 podcast 脚本默认 style [[98186c2](https://github.com/lenML/ChatTTS-Forge/commit/98186c25743cbfa24ca7d41336d4ec84aa34aacf)]
|
118 |
+
- ✨ playground google api [[4109adb](https://github.com/lenML/ChatTTS-Forge/commit/4109adb317be215970d756b4ba7064c9dc4d6fdc)]
|
119 |
+
- ✨ 添加 unload api [[ed9d61a](https://github.com/lenML/ChatTTS-Forge/commit/ed9d61a2fe4ba1d902d91517148f8f7dea47b51b)]
|
120 |
+
- ✨ support api workers [[babdada](https://github.com/lenML/ChatTTS-Forge/commit/babdada50e79e425bac4d3074f8e42dfb4c4c33a)]
|
121 |
+
- ✨ add ffmpeg version to webui footer [[e9241a1](https://github.com/lenML/ChatTTS-Forge/commit/e9241a1a8d1f5840ae6259e46020684ba70a0efb)]
|
122 |
+
- ✨ support use internal ffmpeg [[0e02ab0](https://github.com/lenML/ChatTTS-Forge/commit/0e02ab0f5d81fbfb6166793cb4f6d58c5f17f34c)]
|
123 |
+
- ✨ 增加参数 debug_generate [[94e876a](https://github.com/lenML/ChatTTS-Forge/commit/94e876ae3819c3efbde4a239085f91342874bd5a)]
|
124 |
+
- ✨ 支持 api 服务与 webui 并存 [[4901491](https://github.com/lenML/ChatTTS-Forge/commit/4901491eced3955c51030388d1dcebf049cd790e)]
|
125 |
+
- ✨ refiner api support normalize [[ef665da](https://github.com/lenML/ChatTTS-Forge/commit/ef665dad5a5517c610f0b430bc52a5b0ba3c2d96)]
|
126 |
+
- ✨ add webui 音色编辑器 [[fb4c7b3](https://github.com/lenML/ChatTTS-Forge/commit/fb4c7b3b0949ac669da0d069c739934f116b83e2)]
|
127 |
- ✨ add localization [[c05035d](https://github.com/lenML/ChatTTS-Forge/commit/c05035d5cdcc5aa7efd995fe42f6a2541abe718b)]
|
128 |
- ✨ SSML 支持 enhancer [[5c2788e](https://github.com/lenML/ChatTTS-Forge/commit/5c2788e04f3debfa8bafd8a2e2371dde30f38d4d)]
|
129 |
- ✨ webui 增加 podcast 工具 tab [[b0b169d](https://github.com/lenML/ChatTTS-Forge/commit/b0b169d8b49c8e013209e59d1f8b637382d8b997)]
|
130 |
+
- ✨ 完善 enhancer [[205ebeb](https://github.com/lenML/ChatTTS-Forge/commit/205ebebeb7530c81fde7ea96c7e4c6a888a29835)]
|
131 |
|
132 |
### Changed
|
133 |
|
134 |
+
- ⚡ improve synthesize_audio [[759adc2](https://github.com/lenML/ChatTTS-Forge/commit/759adc2ead1da8395df62ea1724456dad6894eb1)]
|
135 |
+
- ⚡ reduce enhancer chunk vram usage [[3464b42](https://github.com/lenML/ChatTTS-Forge/commit/3464b427b14878ee11e03ebdfb91efee1550de59)]
|
136 |
+
- ⚡ 增加默认说话人 [[d702ad5](https://github.com/lenML/ChatTTS-Forge/commit/d702ad5ad585978f8650284ab99238571dbd163b)]
|
137 |
+
- 🍱 add `podcast` `podcast_p` style [[2b9e5bf](https://github.com/lenML/ChatTTS-Forge/commit/2b9e5bfd8fe4700f802097b995f5b68bf1097087)]
|
138 |
+
- 🎨 improve code [[317951e](https://github.com/lenML/ChatTTS-Forge/commit/317951e431b16c735df31187b1af7230a1608c41)]
|
139 |
- 🍱 update banner [[dbc293e](https://github.com/lenML/ChatTTS-Forge/commit/dbc293e1a7dec35f60020dcaf783ba3b7c734bfa)]
|
140 |
- ⚡ 增强 TN [[092c1b9](https://github.com/lenML/ChatTTS-Forge/commit/092c1b94147249880198fe2ad3dfe3b209099e19)]
|
141 |
- ⚡ enhancer 支持 off_tqdm [[94d34d6](https://github.com/lenML/ChatTTS-Forge/commit/94d34d657fa3433dae9ff61775e0c364a6f77aff)]
|
142 |
- ⚡ 增加 git env [[43d9c65](https://github.com/lenML/ChatTTS-Forge/commit/43d9c65877ff68ad94716bc2e505ccc7ae8869a8)]
|
143 |
+
- ⚡ 修改 webui 保存文件格式 [[2da41c9](https://github.com/lenML/ChatTTS-Forge/commit/2da41c90aa81bf87403598aefaea3e0ae2e83d79)]
|
144 |
+
|
145 |
+
### Breaking changes
|
146 |
+
|
147 |
+
- 💥 enhancer support --half [[fef2ed6](https://github.com/lenML/ChatTTS-Forge/commit/fef2ed659fd7fe5a14807d286c209904875ce594)]
|
148 |
|
149 |
### Removed
|
150 |
|
|
|
152 |
|
153 |
### Fixed
|
154 |
|
155 |
+
- 🐛 fix worker env loader [[5b0bf4e](https://github.com/lenML/ChatTTS-Forge/commit/5b0bf4e93738bcd115f006376691c4eaa89b66de)]
|
156 |
+
- 🐛 fix colab default lang missing [[d4e5190](https://github.com/lenML/ChatTTS-Forge/commit/d4e51901856305fc039d886a92e38eea2a2cd24d)]
|
157 |
+
- 🐛 fix "reflection_pad1d" not implemented for 'Half' [[536c19b](https://github.com/lenML/ChatTTS-Forge/commit/536c19b7f6dc3f1702fcc2a90daa3277040e70f0)]
|
158 |
+
- 🐛 fix [#33](https://github.com/lenML/ChatTTS-Forge/issues/33) [[76e0b58](https://github.com/lenML/ChatTTS-Forge/commit/76e0b5808ede71ebb28edbf0ce0af7d9da9bcb27)]
|
159 |
+
- 🐛 fix localization error [[507dbe7](https://github.com/lenML/ChatTTS-Forge/commit/507dbe7a3b92d1419164d24f7804295f6686b439)]
|
160 |
+
- 🐛 block main thread [#30](https://github.com/lenML/ChatTTS-Forge/issues/30) [[3a7cbde](https://github.com/lenML/ChatTTS-Forge/commit/3a7cbde6ccdfd20a6c53d7625d4e652007367fbf)]
|
161 |
+
- 🐛 fix webui skip no-translate [[a8d595e](https://github.com/lenML/ChatTTS-Forge/commit/a8d595eb490f23c943d6efc35b65b33266c033b7)]
|
162 |
+
- 🐛 fix hf.space force abort [[f564536](https://github.com/lenML/ChatTTS-Forge/commit/f5645360dd1f45a7bf112f01c85fb862ee57df3c)]
|
163 |
+
- 🐛 fix missing device [#25](https://github.com/lenML/ChatTTS-Forge/issues/25) [[07cf6c1](https://github.com/lenML/ChatTTS-Forge/commit/07cf6c1386900999b6c9436debbfcbe59f6b692a)]
|
164 |
+
- 🐛 fix Chat.refiner_prompt() [[0839863](https://github.com/lenML/ChatTTS-Forge/commit/083986369d0e67fcb4bd71930ad3d2bc3fc038fb)]
|
165 |
+
- 🐛 fix --language type check [[50d354c](https://github.com/lenML/ChatTTS-Forge/commit/50d354c91c659d9ae16c8eaa0218d9e08275fbb2)]
|
166 |
- 🐛 fix hparams config [#22](https://github.com/lenML/ChatTTS-Forge/issues/22) [[61d9809](https://github.com/lenML/ChatTTS-Forge/commit/61d9809804ad8c141d36afde51a608734a105662)]
|
167 |
- 🐛 fix enhance 下载脚本 [[d2e14b0](https://github.com/lenML/ChatTTS-Forge/commit/d2e14b0a4905724a55b03493fa4b94b5c4383c95)]
|
168 |
- 🐛 fix 'trange' referenced [[d1a8dae](https://github.com/lenML/ChatTTS-Forge/commit/d1a8daee61e62d14cf5fd7a17fab4424e24b1c41)]
|
|
|
172 |
|
173 |
### Miscellaneous
|
174 |
|
175 |
+
- 🐳 fix docker / 兼容 py 3.9 [[ebb096f](https://github.com/lenML/ChatTTS-Forge/commit/ebb096f9b1b843b65d150fb34da7d3b5acb13011)]
|
176 |
+
- 🐳 add .dockerignore [[57262b8](https://github.com/lenML/ChatTTS-Forge/commit/57262b81a8df3ed26ca5da5e264d5dca7b022471)]
|
177 |
+
- 🧪 add tests [[a807640](https://github.com/lenML/ChatTTS-Forge/commit/a80764030b790baee45a10cbe2d4edd7f183ef3c)]
|
178 |
+
- 🌐 fix [[b34a0f8](https://github.com/lenML/ChatTTS-Forge/commit/b34a0f8654467f3068e43056708742ab69e3665b)]
|
179 |
+
- 🌐 remove chat limit desc [[3f81eca](https://github.com/lenML/ChatTTS-Forge/commit/3f81ecae6e4521eeb4e867534defc36be741e1e2)]
|
180 |
+
- 🧪 add tests [[7a54225](https://github.com/lenML/ChatTTS-Forge/commit/7a542256a157a281a15312bbf987bc9fb16876ee)]
|
181 |
+
- 🔨 improve model downloader [[79a0c59](https://github.com/lenML/ChatTTS-Forge/commit/79a0c599f03b4e47346315a03f1df3d92578fe5d)]
|
182 |
- 🌐 更新翻译文案 [[f56caa7](https://github.com/lenML/ChatTTS-Forge/commit/f56caa71e9186680b93c487d9645186ae18c1dc6)]
|
|
|
|
|
|
|
183 |
|
184 |
<a name="0.5.5"></a>
|
185 |
|
launch.py
CHANGED
@@ -5,6 +5,7 @@ from modules.ffmpeg_env import setup_ffmpeg_path
|
|
5 |
|
6 |
try:
|
7 |
setup_ffmpeg_path()
|
|
|
8 |
logging.basicConfig(
|
9 |
level=os.getenv("LOG_LEVEL", "INFO"),
|
10 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
@@ -16,26 +17,44 @@ import argparse
|
|
16 |
|
17 |
import uvicorn
|
18 |
|
19 |
-
from modules.api.api_setup import setup_api_args
|
|
|
20 |
from modules.utils import env
|
|
|
|
|
|
|
21 |
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
-
if __name__ == "__main__":
|
25 |
-
import dotenv
|
26 |
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
29 |
)
|
30 |
-
parser =
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
)
|
33 |
-
setup_api_args(parser)
|
34 |
-
setup_model_args(parser)
|
35 |
-
setup_uvicon_args(parser=parser)
|
36 |
|
37 |
-
args = parser.parse_args()
|
38 |
|
|
|
39 |
host = env.get_and_update_env(args, "host", "0.0.0.0", str)
|
40 |
port = env.get_and_update_env(args, "port", 7870, int)
|
41 |
reload = env.get_and_update_env(args, "reload", False, bool)
|
@@ -68,3 +87,22 @@ if __name__ == "__main__":
|
|
68 |
ssl_certfile=ssl_certfile,
|
69 |
ssl_keyfile_password=ssl_keyfile_password,
|
70 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
try:
|
7 |
setup_ffmpeg_path()
|
8 |
+
# NOTE: 因为 logger 都是在模块中初始化,所以这个 config 必须在最前面
|
9 |
logging.basicConfig(
|
10 |
level=os.getenv("LOG_LEVEL", "INFO"),
|
11 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
|
17 |
|
18 |
import uvicorn
|
19 |
|
20 |
+
from modules.api.api_setup import setup_api_args
|
21 |
+
from modules.models_setup import setup_model_args
|
22 |
from modules.utils import env
|
23 |
+
from modules.utils.ignore_warn import ignore_useless_warnings
|
24 |
+
|
25 |
+
ignore_useless_warnings()
|
26 |
|
27 |
logger = logging.getLogger(__name__)
|
28 |
|
|
|
|
|
29 |
|
30 |
+
def setup_uvicon_args(parser: argparse.ArgumentParser):
|
31 |
+
parser.add_argument("--host", type=str, help="Host to run the server on")
|
32 |
+
parser.add_argument("--port", type=int, help="Port to run the server on")
|
33 |
+
parser.add_argument(
|
34 |
+
"--reload", action="store_true", help="Enable auto-reload for development"
|
35 |
)
|
36 |
+
parser.add_argument("--workers", type=int, help="Number of worker processes")
|
37 |
+
parser.add_argument("--log_level", type=str, help="Log level")
|
38 |
+
parser.add_argument("--access_log", action="store_true", help="Enable access log")
|
39 |
+
parser.add_argument(
|
40 |
+
"--proxy_headers", action="store_true", help="Enable proxy headers"
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--timeout_keep_alive", type=int, help="Keep-alive timeout duration"
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--timeout_graceful_shutdown",
|
47 |
+
type=int,
|
48 |
+
help="Graceful shutdown timeout duration",
|
49 |
+
)
|
50 |
+
parser.add_argument("--ssl_keyfile", type=str, help="SSL key file path")
|
51 |
+
parser.add_argument("--ssl_certfile", type=str, help="SSL certificate file path")
|
52 |
+
parser.add_argument(
|
53 |
+
"--ssl_keyfile_password", type=str, help="SSL key file password"
|
54 |
)
|
|
|
|
|
|
|
55 |
|
|
|
56 |
|
57 |
+
def process_uvicon_args(args):
|
58 |
host = env.get_and_update_env(args, "host", "0.0.0.0", str)
|
59 |
port = env.get_and_update_env(args, "port", 7870, int)
|
60 |
reload = env.get_and_update_env(args, "reload", False, bool)
|
|
|
87 |
ssl_certfile=ssl_certfile,
|
88 |
ssl_keyfile_password=ssl_keyfile_password,
|
89 |
)
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
import dotenv
|
94 |
+
|
95 |
+
dotenv.load_dotenv(
|
96 |
+
dotenv_path=os.getenv("ENV_FILE", ".env.api"),
|
97 |
+
)
|
98 |
+
parser = argparse.ArgumentParser(
|
99 |
+
description="Start the FastAPI server with command line arguments"
|
100 |
+
)
|
101 |
+
# NOTE: 主进程中不需要处理 model args / api args,但是要接收这些参数, 具体处理在 worker.py 中
|
102 |
+
setup_api_args(parser=parser)
|
103 |
+
setup_model_args(parser=parser)
|
104 |
+
setup_uvicon_args(parser=parser)
|
105 |
+
|
106 |
+
args = parser.parse_args()
|
107 |
+
|
108 |
+
process_uvicon_args(args)
|
modules/ChatTTS/ChatTTS/core.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import logging
|
2 |
import os
|
3 |
|
|
|
4 |
import torch
|
5 |
from huggingface_hub import snapshot_download
|
6 |
from omegaconf import OmegaConf
|
@@ -142,9 +143,12 @@ class Chat:
|
|
142 |
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=map_location))
|
143 |
if compile and "cuda" in str(device):
|
144 |
self.logger.info("compile gpt model")
|
145 |
-
|
146 |
-
gpt.gpt.forward
|
147 |
-
|
|
|
|
|
|
|
148 |
self.pretrain_models["gpt"] = gpt
|
149 |
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
|
150 |
assert os.path.exists(
|
@@ -173,7 +177,7 @@ class Chat:
|
|
173 |
|
174 |
self.check_model()
|
175 |
|
176 |
-
def
|
177 |
self,
|
178 |
text,
|
179 |
skip_refine_text=False,
|
@@ -181,9 +185,11 @@ class Chat:
|
|
181 |
params_refine_text={},
|
182 |
params_infer_code={"prompt": "[speed_5]"},
|
183 |
use_decoder=True,
|
|
|
|
|
184 |
):
|
185 |
|
186 |
-
assert self.check_model(use_decoder=use_decoder)
|
187 |
|
188 |
if not isinstance(text, list):
|
189 |
text = [text]
|
@@ -192,122 +198,147 @@ class Chat:
|
|
192 |
reserved_tokens = self.pretrain_models[
|
193 |
"tokenizer"
|
194 |
].additional_special_tokens
|
195 |
-
invalid_characters = count_invalid_characters(
|
|
|
|
|
196 |
if len(invalid_characters):
|
197 |
self.logger.log(
|
198 |
logging.WARNING, f"Invalid characters found! : {invalid_characters}"
|
199 |
)
|
200 |
-
text[i] = apply_character_map(t)
|
201 |
|
202 |
if not skip_refine_text:
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
text_tokens
|
208 |
-
|
209 |
-
i
|
210 |
-
|
211 |
-
"
|
212 |
-
|
|
|
|
|
|
|
213 |
]
|
214 |
-
|
215 |
-
|
216 |
-
text = self.pretrain_models["tokenizer"].batch_decode(text_tokens)
|
217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
if refine_text_only:
|
219 |
-
return
|
220 |
|
221 |
text = [params_infer_code.get("prompt", "") + i for i in text]
|
222 |
params_infer_code.pop("prompt", "")
|
223 |
-
|
224 |
-
self.pretrain_models,
|
|
|
|
|
|
|
|
|
225 |
)
|
226 |
-
|
227 |
if use_decoder:
|
228 |
-
|
229 |
-
|
230 |
-
for i in result["hiddens"]
|
231 |
-
]
|
232 |
else:
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
]
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
-
def
|
243 |
self,
|
244 |
text,
|
|
|
|
|
245 |
params_refine_text={},
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
text
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
return text[0]
|
274 |
|
275 |
def generate_audio(
|
276 |
self,
|
277 |
prompt,
|
278 |
params_infer_code={"prompt": "[speed_5]"},
|
279 |
use_decoder=True,
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
if not isinstance(prompt, list):
|
285 |
-
prompt = [prompt]
|
286 |
-
|
287 |
-
prompt = [params_infer_code.get("prompt", "") + i for i in prompt]
|
288 |
-
params_infer_code.pop("prompt", "")
|
289 |
-
result = infer_code(
|
290 |
-
self.pretrain_models,
|
291 |
prompt,
|
292 |
-
|
293 |
-
|
|
|
|
|
294 |
)
|
295 |
|
296 |
-
if use_decoder:
|
297 |
-
mel_spec = [
|
298 |
-
self.pretrain_models["decoder"](i[None].permute(0, 2, 1))
|
299 |
-
for i in result["hiddens"]
|
300 |
-
]
|
301 |
-
else:
|
302 |
-
mel_spec = [
|
303 |
-
self.pretrain_models["dvae"](i[None].permute(0, 2, 1))
|
304 |
-
for i in result["ids"]
|
305 |
-
]
|
306 |
-
|
307 |
-
wav = [self.pretrain_models["vocos"].decode(i).cpu().numpy() for i in mel_spec]
|
308 |
-
|
309 |
-
return wav
|
310 |
-
|
311 |
def sample_random_speaker(
|
312 |
self,
|
313 |
) -> torch.Tensor:
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
|
4 |
+
import numpy as np
|
5 |
import torch
|
6 |
from huggingface_hub import snapshot_download
|
7 |
from omegaconf import OmegaConf
|
|
|
143 |
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=map_location))
|
144 |
if compile and "cuda" in str(device):
|
145 |
self.logger.info("compile gpt model")
|
146 |
+
try:
|
147 |
+
gpt.gpt.forward = torch.compile(
|
148 |
+
gpt.gpt.forward, backend="inductor", dynamic=True
|
149 |
+
)
|
150 |
+
except RuntimeError as e:
|
151 |
+
logging.warning(f"Compile failed,{e}. fallback to normal mode.")
|
152 |
self.pretrain_models["gpt"] = gpt
|
153 |
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
|
154 |
assert os.path.exists(
|
|
|
177 |
|
178 |
self.check_model()
|
179 |
|
180 |
+
def _infer(
|
181 |
self,
|
182 |
text,
|
183 |
skip_refine_text=False,
|
|
|
185 |
params_refine_text={},
|
186 |
params_infer_code={"prompt": "[speed_5]"},
|
187 |
use_decoder=True,
|
188 |
+
stream=False,
|
189 |
+
stream_text=False,
|
190 |
):
|
191 |
|
192 |
+
# assert self.check_model(use_decoder=use_decoder)
|
193 |
|
194 |
if not isinstance(text, list):
|
195 |
text = [text]
|
|
|
198 |
reserved_tokens = self.pretrain_models[
|
199 |
"tokenizer"
|
200 |
].additional_special_tokens
|
201 |
+
invalid_characters = count_invalid_characters(
|
202 |
+
t, reserved_tokens=reserved_tokens
|
203 |
+
)
|
204 |
if len(invalid_characters):
|
205 |
self.logger.log(
|
206 |
logging.WARNING, f"Invalid characters found! : {invalid_characters}"
|
207 |
)
|
208 |
+
text[i] = apply_character_map(t, reserved_tokens=reserved_tokens)
|
209 |
|
210 |
if not skip_refine_text:
|
211 |
+
text_tokens_gen = refine_text(
|
212 |
+
self.pretrain_models, text, stream=stream, **params_refine_text
|
213 |
+
)
|
214 |
+
|
215 |
+
def decode_text(text_tokens):
|
216 |
+
text_tokens = [
|
217 |
+
i[
|
218 |
+
i
|
219 |
+
< self.pretrain_models["tokenizer"].convert_tokens_to_ids(
|
220 |
+
"[break_0]"
|
221 |
+
)
|
222 |
+
]
|
223 |
+
for i in text_tokens
|
224 |
]
|
225 |
+
text = self.pretrain_models["tokenizer"].batch_decode(text_tokens)
|
226 |
+
return text
|
|
|
227 |
|
228 |
+
if stream_text:
|
229 |
+
for result in text_tokens_gen:
|
230 |
+
text_incomplete = decode_text(result["ids"])
|
231 |
+
if refine_text_only and stream:
|
232 |
+
yield text_incomplete
|
233 |
+
if refine_text_only:
|
234 |
+
return
|
235 |
+
else:
|
236 |
+
result = next(text_tokens_gen)
|
237 |
+
text = decode_text(result["ids"])
|
238 |
+
if refine_text_only:
|
239 |
+
yield text
|
240 |
if refine_text_only:
|
241 |
+
return
|
242 |
|
243 |
text = [params_infer_code.get("prompt", "") + i for i in text]
|
244 |
params_infer_code.pop("prompt", "")
|
245 |
+
result_gen = infer_code(
|
246 |
+
self.pretrain_models,
|
247 |
+
text,
|
248 |
+
**params_infer_code,
|
249 |
+
return_hidden=use_decoder,
|
250 |
+
stream=stream,
|
251 |
)
|
|
|
252 |
if use_decoder:
|
253 |
+
field = "hiddens"
|
254 |
+
docoder_name = "decoder"
|
|
|
|
|
255 |
else:
|
256 |
+
field = "ids"
|
257 |
+
docoder_name = "dvae"
|
258 |
+
vocos_decode = lambda spec: [
|
259 |
+
self.pretrain_models["vocos"]
|
260 |
+
.decode(i.cpu() if torch.backends.mps.is_available() else i)
|
261 |
+
.cpu()
|
262 |
+
.numpy()
|
263 |
+
for i in spec
|
264 |
+
]
|
265 |
+
if stream:
|
266 |
+
|
267 |
+
length = 0
|
268 |
+
for result in result_gen:
|
269 |
+
chunk_data = result[field][0]
|
270 |
+
assert len(result[field]) == 1
|
271 |
+
start_seek = length
|
272 |
+
length = len(chunk_data)
|
273 |
+
self.logger.debug(
|
274 |
+
f"{start_seek=} total len: {length}, new len: {length - start_seek = }"
|
275 |
+
)
|
276 |
+
chunk_data = chunk_data[start_seek:]
|
277 |
+
if not len(chunk_data):
|
278 |
+
continue
|
279 |
+
self.logger.debug(f"new hidden {len(chunk_data)=}")
|
280 |
+
mel_spec = [
|
281 |
+
self.pretrain_models[docoder_name](i[None].permute(0, 2, 1))
|
282 |
+
for i in [chunk_data]
|
283 |
+
]
|
284 |
+
wav = vocos_decode(mel_spec)
|
285 |
+
self.logger.debug(f"yield wav chunk {len(wav[0])=} {len(wav[0][0])=}")
|
286 |
+
yield wav
|
287 |
+
return
|
288 |
+
mel_spec = [
|
289 |
+
self.pretrain_models[docoder_name](i[None].permute(0, 2, 1))
|
290 |
+
for i in next(result_gen)[field]
|
291 |
+
]
|
292 |
+
yield vocos_decode(mel_spec)
|
293 |
|
294 |
+
def infer(
|
295 |
self,
|
296 |
text,
|
297 |
+
skip_refine_text=False,
|
298 |
+
refine_text_only=False,
|
299 |
params_refine_text={},
|
300 |
+
params_infer_code={"prompt": "[speed_5]"},
|
301 |
+
use_decoder=True,
|
302 |
+
stream=False,
|
303 |
+
):
|
304 |
+
res_gen = self._infer(
|
305 |
+
text=text,
|
306 |
+
skip_refine_text=skip_refine_text,
|
307 |
+
refine_text_only=refine_text_only,
|
308 |
+
params_refine_text=params_refine_text,
|
309 |
+
params_infer_code=params_infer_code,
|
310 |
+
use_decoder=use_decoder,
|
311 |
+
stream=stream,
|
312 |
+
)
|
313 |
+
if stream:
|
314 |
+
return res_gen
|
315 |
+
else:
|
316 |
+
return next(res_gen)
|
317 |
+
|
318 |
+
def refiner_prompt(self, text, params_refine_text={}, stream=False):
|
319 |
+
return self.infer(
|
320 |
+
text=text,
|
321 |
+
skip_refine_text=False,
|
322 |
+
refine_text_only=True,
|
323 |
+
params_refine_text=params_refine_text,
|
324 |
+
stream=stream,
|
325 |
+
)
|
|
|
|
|
326 |
|
327 |
def generate_audio(
|
328 |
self,
|
329 |
prompt,
|
330 |
params_infer_code={"prompt": "[speed_5]"},
|
331 |
use_decoder=True,
|
332 |
+
stream=False,
|
333 |
+
):
|
334 |
+
return self.infer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
prompt,
|
336 |
+
skip_refine_text=True,
|
337 |
+
params_infer_code=params_infer_code,
|
338 |
+
use_decoder=use_decoder,
|
339 |
+
stream=stream,
|
340 |
)
|
341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
def sample_random_speaker(
|
343 |
self,
|
344 |
) -> torch.Tensor:
|
modules/ChatTTS/ChatTTS/infer/api.py
CHANGED
@@ -17,6 +17,7 @@ def infer_code(
|
|
17 |
prompt1="",
|
18 |
prompt2="",
|
19 |
prefix="",
|
|
|
20 |
**kwargs,
|
21 |
):
|
22 |
|
@@ -83,6 +84,7 @@ def infer_code(
|
|
83 |
eos_token=num_code,
|
84 |
max_new_token=max_new_token,
|
85 |
infer_text=False,
|
|
|
86 |
**kwargs,
|
87 |
)
|
88 |
|
@@ -98,6 +100,7 @@ def refine_text(
|
|
98 |
repetition_penalty=1.0,
|
99 |
max_new_token=384,
|
100 |
prompt="",
|
|
|
101 |
**kwargs,
|
102 |
):
|
103 |
device = next(models["gpt"].parameters()).device
|
@@ -152,6 +155,7 @@ def refine_text(
|
|
152 |
)[None],
|
153 |
max_new_token=max_new_token,
|
154 |
infer_text=True,
|
|
|
155 |
**kwargs,
|
156 |
)
|
157 |
return result
|
|
|
17 |
prompt1="",
|
18 |
prompt2="",
|
19 |
prefix="",
|
20 |
+
stream=False,
|
21 |
**kwargs,
|
22 |
):
|
23 |
|
|
|
84 |
eos_token=num_code,
|
85 |
max_new_token=max_new_token,
|
86 |
infer_text=False,
|
87 |
+
stream=stream,
|
88 |
**kwargs,
|
89 |
)
|
90 |
|
|
|
100 |
repetition_penalty=1.0,
|
101 |
max_new_token=384,
|
102 |
prompt="",
|
103 |
+
stream=False,
|
104 |
**kwargs,
|
105 |
):
|
106 |
device = next(models["gpt"].parameters()).device
|
|
|
155 |
)[None],
|
156 |
max_new_token=max_new_token,
|
157 |
infer_text=True,
|
158 |
+
stream=stream,
|
159 |
**kwargs,
|
160 |
)
|
161 |
return result
|
modules/ChatTTS/ChatTTS/model/gpt.py
CHANGED
@@ -3,6 +3,7 @@ import os
|
|
3 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
4 |
|
5 |
import logging
|
|
|
6 |
|
7 |
import torch
|
8 |
import torch.nn as nn
|
@@ -37,7 +38,6 @@ class GPT_warpper(nn.Module):
|
|
37 |
num_audio_tokens,
|
38 |
num_text_tokens,
|
39 |
num_vq=4,
|
40 |
-
**kwargs,
|
41 |
):
|
42 |
super().__init__()
|
43 |
|
@@ -211,12 +211,13 @@ class GPT_warpper(nn.Module):
|
|
211 |
infer_text=False,
|
212 |
return_attn=False,
|
213 |
return_hidden=False,
|
|
|
214 |
disable_tqdm=False,
|
215 |
):
|
|
|
|
|
216 |
if disable_tqdm:
|
217 |
-
tqdm =
|
218 |
-
else:
|
219 |
-
from tqdm import tqdm
|
220 |
|
221 |
with torch.no_grad():
|
222 |
|
@@ -242,90 +243,136 @@ class GPT_warpper(nn.Module):
|
|
242 |
if attention_mask is not None:
|
243 |
attention_mask_cache[:, : attention_mask.shape[1]] = attention_mask
|
244 |
|
245 |
-
|
246 |
-
if finish.all():
|
247 |
-
continue
|
248 |
|
249 |
-
|
250 |
-
inputs_ids,
|
251 |
-
outputs.past_key_values if i != 0 else None,
|
252 |
-
attention_mask_cache[:, : inputs_ids.shape[1]],
|
253 |
-
use_cache=True,
|
254 |
-
)
|
255 |
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
]
|
268 |
-
model_input["inputs_embeds"] = torch.stack(code_emb, 3).sum(3)
|
269 |
-
|
270 |
-
model_input["input_ids"] = None
|
271 |
-
outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
|
272 |
-
attentions.append(outputs.attentions)
|
273 |
-
hidden_states = outputs[0] # 🐻
|
274 |
-
if return_hidden:
|
275 |
-
hiddens.append(hidden_states[:, -1])
|
276 |
-
|
277 |
-
with P.cached():
|
278 |
-
if infer_text:
|
279 |
-
logits = self.head_text(hidden_states)
|
280 |
else:
|
281 |
-
|
282 |
-
[
|
283 |
-
|
|
|
|
|
|
|
|
|
284 |
for i in range(self.num_vq)
|
285 |
-
]
|
286 |
-
3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
)
|
|
|
|
|
288 |
|
289 |
-
|
290 |
|
291 |
-
|
292 |
-
|
293 |
-
logits_token = rearrange(
|
294 |
-
inputs_ids[:, start_idx:], "b c n -> (b n) c"
|
295 |
-
)
|
296 |
-
else:
|
297 |
-
logits_token = inputs_ids[:, start_idx:, 0]
|
298 |
|
299 |
-
|
|
|
300 |
|
301 |
-
|
302 |
-
logits = logitsProcessors(logits_token, logits)
|
303 |
|
304 |
-
|
305 |
-
|
306 |
|
307 |
-
|
308 |
-
logits[:, eos_token] = -torch.inf
|
309 |
|
310 |
-
|
311 |
|
312 |
-
|
313 |
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
|
|
|
|
|
|
|
|
327 |
|
328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
inputs_ids = [
|
331 |
inputs_ids[idx, start_idx : start_idx + i]
|
@@ -342,7 +389,9 @@ class GPT_warpper(nn.Module):
|
|
342 |
f"Incomplete result. hit max_new_token: {max_new_token}"
|
343 |
)
|
344 |
|
345 |
-
|
|
|
|
|
346 |
"ids": inputs_ids,
|
347 |
"attentions": attentions,
|
348 |
"hiddens": hiddens,
|
|
|
3 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
4 |
|
5 |
import logging
|
6 |
+
from functools import partial
|
7 |
|
8 |
import torch
|
9 |
import torch.nn as nn
|
|
|
38 |
num_audio_tokens,
|
39 |
num_text_tokens,
|
40 |
num_vq=4,
|
|
|
41 |
):
|
42 |
super().__init__()
|
43 |
|
|
|
211 |
infer_text=False,
|
212 |
return_attn=False,
|
213 |
return_hidden=False,
|
214 |
+
stream=False,
|
215 |
disable_tqdm=False,
|
216 |
):
|
217 |
+
from tqdm import tqdm
|
218 |
+
|
219 |
if disable_tqdm:
|
220 |
+
tqdm = partial(tqdm, disable=True)
|
|
|
|
|
221 |
|
222 |
with torch.no_grad():
|
223 |
|
|
|
243 |
if attention_mask is not None:
|
244 |
attention_mask_cache[:, : attention_mask.shape[1]] = attention_mask
|
245 |
|
246 |
+
with tqdm(total=max_new_token) as pbar:
|
|
|
|
|
247 |
|
248 |
+
past_key_values = None
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
+
for i in range(max_new_token):
|
251 |
+
pbar.update(1)
|
252 |
+
model_input = self.prepare_inputs_for_generation(
|
253 |
+
inputs_ids,
|
254 |
+
past_key_values,
|
255 |
+
attention_mask_cache[:, : inputs_ids.shape[1]],
|
256 |
+
use_cache=True,
|
257 |
+
)
|
258 |
+
|
259 |
+
if i == 0:
|
260 |
+
model_input["inputs_embeds"] = emb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
else:
|
262 |
+
if infer_text:
|
263 |
+
model_input["inputs_embeds"] = self.emb_text(
|
264 |
+
model_input["input_ids"][:, :, 0]
|
265 |
+
)
|
266 |
+
else:
|
267 |
+
code_emb = [
|
268 |
+
self.emb_code[i](model_input["input_ids"][:, :, i])
|
269 |
for i in range(self.num_vq)
|
270 |
+
]
|
271 |
+
model_input["inputs_embeds"] = torch.stack(code_emb, 3).sum(
|
272 |
+
3
|
273 |
+
)
|
274 |
+
|
275 |
+
model_input["input_ids"] = None
|
276 |
+
outputs = self.gpt.forward(
|
277 |
+
**model_input, output_attentions=return_attn
|
278 |
+
)
|
279 |
+
del model_input
|
280 |
+
attentions.append(outputs.attentions)
|
281 |
+
hidden_states = outputs[0] # 🐻
|
282 |
+
past_key_values = outputs.past_key_values
|
283 |
+
del outputs
|
284 |
+
if return_hidden:
|
285 |
+
hiddens.append(hidden_states[:, -1])
|
286 |
+
|
287 |
+
with P.cached():
|
288 |
+
if infer_text:
|
289 |
+
logits = self.head_text(hidden_states)
|
290 |
+
else:
|
291 |
+
logits = torch.stack(
|
292 |
+
[
|
293 |
+
self.head_code[i](hidden_states)
|
294 |
+
for i in range(self.num_vq)
|
295 |
+
],
|
296 |
+
3,
|
297 |
+
)
|
298 |
+
|
299 |
+
logits = logits[:, -1].float()
|
300 |
+
|
301 |
+
if not infer_text:
|
302 |
+
logits = rearrange(logits, "b c n -> (b n) c")
|
303 |
+
logits_token = rearrange(
|
304 |
+
inputs_ids[:, start_idx:], "b c n -> (b n) c"
|
305 |
)
|
306 |
+
else:
|
307 |
+
logits_token = inputs_ids[:, start_idx:, 0]
|
308 |
|
309 |
+
logits = logits / temperature
|
310 |
|
311 |
+
for logitsProcessors in LogitsProcessors:
|
312 |
+
logits = logitsProcessors(logits_token, logits)
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
+
for logitsWarpers in LogitsWarpers:
|
315 |
+
logits = logitsWarpers(logits_token, logits)
|
316 |
|
317 |
+
del logits_token
|
|
|
318 |
|
319 |
+
if i < min_new_token:
|
320 |
+
logits[:, eos_token] = -torch.inf
|
321 |
|
322 |
+
scores = F.softmax(logits, dim=-1)
|
|
|
323 |
|
324 |
+
del logits
|
325 |
|
326 |
+
idx_next = torch.multinomial(scores, num_samples=1)
|
327 |
|
328 |
+
if not infer_text:
|
329 |
+
idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
|
330 |
+
finish_or = (idx_next == eos_token).any(1)
|
331 |
+
finish |= finish_or
|
332 |
+
del finish_or
|
333 |
+
inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
|
334 |
+
else:
|
335 |
+
finish_or = (idx_next == eos_token).any(1)
|
336 |
+
finish |= finish_or
|
337 |
+
del finish_or
|
338 |
+
inputs_ids = torch.cat(
|
339 |
+
[
|
340 |
+
inputs_ids,
|
341 |
+
idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq),
|
342 |
+
],
|
343 |
+
1,
|
344 |
+
)
|
345 |
|
346 |
+
del idx_next
|
347 |
+
|
348 |
+
end_idx += (~finish).int().to(end_idx.device)
|
349 |
+
if stream:
|
350 |
+
if end_idx % 24 and not finish.all():
|
351 |
+
continue
|
352 |
+
y_inputs_ids = [
|
353 |
+
inputs_ids[idx, start_idx : start_idx + i]
|
354 |
+
for idx, i in enumerate(end_idx.int())
|
355 |
+
]
|
356 |
+
y_inputs_ids = (
|
357 |
+
[i[:, 0] for i in y_inputs_ids]
|
358 |
+
if infer_text
|
359 |
+
else y_inputs_ids
|
360 |
+
)
|
361 |
+
y_hiddens = [[]]
|
362 |
+
if return_hidden:
|
363 |
+
y_hiddens = torch.stack(hiddens, 1)
|
364 |
+
y_hiddens = [
|
365 |
+
y_hiddens[idx, :i]
|
366 |
+
for idx, i in enumerate(end_idx.int())
|
367 |
+
]
|
368 |
+
yield {
|
369 |
+
"ids": y_inputs_ids,
|
370 |
+
"attentions": attentions,
|
371 |
+
"hiddens": y_hiddens,
|
372 |
+
}
|
373 |
+
if finish.all():
|
374 |
+
pbar.update(max_new_token - i - 1)
|
375 |
+
break
|
376 |
|
377 |
inputs_ids = [
|
378 |
inputs_ids[idx, start_idx : start_idx + i]
|
|
|
389 |
f"Incomplete result. hit max_new_token: {max_new_token}"
|
390 |
)
|
391 |
|
392 |
+
del finish
|
393 |
+
|
394 |
+
yield {
|
395 |
"ids": inputs_ids,
|
396 |
"attentions": attentions,
|
397 |
"hiddens": hiddens,
|
modules/ChatTTS/ChatTTS/utils/infer_utils.py
CHANGED
@@ -24,6 +24,7 @@ class CustomRepetitionPenaltyLogitsProcessorRepeat:
|
|
24 |
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
|
25 |
freq[self.max_input_ids :] = 0
|
26 |
alpha = self.penalty**freq
|
|
|
27 |
scores = torch.where(scores < 0, scores * alpha, scores / alpha)
|
28 |
|
29 |
return scores
|
@@ -145,11 +146,35 @@ halfwidth_2_fullwidth_map = {
|
|
145 |
}
|
146 |
|
147 |
|
148 |
-
def
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
|
153 |
-
def apply_character_map(text):
|
154 |
-
|
155 |
-
|
|
|
|
24 |
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
|
25 |
freq[self.max_input_ids :] = 0
|
26 |
alpha = self.penalty**freq
|
27 |
+
scores = scores.contiguous()
|
28 |
scores = torch.where(scores < 0, scores * alpha, scores / alpha)
|
29 |
|
30 |
return scores
|
|
|
146 |
}
|
147 |
|
148 |
|
149 |
+
def replace_unsupported_chars(text, replace_dict, reserved_tokens: list = []):
|
150 |
+
escaped_tokens = [re.escape(token) for token in reserved_tokens]
|
151 |
+
special_tokens_pattern = "|".join(escaped_tokens)
|
152 |
+
tokens = re.split(f"({special_tokens_pattern})", text)
|
153 |
+
|
154 |
+
def replace_chars(segment):
|
155 |
+
for old_char, new_char in replace_dict.items():
|
156 |
+
segment = segment.replace(old_char, new_char)
|
157 |
+
return segment
|
158 |
+
|
159 |
+
result = "".join(
|
160 |
+
(
|
161 |
+
replace_chars(segment)
|
162 |
+
if not re.match(special_tokens_pattern, segment)
|
163 |
+
else segment
|
164 |
+
)
|
165 |
+
for segment in tokens
|
166 |
+
)
|
167 |
+
|
168 |
+
return result
|
169 |
+
|
170 |
+
|
171 |
+
def apply_half2full_map(text, reserved_tokens: list = []):
|
172 |
+
return replace_unsupported_chars(
|
173 |
+
text, halfwidth_2_fullwidth_map, reserved_tokens=reserved_tokens
|
174 |
+
)
|
175 |
|
176 |
|
177 |
+
def apply_character_map(text, reserved_tokens: list = []):
|
178 |
+
return replace_unsupported_chars(
|
179 |
+
text, character_map, reserved_tokens=reserved_tokens
|
180 |
+
)
|
modules/Enhancer/ResembleEnhance.py
CHANGED
@@ -85,7 +85,7 @@ def load_enhancer() -> ResembleEnhance:
|
|
85 |
if resemble_enhance is None:
|
86 |
logger.info("Loading ResembleEnhance model")
|
87 |
resemble_enhance = ResembleEnhance(
|
88 |
-
device=devices.
|
89 |
)
|
90 |
resemble_enhance.load_model()
|
91 |
logger.info("ResembleEnhance model loaded")
|
|
|
85 |
if resemble_enhance is None:
|
86 |
logger.info("Loading ResembleEnhance model")
|
87 |
resemble_enhance = ResembleEnhance(
|
88 |
+
device=devices.get_device_for("enhancer"), dtype=devices.dtype
|
89 |
)
|
90 |
resemble_enhance.load_model()
|
91 |
logger.info("ResembleEnhance model loaded")
|
modules/SentenceSplitter.py
CHANGED
@@ -2,6 +2,8 @@ import re
|
|
2 |
|
3 |
import zhon
|
4 |
|
|
|
|
|
5 |
|
6 |
def split_zhon_sentence(text):
|
7 |
result = []
|
@@ -21,6 +23,35 @@ def split_zhon_sentence(text):
|
|
21 |
return result
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# 解析文本 并根据停止符号分割成句子
|
25 |
# 可以设置最大阈值,即如果分割片段小于这个阈值会与下一段合并
|
26 |
class SentenceSplitter:
|
@@ -28,7 +59,7 @@ class SentenceSplitter:
|
|
28 |
self.sentence_threshold = threshold
|
29 |
|
30 |
def parse(self, text):
|
31 |
-
sentences =
|
32 |
|
33 |
# 合并小于最大阈值的片段
|
34 |
merged_sentences = []
|
|
|
2 |
|
3 |
import zhon
|
4 |
|
5 |
+
from modules.utils.detect_lang import guess_lang
|
6 |
+
|
7 |
|
8 |
def split_zhon_sentence(text):
|
9 |
result = []
|
|
|
23 |
return result
|
24 |
|
25 |
|
26 |
+
def split_en_sentence(text):
|
27 |
+
"""
|
28 |
+
Split English text into sentences.
|
29 |
+
"""
|
30 |
+
# Define a regex pattern for English sentence splitting
|
31 |
+
pattern = re.compile(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s")
|
32 |
+
result = pattern.split(text)
|
33 |
+
|
34 |
+
# Filter out any empty strings or strings that are just whitespace
|
35 |
+
result = [sentence.strip() for sentence in result if sentence.strip()]
|
36 |
+
|
37 |
+
return result
|
38 |
+
|
39 |
+
|
40 |
+
def is_eng_sentence(text):
|
41 |
+
return guess_lang(text) == "en"
|
42 |
+
|
43 |
+
|
44 |
+
def split_zhon_paragraph(text):
|
45 |
+
lines = text.split("\n")
|
46 |
+
result = []
|
47 |
+
for line in lines:
|
48 |
+
if is_eng_sentence(line):
|
49 |
+
result.extend(split_en_sentence(line))
|
50 |
+
else:
|
51 |
+
result.extend(split_zhon_sentence(line))
|
52 |
+
return result
|
53 |
+
|
54 |
+
|
55 |
# 解析文本 并根据停止符号分割成句子
|
56 |
# 可以设置最大阈值,即如果分割片段小于这个阈值会与下一段合并
|
57 |
class SentenceSplitter:
|
|
|
59 |
self.sentence_threshold = threshold
|
60 |
|
61 |
def parse(self, text):
|
62 |
+
sentences = split_zhon_paragraph(text)
|
63 |
|
64 |
# 合并小于最大阈值的片段
|
65 |
merged_sentences = []
|
modules/SynthesizeSegments.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import copy
|
|
|
2 |
import json
|
3 |
import logging
|
4 |
import re
|
@@ -7,6 +8,7 @@ from typing import List, Union
|
|
7 |
import numpy as np
|
8 |
from box import Box
|
9 |
from pydub import AudioSegment
|
|
|
10 |
|
11 |
from modules import generate_audio
|
12 |
from modules.api.utils import calc_spk_style
|
@@ -15,15 +17,39 @@ from modules.SentenceSplitter import SentenceSplitter
|
|
15 |
from modules.speaker import Speaker
|
16 |
from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLContext, SSMLSegment
|
17 |
from modules.utils import rng
|
18 |
-
from modules.utils.audio import
|
19 |
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def audio_data_to_segment(audio_data: np.ndarray, sr: int):
|
24 |
"""
|
25 |
optimize: https://github.com/lenML/ChatTTS-Forge/issues/57
|
26 |
"""
|
|
|
|
|
|
|
|
|
27 |
audio_data = (audio_data * 32767).astype(np.int16)
|
28 |
audio_segment = AudioSegment(
|
29 |
audio_data.tobytes(),
|
@@ -41,21 +67,6 @@ def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment:
|
|
41 |
return combined_audio
|
42 |
|
43 |
|
44 |
-
def apply_prosody(
|
45 |
-
audio_segment: AudioSegment, rate: float, volume: float, pitch: float
|
46 |
-
) -> AudioSegment:
|
47 |
-
if rate != 1:
|
48 |
-
audio_segment = time_stretch(audio_segment, rate)
|
49 |
-
|
50 |
-
if volume != 0:
|
51 |
-
audio_segment += volume
|
52 |
-
|
53 |
-
if pitch != 0:
|
54 |
-
audio_segment = pitch_shift(audio_segment, pitch)
|
55 |
-
|
56 |
-
return audio_segment
|
57 |
-
|
58 |
-
|
59 |
def to_number(value, t, default=0):
|
60 |
try:
|
61 |
number = t(value)
|
@@ -202,7 +213,9 @@ class SynthesizeSegments:
|
|
202 |
pitch = float(segment.get("pitch", "0"))
|
203 |
|
204 |
audio_segment = audio_data_to_segment(audio_data, sr)
|
205 |
-
audio_segment =
|
|
|
|
|
206 |
# compare by Box object
|
207 |
original_index = src_segments.index(segment)
|
208 |
audio_segments[original_index] = audio_segment
|
|
|
1 |
import copy
|
2 |
+
import io
|
3 |
import json
|
4 |
import logging
|
5 |
import re
|
|
|
8 |
import numpy as np
|
9 |
from box import Box
|
10 |
from pydub import AudioSegment
|
11 |
+
from scipy.io import wavfile
|
12 |
|
13 |
from modules import generate_audio
|
14 |
from modules.api.utils import calc_spk_style
|
|
|
17 |
from modules.speaker import Speaker
|
18 |
from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLContext, SSMLSegment
|
19 |
from modules.utils import rng
|
20 |
+
from modules.utils.audio import apply_prosody_to_audio_segment
|
21 |
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
|
25 |
+
def audio_data_to_segment_slow(audio_data, sr):
|
26 |
+
byte_io = io.BytesIO()
|
27 |
+
wavfile.write(byte_io, rate=sr, data=audio_data)
|
28 |
+
byte_io.seek(0)
|
29 |
+
|
30 |
+
return AudioSegment.from_file(byte_io, format="wav")
|
31 |
+
|
32 |
+
|
33 |
+
def clip_audio(audio_data: np.ndarray, threshold: float = 0.99):
|
34 |
+
audio_data = np.clip(audio_data, -threshold, threshold)
|
35 |
+
return audio_data
|
36 |
+
|
37 |
+
|
38 |
+
def normalize_audio(audio_data: np.ndarray, norm_factor: float = 0.8):
|
39 |
+
max_amplitude = np.max(np.abs(audio_data))
|
40 |
+
if max_amplitude > 0:
|
41 |
+
audio_data = audio_data / max_amplitude * norm_factor
|
42 |
+
return audio_data
|
43 |
+
|
44 |
+
|
45 |
def audio_data_to_segment(audio_data: np.ndarray, sr: int):
|
46 |
"""
|
47 |
optimize: https://github.com/lenML/ChatTTS-Forge/issues/57
|
48 |
"""
|
49 |
+
|
50 |
+
audio_data = normalize_audio(audio_data)
|
51 |
+
audio_data = clip_audio(audio_data)
|
52 |
+
|
53 |
audio_data = (audio_data * 32767).astype(np.int16)
|
54 |
audio_segment = AudioSegment(
|
55 |
audio_data.tobytes(),
|
|
|
67 |
return combined_audio
|
68 |
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
def to_number(value, t, default=0):
|
71 |
try:
|
72 |
number = t(value)
|
|
|
213 |
pitch = float(segment.get("pitch", "0"))
|
214 |
|
215 |
audio_segment = audio_data_to_segment(audio_data, sr)
|
216 |
+
audio_segment = apply_prosody_to_audio_segment(
|
217 |
+
audio_segment, rate=rate, volume=volume, pitch=pitch
|
218 |
+
)
|
219 |
# compare by Box object
|
220 |
original_index = src_segments.index(segment)
|
221 |
audio_segments[original_index] = audio_segment
|
modules/api/api_setup.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import argparse
|
2 |
import logging
|
3 |
|
4 |
-
from
|
|
|
|
|
5 |
from modules.api.Api import APIManager
|
6 |
from modules.api.impl import (
|
7 |
google_api,
|
@@ -15,15 +17,12 @@ from modules.api.impl import (
|
|
15 |
tts_api,
|
16 |
xtts_v2_api,
|
17 |
)
|
18 |
-
from modules.devices import devices
|
19 |
-
from modules.Enhancer.ResembleEnhance import load_enhancer
|
20 |
-
from modules.models import load_chat_tts
|
21 |
from modules.utils import env
|
22 |
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
25 |
|
26 |
-
def create_api(app, exclude=[]):
|
27 |
app_mgr = APIManager(app=app, exclude_patterns=exclude)
|
28 |
|
29 |
ping_api.setup(app_mgr)
|
@@ -40,98 +39,6 @@ def create_api(app, exclude=[]):
|
|
40 |
return app_mgr
|
41 |
|
42 |
|
43 |
-
def setup_model_args(parser: argparse.ArgumentParser):
|
44 |
-
parser.add_argument("--compile", action="store_true", help="Enable model compile")
|
45 |
-
parser.add_argument(
|
46 |
-
"--no_half",
|
47 |
-
action="store_true",
|
48 |
-
help="Disalbe half precision for model inference",
|
49 |
-
)
|
50 |
-
parser.add_argument(
|
51 |
-
"--off_tqdm",
|
52 |
-
action="store_true",
|
53 |
-
help="Disable tqdm progress bar",
|
54 |
-
)
|
55 |
-
parser.add_argument(
|
56 |
-
"--device_id",
|
57 |
-
type=str,
|
58 |
-
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
|
59 |
-
default=None,
|
60 |
-
)
|
61 |
-
parser.add_argument(
|
62 |
-
"--use_cpu",
|
63 |
-
nargs="+",
|
64 |
-
help="use CPU as torch device for specified modules",
|
65 |
-
default=[],
|
66 |
-
type=str.lower,
|
67 |
-
)
|
68 |
-
parser.add_argument(
|
69 |
-
"--lru_size",
|
70 |
-
type=int,
|
71 |
-
default=64,
|
72 |
-
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
|
73 |
-
)
|
74 |
-
parser.add_argument(
|
75 |
-
"--debug_generate",
|
76 |
-
action="store_true",
|
77 |
-
help="Enable debug mode for audio generation",
|
78 |
-
)
|
79 |
-
parser.add_argument(
|
80 |
-
"--preload_models",
|
81 |
-
action="store_true",
|
82 |
-
help="Preload all models at startup",
|
83 |
-
)
|
84 |
-
|
85 |
-
|
86 |
-
def process_model_args(args):
|
87 |
-
lru_size = env.get_and_update_env(args, "lru_size", 64, int)
|
88 |
-
compile = env.get_and_update_env(args, "compile", False, bool)
|
89 |
-
device_id = env.get_and_update_env(args, "device_id", None, str)
|
90 |
-
use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
|
91 |
-
no_half = env.get_and_update_env(args, "no_half", False, bool)
|
92 |
-
off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
|
93 |
-
debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
|
94 |
-
preload_models = env.get_and_update_env(args, "preload_models", False, bool)
|
95 |
-
|
96 |
-
generate_audio.setup_lru_cache()
|
97 |
-
devices.reset_device()
|
98 |
-
devices.first_time_calculation()
|
99 |
-
|
100 |
-
if debug_generate:
|
101 |
-
generate_audio.logger.setLevel(logging.DEBUG)
|
102 |
-
|
103 |
-
if preload_models:
|
104 |
-
load_chat_tts()
|
105 |
-
load_enhancer()
|
106 |
-
|
107 |
-
|
108 |
-
def setup_uvicon_args(parser: argparse.ArgumentParser):
|
109 |
-
parser.add_argument("--host", type=str, help="Host to run the server on")
|
110 |
-
parser.add_argument("--port", type=int, help="Port to run the server on")
|
111 |
-
parser.add_argument(
|
112 |
-
"--reload", action="store_true", help="Enable auto-reload for development"
|
113 |
-
)
|
114 |
-
parser.add_argument("--workers", type=int, help="Number of worker processes")
|
115 |
-
parser.add_argument("--log_level", type=str, help="Log level")
|
116 |
-
parser.add_argument("--access_log", action="store_true", help="Enable access log")
|
117 |
-
parser.add_argument(
|
118 |
-
"--proxy_headers", action="store_true", help="Enable proxy headers"
|
119 |
-
)
|
120 |
-
parser.add_argument(
|
121 |
-
"--timeout_keep_alive", type=int, help="Keep-alive timeout duration"
|
122 |
-
)
|
123 |
-
parser.add_argument(
|
124 |
-
"--timeout_graceful_shutdown",
|
125 |
-
type=int,
|
126 |
-
help="Graceful shutdown timeout duration",
|
127 |
-
)
|
128 |
-
parser.add_argument("--ssl_keyfile", type=str, help="SSL key file path")
|
129 |
-
parser.add_argument("--ssl_certfile", type=str, help="SSL certificate file path")
|
130 |
-
parser.add_argument(
|
131 |
-
"--ssl_keyfile_password", type=str, help="SSL key file password"
|
132 |
-
)
|
133 |
-
|
134 |
-
|
135 |
def setup_api_args(parser: argparse.ArgumentParser):
|
136 |
parser.add_argument(
|
137 |
"--cors_origin",
|
@@ -156,7 +63,7 @@ def setup_api_args(parser: argparse.ArgumentParser):
|
|
156 |
)
|
157 |
|
158 |
|
159 |
-
def process_api_args(args, app):
|
160 |
cors_origin = env.get_and_update_env(args, "cors_origin", "*", str)
|
161 |
no_playground = env.get_and_update_env(args, "no_playground", False, bool)
|
162 |
no_docs = env.get_and_update_env(args, "no_docs", False, bool)
|
|
|
1 |
import argparse
|
2 |
import logging
|
3 |
|
4 |
+
from fastapi import FastAPI
|
5 |
+
|
6 |
+
from modules import config
|
7 |
from modules.api.Api import APIManager
|
8 |
from modules.api.impl import (
|
9 |
google_api,
|
|
|
17 |
tts_api,
|
18 |
xtts_v2_api,
|
19 |
)
|
|
|
|
|
|
|
20 |
from modules.utils import env
|
21 |
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
|
25 |
+
def create_api(app: FastAPI, exclude=[]):
|
26 |
app_mgr = APIManager(app=app, exclude_patterns=exclude)
|
27 |
|
28 |
ping_api.setup(app_mgr)
|
|
|
39 |
return app_mgr
|
40 |
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def setup_api_args(parser: argparse.ArgumentParser):
|
43 |
parser.add_argument(
|
44 |
"--cors_origin",
|
|
|
63 |
)
|
64 |
|
65 |
|
66 |
+
def process_api_args(args: argparse.Namespace, app: FastAPI):
|
67 |
cors_origin = env.get_and_update_env(args, "cors_origin", "*", str)
|
68 |
no_playground = env.get_and_update_env(args, "no_playground", False, bool)
|
69 |
no_docs = env.get_and_update_env(args, "no_docs", False, bool)
|
modules/api/impl/handler/AudioHandler.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import base64
|
2 |
import io
|
|
|
3 |
|
4 |
import numpy as np
|
5 |
import soundfile as sf
|
@@ -10,7 +11,24 @@ from modules.api.impl.model.audio_model import AudioFormat
|
|
10 |
|
11 |
class AudioHandler:
|
12 |
def enqueue(self) -> tuple[np.ndarray, int]:
|
13 |
-
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def enqueue_to_buffer(self, format: AudioFormat) -> io.BytesIO:
|
16 |
audio_data, sample_rate = self.enqueue()
|
|
|
1 |
import base64
|
2 |
import io
|
3 |
+
from typing import Generator
|
4 |
|
5 |
import numpy as np
|
6 |
import soundfile as sf
|
|
|
11 |
|
12 |
class AudioHandler:
|
13 |
def enqueue(self) -> tuple[np.ndarray, int]:
|
14 |
+
raise NotImplementedError("Method 'enqueue' must be implemented by subclass")
|
15 |
+
|
16 |
+
def enqueue_stream(self) -> Generator[tuple[np.ndarray, int], None, None]:
|
17 |
+
raise NotImplementedError(
|
18 |
+
"Method 'enqueue_stream' must be implemented by subclass"
|
19 |
+
)
|
20 |
+
|
21 |
+
def enqueue_to_stream(self, format: AudioFormat) -> Generator[bytes, None, None]:
|
22 |
+
for audio_data, sample_rate in self.enqueue_stream():
|
23 |
+
buffer = io.BytesIO()
|
24 |
+
sf.write(buffer, audio_data, sample_rate, format="wav")
|
25 |
+
buffer.seek(0)
|
26 |
+
|
27 |
+
if format == AudioFormat.mp3:
|
28 |
+
buffer = api_utils.wav_to_mp3(buffer)
|
29 |
+
|
30 |
+
binary = buffer.read()
|
31 |
+
yield binary
|
32 |
|
33 |
def enqueue_to_buffer(self, format: AudioFormat) -> io.BytesIO:
|
34 |
audio_data, sample_rate = self.enqueue()
|
modules/api/impl/handler/SSMLHandler.py
CHANGED
@@ -91,4 +91,9 @@ class SSMLHandler(AudioHandler):
|
|
91 |
sr=sample_rate,
|
92 |
)
|
93 |
|
|
|
|
|
|
|
|
|
|
|
94 |
return audio_data, sample_rate
|
|
|
91 |
sr=sample_rate,
|
92 |
)
|
93 |
|
94 |
+
if adjust_config.normalize:
|
95 |
+
sample_rate, audio_data = audio.apply_normalize(
|
96 |
+
audio_data=audio_data, headroom=adjust_config.headroom, sr=sample_rate
|
97 |
+
)
|
98 |
+
|
99 |
return audio_data, sample_rate
|
modules/api/impl/handler/TTSHandler.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
|
3 |
from modules.api.impl.handler.AudioHandler import AudioHandler
|
@@ -8,7 +11,10 @@ from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
|
|
8 |
from modules.normalization import text_normalize
|
9 |
from modules.speaker import Speaker
|
10 |
from modules.synthesize_audio import synthesize_audio
|
11 |
-
from modules.
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
class TTSHandler(AudioHandler):
|
@@ -94,4 +100,57 @@ class TTSHandler(AudioHandler):
|
|
94 |
sr=sample_rate,
|
95 |
)
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
return audio_data, sample_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Generator
|
3 |
+
|
4 |
import numpy as np
|
5 |
|
6 |
from modules.api.impl.handler.AudioHandler import AudioHandler
|
|
|
11 |
from modules.normalization import text_normalize
|
12 |
from modules.speaker import Speaker
|
13 |
from modules.synthesize_audio import synthesize_audio
|
14 |
+
from modules.synthesize_stream import synthesize_stream
|
15 |
+
from modules.utils.audio import apply_normalize, apply_prosody_to_audio_data
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
|
19 |
|
20 |
class TTSHandler(AudioHandler):
|
|
|
100 |
sr=sample_rate,
|
101 |
)
|
102 |
|
103 |
+
if adjust_config.normalize:
|
104 |
+
sample_rate, audio_data = apply_normalize(
|
105 |
+
audio_data=audio_data,
|
106 |
+
headroom=adjust_config.headroom,
|
107 |
+
sr=sample_rate,
|
108 |
+
)
|
109 |
+
|
110 |
return audio_data, sample_rate
|
111 |
+
|
112 |
+
def enqueue_stream(self) -> Generator[tuple[np.ndarray, int], None, None]:
|
113 |
+
text = text_normalize(self.text_content)
|
114 |
+
tts_config = self.tts_config
|
115 |
+
infer_config = self.infer_config
|
116 |
+
adjust_config = self.adjest_config
|
117 |
+
enhancer_config = self.enhancer_config
|
118 |
+
|
119 |
+
if enhancer_config.enabled:
|
120 |
+
logger.warning(
|
121 |
+
"enhancer_config is enabled, but it is not supported in stream mode"
|
122 |
+
)
|
123 |
+
|
124 |
+
gen = synthesize_stream(
|
125 |
+
text,
|
126 |
+
spk=self.spk,
|
127 |
+
temperature=tts_config.temperature,
|
128 |
+
top_P=tts_config.top_p,
|
129 |
+
top_K=tts_config.top_k,
|
130 |
+
prompt1=tts_config.prompt1,
|
131 |
+
prompt2=tts_config.prompt2,
|
132 |
+
prefix=tts_config.prefix,
|
133 |
+
infer_seed=infer_config.seed,
|
134 |
+
spliter_threshold=infer_config.spliter_threshold,
|
135 |
+
end_of_sentence=infer_config.eos,
|
136 |
+
)
|
137 |
+
|
138 |
+
# FIXME: 很奇怪,合并出来的音频每个 chunk 之前会有一段异常,暂时没有查出来是哪里的问题,可能是解码时候切割漏了?或者多了?
|
139 |
+
for sr, wav in gen:
|
140 |
+
|
141 |
+
wav = apply_prosody_to_audio_data(
|
142 |
+
audio_data=wav,
|
143 |
+
rate=adjust_config.speed_rate,
|
144 |
+
pitch=adjust_config.pitch,
|
145 |
+
volume=adjust_config.volume_gain_db,
|
146 |
+
sr=sr,
|
147 |
+
)
|
148 |
+
|
149 |
+
if adjust_config.normalize:
|
150 |
+
sr, wav = apply_normalize(
|
151 |
+
audio_data=wav,
|
152 |
+
headroom=adjust_config.headroom,
|
153 |
+
sr=sr,
|
154 |
+
)
|
155 |
+
|
156 |
+
yield wav, sr
|
modules/api/impl/model/audio_model.py
CHANGED
@@ -12,3 +12,7 @@ class AdjustConfig(BaseModel):
|
|
12 |
pitch: float = 0
|
13 |
speed_rate: float = 1
|
14 |
volume_gain_db: float = 0
|
|
|
|
|
|
|
|
|
|
12 |
pitch: float = 0
|
13 |
speed_rate: float = 1
|
14 |
volume_gain_db: float = 0
|
15 |
+
|
16 |
+
# 响度均衡
|
17 |
+
normalize: bool = True
|
18 |
+
headroom: float = 1
|
modules/api/impl/tts_api.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from fastapi import Depends, HTTPException, Query
|
2 |
from fastapi.responses import FileResponse, StreamingResponse
|
3 |
from pydantic import BaseModel
|
@@ -10,6 +12,8 @@ from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
|
10 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
11 |
from modules.speaker import Speaker
|
12 |
|
|
|
|
|
13 |
|
14 |
class TTSParams(BaseModel):
|
15 |
text: str = Query(..., description="Text to synthesize")
|
@@ -44,6 +48,8 @@ class TTSParams(BaseModel):
|
|
44 |
pitch: float = Query(0, description="Pitch of the audio")
|
45 |
volume_gain: float = Query(0, description="Volume gain of the audio")
|
46 |
|
|
|
|
|
47 |
|
48 |
async def synthesize_tts(params: TTSParams = Depends()):
|
49 |
try:
|
@@ -132,14 +138,22 @@ async def synthesize_tts(params: TTSParams = Depends()):
|
|
132 |
adjust_config=adjust_config,
|
133 |
enhancer_config=enhancer_config,
|
134 |
)
|
135 |
-
|
136 |
-
buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))
|
137 |
-
|
138 |
media_type = f"audio/{params.format}"
|
139 |
if params.format == "mp3":
|
140 |
media_type = "audio/mpeg"
|
141 |
-
return StreamingResponse(buffer, media_type=media_type)
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
except Exception as e:
|
144 |
import logging
|
145 |
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
from fastapi import Depends, HTTPException, Query
|
4 |
from fastapi.responses import FileResponse, StreamingResponse
|
5 |
from pydantic import BaseModel
|
|
|
12 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
13 |
from modules.speaker import Speaker
|
14 |
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
|
18 |
class TTSParams(BaseModel):
|
19 |
text: str = Query(..., description="Text to synthesize")
|
|
|
48 |
pitch: float = Query(0, description="Pitch of the audio")
|
49 |
volume_gain: float = Query(0, description="Volume gain of the audio")
|
50 |
|
51 |
+
stream: bool = Query(False, description="Stream the audio")
|
52 |
+
|
53 |
|
54 |
async def synthesize_tts(params: TTSParams = Depends()):
|
55 |
try:
|
|
|
138 |
adjust_config=adjust_config,
|
139 |
enhancer_config=enhancer_config,
|
140 |
)
|
|
|
|
|
|
|
141 |
media_type = f"audio/{params.format}"
|
142 |
if params.format == "mp3":
|
143 |
media_type = "audio/mpeg"
|
|
|
144 |
|
145 |
+
if params.stream:
|
146 |
+
if infer_config.batch_size != 1:
|
147 |
+
# 流式生成下仅支持 batch size 为 1,当前请求参数将被忽略
|
148 |
+
logger.warning(
|
149 |
+
f"Batch size {infer_config.batch_size} is not supported in streaming mode, will set to 1"
|
150 |
+
)
|
151 |
+
|
152 |
+
buffer_gen = handler.enqueue_to_stream(format=AudioFormat(params.format))
|
153 |
+
return StreamingResponse(buffer_gen, media_type=media_type)
|
154 |
+
else:
|
155 |
+
buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))
|
156 |
+
return StreamingResponse(buffer, media_type=media_type)
|
157 |
except Exception as e:
|
158 |
import logging
|
159 |
|
modules/api/impl/xtts_v2_api.py
CHANGED
@@ -1,18 +1,15 @@
|
|
1 |
-
import io
|
2 |
import logging
|
3 |
|
4 |
-
import
|
5 |
-
from fastapi import HTTPException
|
6 |
from fastapi.responses import StreamingResponse
|
7 |
from pydantic import BaseModel
|
8 |
|
9 |
-
from modules import config
|
10 |
-
from modules.api import utils as api_utils
|
11 |
from modules.api.Api import APIManager
|
12 |
-
from modules.
|
|
|
|
|
|
|
13 |
from modules.speaker import speaker_mgr
|
14 |
-
from modules.synthesize_audio import synthesize_audio
|
15 |
-
from modules.utils.audio import apply_prosody_to_audio_data
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
@@ -22,8 +19,11 @@ class XTTS_V2_Settings:
|
|
22 |
self.stream_chunk_size = 100
|
23 |
self.temperature = 0.3
|
24 |
self.speed = 1
|
|
|
|
|
25 |
self.length_penalty = 0.5
|
26 |
self.repetition_penalty = 1.0
|
|
|
27 |
self.top_p = 0.7
|
28 |
self.top_k = 20
|
29 |
self.enable_text_splitting = True
|
@@ -37,6 +37,7 @@ class XTTS_V2_Settings:
|
|
37 |
self.prompt2 = ""
|
38 |
self.prefix = ""
|
39 |
self.spliter_threshold = 100
|
|
|
40 |
|
41 |
|
42 |
class TTSSettingsRequest(BaseModel):
|
@@ -58,6 +59,7 @@ class TTSSettingsRequest(BaseModel):
|
|
58 |
prompt2: str = None
|
59 |
prefix: str = None
|
60 |
spliter_threshold: int = None
|
|
|
61 |
|
62 |
|
63 |
class SynthesisRequest(BaseModel):
|
@@ -95,45 +97,101 @@ def setup(app: APIManager):
|
|
95 |
if spk is None:
|
96 |
raise HTTPException(status_code=400, detail="Invalid speaker id")
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
# TODO: 这两个参数现在用不着...但是其实gpt是可以用的
|
101 |
-
# length_penalty=XTTSV2.length_penalty,
|
102 |
-
# repetition_penalty=XTTSV2.repetition_penalty,
|
103 |
-
text=text,
|
104 |
temperature=XTTSV2.temperature,
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
spliter_threshold=XTTSV2.spliter_threshold,
|
109 |
-
batch_size=XTTSV2.batch_size,
|
110 |
-
end_of_sentence=XTTSV2.eos,
|
111 |
-
infer_seed=XTTSV2.infer_seed,
|
112 |
-
use_decoder=XTTSV2.use_decoder,
|
113 |
prompt1=XTTSV2.prompt1,
|
114 |
prompt2=XTTSV2.prompt2,
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
)
|
117 |
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
buffer = io.BytesIO()
|
127 |
-
sf.write(buffer, audio_data, sample_rate, format="wav")
|
128 |
-
buffer.seek(0)
|
129 |
|
130 |
-
buffer =
|
131 |
|
132 |
return StreamingResponse(buffer, media_type="audio/mpeg")
|
133 |
|
134 |
@app.get("/v1/xtts_v2/tts_stream")
|
135 |
-
async def tts_stream(
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
@app.post("/v1/xtts_v2/set_tts_settings")
|
139 |
async def set_tts_settings(request: TTSSettingsRequest):
|
@@ -195,6 +253,8 @@ def setup(app: APIManager):
|
|
195 |
XTTSV2.prefix = request.prefix
|
196 |
if request.spliter_threshold:
|
197 |
XTTSV2.spliter_threshold = request.spliter_threshold
|
|
|
|
|
198 |
|
199 |
return {"message": "Settings successfully applied"}
|
200 |
except Exception as e:
|
|
|
|
|
1 |
import logging
|
2 |
|
3 |
+
from fastapi import HTTPException, Query, Request
|
|
|
4 |
from fastapi.responses import StreamingResponse
|
5 |
from pydantic import BaseModel
|
6 |
|
|
|
|
|
7 |
from modules.api.Api import APIManager
|
8 |
+
from modules.api.impl.handler.TTSHandler import TTSHandler
|
9 |
+
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
10 |
+
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
11 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
12 |
from modules.speaker import speaker_mgr
|
|
|
|
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
|
|
19 |
self.stream_chunk_size = 100
|
20 |
self.temperature = 0.3
|
21 |
self.speed = 1
|
22 |
+
|
23 |
+
# TODO: 这两个参数现在用不着...但是其实gpt是可以用的可以考虑增加
|
24 |
self.length_penalty = 0.5
|
25 |
self.repetition_penalty = 1.0
|
26 |
+
|
27 |
self.top_p = 0.7
|
28 |
self.top_k = 20
|
29 |
self.enable_text_splitting = True
|
|
|
37 |
self.prompt2 = ""
|
38 |
self.prefix = ""
|
39 |
self.spliter_threshold = 100
|
40 |
+
self.style = ""
|
41 |
|
42 |
|
43 |
class TTSSettingsRequest(BaseModel):
|
|
|
59 |
prompt2: str = None
|
60 |
prefix: str = None
|
61 |
spliter_threshold: int = None
|
62 |
+
style: str = None
|
63 |
|
64 |
|
65 |
class SynthesisRequest(BaseModel):
|
|
|
97 |
if spk is None:
|
98 |
raise HTTPException(status_code=400, detail="Invalid speaker id")
|
99 |
|
100 |
+
tts_config = ChatTTSConfig(
|
101 |
+
style=XTTSV2.style,
|
|
|
|
|
|
|
|
|
102 |
temperature=XTTSV2.temperature,
|
103 |
+
top_k=XTTSV2.top_k,
|
104 |
+
top_p=XTTSV2.top_p,
|
105 |
+
prefix=XTTSV2.prefix,
|
|
|
|
|
|
|
|
|
|
|
106 |
prompt1=XTTSV2.prompt1,
|
107 |
prompt2=XTTSV2.prompt2,
|
108 |
+
)
|
109 |
+
infer_config = InferConfig(
|
110 |
+
batch_size=XTTSV2.batch_size,
|
111 |
+
spliter_threshold=XTTSV2.spliter_threshold,
|
112 |
+
eos=XTTSV2.eos,
|
113 |
+
seed=XTTSV2.infer_seed,
|
114 |
+
)
|
115 |
+
adjust_config = AdjustConfig(
|
116 |
+
speed_rate=XTTSV2.speed,
|
117 |
+
)
|
118 |
+
# TODO: support enhancer
|
119 |
+
enhancer_config = EnhancerConfig(
|
120 |
+
# enabled=params.enhance or params.denoise or False,
|
121 |
+
# lambd=0.9 if params.denoise else 0.1,
|
122 |
)
|
123 |
|
124 |
+
handler = TTSHandler(
|
125 |
+
text_content=text,
|
126 |
+
spk=spk,
|
127 |
+
tts_config=tts_config,
|
128 |
+
infer_config=infer_config,
|
129 |
+
adjust_config=adjust_config,
|
130 |
+
enhancer_config=enhancer_config,
|
131 |
+
)
|
|
|
|
|
|
|
132 |
|
133 |
+
buffer = handler.enqueue_to_buffer(AudioFormat.mp3)
|
134 |
|
135 |
return StreamingResponse(buffer, media_type="audio/mpeg")
|
136 |
|
137 |
@app.get("/v1/xtts_v2/tts_stream")
|
138 |
+
async def tts_stream(
|
139 |
+
request: Request,
|
140 |
+
text: str = Query(),
|
141 |
+
speaker_wav: str = Query(),
|
142 |
+
language: str = Query(),
|
143 |
+
):
|
144 |
+
# speaker_wav 就是 speaker id 。。。
|
145 |
+
voice_id = speaker_wav
|
146 |
+
|
147 |
+
spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker(
|
148 |
+
voice_id
|
149 |
+
)
|
150 |
+
if spk is None:
|
151 |
+
raise HTTPException(status_code=400, detail="Invalid speaker id")
|
152 |
+
|
153 |
+
tts_config = ChatTTSConfig(
|
154 |
+
style=XTTSV2.style,
|
155 |
+
temperature=XTTSV2.temperature,
|
156 |
+
top_k=XTTSV2.top_k,
|
157 |
+
top_p=XTTSV2.top_p,
|
158 |
+
prefix=XTTSV2.prefix,
|
159 |
+
prompt1=XTTSV2.prompt1,
|
160 |
+
prompt2=XTTSV2.prompt2,
|
161 |
+
)
|
162 |
+
infer_config = InferConfig(
|
163 |
+
batch_size=XTTSV2.batch_size,
|
164 |
+
spliter_threshold=XTTSV2.spliter_threshold,
|
165 |
+
eos=XTTSV2.eos,
|
166 |
+
seed=XTTSV2.infer_seed,
|
167 |
+
)
|
168 |
+
adjust_config = AdjustConfig(
|
169 |
+
speed_rate=XTTSV2.speed,
|
170 |
+
)
|
171 |
+
# TODO: support enhancer
|
172 |
+
enhancer_config = EnhancerConfig(
|
173 |
+
# enabled=params.enhance or params.denoise or False,
|
174 |
+
# lambd=0.9 if params.denoise else 0.1,
|
175 |
+
)
|
176 |
+
|
177 |
+
handler = TTSHandler(
|
178 |
+
text_content=text,
|
179 |
+
spk=spk,
|
180 |
+
tts_config=tts_config,
|
181 |
+
infer_config=infer_config,
|
182 |
+
adjust_config=adjust_config,
|
183 |
+
enhancer_config=enhancer_config,
|
184 |
+
)
|
185 |
+
|
186 |
+
async def generator():
|
187 |
+
for chunk in handler.enqueue_to_stream(AudioFormat.mp3):
|
188 |
+
disconnected = await request.is_disconnected()
|
189 |
+
if disconnected:
|
190 |
+
break
|
191 |
+
|
192 |
+
yield chunk
|
193 |
+
|
194 |
+
return StreamingResponse(generator(), media_type="audio/mpeg")
|
195 |
|
196 |
@app.post("/v1/xtts_v2/set_tts_settings")
|
197 |
async def set_tts_settings(request: TTSSettingsRequest):
|
|
|
253 |
XTTSV2.prefix = request.prefix
|
254 |
if request.spliter_threshold:
|
255 |
XTTSV2.spliter_threshold = request.spliter_threshold
|
256 |
+
if request.style:
|
257 |
+
XTTSV2.style = request.style
|
258 |
|
259 |
return {"message": "Settings successfully applied"}
|
260 |
except Exception as e:
|
modules/api/worker.py
CHANGED
@@ -5,7 +5,9 @@ import os
|
|
5 |
import dotenv
|
6 |
from fastapi import FastAPI
|
7 |
|
|
|
8 |
from modules.ffmpeg_env import setup_ffmpeg_path
|
|
|
9 |
|
10 |
setup_ffmpeg_path()
|
11 |
logging.basicConfig(
|
@@ -14,13 +16,7 @@ logging.basicConfig(
|
|
14 |
)
|
15 |
|
16 |
from modules import config
|
17 |
-
from modules.api.api_setup import
|
18 |
-
process_api_args,
|
19 |
-
process_model_args,
|
20 |
-
setup_api_args,
|
21 |
-
setup_model_args,
|
22 |
-
setup_uvicon_args,
|
23 |
-
)
|
24 |
from modules.api.app_config import app_description, app_title, app_version
|
25 |
from modules.utils.torch_opt import configure_torch_optimizations
|
26 |
|
|
|
5 |
import dotenv
|
6 |
from fastapi import FastAPI
|
7 |
|
8 |
+
from launch import setup_uvicon_args
|
9 |
from modules.ffmpeg_env import setup_ffmpeg_path
|
10 |
+
from modules.models_setup import process_model_args, setup_model_args
|
11 |
|
12 |
setup_ffmpeg_path()
|
13 |
logging.basicConfig(
|
|
|
16 |
)
|
17 |
|
18 |
from modules import config
|
19 |
+
from modules.api.api_setup import process_api_args, setup_api_args
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
from modules.api.app_config import app_description, app_title, app_version
|
21 |
from modules.utils.torch_opt import configure_torch_optimizations
|
22 |
|
modules/devices/devices.py
CHANGED
@@ -92,7 +92,10 @@ def get_optimal_device():
|
|
92 |
|
93 |
|
94 |
def get_device_for(task):
|
95 |
-
if
|
|
|
|
|
|
|
96 |
return cpu
|
97 |
|
98 |
return get_optimal_device()
|
@@ -128,6 +131,9 @@ def reset_device():
|
|
128 |
global dtype_gpt
|
129 |
global dtype_decoder
|
130 |
|
|
|
|
|
|
|
131 |
if "all" in config.runtime_env_vars.use_cpu and not config.runtime_env_vars.no_half:
|
132 |
logger.warning(
|
133 |
"Cannot use half precision with CPU, using full precision instead"
|
|
|
92 |
|
93 |
|
94 |
def get_device_for(task):
|
95 |
+
if (
|
96 |
+
task in config.runtime_env_vars.use_cpu
|
97 |
+
or "all" in config.runtime_env_vars.use_cpu
|
98 |
+
):
|
99 |
return cpu
|
100 |
|
101 |
return get_optimal_device()
|
|
|
131 |
global dtype_gpt
|
132 |
global dtype_decoder
|
133 |
|
134 |
+
if config.runtime_env_vars.use_cpu is None:
|
135 |
+
config.runtime_env_vars.use_cpu = []
|
136 |
+
|
137 |
if "all" in config.runtime_env_vars.use_cpu and not config.runtime_env_vars.no_half:
|
138 |
logger.warning(
|
139 |
"Cannot use half precision with CPU, using full precision instead"
|
modules/finetune/train_speaker.py
CHANGED
@@ -255,7 +255,7 @@ if __name__ == "__main__":
|
|
255 |
vocos_model=chat.pretrain_models["vocos"],
|
256 |
tar_path=tar_path,
|
257 |
tar_in_memory=tar_in_memory,
|
258 |
-
device=devices.
|
259 |
# speakers=None, # set(['speaker_A', 'speaker_B'])
|
260 |
)
|
261 |
|
@@ -267,7 +267,7 @@ if __name__ == "__main__":
|
|
267 |
speaker_embeds = {
|
268 |
speaker: torch.tensor(
|
269 |
spk.emb,
|
270 |
-
device=devices.
|
271 |
requires_grad=True,
|
272 |
)
|
273 |
for speaker in dataset.speakers
|
|
|
255 |
vocos_model=chat.pretrain_models["vocos"],
|
256 |
tar_path=tar_path,
|
257 |
tar_in_memory=tar_in_memory,
|
258 |
+
device=devices.get_device_for("trainer"),
|
259 |
# speakers=None, # set(['speaker_A', 'speaker_B'])
|
260 |
)
|
261 |
|
|
|
267 |
speaker_embeds = {
|
268 |
speaker: torch.tensor(
|
269 |
spk.emb,
|
270 |
+
device=devices.get_device_for("trainer"),
|
271 |
requires_grad=True,
|
272 |
)
|
273 |
for speaker in dataset.speakers
|
modules/generate_audio.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import gc
|
2 |
import logging
|
3 |
-
from typing import Union
|
4 |
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
|
8 |
from modules import config, models
|
|
|
9 |
from modules.devices import devices
|
10 |
from modules.speaker import Speaker
|
11 |
from modules.utils.cache import conditional_cache
|
@@ -13,6 +14,8 @@ from modules.utils.SeedContext import SeedContext
|
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
|
|
|
|
16 |
|
17 |
def generate_audio(
|
18 |
text: str,
|
@@ -42,20 +45,18 @@ def generate_audio(
|
|
42 |
return (sample_rate, wav)
|
43 |
|
44 |
|
45 |
-
|
46 |
-
def generate_audio_batch(
|
47 |
texts: list[str],
|
|
|
48 |
temperature: float = 0.3,
|
49 |
top_P: float = 0.7,
|
50 |
top_K: float = 20,
|
51 |
spk: Union[int, Speaker] = -1,
|
52 |
infer_seed: int = -1,
|
53 |
-
use_decoder: bool = True,
|
54 |
prompt1: str = "",
|
55 |
prompt2: str = "",
|
56 |
prefix: str = "",
|
57 |
):
|
58 |
-
chat_tts = models.load_chat_tts()
|
59 |
params_infer_code = {
|
60 |
"spk_emb": None,
|
61 |
"temperature": temperature,
|
@@ -97,18 +98,93 @@ def generate_audio_batch(
|
|
97 |
}
|
98 |
)
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
with SeedContext(infer_seed, True):
|
101 |
wavs = chat_tts.generate_audio(
|
102 |
-
texts, params_infer_code, use_decoder=use_decoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
)
|
104 |
|
105 |
-
|
|
|
106 |
|
107 |
if config.auto_gc:
|
108 |
devices.torch_gc()
|
109 |
gc.collect()
|
110 |
|
111 |
-
return
|
112 |
|
113 |
|
114 |
lru_cache_enabled = False
|
|
|
1 |
import gc
|
2 |
import logging
|
3 |
+
from typing import Generator, Union
|
4 |
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
|
8 |
from modules import config, models
|
9 |
+
from modules.ChatTTS import ChatTTS
|
10 |
from modules.devices import devices
|
11 |
from modules.speaker import Speaker
|
12 |
from modules.utils.cache import conditional_cache
|
|
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
17 |
+
SAMPLE_RATE = 24000
|
18 |
+
|
19 |
|
20 |
def generate_audio(
|
21 |
text: str,
|
|
|
45 |
return (sample_rate, wav)
|
46 |
|
47 |
|
48 |
+
def parse_infer_params(
|
|
|
49 |
texts: list[str],
|
50 |
+
chat_tts: ChatTTS.Chat,
|
51 |
temperature: float = 0.3,
|
52 |
top_P: float = 0.7,
|
53 |
top_K: float = 20,
|
54 |
spk: Union[int, Speaker] = -1,
|
55 |
infer_seed: int = -1,
|
|
|
56 |
prompt1: str = "",
|
57 |
prompt2: str = "",
|
58 |
prefix: str = "",
|
59 |
):
|
|
|
60 |
params_infer_code = {
|
61 |
"spk_emb": None,
|
62 |
"temperature": temperature,
|
|
|
98 |
}
|
99 |
)
|
100 |
|
101 |
+
return params_infer_code
|
102 |
+
|
103 |
+
|
104 |
+
@torch.inference_mode()
|
105 |
+
def generate_audio_batch(
|
106 |
+
texts: list[str],
|
107 |
+
temperature: float = 0.3,
|
108 |
+
top_P: float = 0.7,
|
109 |
+
top_K: float = 20,
|
110 |
+
spk: Union[int, Speaker] = -1,
|
111 |
+
infer_seed: int = -1,
|
112 |
+
use_decoder: bool = True,
|
113 |
+
prompt1: str = "",
|
114 |
+
prompt2: str = "",
|
115 |
+
prefix: str = "",
|
116 |
+
):
|
117 |
+
chat_tts = models.load_chat_tts()
|
118 |
+
params_infer_code = parse_infer_params(
|
119 |
+
texts=texts,
|
120 |
+
chat_tts=chat_tts,
|
121 |
+
temperature=temperature,
|
122 |
+
top_P=top_P,
|
123 |
+
top_K=top_K,
|
124 |
+
spk=spk,
|
125 |
+
infer_seed=infer_seed,
|
126 |
+
prompt1=prompt1,
|
127 |
+
prompt2=prompt2,
|
128 |
+
prefix=prefix,
|
129 |
+
)
|
130 |
+
|
131 |
with SeedContext(infer_seed, True):
|
132 |
wavs = chat_tts.generate_audio(
|
133 |
+
prompt=texts, params_infer_code=params_infer_code, use_decoder=use_decoder
|
134 |
+
)
|
135 |
+
|
136 |
+
if config.auto_gc:
|
137 |
+
devices.torch_gc()
|
138 |
+
gc.collect()
|
139 |
+
|
140 |
+
return [(SAMPLE_RATE, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
|
141 |
+
|
142 |
+
|
143 |
+
# TODO: generate_audio_stream 也应该支持 lru cache
|
144 |
+
@torch.inference_mode()
|
145 |
+
def generate_audio_stream(
|
146 |
+
text: str,
|
147 |
+
temperature: float = 0.3,
|
148 |
+
top_P: float = 0.7,
|
149 |
+
top_K: float = 20,
|
150 |
+
spk: Union[int, Speaker] = -1,
|
151 |
+
infer_seed: int = -1,
|
152 |
+
use_decoder: bool = True,
|
153 |
+
prompt1: str = "",
|
154 |
+
prompt2: str = "",
|
155 |
+
prefix: str = "",
|
156 |
+
) -> Generator[tuple[int, np.ndarray], None, None]:
|
157 |
+
chat_tts = models.load_chat_tts()
|
158 |
+
texts = [text]
|
159 |
+
params_infer_code = parse_infer_params(
|
160 |
+
texts=texts,
|
161 |
+
chat_tts=chat_tts,
|
162 |
+
temperature=temperature,
|
163 |
+
top_P=top_P,
|
164 |
+
top_K=top_K,
|
165 |
+
spk=spk,
|
166 |
+
infer_seed=infer_seed,
|
167 |
+
prompt1=prompt1,
|
168 |
+
prompt2=prompt2,
|
169 |
+
prefix=prefix,
|
170 |
+
)
|
171 |
+
|
172 |
+
with SeedContext(infer_seed, True):
|
173 |
+
wavs_gen = chat_tts.generate_audio(
|
174 |
+
prompt=texts,
|
175 |
+
params_infer_code=params_infer_code,
|
176 |
+
use_decoder=use_decoder,
|
177 |
+
stream=True,
|
178 |
)
|
179 |
|
180 |
+
for wav in wavs_gen:
|
181 |
+
yield [SAMPLE_RATE, np.array(wav).flatten().astype(np.float32)]
|
182 |
|
183 |
if config.auto_gc:
|
184 |
devices.torch_gc()
|
185 |
gc.collect()
|
186 |
|
187 |
+
return
|
188 |
|
189 |
|
190 |
lru_cache_enabled = False
|
modules/models.py
CHANGED
@@ -21,18 +21,27 @@ def load_chat_tts_in_thread():
|
|
21 |
|
22 |
logger.info("Loading ChatTTS models")
|
23 |
chat_tts = ChatTTS.Chat()
|
|
|
|
|
24 |
chat_tts.load_models(
|
25 |
compile=config.runtime_env_vars.compile,
|
26 |
source="local",
|
27 |
local_path="./models/ChatTTS",
|
28 |
-
device=
|
29 |
-
dtype=
|
30 |
dtype_vocos=devices.dtype_vocos,
|
31 |
dtype_dvae=devices.dtype_dvae,
|
32 |
dtype_gpt=devices.dtype_gpt,
|
33 |
dtype_decoder=devices.dtype_decoder,
|
34 |
)
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
devices.torch_gc()
|
37 |
logger.info("ChatTTS models loaded")
|
38 |
|
|
|
21 |
|
22 |
logger.info("Loading ChatTTS models")
|
23 |
chat_tts = ChatTTS.Chat()
|
24 |
+
device = devices.get_device_for("chattts")
|
25 |
+
dtype = devices.dtype
|
26 |
chat_tts.load_models(
|
27 |
compile=config.runtime_env_vars.compile,
|
28 |
source="local",
|
29 |
local_path="./models/ChatTTS",
|
30 |
+
device=device,
|
31 |
+
dtype=dtype,
|
32 |
dtype_vocos=devices.dtype_vocos,
|
33 |
dtype_dvae=devices.dtype_dvae,
|
34 |
dtype_gpt=devices.dtype_gpt,
|
35 |
dtype_decoder=devices.dtype_decoder,
|
36 |
)
|
37 |
|
38 |
+
# 如果 device 为 cpu 同时,又是 dtype == float16 那么报 warn
|
39 |
+
# 提示可能无法正常运行,建议使用 float32 即开启 `--no_half` 参数
|
40 |
+
if device == devices.cpu and dtype == torch.float16:
|
41 |
+
logger.warning(
|
42 |
+
"The device is CPU and dtype is float16, which may not work properly. It is recommended to use float32 by enabling the `--no_half` parameter."
|
43 |
+
)
|
44 |
+
|
45 |
devices.torch_gc()
|
46 |
logger.info("ChatTTS models loaded")
|
47 |
|
modules/models_setup.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from modules import generate_audio
|
5 |
+
from modules.devices import devices
|
6 |
+
from modules.Enhancer.ResembleEnhance import load_enhancer
|
7 |
+
from modules.models import load_chat_tts
|
8 |
+
from modules.utils import env
|
9 |
+
|
10 |
+
|
11 |
+
def setup_model_args(parser: argparse.ArgumentParser):
|
12 |
+
parser.add_argument("--compile", action="store_true", help="Enable model compile")
|
13 |
+
parser.add_argument(
|
14 |
+
"--no_half",
|
15 |
+
action="store_true",
|
16 |
+
help="Disalbe half precision for model inference",
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--off_tqdm",
|
20 |
+
action="store_true",
|
21 |
+
help="Disable tqdm progress bar",
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--device_id",
|
25 |
+
type=str,
|
26 |
+
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
|
27 |
+
default=None,
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--use_cpu",
|
31 |
+
nargs="+",
|
32 |
+
help="use CPU as torch device for specified modules",
|
33 |
+
default=[],
|
34 |
+
type=str.lower,
|
35 |
+
choices=["all", "chattts", "enhancer", "trainer"],
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--lru_size",
|
39 |
+
type=int,
|
40 |
+
default=64,
|
41 |
+
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--debug_generate",
|
45 |
+
action="store_true",
|
46 |
+
help="Enable debug mode for audio generation",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--preload_models",
|
50 |
+
action="store_true",
|
51 |
+
help="Preload all models at startup",
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def process_model_args(args: argparse.Namespace):
|
56 |
+
lru_size = env.get_and_update_env(args, "lru_size", 64, int)
|
57 |
+
compile = env.get_and_update_env(args, "compile", False, bool)
|
58 |
+
device_id = env.get_and_update_env(args, "device_id", None, str)
|
59 |
+
use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
|
60 |
+
no_half = env.get_and_update_env(args, "no_half", False, bool)
|
61 |
+
off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
|
62 |
+
debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
|
63 |
+
preload_models = env.get_and_update_env(args, "preload_models", False, bool)
|
64 |
+
|
65 |
+
generate_audio.setup_lru_cache()
|
66 |
+
devices.reset_device()
|
67 |
+
devices.first_time_calculation()
|
68 |
+
|
69 |
+
if debug_generate:
|
70 |
+
generate_audio.logger.setLevel(logging.DEBUG)
|
71 |
+
|
72 |
+
if preload_models:
|
73 |
+
load_chat_tts()
|
74 |
+
load_enhancer()
|
modules/normalization.py
CHANGED
@@ -1,39 +1,21 @@
|
|
|
|
1 |
import re
|
2 |
-
from functools import lru_cache
|
3 |
|
4 |
import emojiswitch
|
|
|
5 |
|
6 |
from modules import models
|
|
|
|
|
|
|
7 |
from modules.utils.markdown import markdown_to_text
|
8 |
-
from modules.utils.zh_normalization.text_normlization import
|
9 |
|
10 |
# 是否关闭 unk token 检查
|
11 |
# NOTE: 单测的时候用于跳过模型加载
|
12 |
DISABLE_UNK_TOKEN_CHECK = False
|
13 |
|
14 |
|
15 |
-
@lru_cache(maxsize=64)
|
16 |
-
def is_chinese(text):
|
17 |
-
# 中文字符的 Unicode 范围是 \u4e00-\u9fff
|
18 |
-
chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
|
19 |
-
return bool(chinese_pattern.search(text))
|
20 |
-
|
21 |
-
|
22 |
-
@lru_cache(maxsize=64)
|
23 |
-
def is_eng(text):
|
24 |
-
eng_pattern = re.compile(r"[a-zA-Z]")
|
25 |
-
return bool(eng_pattern.search(text))
|
26 |
-
|
27 |
-
|
28 |
-
@lru_cache(maxsize=64)
|
29 |
-
def guess_lang(text):
|
30 |
-
if is_chinese(text):
|
31 |
-
return "zh"
|
32 |
-
if is_eng(text):
|
33 |
-
return "en"
|
34 |
-
return "zh"
|
35 |
-
|
36 |
-
|
37 |
post_normalize_pipeline = []
|
38 |
pre_normalize_pipeline = []
|
39 |
|
@@ -184,9 +166,32 @@ def replace_unk_tokens(text):
|
|
184 |
return output_text
|
185 |
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
## ---------- pre normalize ----------
|
188 |
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
@pre_normalize()
|
191 |
def apply_markdown_to_text(text):
|
192 |
if is_markdown(text):
|
@@ -194,6 +199,11 @@ def apply_markdown_to_text(text):
|
|
194 |
return text
|
195 |
|
196 |
|
|
|
|
|
|
|
|
|
|
|
197 |
# 将 "xxx" => \nxxx\n
|
198 |
# 将 'xxx' => \nxxx\n
|
199 |
@pre_normalize()
|
@@ -293,6 +303,7 @@ if __name__ == "__main__":
|
|
293 |
" [oral_9] [laugh_0] [break_0] 电 [speed_0] 影 [speed_0] 中 梁朝伟 [speed_9] 扮演的陈永仁的编号27149",
|
294 |
" 明天有62%的概率降雨",
|
295 |
"大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
|
|
|
296 |
"""
|
297 |
# 你好,世界
|
298 |
```js
|
|
|
1 |
+
import html
|
2 |
import re
|
|
|
3 |
|
4 |
import emojiswitch
|
5 |
+
import ftfy
|
6 |
|
7 |
from modules import models
|
8 |
+
from modules.utils.detect_lang import guess_lang
|
9 |
+
from modules.utils.HomophonesReplacer import HomophonesReplacer
|
10 |
+
from modules.utils.html import remove_html_tags as _remove_html_tags
|
11 |
from modules.utils.markdown import markdown_to_text
|
12 |
+
from modules.utils.zh_normalization.text_normlization import TextNormalizer
|
13 |
|
14 |
# 是否关闭 unk token 检查
|
15 |
# NOTE: 单测的时候用于跳过模型加载
|
16 |
DISABLE_UNK_TOKEN_CHECK = False
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
post_normalize_pipeline = []
|
20 |
pre_normalize_pipeline = []
|
21 |
|
|
|
166 |
return output_text
|
167 |
|
168 |
|
169 |
+
homo_replacer = HomophonesReplacer(map_file_path="./data/homophones_map.json")
|
170 |
+
|
171 |
+
|
172 |
+
@post_normalize()
|
173 |
+
def replace_homophones(text):
|
174 |
+
lang = guess_lang(text)
|
175 |
+
if lang == "zh":
|
176 |
+
text = homo_replacer.replace(text)
|
177 |
+
return text
|
178 |
+
|
179 |
+
|
180 |
## ---------- pre normalize ----------
|
181 |
|
182 |
|
183 |
+
@pre_normalize()
|
184 |
+
def html_unescape(text):
|
185 |
+
text = html.unescape(text)
|
186 |
+
text = html.unescape(text)
|
187 |
+
return text
|
188 |
+
|
189 |
+
|
190 |
+
@pre_normalize()
|
191 |
+
def fix_text(text):
|
192 |
+
return ftfy.fix_text(text=text)
|
193 |
+
|
194 |
+
|
195 |
@pre_normalize()
|
196 |
def apply_markdown_to_text(text):
|
197 |
if is_markdown(text):
|
|
|
199 |
return text
|
200 |
|
201 |
|
202 |
+
@pre_normalize()
|
203 |
+
def remove_html_tags(text):
|
204 |
+
return _remove_html_tags(text)
|
205 |
+
|
206 |
+
|
207 |
# 将 "xxx" => \nxxx\n
|
208 |
# 将 'xxx' => \nxxx\n
|
209 |
@pre_normalize()
|
|
|
303 |
" [oral_9] [laugh_0] [break_0] 电 [speed_0] 影 [speed_0] 中 梁朝伟 [speed_9] 扮演的陈永仁的编号27149",
|
304 |
" 明天有62%的概率降雨",
|
305 |
"大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
|
306 |
+
"I like eating 🍏",
|
307 |
"""
|
308 |
# 你好,世界
|
309 |
```js
|
modules/refiner.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
|
@@ -31,4 +33,10 @@ def refine_text(
|
|
31 |
"disable_tqdm": config.runtime_env_vars.off_tqdm,
|
32 |
},
|
33 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
return refined_text
|
|
|
1 |
+
from typing import Generator
|
2 |
+
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
|
|
|
33 |
"disable_tqdm": config.runtime_env_vars.off_tqdm,
|
34 |
},
|
35 |
)
|
36 |
+
if isinstance(refined_text, Generator):
|
37 |
+
raise NotImplementedError(
|
38 |
+
"Refiner is not yet implemented for generator output"
|
39 |
+
)
|
40 |
+
if isinstance(refined_text, list):
|
41 |
+
refined_text = "\n".join(refined_text)
|
42 |
return refined_text
|
modules/repos_static/resemble_enhance/inference.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import logging
|
2 |
import time
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from torch.nn.utils.parametrize import remove_parametrizations
|
7 |
from torchaudio.functional import resample
|
8 |
from torchaudio.transforms import MelSpectrogram
|
9 |
-
from tqdm import trange
|
10 |
|
11 |
from modules import config
|
12 |
from modules.devices import devices
|
@@ -142,10 +142,10 @@ def inference(
|
|
142 |
chunk_seconds: float = 30.0,
|
143 |
overlap_seconds: float = 1.0,
|
144 |
):
|
|
|
|
|
145 |
if config.runtime_env_vars.off_tqdm:
|
146 |
-
trange =
|
147 |
-
else:
|
148 |
-
from tqdm import trange
|
149 |
|
150 |
remove_weight_norm_recursively(model)
|
151 |
|
@@ -188,7 +188,7 @@ def inference(
|
|
188 |
torch.cuda.synchronize()
|
189 |
|
190 |
elapsed_time = time.perf_counter() - start_time
|
191 |
-
logger.
|
192 |
f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz"
|
193 |
)
|
194 |
devices.torch_gc()
|
|
|
1 |
import logging
|
2 |
import time
|
3 |
+
from functools import partial
|
4 |
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
from torch.nn.utils.parametrize import remove_parametrizations
|
8 |
from torchaudio.functional import resample
|
9 |
from torchaudio.transforms import MelSpectrogram
|
|
|
10 |
|
11 |
from modules import config
|
12 |
from modules.devices import devices
|
|
|
142 |
chunk_seconds: float = 30.0,
|
143 |
overlap_seconds: float = 1.0,
|
144 |
):
|
145 |
+
from tqdm import trange
|
146 |
+
|
147 |
if config.runtime_env_vars.off_tqdm:
|
148 |
+
trange = partial(trange, disable=True)
|
|
|
|
|
149 |
|
150 |
remove_weight_norm_recursively(model)
|
151 |
|
|
|
188 |
torch.cuda.synchronize()
|
189 |
|
190 |
elapsed_time = time.perf_counter() - start_time
|
191 |
+
logger.debug(
|
192 |
f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz"
|
193 |
)
|
194 |
devices.torch_gc()
|
modules/speaker.py
CHANGED
@@ -29,6 +29,12 @@ class Speaker:
|
|
29 |
speaker.emb = tensor
|
30 |
return speaker
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
def __init__(
|
33 |
self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe=""
|
34 |
):
|
|
|
29 |
speaker.emb = tensor
|
30 |
return speaker
|
31 |
|
32 |
+
@staticmethod
|
33 |
+
def from_seed(seed: int):
|
34 |
+
speaker = Speaker(seed_or_tensor=seed)
|
35 |
+
speaker.emb = create_speaker_from_seed(seed)
|
36 |
+
return speaker
|
37 |
+
|
38 |
def __init__(
|
39 |
self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe=""
|
40 |
):
|
modules/synthesize_audio.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
-
import io
|
2 |
from typing import Union
|
3 |
|
4 |
-
from modules import generate_audio as generate
|
5 |
from modules.SentenceSplitter import SentenceSplitter
|
6 |
from modules.speaker import Speaker
|
7 |
from modules.ssml_parser.SSMLParser import SSMLSegment
|
|
|
|
|
1 |
from typing import Union
|
2 |
|
|
|
3 |
from modules.SentenceSplitter import SentenceSplitter
|
4 |
from modules.speaker import Speaker
|
5 |
from modules.ssml_parser.SSMLParser import SSMLSegment
|
modules/synthesize_stream.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
from typing import Generator, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from modules import generate_audio as generate
|
7 |
+
from modules.SentenceSplitter import SentenceSplitter
|
8 |
+
from modules.speaker import Speaker
|
9 |
+
|
10 |
+
|
11 |
+
def synthesize_stream(
|
12 |
+
text: str,
|
13 |
+
temperature: float = 0.3,
|
14 |
+
top_P: float = 0.7,
|
15 |
+
top_K: float = 20,
|
16 |
+
spk: Union[int, Speaker] = -1,
|
17 |
+
infer_seed: int = -1,
|
18 |
+
use_decoder: bool = True,
|
19 |
+
prompt1: str = "",
|
20 |
+
prompt2: str = "",
|
21 |
+
prefix: str = "",
|
22 |
+
spliter_threshold: int = 100,
|
23 |
+
end_of_sentence="",
|
24 |
+
) -> Generator[tuple[int, np.ndarray], None, None]:
|
25 |
+
spliter = SentenceSplitter(spliter_threshold)
|
26 |
+
sentences = spliter.parse(text)
|
27 |
+
|
28 |
+
for sentence in sentences:
|
29 |
+
wav_gen = generate.generate_audio_stream(
|
30 |
+
text=sentence + end_of_sentence,
|
31 |
+
temperature=temperature,
|
32 |
+
top_P=top_P,
|
33 |
+
top_K=top_K,
|
34 |
+
spk=spk,
|
35 |
+
infer_seed=infer_seed,
|
36 |
+
use_decoder=use_decoder,
|
37 |
+
prompt1=prompt1,
|
38 |
+
prompt2=prompt2,
|
39 |
+
prefix=prefix,
|
40 |
+
)
|
41 |
+
for sr, wav in wav_gen:
|
42 |
+
yield sr, wav
|
modules/utils/HomophonesReplacer.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
|
4 |
+
# ref: https://github.com/2noise/ChatTTS/commit/ce1c962b6235bd7d0c637fbdcda5e2dccdbac80d
|
5 |
+
class HomophonesReplacer:
|
6 |
+
"""
|
7 |
+
Homophones Replacer
|
8 |
+
|
9 |
+
Replace the mispronounced characters with correctly pronounced ones.
|
10 |
+
|
11 |
+
Creation process of homophones_map.json:
|
12 |
+
|
13 |
+
1. Establish a word corpus using the [Tencent AI Lab Embedding Corpora v0.2.0 large] with 12 million entries. After cleaning, approximately 1.8 million entries remain. Use ChatTTS to infer the text.
|
14 |
+
2. Record discrepancies between the inferred and input text, identifying about 180,000 misread words.
|
15 |
+
3. Create a pinyin to common characters mapping using correctly read characters by ChatTTS.
|
16 |
+
4. For each discrepancy, extract the correct pinyin using [python-pinyin] and find homophones with the correct pronunciation from the mapping.
|
17 |
+
|
18 |
+
Thanks to:
|
19 |
+
[Tencent AI Lab Embedding Corpora for Chinese and English Words and Phrases](https://ai.tencent.com/ailab/nlp/en/embedding.html)
|
20 |
+
[python-pinyin](https://github.com/mozillazg/python-pinyin)
|
21 |
+
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, map_file_path):
|
25 |
+
self.homophones_map = self.load_homophones_map(map_file_path)
|
26 |
+
|
27 |
+
def load_homophones_map(self, map_file_path):
|
28 |
+
with open(map_file_path, "r", encoding="utf-8") as f:
|
29 |
+
homophones_map = json.load(f)
|
30 |
+
return homophones_map
|
31 |
+
|
32 |
+
def replace(self, text):
|
33 |
+
result = []
|
34 |
+
for char in text:
|
35 |
+
if char in self.homophones_map:
|
36 |
+
result.append(self.homophones_map[char])
|
37 |
+
else:
|
38 |
+
result.append(char)
|
39 |
+
return "".join(result)
|
modules/utils/audio.py
CHANGED
@@ -2,14 +2,14 @@ import sys
|
|
2 |
from io import BytesIO
|
3 |
|
4 |
import numpy as np
|
5 |
-
import pyrubberband as pyrb
|
6 |
import soundfile as sf
|
7 |
-
from pydub import AudioSegment
|
|
|
8 |
|
9 |
INT16_MAX = np.iinfo(np.int16).max
|
10 |
|
11 |
|
12 |
-
def audio_to_int16(audio_data):
|
13 |
if (
|
14 |
audio_data.dtype == np.float32
|
15 |
or audio_data.dtype == np.float64
|
@@ -20,6 +20,23 @@ def audio_to_int16(audio_data):
|
|
20 |
return audio_data
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def audiosegment_to_librosawav(audiosegment: AudioSegment) -> np.ndarray:
|
24 |
"""
|
25 |
Converts pydub audio segment into np.float32 of shape [duration_in_seconds*sample_rate, channels],
|
@@ -35,64 +52,42 @@ def audiosegment_to_librosawav(audiosegment: AudioSegment) -> np.ndarray:
|
|
35 |
return fp_arr
|
36 |
|
37 |
|
38 |
-
def
|
39 |
-
|
40 |
-
|
41 |
-
where each value is in range [-1.0, 1.0].
|
42 |
-
Returns tuple (audio_np_array, sample_rate).
|
43 |
-
"""
|
44 |
-
return (
|
45 |
-
audio.frame_rate,
|
46 |
-
np.array(audio.get_array_of_samples(), dtype=np.float32).reshape(
|
47 |
-
(-1, audio.channels)
|
48 |
-
)
|
49 |
-
/ (1 << (8 * audio.sample_width - 1)),
|
50 |
-
)
|
51 |
-
|
52 |
-
|
53 |
-
def ndarray_to_segment(ndarray, frame_rate):
|
54 |
buffer = BytesIO()
|
55 |
-
sf.write(buffer, ndarray, frame_rate, format="wav")
|
56 |
buffer.seek(0)
|
57 |
-
sound = AudioSegment.from_wav(
|
58 |
-
buffer,
|
59 |
-
)
|
60 |
-
return sound
|
61 |
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
time_factor = np.clip(time_factor, 0.2, 10)
|
68 |
-
sr = input_segment.frame_rate
|
69 |
-
y = audiosegment_to_librosawav(input_segment)
|
70 |
-
y_stretch = pyrb.time_stretch(y, sr, time_factor)
|
71 |
-
|
72 |
-
sound = ndarray_to_segment(
|
73 |
-
y_stretch,
|
74 |
-
frame_rate=sr,
|
75 |
)
|
76 |
-
return sound
|
77 |
|
78 |
|
79 |
-
def
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
82 |
) -> AudioSegment:
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
y_shift = pyrb.pitch_shift(y, sr, pitch_shift_factor)
|
90 |
-
|
91 |
-
sound = ndarray_to_segment(
|
92 |
-
y_shift,
|
93 |
-
frame_rate=sr,
|
94 |
)
|
95 |
-
|
|
|
96 |
|
97 |
|
98 |
def apply_prosody_to_audio_data(
|
@@ -114,6 +109,17 @@ def apply_prosody_to_audio_data(
|
|
114 |
return audio_data
|
115 |
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
if __name__ == "__main__":
|
118 |
input_file = sys.argv[1]
|
119 |
|
@@ -123,11 +129,11 @@ if __name__ == "__main__":
|
|
123 |
input_sound = AudioSegment.from_mp3(input_file)
|
124 |
|
125 |
for time_factor in time_stretch_factors:
|
126 |
-
output_wav = f"
|
127 |
-
|
128 |
-
|
129 |
|
130 |
for pitch_factor in pitch_shift_factors:
|
131 |
-
output_wav = f"
|
132 |
-
|
133 |
-
|
|
|
2 |
from io import BytesIO
|
3 |
|
4 |
import numpy as np
|
|
|
5 |
import soundfile as sf
|
6 |
+
from pydub import AudioSegment, effects
|
7 |
+
import pyrubberband as pyrb
|
8 |
|
9 |
INT16_MAX = np.iinfo(np.int16).max
|
10 |
|
11 |
|
12 |
+
def audio_to_int16(audio_data: np.ndarray) -> np.ndarray:
|
13 |
if (
|
14 |
audio_data.dtype == np.float32
|
15 |
or audio_data.dtype == np.float64
|
|
|
20 |
return audio_data
|
21 |
|
22 |
|
23 |
+
def pydub_to_np(audio: AudioSegment) -> tuple[int, np.ndarray]:
|
24 |
+
"""
|
25 |
+
Converts pydub audio segment into np.float32 of shape [duration_in_seconds*sample_rate, channels],
|
26 |
+
where each value is in range [-1.0, 1.0].
|
27 |
+
Returns tuple (audio_np_array, sample_rate).
|
28 |
+
"""
|
29 |
+
nd_array = np.array(audio.get_array_of_samples(), dtype=np.float32)
|
30 |
+
if audio.channels != 1:
|
31 |
+
nd_array = nd_array.reshape((-1, audio.channels))
|
32 |
+
nd_array = nd_array / (1 << (8 * audio.sample_width - 1))
|
33 |
+
|
34 |
+
return (
|
35 |
+
audio.frame_rate,
|
36 |
+
nd_array,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
def audiosegment_to_librosawav(audiosegment: AudioSegment) -> np.ndarray:
|
41 |
"""
|
42 |
Converts pydub audio segment into np.float32 of shape [duration_in_seconds*sample_rate, channels],
|
|
|
52 |
return fp_arr
|
53 |
|
54 |
|
55 |
+
def ndarray_to_segment(
|
56 |
+
ndarray: np.ndarray, frame_rate: int, sample_width: int = None, channels: int = None
|
57 |
+
) -> AudioSegment:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
buffer = BytesIO()
|
59 |
+
sf.write(buffer, ndarray, frame_rate, format="wav", subtype="PCM_16")
|
60 |
buffer.seek(0)
|
61 |
+
sound: AudioSegment = AudioSegment.from_wav(buffer)
|
|
|
|
|
|
|
62 |
|
63 |
+
if sample_width is None:
|
64 |
+
sample_width = sound.sample_width
|
65 |
+
if channels is None:
|
66 |
+
channels = sound.channels
|
67 |
|
68 |
+
return (
|
69 |
+
sound.set_frame_rate(frame_rate)
|
70 |
+
.set_sample_width(sample_width)
|
71 |
+
.set_channels(channels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
)
|
|
|
73 |
|
74 |
|
75 |
+
def apply_prosody_to_audio_segment(
|
76 |
+
audio_segment: AudioSegment,
|
77 |
+
rate: float = 1,
|
78 |
+
volume: float = 0,
|
79 |
+
pitch: int = 0,
|
80 |
+
sr: int = 24000,
|
81 |
) -> AudioSegment:
|
82 |
+
audio_data = audiosegment_to_librosawav(audio_segment)
|
83 |
+
|
84 |
+
audio_data = apply_prosody_to_audio_data(audio_data, rate, volume, pitch, sr)
|
85 |
+
|
86 |
+
audio_segment = ndarray_to_segment(
|
87 |
+
audio_data, sr, audio_segment.sample_width, audio_segment.channels
|
|
|
|
|
|
|
|
|
|
|
88 |
)
|
89 |
+
|
90 |
+
return audio_segment
|
91 |
|
92 |
|
93 |
def apply_prosody_to_audio_data(
|
|
|
109 |
return audio_data
|
110 |
|
111 |
|
112 |
+
def apply_normalize(
|
113 |
+
audio_data: np.ndarray,
|
114 |
+
headroom: float = 1,
|
115 |
+
sr: int = 24000,
|
116 |
+
):
|
117 |
+
segment = ndarray_to_segment(audio_data, sr)
|
118 |
+
segment = effects.normalize(seg=segment, headroom=headroom)
|
119 |
+
|
120 |
+
return pydub_to_np(segment)
|
121 |
+
|
122 |
+
|
123 |
if __name__ == "__main__":
|
124 |
input_file = sys.argv[1]
|
125 |
|
|
|
129 |
input_sound = AudioSegment.from_mp3(input_file)
|
130 |
|
131 |
for time_factor in time_stretch_factors:
|
132 |
+
output_wav = f"{input_file}_time_{time_factor}.wav"
|
133 |
+
output_sound = apply_prosody_to_audio_segment(input_sound, rate=time_factor)
|
134 |
+
output_sound.export(output_wav, format="wav")
|
135 |
|
136 |
for pitch_factor in pitch_shift_factors:
|
137 |
+
output_wav = f"{input_file}_pitch_{pitch_factor}.wav"
|
138 |
+
output_sound = apply_prosody_to_audio_segment(input_sound, pitch=pitch_factor)
|
139 |
+
output_sound.export(output_wav, format="wav")
|
modules/utils/detect_lang.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
from typing import Literal
|
3 |
+
|
4 |
+
|
5 |
+
@lru_cache(maxsize=64)
|
6 |
+
def is_chinese(text):
|
7 |
+
for char in text:
|
8 |
+
if "\u4e00" <= char <= "\u9fff":
|
9 |
+
return True
|
10 |
+
return False
|
11 |
+
|
12 |
+
|
13 |
+
@lru_cache(maxsize=64)
|
14 |
+
def is_eng(text):
|
15 |
+
for char in text:
|
16 |
+
if "a" <= char.lower() <= "z":
|
17 |
+
return True
|
18 |
+
return False
|
19 |
+
|
20 |
+
|
21 |
+
@lru_cache(maxsize=64)
|
22 |
+
def guess_lang(text) -> Literal["zh", "en"]:
|
23 |
+
if is_chinese(text):
|
24 |
+
return "zh"
|
25 |
+
if is_eng(text):
|
26 |
+
return "en"
|
27 |
+
return "zh"
|
modules/utils/html.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from html.parser import HTMLParser
|
2 |
+
|
3 |
+
|
4 |
+
class HTMLTagRemover(HTMLParser):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.reset()
|
8 |
+
self.fed = []
|
9 |
+
|
10 |
+
def handle_data(self, data):
|
11 |
+
self.fed.append(data)
|
12 |
+
|
13 |
+
def get_data(self):
|
14 |
+
return "\n".join(self.fed)
|
15 |
+
|
16 |
+
|
17 |
+
def remove_html_tags(text):
|
18 |
+
parser = HTMLTagRemover()
|
19 |
+
parser.feed(text)
|
20 |
+
return parser.get_data()
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
input_text = "<h1>一个标题</h1> 这是一段包含<code>标签</code>的文本。"
|
25 |
+
output_text = remove_html_tags(input_text)
|
26 |
+
print(output_text) # 输出: 一个标题 这是一段包含标签的文本。
|
modules/utils/ignore_warn.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
|
4 |
+
def ignore_useless_warnings():
|
5 |
+
|
6 |
+
# NOTE: 因为触发位置在 `vocos/heads.py:60` 改不动...所以忽略
|
7 |
+
warnings.filterwarnings(
|
8 |
+
"ignore", category=UserWarning, message="ComplexHalf support is experimental"
|
9 |
+
)
|
modules/utils/markdown.py
CHANGED
@@ -36,6 +36,7 @@ class PlainTextRenderer(mistune.HTMLRenderer):
|
|
36 |
return html + "\n" + text + "\n"
|
37 |
return "\n" + text + "\n"
|
38 |
|
|
|
39 |
def list_item(self, text):
|
40 |
return "" + text + "\n"
|
41 |
|
|
|
36 |
return html + "\n" + text + "\n"
|
37 |
return "\n" + text + "\n"
|
38 |
|
39 |
+
# FIXME: 现在的 list 转换没法保留序号
|
40 |
def list_item(self, text):
|
41 |
return "" + text + "\n"
|
42 |
|
modules/webui/localization_runtime.py
CHANGED
@@ -90,6 +90,28 @@ class ZHLocalizationVars(LocalizationVars):
|
|
90 |
]
|
91 |
|
92 |
self.tts_examples = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
{
|
94 |
"text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
|
95 |
},
|
|
|
90 |
]
|
91 |
|
92 |
self.tts_examples = [
|
93 |
+
{
|
94 |
+
"text": """
|
95 |
+
Fear is the path to the dark side. Fear leads to anger. Anger leads to hate. Hate leads to suffering.
|
96 |
+
恐惧是通向黑暗之路。恐惧导致愤怒。愤怒引发仇恨。仇恨造成痛苦。 [lbreak]
|
97 |
+
Do. Or do not. There is no try.
|
98 |
+
要么做,要么不做,没有试试看。[lbreak]
|
99 |
+
Peace is a lie, there is only passion.
|
100 |
+
安宁即是谎言,激情方为王道。[lbreak]
|
101 |
+
Through passion, I gain strength.
|
102 |
+
我以激情换取力量。[lbreak]
|
103 |
+
Through strength, I gain power.
|
104 |
+
以力量赚取权力。[lbreak]
|
105 |
+
Through power, I gain victory.
|
106 |
+
以权力赢取胜利。[lbreak]
|
107 |
+
Through victory, my chains are broken.
|
108 |
+
于胜利中超越自我。[lbreak]
|
109 |
+
The Force shall free me.
|
110 |
+
原力任我逍遥。[lbreak]
|
111 |
+
May the force be with you!
|
112 |
+
愿原力与你同在![lbreak]
|
113 |
+
""".strip()
|
114 |
+
},
|
115 |
{
|
116 |
"text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
|
117 |
},
|
modules/webui/speaker/speaker_creator.py
CHANGED
@@ -62,11 +62,10 @@ def create_spk_from_seed(
|
|
62 |
gender: str,
|
63 |
desc: str,
|
64 |
):
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
spk =
|
69 |
-
spk.emb = emb
|
70 |
|
71 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
|
72 |
torch.save(spk, tmp_file)
|
@@ -82,7 +81,8 @@ def test_spk_voice(
|
|
82 |
text: str,
|
83 |
progress=gr.Progress(track_tqdm=True),
|
84 |
):
|
85 |
-
|
|
|
86 |
|
87 |
|
88 |
def random_speaker():
|
|
|
62 |
gender: str,
|
63 |
desc: str,
|
64 |
):
|
65 |
+
spk = Speaker.from_seed(seed)
|
66 |
+
spk.name = name
|
67 |
+
spk.gender = gender
|
68 |
+
spk.describe = desc
|
|
|
69 |
|
70 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
|
71 |
torch.save(spk, tmp_file)
|
|
|
81 |
text: str,
|
82 |
progress=gr.Progress(track_tqdm=True),
|
83 |
):
|
84 |
+
spk = Speaker.from_seed(seed)
|
85 |
+
return tts_generate(spk=spk, text=text, progress=progress)
|
86 |
|
87 |
|
88 |
def random_speaker():
|
modules/webui/ssml/podcast_tab.py
CHANGED
@@ -124,7 +124,7 @@ def create_ssml_podcast_tab(ssml_input: gr.Textbox, tabs1: gr.Tabs, tabs2: gr.Ta
|
|
124 |
|
125 |
def send_to_ssml(msg, spk, style, sheet: pd.DataFrame):
|
126 |
if sheet.empty:
|
127 |
-
|
128 |
msg, spk, style, ssml = merge_dataframe_to_ssml(msg, spk, style, sheet)
|
129 |
return [
|
130 |
msg,
|
|
|
124 |
|
125 |
def send_to_ssml(msg, spk, style, sheet: pd.DataFrame):
|
126 |
if sheet.empty:
|
127 |
+
raise gr.Error("Please add some text to the script table.")
|
128 |
msg, spk, style, ssml = merge_dataframe_to_ssml(msg, spk, style, sheet)
|
129 |
return [
|
130 |
msg,
|
modules/webui/ssml/ssml_tab.py
CHANGED
@@ -6,19 +6,6 @@ from modules.webui.webui_utils import synthesize_ssml
|
|
6 |
|
7 |
def create_ssml_interface():
|
8 |
with gr.Row():
|
9 |
-
with gr.Column(scale=3):
|
10 |
-
with gr.Group():
|
11 |
-
gr.Markdown("📝SSML Input")
|
12 |
-
gr.Markdown("SSML_TEXT_GUIDE")
|
13 |
-
ssml_input = gr.Textbox(
|
14 |
-
label="SSML Input",
|
15 |
-
lines=10,
|
16 |
-
value=webui_config.localization.DEFAULT_SSML_TEXT,
|
17 |
-
placeholder="输入 SSML 或选择示例",
|
18 |
-
elem_id="ssml_input",
|
19 |
-
show_label=False,
|
20 |
-
)
|
21 |
-
ssml_button = gr.Button("🔊Synthesize SSML", variant="primary")
|
22 |
with gr.Column(scale=1):
|
23 |
with gr.Group():
|
24 |
gr.Markdown("🎛️Parameters")
|
@@ -44,11 +31,64 @@ def create_ssml_interface():
|
|
44 |
step=1,
|
45 |
)
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
with gr.Group():
|
48 |
gr.Markdown("💪🏼Enhance")
|
49 |
enable_enhance = gr.Checkbox(value=True, label="Enable Enhance")
|
50 |
enable_de_noise = gr.Checkbox(value=False, label="Enable De-noise")
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
with gr.Group():
|
53 |
gr.Markdown("🎄Examples")
|
54 |
gr.Examples(
|
@@ -56,7 +96,9 @@ def create_ssml_interface():
|
|
56 |
inputs=[ssml_input],
|
57 |
)
|
58 |
|
59 |
-
|
|
|
|
|
60 |
|
61 |
ssml_button.click(
|
62 |
synthesize_ssml,
|
@@ -67,6 +109,11 @@ def create_ssml_interface():
|
|
67 |
enable_de_noise,
|
68 |
eos_input,
|
69 |
spliter_thr_input,
|
|
|
|
|
|
|
|
|
|
|
70 |
],
|
71 |
outputs=ssml_output,
|
72 |
)
|
|
|
6 |
|
7 |
def create_ssml_interface():
|
8 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
with gr.Column(scale=1):
|
10 |
with gr.Group():
|
11 |
gr.Markdown("🎛️Parameters")
|
|
|
31 |
step=1,
|
32 |
)
|
33 |
|
34 |
+
with gr.Group():
|
35 |
+
gr.Markdown("🎛️Adjuster")
|
36 |
+
# 调节 speed pitch volume
|
37 |
+
# 可以选择开启 响度均衡
|
38 |
+
|
39 |
+
speed_input = gr.Slider(
|
40 |
+
label="Speed",
|
41 |
+
value=1.0,
|
42 |
+
minimum=0.5,
|
43 |
+
maximum=2.0,
|
44 |
+
step=0.1,
|
45 |
+
)
|
46 |
+
pitch_input = gr.Slider(
|
47 |
+
label="Pitch",
|
48 |
+
value=0,
|
49 |
+
minimum=-12,
|
50 |
+
maximum=12,
|
51 |
+
step=0.1,
|
52 |
+
)
|
53 |
+
volume_up_input = gr.Slider(
|
54 |
+
label="Volume Gain",
|
55 |
+
value=0,
|
56 |
+
minimum=-12,
|
57 |
+
maximum=12,
|
58 |
+
step=0.1,
|
59 |
+
)
|
60 |
+
|
61 |
+
enable_loudness_normalization = gr.Checkbox(
|
62 |
+
value=True,
|
63 |
+
label="Enable Loudness EQ",
|
64 |
+
)
|
65 |
+
headroom_input = gr.Slider(
|
66 |
+
label="Headroom",
|
67 |
+
value=1,
|
68 |
+
minimum=0,
|
69 |
+
maximum=12,
|
70 |
+
step=0.1,
|
71 |
+
)
|
72 |
+
|
73 |
with gr.Group():
|
74 |
gr.Markdown("💪🏼Enhance")
|
75 |
enable_enhance = gr.Checkbox(value=True, label="Enable Enhance")
|
76 |
enable_de_noise = gr.Checkbox(value=False, label="Enable De-noise")
|
77 |
|
78 |
+
with gr.Column(scale=3):
|
79 |
+
with gr.Group():
|
80 |
+
gr.Markdown("📝SSML Input")
|
81 |
+
gr.Markdown("SSML_TEXT_GUIDE")
|
82 |
+
ssml_input = gr.Textbox(
|
83 |
+
label="SSML Input",
|
84 |
+
lines=10,
|
85 |
+
value=webui_config.localization.DEFAULT_SSML_TEXT,
|
86 |
+
placeholder="输入 SSML 或选择示例",
|
87 |
+
elem_id="ssml_input",
|
88 |
+
show_label=False,
|
89 |
+
)
|
90 |
+
ssml_button = gr.Button("🔊Synthesize SSML", variant="primary")
|
91 |
+
|
92 |
with gr.Group():
|
93 |
gr.Markdown("🎄Examples")
|
94 |
gr.Examples(
|
|
|
96 |
inputs=[ssml_input],
|
97 |
)
|
98 |
|
99 |
+
with gr.Group():
|
100 |
+
gr.Markdown("🎨Output")
|
101 |
+
ssml_output = gr.Audio(label="Generated Audio", format="mp3")
|
102 |
|
103 |
ssml_button.click(
|
104 |
synthesize_ssml,
|
|
|
109 |
enable_de_noise,
|
110 |
eos_input,
|
111 |
spliter_thr_input,
|
112 |
+
pitch_input,
|
113 |
+
speed_input,
|
114 |
+
volume_up_input,
|
115 |
+
enable_loudness_normalization,
|
116 |
+
headroom_input,
|
117 |
],
|
118 |
outputs=ssml_output,
|
119 |
)
|
modules/webui/tts_tab.py
CHANGED
@@ -228,14 +228,56 @@ def create_tts_interface():
|
|
228 |
label="prompt_audio", visible=webui_config.experimental
|
229 |
)
|
230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
with gr.Group():
|
232 |
gr.Markdown("🔊Generate")
|
233 |
disable_normalize_input = gr.Checkbox(
|
234 |
-
value=False,
|
|
|
|
|
|
|
235 |
)
|
236 |
|
237 |
with gr.Group():
|
238 |
-
gr.Markdown("💪🏼Enhance")
|
239 |
enable_enhance = gr.Checkbox(value=True, label="Enable Enhance")
|
240 |
enable_de_noise = gr.Checkbox(value=False, label="Enable De-noise")
|
241 |
tts_button = gr.Button(
|
@@ -271,6 +313,11 @@ def create_tts_interface():
|
|
271 |
spk_file_upload,
|
272 |
spliter_thr_input,
|
273 |
eos_input,
|
|
|
|
|
|
|
|
|
|
|
274 |
],
|
275 |
outputs=tts_output,
|
276 |
)
|
|
|
228 |
label="prompt_audio", visible=webui_config.experimental
|
229 |
)
|
230 |
|
231 |
+
with gr.Group():
|
232 |
+
gr.Markdown("🎛️Adjuster")
|
233 |
+
# 调节 speed pitch volume
|
234 |
+
# 可以选择开启 响度均衡
|
235 |
+
|
236 |
+
speed_input = gr.Slider(
|
237 |
+
label="Speed",
|
238 |
+
value=1.0,
|
239 |
+
minimum=0.5,
|
240 |
+
maximum=2.0,
|
241 |
+
step=0.1,
|
242 |
+
)
|
243 |
+
pitch_input = gr.Slider(
|
244 |
+
label="Pitch",
|
245 |
+
value=0,
|
246 |
+
minimum=-12,
|
247 |
+
maximum=12,
|
248 |
+
step=0.1,
|
249 |
+
)
|
250 |
+
volume_up_input = gr.Slider(
|
251 |
+
label="Volume Gain",
|
252 |
+
value=0,
|
253 |
+
minimum=-12,
|
254 |
+
maximum=12,
|
255 |
+
step=0.1,
|
256 |
+
)
|
257 |
+
|
258 |
+
enable_loudness_normalization = gr.Checkbox(
|
259 |
+
value=True,
|
260 |
+
label="Enable Loudness EQ",
|
261 |
+
)
|
262 |
+
headroom_input = gr.Slider(
|
263 |
+
label="Headroom",
|
264 |
+
value=1,
|
265 |
+
minimum=0,
|
266 |
+
maximum=12,
|
267 |
+
step=0.1,
|
268 |
+
)
|
269 |
+
|
270 |
with gr.Group():
|
271 |
gr.Markdown("🔊Generate")
|
272 |
disable_normalize_input = gr.Checkbox(
|
273 |
+
value=False,
|
274 |
+
label="Disable Normalize",
|
275 |
+
# 不需要了
|
276 |
+
visible=False,
|
277 |
)
|
278 |
|
279 |
with gr.Group():
|
280 |
+
# gr.Markdown("💪🏼Enhance")
|
281 |
enable_enhance = gr.Checkbox(value=True, label="Enable Enhance")
|
282 |
enable_de_noise = gr.Checkbox(value=False, label="Enable De-noise")
|
283 |
tts_button = gr.Button(
|
|
|
313 |
spk_file_upload,
|
314 |
spliter_thr_input,
|
315 |
eos_input,
|
316 |
+
pitch_input,
|
317 |
+
speed_input,
|
318 |
+
volume_up_input,
|
319 |
+
enable_loudness_normalization,
|
320 |
+
headroom_input,
|
321 |
],
|
322 |
outputs=tts_output,
|
323 |
)
|
modules/webui/webui_utils.py
CHANGED
@@ -6,6 +6,11 @@ import torch
|
|
6 |
import torch.profiler
|
7 |
|
8 |
from modules import refiner
|
|
|
|
|
|
|
|
|
|
|
9 |
from modules.api.utils import calc_spk_style
|
10 |
from modules.data import styles_mgr
|
11 |
from modules.Enhancer.ResembleEnhance import apply_audio_enhance as _apply_audio_enhance
|
@@ -13,8 +18,6 @@ from modules.normalization import text_normalize
|
|
13 |
from modules.SentenceSplitter import SentenceSplitter
|
14 |
from modules.speaker import Speaker, speaker_mgr
|
15 |
from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLSegment, create_ssml_parser
|
16 |
-
from modules.synthesize_audio import synthesize_audio
|
17 |
-
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
|
18 |
from modules.utils import audio
|
19 |
from modules.utils.hf import spaces
|
20 |
from modules.webui import webui_config
|
@@ -89,6 +92,11 @@ def synthesize_ssml(
|
|
89 |
enable_denoise=False,
|
90 |
eos: str = "[uv_break]",
|
91 |
spliter_thr: int = 100,
|
|
|
|
|
|
|
|
|
|
|
92 |
progress=gr.Progress(track_tqdm=True),
|
93 |
):
|
94 |
try:
|
@@ -99,7 +107,7 @@ def synthesize_ssml(
|
|
99 |
ssml = ssml.strip()
|
100 |
|
101 |
if ssml == "":
|
102 |
-
|
103 |
|
104 |
parser = create_ssml_parser()
|
105 |
segments = parser.parse(ssml)
|
@@ -107,22 +115,36 @@ def synthesize_ssml(
|
|
107 |
segments = segments_length_limit(segments, max_len)
|
108 |
|
109 |
if len(segments) == 0:
|
110 |
-
|
111 |
|
112 |
-
|
113 |
-
batch_size=batch_size,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
)
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
sr,
|
122 |
-
enable_denoise,
|
123 |
-
enable_enhance,
|
124 |
)
|
125 |
|
|
|
|
|
126 |
# NOTE: 这里必须要加,不然 gradio 没法解析成 mp3 格式
|
127 |
audio_data = audio.audio_to_int16(audio_data)
|
128 |
|
@@ -150,6 +172,11 @@ def tts_generate(
|
|
150 |
spk_file=None,
|
151 |
spliter_thr: int = 100,
|
152 |
eos: str = "[uv_break]",
|
|
|
|
|
|
|
|
|
|
|
153 |
progress=gr.Progress(track_tqdm=True),
|
154 |
):
|
155 |
try:
|
@@ -161,10 +188,10 @@ def tts_generate(
|
|
161 |
text = text.strip()[0:max_len]
|
162 |
|
163 |
if text == "":
|
164 |
-
|
165 |
|
166 |
if style == "*auto":
|
167 |
-
style =
|
168 |
|
169 |
if isinstance(top_k, float):
|
170 |
top_k = int(top_k)
|
@@ -181,31 +208,56 @@ def tts_generate(
|
|
181 |
infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.float64)
|
182 |
infer_seed = int(infer_seed)
|
183 |
|
184 |
-
if
|
185 |
-
|
186 |
|
187 |
if spk_file:
|
188 |
-
|
|
|
|
|
|
|
189 |
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
192 |
temperature=temperature,
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
infer_seed=infer_seed,
|
197 |
-
use_decoder=use_decoder,
|
198 |
prompt1=prompt1,
|
199 |
prompt2=prompt2,
|
200 |
-
|
|
|
201 |
batch_size=batch_size,
|
202 |
-
end_of_sentence=eos,
|
203 |
spliter_threshold=spliter_thr,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
)
|
205 |
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
208 |
)
|
|
|
|
|
|
|
209 |
# NOTE: 这里必须要加,不然 gradio 没法解析成 mp3 格式
|
210 |
audio_data = audio.audio_to_int16(audio_data)
|
211 |
return sample_rate, audio_data
|
|
|
6 |
import torch.profiler
|
7 |
|
8 |
from modules import refiner
|
9 |
+
from modules.api.impl.handler.SSMLHandler import SSMLHandler
|
10 |
+
from modules.api.impl.handler.TTSHandler import TTSHandler
|
11 |
+
from modules.api.impl.model.audio_model import AdjustConfig
|
12 |
+
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
13 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
14 |
from modules.api.utils import calc_spk_style
|
15 |
from modules.data import styles_mgr
|
16 |
from modules.Enhancer.ResembleEnhance import apply_audio_enhance as _apply_audio_enhance
|
|
|
18 |
from modules.SentenceSplitter import SentenceSplitter
|
19 |
from modules.speaker import Speaker, speaker_mgr
|
20 |
from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLSegment, create_ssml_parser
|
|
|
|
|
21 |
from modules.utils import audio
|
22 |
from modules.utils.hf import spaces
|
23 |
from modules.webui import webui_config
|
|
|
92 |
enable_denoise=False,
|
93 |
eos: str = "[uv_break]",
|
94 |
spliter_thr: int = 100,
|
95 |
+
pitch: float = 0,
|
96 |
+
speed_rate: float = 1,
|
97 |
+
volume_gain_db: float = 0,
|
98 |
+
normalize: bool = True,
|
99 |
+
headroom: float = 1,
|
100 |
progress=gr.Progress(track_tqdm=True),
|
101 |
):
|
102 |
try:
|
|
|
107 |
ssml = ssml.strip()
|
108 |
|
109 |
if ssml == "":
|
110 |
+
raise gr.Error("SSML is empty, please input some SSML")
|
111 |
|
112 |
parser = create_ssml_parser()
|
113 |
segments = parser.parse(ssml)
|
|
|
115 |
segments = segments_length_limit(segments, max_len)
|
116 |
|
117 |
if len(segments) == 0:
|
118 |
+
raise gr.Error("No valid segments in SSML")
|
119 |
|
120 |
+
infer_config = InferConfig(
|
121 |
+
batch_size=batch_size,
|
122 |
+
spliter_threshold=spliter_thr,
|
123 |
+
eos=eos,
|
124 |
+
# NOTE: SSML not support `infer_seed` contorl
|
125 |
+
# seed=42,
|
126 |
+
)
|
127 |
+
adjust_config = AdjustConfig(
|
128 |
+
pitch=pitch,
|
129 |
+
speed_rate=speed_rate,
|
130 |
+
volume_gain_db=volume_gain_db,
|
131 |
+
normalize=normalize,
|
132 |
+
headroom=headroom,
|
133 |
+
)
|
134 |
+
enhancer_config = EnhancerConfig(
|
135 |
+
enabled=enable_denoise or enable_enhance or False,
|
136 |
+
lambd=0.9 if enable_denoise else 0.1,
|
137 |
)
|
138 |
+
|
139 |
+
handler = SSMLHandler(
|
140 |
+
ssml_content=ssml,
|
141 |
+
infer_config=infer_config,
|
142 |
+
adjust_config=adjust_config,
|
143 |
+
enhancer_config=enhancer_config,
|
|
|
|
|
|
|
144 |
)
|
145 |
|
146 |
+
audio_data, sr = handler.enqueue()
|
147 |
+
|
148 |
# NOTE: 这里必须要加,不然 gradio 没法解析成 mp3 格式
|
149 |
audio_data = audio.audio_to_int16(audio_data)
|
150 |
|
|
|
172 |
spk_file=None,
|
173 |
spliter_thr: int = 100,
|
174 |
eos: str = "[uv_break]",
|
175 |
+
pitch: float = 0,
|
176 |
+
speed_rate: float = 1,
|
177 |
+
volume_gain_db: float = 0,
|
178 |
+
normalize: bool = True,
|
179 |
+
headroom: float = 1,
|
180 |
progress=gr.Progress(track_tqdm=True),
|
181 |
):
|
182 |
try:
|
|
|
188 |
text = text.strip()[0:max_len]
|
189 |
|
190 |
if text == "":
|
191 |
+
raise gr.Error("Text is empty, please input some text")
|
192 |
|
193 |
if style == "*auto":
|
194 |
+
style = ""
|
195 |
|
196 |
if isinstance(top_k, float):
|
197 |
top_k = int(top_k)
|
|
|
208 |
infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.float64)
|
209 |
infer_seed = int(infer_seed)
|
210 |
|
211 |
+
if isinstance(spk, int):
|
212 |
+
spk = Speaker.from_seed(spk)
|
213 |
|
214 |
if spk_file:
|
215 |
+
try:
|
216 |
+
spk: Speaker = Speaker.from_file(spk_file)
|
217 |
+
except Exception:
|
218 |
+
raise gr.Error("Failed to load speaker file")
|
219 |
|
220 |
+
if not isinstance(spk.emb, torch.Tensor):
|
221 |
+
raise gr.Error("Speaker file is not supported")
|
222 |
+
|
223 |
+
tts_config = ChatTTSConfig(
|
224 |
+
style=style,
|
225 |
temperature=temperature,
|
226 |
+
top_k=top_k,
|
227 |
+
top_p=top_p,
|
228 |
+
prefix=prefix,
|
|
|
|
|
229 |
prompt1=prompt1,
|
230 |
prompt2=prompt2,
|
231 |
+
)
|
232 |
+
infer_config = InferConfig(
|
233 |
batch_size=batch_size,
|
|
|
234 |
spliter_threshold=spliter_thr,
|
235 |
+
eos=eos,
|
236 |
+
seed=infer_seed,
|
237 |
+
)
|
238 |
+
adjust_config = AdjustConfig(
|
239 |
+
pitch=pitch,
|
240 |
+
speed_rate=speed_rate,
|
241 |
+
volume_gain_db=volume_gain_db,
|
242 |
+
normalize=normalize,
|
243 |
+
headroom=headroom,
|
244 |
+
)
|
245 |
+
enhancer_config = EnhancerConfig(
|
246 |
+
enabled=enable_denoise or enable_enhance or False,
|
247 |
+
lambd=0.9 if enable_denoise else 0.1,
|
248 |
)
|
249 |
|
250 |
+
handler = TTSHandler(
|
251 |
+
text_content=text,
|
252 |
+
spk=spk,
|
253 |
+
tts_config=tts_config,
|
254 |
+
infer_config=infer_config,
|
255 |
+
adjust_config=adjust_config,
|
256 |
+
enhancer_config=enhancer_config,
|
257 |
)
|
258 |
+
|
259 |
+
audio_data, sample_rate = handler.enqueue()
|
260 |
+
|
261 |
# NOTE: 这里必须要加,不然 gradio 没法解析成 mp3 格式
|
262 |
audio_data = audio.audio_to_int16(audio_data)
|
263 |
return sample_rate, audio_data
|
requirements.txt
CHANGED
@@ -1,27 +1,29 @@
|
|
1 |
-
numpy
|
2 |
scipy
|
3 |
lxml
|
4 |
pydub
|
5 |
fastapi
|
6 |
soundfile
|
7 |
-
pyrubberband
|
8 |
omegaconf
|
9 |
pypinyin
|
|
|
10 |
pandas
|
11 |
vector_quantize_pytorch
|
12 |
einops
|
|
|
13 |
omegaconf~=2.3.0
|
14 |
tqdm
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
torch
|
19 |
-
torchvision
|
20 |
-
torchaudio
|
21 |
gradio
|
22 |
emojiswitch
|
23 |
python-dotenv
|
24 |
zhon
|
25 |
mistune==3.0.2
|
26 |
cn2an
|
27 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.26.4
|
2 |
scipy
|
3 |
lxml
|
4 |
pydub
|
5 |
fastapi
|
6 |
soundfile
|
|
|
7 |
omegaconf
|
8 |
pypinyin
|
9 |
+
vocos
|
10 |
pandas
|
11 |
vector_quantize_pytorch
|
12 |
einops
|
13 |
+
transformers~=4.41.1
|
14 |
omegaconf~=2.3.0
|
15 |
tqdm
|
16 |
+
# torch
|
17 |
+
# torchvision
|
18 |
+
# torchaudio
|
|
|
|
|
|
|
19 |
gradio
|
20 |
emojiswitch
|
21 |
python-dotenv
|
22 |
zhon
|
23 |
mistune==3.0.2
|
24 |
cn2an
|
25 |
+
# audio_denoiser
|
26 |
+
python-box
|
27 |
+
ftfy
|
28 |
+
librosa
|
29 |
+
pyrubberband
|
webui.py
CHANGED
@@ -6,6 +6,7 @@ from modules.ffmpeg_env import setup_ffmpeg_path
|
|
6 |
|
7 |
try:
|
8 |
setup_ffmpeg_path()
|
|
|
9 |
logging.basicConfig(
|
10 |
level=os.getenv("LOG_LEVEL", "INFO"),
|
11 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
@@ -16,20 +17,18 @@ except BaseException:
|
|
16 |
import argparse
|
17 |
|
18 |
from modules import config
|
19 |
-
from modules.api.api_setup import
|
20 |
-
process_api_args,
|
21 |
-
process_model_args,
|
22 |
-
setup_api_args,
|
23 |
-
setup_model_args,
|
24 |
-
)
|
25 |
from modules.api.app_config import app_description, app_title, app_version
|
26 |
from modules.gradio_dcls_fix import dcls_patch
|
|
|
27 |
from modules.utils.env import get_and_update_env
|
|
|
28 |
from modules.utils.torch_opt import configure_torch_optimizations
|
29 |
from modules.webui import webui_config
|
30 |
from modules.webui.app import create_interface, webui_init
|
31 |
|
32 |
dcls_patch()
|
|
|
33 |
|
34 |
|
35 |
def setup_webui_args(parser: argparse.ArgumentParser):
|
|
|
6 |
|
7 |
try:
|
8 |
setup_ffmpeg_path()
|
9 |
+
# NOTE: 因为 logger 都是在模块中初始化,所以这个 config 必须在最前面
|
10 |
logging.basicConfig(
|
11 |
level=os.getenv("LOG_LEVEL", "INFO"),
|
12 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
|
17 |
import argparse
|
18 |
|
19 |
from modules import config
|
20 |
+
from modules.api.api_setup import process_api_args, setup_api_args
|
|
|
|
|
|
|
|
|
|
|
21 |
from modules.api.app_config import app_description, app_title, app_version
|
22 |
from modules.gradio_dcls_fix import dcls_patch
|
23 |
+
from modules.models_setup import process_model_args, setup_model_args
|
24 |
from modules.utils.env import get_and_update_env
|
25 |
+
from modules.utils.ignore_warn import ignore_useless_warnings
|
26 |
from modules.utils.torch_opt import configure_torch_optimizations
|
27 |
from modules.webui import webui_config
|
28 |
from modules.webui.app import create_interface, webui_init
|
29 |
|
30 |
dcls_patch()
|
31 |
+
ignore_useless_warnings()
|
32 |
|
33 |
|
34 |
def setup_webui_args(parser: argparse.ArgumentParser):
|