zhzluke96 commited on
Commit
bed01bd
·
1 Parent(s): 37195a7
Files changed (42) hide show
  1. CHANGELOG.md +150 -7
  2. launch.py +49 -11
  3. modules/ChatTTS/ChatTTS/core.py +124 -93
  4. modules/ChatTTS/ChatTTS/infer/api.py +4 -0
  5. modules/ChatTTS/ChatTTS/model/gpt.py +123 -74
  6. modules/ChatTTS/ChatTTS/utils/infer_utils.py +31 -6
  7. modules/Enhancer/ResembleEnhance.py +1 -1
  8. modules/SentenceSplitter.py +32 -1
  9. modules/SynthesizeSegments.py +30 -17
  10. modules/api/api_setup.py +5 -98
  11. modules/api/impl/handler/AudioHandler.py +19 -1
  12. modules/api/impl/handler/SSMLHandler.py +5 -0
  13. modules/api/impl/handler/TTSHandler.py +60 -1
  14. modules/api/impl/model/audio_model.py +4 -0
  15. modules/api/impl/tts_api.py +18 -4
  16. modules/api/impl/xtts_v2_api.py +97 -37
  17. modules/api/worker.py +3 -7
  18. modules/devices/devices.py +7 -1
  19. modules/finetune/train_speaker.py +2 -2
  20. modules/generate_audio.py +84 -8
  21. modules/models.py +11 -2
  22. modules/models_setup.py +74 -0
  23. modules/normalization.py +35 -24
  24. modules/refiner.py +8 -0
  25. modules/repos_static/resemble_enhance/inference.py +5 -5
  26. modules/speaker.py +6 -0
  27. modules/synthesize_audio.py +0 -2
  28. modules/synthesize_stream.py +42 -0
  29. modules/utils/HomophonesReplacer.py +39 -0
  30. modules/utils/audio.py +64 -58
  31. modules/utils/detect_lang.py +27 -0
  32. modules/utils/html.py +26 -0
  33. modules/utils/ignore_warn.py +9 -0
  34. modules/utils/markdown.py +1 -0
  35. modules/webui/localization_runtime.py +22 -0
  36. modules/webui/speaker/speaker_creator.py +6 -6
  37. modules/webui/ssml/podcast_tab.py +1 -1
  38. modules/webui/ssml/ssml_tab.py +61 -14
  39. modules/webui/tts_tab.py +49 -2
  40. modules/webui/webui_utils.py +83 -31
  41. requirements.txt +11 -9
  42. webui.py +5 -6
CHANGELOG.md CHANGED
@@ -1,22 +1,150 @@
1
  # Changelog
2
 
3
- <a name="0.5.6-rc"></a>
4
- ## 0.5.6-rc (2024-06-09)
5
 
6
  ### Added
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  - ✨ add localization [[c05035d](https://github.com/lenML/ChatTTS-Forge/commit/c05035d5cdcc5aa7efd995fe42f6a2541abe718b)]
9
  - ✨ SSML 支持 enhancer [[5c2788e](https://github.com/lenML/ChatTTS-Forge/commit/5c2788e04f3debfa8bafd8a2e2371dde30f38d4d)]
10
  - ✨ webui 增加 podcast 工具 tab [[b0b169d](https://github.com/lenML/ChatTTS-Forge/commit/b0b169d8b49c8e013209e59d1f8b637382d8b997)]
11
- - ✨ 完善 enhancer [[205ebeb](https://github.com/lenML/ChatTTS-Forge/commit/205ebebeb7530c81fde7ea96c7e4c6a888a29835)]
12
 
13
  ### Changed
14
 
 
 
 
 
 
15
  - 🍱 update banner [[dbc293e](https://github.com/lenML/ChatTTS-Forge/commit/dbc293e1a7dec35f60020dcaf783ba3b7c734bfa)]
16
  - ⚡ 增强 TN [[092c1b9](https://github.com/lenML/ChatTTS-Forge/commit/092c1b94147249880198fe2ad3dfe3b209099e19)]
17
  - ⚡ enhancer 支持 off_tqdm [[94d34d6](https://github.com/lenML/ChatTTS-Forge/commit/94d34d657fa3433dae9ff61775e0c364a6f77aff)]
18
  - ⚡ 增加 git env [[43d9c65](https://github.com/lenML/ChatTTS-Forge/commit/43d9c65877ff68ad94716bc2e505ccc7ae8869a8)]
19
- - ⚡ 修改webui保存文件格式 [[2da41c9](https://github.com/lenML/ChatTTS-Forge/commit/2da41c90aa81bf87403598aefaea3e0ae2e83d79)]
 
 
 
 
20
 
21
  ### Removed
22
 
@@ -24,6 +152,17 @@
24
 
25
  ### Fixed
26
 
 
 
 
 
 
 
 
 
 
 
 
27
  - 🐛 fix hparams config [#22](https://github.com/lenML/ChatTTS-Forge/issues/22) [[61d9809](https://github.com/lenML/ChatTTS-Forge/commit/61d9809804ad8c141d36afde51a608734a105662)]
28
  - 🐛 fix enhance 下载脚本 [[d2e14b0](https://github.com/lenML/ChatTTS-Forge/commit/d2e14b0a4905724a55b03493fa4b94b5c4383c95)]
29
  - 🐛 fix &#x27;trange&#x27; referenced [[d1a8dae](https://github.com/lenML/ChatTTS-Forge/commit/d1a8daee61e62d14cf5fd7a17fab4424e24b1c41)]
@@ -33,10 +172,14 @@
33
 
34
  ### Miscellaneous
35
 
 
 
 
 
 
 
 
36
  - 🌐 更新翻译文案 [[f56caa7](https://github.com/lenML/ChatTTS-Forge/commit/f56caa71e9186680b93c487d9645186ae18c1dc6)]
37
- - 📝 update [[7cacf91](https://github.com/lenML/ChatTTS-Forge/commit/7cacf913541ee5f86eaa80d8b193b94b3db2b67c)]
38
- - 📝 update webui document [[7f2bb22](https://github.com/lenML/ChatTTS-Forge/commit/7f2bb227027cc0eff312c37758a20916c1ebade6)]
39
-
40
 
41
  <a name="0.5.5"></a>
42
 
 
1
  # Changelog
2
 
3
+ <a name="0.6.2-rc"></a>
4
+ ## 0.6.2-rc (2024-06-23)
5
 
6
  ### Added
7
 
8
+ - ✨ add adjuster to webui [[01f09b4](https://github.com/lenML/ChatTTS-Forge/commit/01f09b4fad2eb8b24a16b7768403de4975d51774)]
9
+ - ✨ stream mode support adjuster [[585d2dd](https://github.com/lenML/ChatTTS-Forge/commit/585d2dd488d8f8387e0d9435fb399f090a41b9cc)]
10
+ - ✨ improve xtts_v2 api [[fec66c7](https://github.com/lenML/ChatTTS-Forge/commit/fec66c7c00939a3c7c15e007536e037ac01153fa)]
11
+ - ✨ improve normalize [[d0da37e](https://github.com/lenML/ChatTTS-Forge/commit/d0da37e43f1de4088ef638edd90723f93894b1d2)]
12
+ - ✨ improve normalize/spliter [[163b649](https://github.com/lenML/ChatTTS-Forge/commit/163b6490e4d453c37cc259ce27208f55d10a9084)]
13
+ - ✨ add loudness equalization [[bc8bda7](https://github.com/lenML/ChatTTS-Forge/commit/bc8bda74825c31985d3cc1a44366ad92af1b623a)]
14
+ - ✨ support &#x60;--use_cpu&#x3D;chattts,enhancer,trainer,all&#x60; [[23023bc](https://github.com/lenML/ChatTTS-Forge/commit/23023bc610f6f74a157faa8a6c6aacf64d91d870)]
15
+ - ✨ improve normalizetion.py [[1a7c0ed](https://github.com/lenML/ChatTTS-Forge/commit/1a7c0ed3923234ceadb79f397fa7577f9e682f2d)]
16
+ - ✨ ignore_useless_warnings [[4b9a32e](https://github.com/lenML/ChatTTS-Forge/commit/4b9a32ef821d85ceaf3d62af8f871aeb5088e084)]
17
+ - ✨ enhance logger, info &#x3D;&gt; debug [[73bc8e7](https://github.com/lenML/ChatTTS-Forge/commit/73bc8e72b40146debd0a59100b1cca4cc42f5029)]
18
+ - ✨ add playground.stream page [[31377b0](https://github.com/lenML/ChatTTS-Forge/commit/31377b060c182519d74a12d81e66c8e73686bcd8)]
19
+ - ✨ tts api support stream [#5](https://github.com/lenML/ChatTTS-Forge/issues/5) [[15e0b2c](https://github.com/lenML/ChatTTS-Forge/commit/15e0b2cb051ba39dcf99f60f1faa11941f6dc656)]
20
+
21
+ ### Changed
22
+
23
+ - 🍱 add _p_en [[56f1fbf](https://github.com/lenML/ChatTTS-Forge/commit/56f1fbf1f3fff6f76ca8c29aa12a6ddef665cf9f)]
24
+ - 🍱 update prompt [[4f95b31](https://github.com/lenML/ChatTTS-Forge/commit/4f95b31679225e1ee144a411a9cfa9b30c598450)]
25
+ - ⚡ Reduce popping sounds [[2d0fd68](https://github.com/lenML/ChatTTS-Forge/commit/2d0fd688ad1a5cff1e6aafc0502aee26de3f1d75)]
26
+ - ⚡ improve &#x60;apply_character_map&#x60; [[ea7399f](https://github.com/lenML/ChatTTS-Forge/commit/ea7399facc5c29327a7870bd66ad6222f5731ce3)]
27
+
28
+ ### Fixed
29
+
30
+ - 🐛 fix &#x60;apply_normalize&#x60; missing &#x60;sr&#x60; [[2db6d65](https://github.com/lenML/ChatTTS-Forge/commit/2db6d65ef8fbf8a3a213cbdc3d4b1143396cc165)]
31
+ - 🐛 fix sentence spliter [[5d8937c](https://github.com/lenML/ChatTTS-Forge/commit/5d8937c169d5f7784920a93834df0480dd3a67b3)]
32
+ - 🐛 fix playground url_join [[53e7cbc](https://github.com/lenML/ChatTTS-Forge/commit/53e7cbc6103bc0e3bb83767a9233c45285b77e75)]
33
+ - 🐛 fix generate_audio args [[a7a698c](https://github.com/lenML/ChatTTS-Forge/commit/a7a698c760b5bc97c90a144a4a7afb5e17414995)]
34
+ - 🐛 fix infer func [[b0de527](https://github.com/lenML/ChatTTS-Forge/commit/b0de5275342c02d332a50d0ab5ac171a7007b300)]
35
+ - 🐛 fix webui logging format [[4adc29e](https://github.com/lenML/ChatTTS-Forge/commit/4adc29e6c06fa806a8178f445399bbac8ed57911)]
36
+ - 🐛 fix webui speaker_tab missing progress [[fafe242](https://github.com/lenML/ChatTTS-Forge/commit/fafe242e69ea8019729a62e52f6c0b3c0d6a63ad)]
37
+
38
+ ### Miscellaneous
39
+
40
+ - 📝 添加整合包地址 [[26122d4](https://github.com/lenML/ChatTTS-Forge/commit/26122d4cfd975206211fc37491348cf40aa39561)]
41
+ - 📝 details &#x60;.env&#x60; file and cli usage docs [[ec3d36f](https://github.com/lenML/ChatTTS-Forge/commit/ec3d36f8a67215e243e6b8225aa9144ac888313a)]
42
+ - 📝 update changelog [[22996e9](https://github.com/lenML/ChatTTS-Forge/commit/22996e9f0c42d9cad59950aecfe6b16413f2ab40)]
43
+ - Windows not yet supported for torch.compile fix [[74ac27d](https://github.com/lenML/ChatTTS-Forge/commit/74ac27d56a370f87560329043c42be27022ca0f5)]
44
+ - fix: replace mispronounced words in TTS [[de66e6b](https://github.com/lenML/ChatTTS-Forge/commit/de66e6b8f7f8b5c10e7ac54f7b2488c798e5ef81)]
45
+ - feat: support stream mode [[3da0f0c](https://github.com/lenML/ChatTTS-Forge/commit/3da0f0cb7f213dee40d00a89093166ad9e1d17a0)]
46
+ - optimize: mps audio quality by contiguous scores [[1e4d79f](https://github.com/lenML/ChatTTS-Forge/commit/1e4d79f1a81a3ac8697afff0e44f0cfd2608599a)]
47
+ - 📝 update changelog [[ab55c14](https://github.com/lenML/ChatTTS-Forge/commit/ab55c149d48edc52f1de9c6d4fe6e6ed78b3b134)]
48
+
49
+
50
+ <a name="0.6.1"></a>
51
+
52
+ ## 0.6.1 (2024-06-18)
53
+
54
+ ### Added
55
+
56
+ - ✨ add &#x60;--preload_models&#x60; [[73a41e0](https://github.com/lenML/ChatTTS-Forge/commit/73a41e009cd4426dfe4b0a35325da68189966390)]
57
+ - ✨ add webui progress [[778802d](https://github.com/lenML/ChatTTS-Forge/commit/778802ded12de340520f41a3e1bdb852f00bd637)]
58
+ - ✨ add merger error [[51060bc](https://github.com/lenML/ChatTTS-Forge/commit/51060bc343a6308493b7d582e21dca62eacaa7cb)]
59
+ - ✨ tts prompt &#x3D;&gt; experimental [[d3e6315](https://github.com/lenML/ChatTTS-Forge/commit/d3e6315a3cb8b1fa254cefb2efe2bae7c74a50f8)]
60
+ - ✨ add 基本的 speaker finetune ui [[5f68f19](https://github.com/lenML/ChatTTS-Forge/commit/5f68f193e78f470bd2c3ca4b9fa1008cf809e753)]
61
+ - ✨ add speaker finetune [[5ce27ed](https://github.com/lenML/ChatTTS-Forge/commit/5ce27ed7e4da6c96bb3fd016b8b491768faf319d)]
62
+ - ✨ add &#x60;--ino_half&#x60; remove &#x60;--half&#x60; [[5820e57](https://github.com/lenML/ChatTTS-Forge/commit/5820e576b288df50b929fbdfd9d0d6b6f548b54e)]
63
+ - ✨ add webui podcast 默认值 [[dd786a8](https://github.com/lenML/ChatTTS-Forge/commit/dd786a83733a71d005ff7efe6312e35d652b2525)]
64
+ - ✨ add webui 分割器配置 [[589327b](https://github.com/lenML/ChatTTS-Forge/commit/589327b729188d1385838816b9807e894eb128b0)]
65
+ - ✨ add &#x60;eos&#x60; params to all api [[79c994f](https://github.com/lenML/ChatTTS-Forge/commit/79c994fadf7d60ea432b62f4000b62b67efe7259)]
66
+
67
+ ### Changed
68
+
69
+ - ⬆️ Bump urllib3 from 2.2.1 to 2.2.2 [[097c15b](https://github.com/lenML/ChatTTS-Forge/commit/097c15ba56f8197a4f26adcfb77336a70e5ed806)]
70
+ - 🎨 run formatter [[8c267e1](https://github.com/lenML/ChatTTS-Forge/commit/8c267e151152fe2090528104627ec031453d4ed5)]
71
+ - ⚡ Optimize &#x60;audio_data_to_segment&#x60; [#57](https://github.com/lenML/ChatTTS-Forge/issues/57) [[d33809c](https://github.com/lenML/ChatTTS-Forge/commit/d33809c60a3ac76a01f71de4fd26b315d066c8d3)]
72
+ - ⚡ map_location&#x3D;&quot;cpu&quot; [[0f58c10](https://github.com/lenML/ChatTTS-Forge/commit/0f58c10a445efaa9829f862acb4fb94bc07f07bf)]
73
+ - ⚡ colab use default GPU [[c7938ad](https://github.com/lenML/ChatTTS-Forge/commit/c7938adb6d3615f37210b1f3cbe4671f93d58285)]
74
+ - ⚡ improve hf calling [[2dde612](https://github.com/lenML/ChatTTS-Forge/commit/2dde6127906ce6e77a970b4cd96e68f7a5417c6a)]
75
+ - 🍱 add &#x60;bob_ft10.pt&#x60; [[9eee965](https://github.com/lenML/ChatTTS-Forge/commit/9eee965425a7d6640eba22d843db4975dd3e355a)]
76
+ - ⚡ enhance SynthesizeSegments [[0bb4dd7](https://github.com/lenML/ChatTTS-Forge/commit/0bb4dd7676c38249f10bf0326174ff8b74b2abae)]
77
+ - 🍱 add &#x60;bob_ft10.pt&#x60; [[bef1b02](https://github.com/lenML/ChatTTS-Forge/commit/bef1b02435c39830612b18738bb31ac48e340fc6)]
78
+ - ♻️ refactor api [[671fcc3](https://github.com/lenML/ChatTTS-Forge/commit/671fcc38a570d0cb7de0a214d318281084c9608c)]
79
+ - ⚡ improve xtts_v2 api [[206fabc](https://github.com/lenML/ChatTTS-Forge/commit/206fabc76f1dbad261c857cb02f8c99c21e99eef)]
80
+ - ⚡ train text &#x3D;&gt; just text [[e2037e0](https://github.com/lenML/ChatTTS-Forge/commit/e2037e0f97f15ff560fce14bbdc3926e3261bff9)]
81
+ - ⚡ improve TN [[a0069ed](https://github.com/lenML/ChatTTS-Forge/commit/a0069ed2d0c3122444e873fb13b9922f9ab88a79)]
82
+
83
+ ### Fixed
84
+
85
+ - 🐛 fix webui speaker_editor missing &#x60;describe&#x60; [[2a2a36d](https://github.com/lenML/ChatTTS-Forge/commit/2a2a36d62d8f253fc2e17ccc558038dbcc99d1ee)]
86
+ - 💚 Dependabot alerts [[f501860](https://github.com/lenML/ChatTTS-Forge/commit/f5018607f602769d4dda7aa00573b9a06e659d91)]
87
+ - 🐛 fix &#x60;numpy&lt;2&#x60; [#50](https://github.com/lenML/ChatTTS-Forge/issues/50) [[e4fea4f](https://github.com/lenML/ChatTTS-Forge/commit/e4fea4f80b31d962f02cd1146ce8c73bf75b6a39)]
88
+ - 🐛 fix Box() index [#49](https://github.com/lenML/ChatTTS-Forge/issues/49) add testcase [[d982e33](https://github.com/lenML/ChatTTS-Forge/commit/d982e33ed30749d7ae6570ade5ec7b560a3d1f06)]
89
+ - 🐛 fix Box() index [#49](https://github.com/lenML/ChatTTS-Forge/issues/49) [[1788318](https://github.com/lenML/ChatTTS-Forge/commit/1788318a96c014a53ee41c4db7d60fdd4b15cfca)]
90
+ - 🐛 fix &#x60;--use_cpu&#x60; [#47](https://github.com/lenML/ChatTTS-Forge/issues/47) update conftest [[4095b08](https://github.com/lenML/ChatTTS-Forge/commit/4095b085c4c6523f2579e00edfb1569d65608ca2)]
91
+ - 🐛 fix &#x60;--use_cpu&#x60; [#47](https://github.com/lenML/ChatTTS-Forge/issues/47) [[221962f](https://github.com/lenML/ChatTTS-Forge/commit/221962fd0f61d3f269918b26a814cbcd5aabd1f0)]
92
+ - 🐛 fix webui speaker args [[3b3c331](https://github.com/lenML/ChatTTS-Forge/commit/3b3c3311dd0add0e567179fc38223a3cc5e56f6e)]
93
+ - 🐛 fix speaker trainer [[52d473f](https://github.com/lenML/ChatTTS-Forge/commit/52d473f37f6a3950d4c8738c294f048f11198776)]
94
+ - 🐛 兼容 win32 [[7ffa37f](https://github.com/lenML/ChatTTS-Forge/commit/7ffa37f3d36fb9ba53ab051b2fce6229920b1208)]
95
+ - 🐛 fix google api ssml synthesize [#43](https://github.com/lenML/ChatTTS-Forge/issues/43) [[1566f88](https://github.com/lenML/ChatTTS-Forge/commit/1566f8891c22d63681d756deba70374e2b75d078)]
96
+
97
+ ### Miscellaneous
98
+
99
+ - Merge pull request [#58](https://github.com/lenML/ChatTTS-Forge/issues/58) from lenML/dependabot/pip/urllib3-2.2.2 [[f259f18](https://github.com/lenML/ChatTTS-Forge/commit/f259f180af57f9a6938b14bf263d0387b6900e57)]
100
+ - 📝 update changelog [[b9da7ec](https://github.com/lenML/ChatTTS-Forge/commit/b9da7ec1afed416a825e9e4a507b8263f69bf47e)]
101
+ - 📝 update [[8439437](https://github.com/lenML/ChatTTS-Forge/commit/84394373de66b81a9f7f70ef8484254190e292ab)]
102
+ - 📝 update [[ef97206](https://github.com/lenML/ChatTTS-Forge/commit/ef972066558d0b229d6d0b3d83bb4f8e8517558f)]
103
+ - 📝 improve readme.md [[7bf3de2](https://github.com/lenML/ChatTTS-Forge/commit/7bf3de2afb41b9a29071bec18ee6306ce8e70183)]
104
+ - 📝 add bug report forms [[091cf09](https://github.com/lenML/ChatTTS-Forge/commit/091cf0958a719236c77107acf4cfb8c0ba090946)]
105
+ - 📝 update changelog [[3d519ec](https://github.com/lenML/ChatTTS-Forge/commit/3d519ec8a20098c2de62631ae586f39053dd89a5)]
106
+ - 📝 update [[66963f8](https://github.com/lenML/ChatTTS-Forge/commit/66963f8ff8f29c298de64cd4a54913b1d3e29a6a)]
107
+ - 📝 update [[b7a63b5](https://github.com/lenML/ChatTTS-Forge/commit/b7a63b59132d2c8dbb4ad2e15bd23713f00f0084)]
108
+
109
+ <a name="0.6.0"></a>
110
+
111
+ ## 0.6.0 (2024-06-12)
112
+
113
+ ### Added
114
+
115
+ - ✨ add XTTSv2 api [#42](https://github.com/lenML/ChatTTS-Forge/issues/42) [[d1fc63c](https://github.com/lenML/ChatTTS-Forge/commit/d1fc63cd1e847d622135c96371bbfe2868a80c19)]
116
+ - ✨ google api 支持 enhancer [[14fecdb](https://github.com/lenML/ChatTTS-Forge/commit/14fecdb8ea0f9a5d872a4c7ca862e901990076c0)]
117
+ - ✨ 修改 podcast 脚本默认 style [[98186c2](https://github.com/lenML/ChatTTS-Forge/commit/98186c25743cbfa24ca7d41336d4ec84aa34aacf)]
118
+ - ✨ playground google api [[4109adb](https://github.com/lenML/ChatTTS-Forge/commit/4109adb317be215970d756b4ba7064c9dc4d6fdc)]
119
+ - ✨ 添加 unload api [[ed9d61a](https://github.com/lenML/ChatTTS-Forge/commit/ed9d61a2fe4ba1d902d91517148f8f7dea47b51b)]
120
+ - ✨ support api workers [[babdada](https://github.com/lenML/ChatTTS-Forge/commit/babdada50e79e425bac4d3074f8e42dfb4c4c33a)]
121
+ - ✨ add ffmpeg version to webui footer [[e9241a1](https://github.com/lenML/ChatTTS-Forge/commit/e9241a1a8d1f5840ae6259e46020684ba70a0efb)]
122
+ - ✨ support use internal ffmpeg [[0e02ab0](https://github.com/lenML/ChatTTS-Forge/commit/0e02ab0f5d81fbfb6166793cb4f6d58c5f17f34c)]
123
+ - ✨ 增加参数 debug_generate [[94e876a](https://github.com/lenML/ChatTTS-Forge/commit/94e876ae3819c3efbde4a239085f91342874bd5a)]
124
+ - ✨ 支持 api 服务与 webui 并存 [[4901491](https://github.com/lenML/ChatTTS-Forge/commit/4901491eced3955c51030388d1dcebf049cd790e)]
125
+ - ✨ refiner api support normalize [[ef665da](https://github.com/lenML/ChatTTS-Forge/commit/ef665dad5a5517c610f0b430bc52a5b0ba3c2d96)]
126
+ - ✨ add webui 音色编辑器 [[fb4c7b3](https://github.com/lenML/ChatTTS-Forge/commit/fb4c7b3b0949ac669da0d069c739934f116b83e2)]
127
  - ✨ add localization [[c05035d](https://github.com/lenML/ChatTTS-Forge/commit/c05035d5cdcc5aa7efd995fe42f6a2541abe718b)]
128
  - ✨ SSML 支持 enhancer [[5c2788e](https://github.com/lenML/ChatTTS-Forge/commit/5c2788e04f3debfa8bafd8a2e2371dde30f38d4d)]
129
  - ✨ webui 增加 podcast 工具 tab [[b0b169d](https://github.com/lenML/ChatTTS-Forge/commit/b0b169d8b49c8e013209e59d1f8b637382d8b997)]
130
+ - ✨ 完善 enhancer [[205ebeb](https://github.com/lenML/ChatTTS-Forge/commit/205ebebeb7530c81fde7ea96c7e4c6a888a29835)]
131
 
132
  ### Changed
133
 
134
+ - ⚡ improve synthesize_audio [[759adc2](https://github.com/lenML/ChatTTS-Forge/commit/759adc2ead1da8395df62ea1724456dad6894eb1)]
135
+ - ⚡ reduce enhancer chunk vram usage [[3464b42](https://github.com/lenML/ChatTTS-Forge/commit/3464b427b14878ee11e03ebdfb91efee1550de59)]
136
+ - ⚡ 增加默认说话人 [[d702ad5](https://github.com/lenML/ChatTTS-Forge/commit/d702ad5ad585978f8650284ab99238571dbd163b)]
137
+ - 🍱 add &#x60;podcast&#x60; &#x60;podcast_p&#x60; style [[2b9e5bf](https://github.com/lenML/ChatTTS-Forge/commit/2b9e5bfd8fe4700f802097b995f5b68bf1097087)]
138
+ - 🎨 improve code [[317951e](https://github.com/lenML/ChatTTS-Forge/commit/317951e431b16c735df31187b1af7230a1608c41)]
139
  - 🍱 update banner [[dbc293e](https://github.com/lenML/ChatTTS-Forge/commit/dbc293e1a7dec35f60020dcaf783ba3b7c734bfa)]
140
  - ⚡ 增强 TN [[092c1b9](https://github.com/lenML/ChatTTS-Forge/commit/092c1b94147249880198fe2ad3dfe3b209099e19)]
141
  - ⚡ enhancer 支持 off_tqdm [[94d34d6](https://github.com/lenML/ChatTTS-Forge/commit/94d34d657fa3433dae9ff61775e0c364a6f77aff)]
142
  - ⚡ 增加 git env [[43d9c65](https://github.com/lenML/ChatTTS-Forge/commit/43d9c65877ff68ad94716bc2e505ccc7ae8869a8)]
143
+ - ⚡ 修改 webui 保存文件格式 [[2da41c9](https://github.com/lenML/ChatTTS-Forge/commit/2da41c90aa81bf87403598aefaea3e0ae2e83d79)]
144
+
145
+ ### Breaking changes
146
+
147
+ - 💥 enhancer support --half [[fef2ed6](https://github.com/lenML/ChatTTS-Forge/commit/fef2ed659fd7fe5a14807d286c209904875ce594)]
148
 
149
  ### Removed
150
 
 
152
 
153
  ### Fixed
154
 
155
+ - 🐛 fix worker env loader [[5b0bf4e](https://github.com/lenML/ChatTTS-Forge/commit/5b0bf4e93738bcd115f006376691c4eaa89b66de)]
156
+ - 🐛 fix colab default lang missing [[d4e5190](https://github.com/lenML/ChatTTS-Forge/commit/d4e51901856305fc039d886a92e38eea2a2cd24d)]
157
+ - 🐛 fix &quot;reflection_pad1d&quot; not implemented for &#x27;Half&#x27; [[536c19b](https://github.com/lenML/ChatTTS-Forge/commit/536c19b7f6dc3f1702fcc2a90daa3277040e70f0)]
158
+ - 🐛 fix [#33](https://github.com/lenML/ChatTTS-Forge/issues/33) [[76e0b58](https://github.com/lenML/ChatTTS-Forge/commit/76e0b5808ede71ebb28edbf0ce0af7d9da9bcb27)]
159
+ - 🐛 fix localization error [[507dbe7](https://github.com/lenML/ChatTTS-Forge/commit/507dbe7a3b92d1419164d24f7804295f6686b439)]
160
+ - 🐛 block main thread [#30](https://github.com/lenML/ChatTTS-Forge/issues/30) [[3a7cbde](https://github.com/lenML/ChatTTS-Forge/commit/3a7cbde6ccdfd20a6c53d7625d4e652007367fbf)]
161
+ - 🐛 fix webui skip no-translate [[a8d595e](https://github.com/lenML/ChatTTS-Forge/commit/a8d595eb490f23c943d6efc35b65b33266c033b7)]
162
+ - 🐛 fix hf.space force abort [[f564536](https://github.com/lenML/ChatTTS-Forge/commit/f5645360dd1f45a7bf112f01c85fb862ee57df3c)]
163
+ - 🐛 fix missing device [#25](https://github.com/lenML/ChatTTS-Forge/issues/25) [[07cf6c1](https://github.com/lenML/ChatTTS-Forge/commit/07cf6c1386900999b6c9436debbfcbe59f6b692a)]
164
+ - 🐛 fix Chat.refiner_prompt() [[0839863](https://github.com/lenML/ChatTTS-Forge/commit/083986369d0e67fcb4bd71930ad3d2bc3fc038fb)]
165
+ - 🐛 fix --language type check [[50d354c](https://github.com/lenML/ChatTTS-Forge/commit/50d354c91c659d9ae16c8eaa0218d9e08275fbb2)]
166
  - 🐛 fix hparams config [#22](https://github.com/lenML/ChatTTS-Forge/issues/22) [[61d9809](https://github.com/lenML/ChatTTS-Forge/commit/61d9809804ad8c141d36afde51a608734a105662)]
167
  - 🐛 fix enhance 下载脚本 [[d2e14b0](https://github.com/lenML/ChatTTS-Forge/commit/d2e14b0a4905724a55b03493fa4b94b5c4383c95)]
168
  - 🐛 fix &#x27;trange&#x27; referenced [[d1a8dae](https://github.com/lenML/ChatTTS-Forge/commit/d1a8daee61e62d14cf5fd7a17fab4424e24b1c41)]
 
172
 
173
  ### Miscellaneous
174
 
175
+ - 🐳 fix docker / 兼容 py 3.9 [[ebb096f](https://github.com/lenML/ChatTTS-Forge/commit/ebb096f9b1b843b65d150fb34da7d3b5acb13011)]
176
+ - 🐳 add .dockerignore [[57262b8](https://github.com/lenML/ChatTTS-Forge/commit/57262b81a8df3ed26ca5da5e264d5dca7b022471)]
177
+ - 🧪 add tests [[a807640](https://github.com/lenML/ChatTTS-Forge/commit/a80764030b790baee45a10cbe2d4edd7f183ef3c)]
178
+ - 🌐 fix [[b34a0f8](https://github.com/lenML/ChatTTS-Forge/commit/b34a0f8654467f3068e43056708742ab69e3665b)]
179
+ - 🌐 remove chat limit desc [[3f81eca](https://github.com/lenML/ChatTTS-Forge/commit/3f81ecae6e4521eeb4e867534defc36be741e1e2)]
180
+ - 🧪 add tests [[7a54225](https://github.com/lenML/ChatTTS-Forge/commit/7a542256a157a281a15312bbf987bc9fb16876ee)]
181
+ - 🔨 improve model downloader [[79a0c59](https://github.com/lenML/ChatTTS-Forge/commit/79a0c599f03b4e47346315a03f1df3d92578fe5d)]
182
  - 🌐 更新翻译文案 [[f56caa7](https://github.com/lenML/ChatTTS-Forge/commit/f56caa71e9186680b93c487d9645186ae18c1dc6)]
 
 
 
183
 
184
  <a name="0.5.5"></a>
185
 
launch.py CHANGED
@@ -5,6 +5,7 @@ from modules.ffmpeg_env import setup_ffmpeg_path
5
 
6
  try:
7
  setup_ffmpeg_path()
 
8
  logging.basicConfig(
9
  level=os.getenv("LOG_LEVEL", "INFO"),
10
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
@@ -16,26 +17,44 @@ import argparse
16
 
17
  import uvicorn
18
 
19
- from modules.api.api_setup import setup_api_args, setup_model_args, setup_uvicon_args
 
20
  from modules.utils import env
 
 
 
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
- if __name__ == "__main__":
25
- import dotenv
26
 
27
- dotenv.load_dotenv(
28
- dotenv_path=os.getenv("ENV_FILE", ".env.api"),
 
 
 
29
  )
30
- parser = argparse.ArgumentParser(
31
- description="Start the FastAPI server with command line arguments"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
- setup_api_args(parser)
34
- setup_model_args(parser)
35
- setup_uvicon_args(parser=parser)
36
 
37
- args = parser.parse_args()
38
 
 
39
  host = env.get_and_update_env(args, "host", "0.0.0.0", str)
40
  port = env.get_and_update_env(args, "port", 7870, int)
41
  reload = env.get_and_update_env(args, "reload", False, bool)
@@ -68,3 +87,22 @@ if __name__ == "__main__":
68
  ssl_certfile=ssl_certfile,
69
  ssl_keyfile_password=ssl_keyfile_password,
70
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  try:
7
  setup_ffmpeg_path()
8
+ # NOTE: 因为 logger 都是在模块中初始化,所以这个 config 必须在最前面
9
  logging.basicConfig(
10
  level=os.getenv("LOG_LEVEL", "INFO"),
11
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
 
17
 
18
  import uvicorn
19
 
20
+ from modules.api.api_setup import setup_api_args
21
+ from modules.models_setup import setup_model_args
22
  from modules.utils import env
23
+ from modules.utils.ignore_warn import ignore_useless_warnings
24
+
25
+ ignore_useless_warnings()
26
 
27
  logger = logging.getLogger(__name__)
28
 
 
 
29
 
30
+ def setup_uvicon_args(parser: argparse.ArgumentParser):
31
+ parser.add_argument("--host", type=str, help="Host to run the server on")
32
+ parser.add_argument("--port", type=int, help="Port to run the server on")
33
+ parser.add_argument(
34
+ "--reload", action="store_true", help="Enable auto-reload for development"
35
  )
36
+ parser.add_argument("--workers", type=int, help="Number of worker processes")
37
+ parser.add_argument("--log_level", type=str, help="Log level")
38
+ parser.add_argument("--access_log", action="store_true", help="Enable access log")
39
+ parser.add_argument(
40
+ "--proxy_headers", action="store_true", help="Enable proxy headers"
41
+ )
42
+ parser.add_argument(
43
+ "--timeout_keep_alive", type=int, help="Keep-alive timeout duration"
44
+ )
45
+ parser.add_argument(
46
+ "--timeout_graceful_shutdown",
47
+ type=int,
48
+ help="Graceful shutdown timeout duration",
49
+ )
50
+ parser.add_argument("--ssl_keyfile", type=str, help="SSL key file path")
51
+ parser.add_argument("--ssl_certfile", type=str, help="SSL certificate file path")
52
+ parser.add_argument(
53
+ "--ssl_keyfile_password", type=str, help="SSL key file password"
54
  )
 
 
 
55
 
 
56
 
57
+ def process_uvicon_args(args):
58
  host = env.get_and_update_env(args, "host", "0.0.0.0", str)
59
  port = env.get_and_update_env(args, "port", 7870, int)
60
  reload = env.get_and_update_env(args, "reload", False, bool)
 
87
  ssl_certfile=ssl_certfile,
88
  ssl_keyfile_password=ssl_keyfile_password,
89
  )
90
+
91
+
92
+ if __name__ == "__main__":
93
+ import dotenv
94
+
95
+ dotenv.load_dotenv(
96
+ dotenv_path=os.getenv("ENV_FILE", ".env.api"),
97
+ )
98
+ parser = argparse.ArgumentParser(
99
+ description="Start the FastAPI server with command line arguments"
100
+ )
101
+ # NOTE: 主进程中不需要处理 model args / api args,但是要接收这些参数, 具体处理在 worker.py 中
102
+ setup_api_args(parser=parser)
103
+ setup_model_args(parser=parser)
104
+ setup_uvicon_args(parser=parser)
105
+
106
+ args = parser.parse_args()
107
+
108
+ process_uvicon_args(args)
modules/ChatTTS/ChatTTS/core.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  import os
3
 
 
4
  import torch
5
  from huggingface_hub import snapshot_download
6
  from omegaconf import OmegaConf
@@ -142,9 +143,12 @@ class Chat:
142
  gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=map_location))
143
  if compile and "cuda" in str(device):
144
  self.logger.info("compile gpt model")
145
- gpt.gpt.forward = torch.compile(
146
- gpt.gpt.forward, backend="inductor", dynamic=True
147
- )
 
 
 
148
  self.pretrain_models["gpt"] = gpt
149
  spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
150
  assert os.path.exists(
@@ -173,7 +177,7 @@ class Chat:
173
 
174
  self.check_model()
175
 
176
- def infer(
177
  self,
178
  text,
179
  skip_refine_text=False,
@@ -181,9 +185,11 @@ class Chat:
181
  params_refine_text={},
182
  params_infer_code={"prompt": "[speed_5]"},
183
  use_decoder=True,
 
 
184
  ):
185
 
186
- assert self.check_model(use_decoder=use_decoder)
187
 
188
  if not isinstance(text, list):
189
  text = [text]
@@ -192,122 +198,147 @@ class Chat:
192
  reserved_tokens = self.pretrain_models[
193
  "tokenizer"
194
  ].additional_special_tokens
195
- invalid_characters = count_invalid_characters(t, reserved_tokens)
 
 
196
  if len(invalid_characters):
197
  self.logger.log(
198
  logging.WARNING, f"Invalid characters found! : {invalid_characters}"
199
  )
200
- text[i] = apply_character_map(t)
201
 
202
  if not skip_refine_text:
203
- text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)[
204
- "ids"
205
- ]
206
-
207
- text_tokens = [
208
- i[
209
- i
210
- < self.pretrain_models["tokenizer"].convert_tokens_to_ids(
211
- "[break_0]"
212
- )
 
 
 
213
  ]
214
- for i in text_tokens
215
- ]
216
- text = self.pretrain_models["tokenizer"].batch_decode(text_tokens)
217
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  if refine_text_only:
219
- return text
220
 
221
  text = [params_infer_code.get("prompt", "") + i for i in text]
222
  params_infer_code.pop("prompt", "")
223
- result = infer_code(
224
- self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder
 
 
 
 
225
  )
226
-
227
  if use_decoder:
228
- mel_spec = [
229
- self.pretrain_models["decoder"](i[None].permute(0, 2, 1))
230
- for i in result["hiddens"]
231
- ]
232
  else:
233
- mel_spec = [
234
- self.pretrain_models["dvae"](i[None].permute(0, 2, 1))
235
- for i in result["ids"]
236
- ]
237
-
238
- wav = [self.pretrain_models["vocos"].decode(i).cpu().numpy() for i in mel_spec]
239
-
240
- return wav
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
- def refiner_prompt(
243
  self,
244
  text,
 
 
245
  params_refine_text={},
246
- ) -> str:
247
-
248
- # assert self.check_model(use_decoder=False)
249
-
250
- if not isinstance(text, list):
251
- text = [text]
252
-
253
- for i, t in enumerate(text):
254
- reserved_tokens = self.pretrain_models[
255
- "tokenizer"
256
- ].additional_special_tokens
257
- invalid_characters = count_invalid_characters(t, reserved_tokens)
258
- if len(invalid_characters):
259
- self.logger.log(
260
- logging.WARNING, f"Invalid characters found! : {invalid_characters}"
261
- )
262
- text[i] = apply_character_map(t)
263
-
264
- text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)[
265
- "ids"
266
- ]
267
- text_tokens = [
268
- i[i < self.pretrain_models["tokenizer"].convert_tokens_to_ids("[break_0]")]
269
- for i in text_tokens
270
- ]
271
- text = self.pretrain_models["tokenizer"].batch_decode(text_tokens)
272
-
273
- return text[0]
274
 
275
  def generate_audio(
276
  self,
277
  prompt,
278
  params_infer_code={"prompt": "[speed_5]"},
279
  use_decoder=True,
280
- ) -> list:
281
-
282
- # assert self.check_model(use_decoder=use_decoder)
283
-
284
- if not isinstance(prompt, list):
285
- prompt = [prompt]
286
-
287
- prompt = [params_infer_code.get("prompt", "") + i for i in prompt]
288
- params_infer_code.pop("prompt", "")
289
- result = infer_code(
290
- self.pretrain_models,
291
  prompt,
292
- return_hidden=use_decoder,
293
- **params_infer_code,
 
 
294
  )
295
 
296
- if use_decoder:
297
- mel_spec = [
298
- self.pretrain_models["decoder"](i[None].permute(0, 2, 1))
299
- for i in result["hiddens"]
300
- ]
301
- else:
302
- mel_spec = [
303
- self.pretrain_models["dvae"](i[None].permute(0, 2, 1))
304
- for i in result["ids"]
305
- ]
306
-
307
- wav = [self.pretrain_models["vocos"].decode(i).cpu().numpy() for i in mel_spec]
308
-
309
- return wav
310
-
311
  def sample_random_speaker(
312
  self,
313
  ) -> torch.Tensor:
 
1
  import logging
2
  import os
3
 
4
+ import numpy as np
5
  import torch
6
  from huggingface_hub import snapshot_download
7
  from omegaconf import OmegaConf
 
143
  gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=map_location))
144
  if compile and "cuda" in str(device):
145
  self.logger.info("compile gpt model")
146
+ try:
147
+ gpt.gpt.forward = torch.compile(
148
+ gpt.gpt.forward, backend="inductor", dynamic=True
149
+ )
150
+ except RuntimeError as e:
151
+ logging.warning(f"Compile failed,{e}. fallback to normal mode.")
152
  self.pretrain_models["gpt"] = gpt
153
  spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
154
  assert os.path.exists(
 
177
 
178
  self.check_model()
179
 
180
+ def _infer(
181
  self,
182
  text,
183
  skip_refine_text=False,
 
185
  params_refine_text={},
186
  params_infer_code={"prompt": "[speed_5]"},
187
  use_decoder=True,
188
+ stream=False,
189
+ stream_text=False,
190
  ):
191
 
192
+ # assert self.check_model(use_decoder=use_decoder)
193
 
194
  if not isinstance(text, list):
195
  text = [text]
 
198
  reserved_tokens = self.pretrain_models[
199
  "tokenizer"
200
  ].additional_special_tokens
201
+ invalid_characters = count_invalid_characters(
202
+ t, reserved_tokens=reserved_tokens
203
+ )
204
  if len(invalid_characters):
205
  self.logger.log(
206
  logging.WARNING, f"Invalid characters found! : {invalid_characters}"
207
  )
208
+ text[i] = apply_character_map(t, reserved_tokens=reserved_tokens)
209
 
210
  if not skip_refine_text:
211
+ text_tokens_gen = refine_text(
212
+ self.pretrain_models, text, stream=stream, **params_refine_text
213
+ )
214
+
215
+ def decode_text(text_tokens):
216
+ text_tokens = [
217
+ i[
218
+ i
219
+ < self.pretrain_models["tokenizer"].convert_tokens_to_ids(
220
+ "[break_0]"
221
+ )
222
+ ]
223
+ for i in text_tokens
224
  ]
225
+ text = self.pretrain_models["tokenizer"].batch_decode(text_tokens)
226
+ return text
 
227
 
228
+ if stream_text:
229
+ for result in text_tokens_gen:
230
+ text_incomplete = decode_text(result["ids"])
231
+ if refine_text_only and stream:
232
+ yield text_incomplete
233
+ if refine_text_only:
234
+ return
235
+ else:
236
+ result = next(text_tokens_gen)
237
+ text = decode_text(result["ids"])
238
+ if refine_text_only:
239
+ yield text
240
  if refine_text_only:
241
+ return
242
 
243
  text = [params_infer_code.get("prompt", "") + i for i in text]
244
  params_infer_code.pop("prompt", "")
245
+ result_gen = infer_code(
246
+ self.pretrain_models,
247
+ text,
248
+ **params_infer_code,
249
+ return_hidden=use_decoder,
250
+ stream=stream,
251
  )
 
252
  if use_decoder:
253
+ field = "hiddens"
254
+ docoder_name = "decoder"
 
 
255
  else:
256
+ field = "ids"
257
+ docoder_name = "dvae"
258
+ vocos_decode = lambda spec: [
259
+ self.pretrain_models["vocos"]
260
+ .decode(i.cpu() if torch.backends.mps.is_available() else i)
261
+ .cpu()
262
+ .numpy()
263
+ for i in spec
264
+ ]
265
+ if stream:
266
+
267
+ length = 0
268
+ for result in result_gen:
269
+ chunk_data = result[field][0]
270
+ assert len(result[field]) == 1
271
+ start_seek = length
272
+ length = len(chunk_data)
273
+ self.logger.debug(
274
+ f"{start_seek=} total len: {length}, new len: {length - start_seek = }"
275
+ )
276
+ chunk_data = chunk_data[start_seek:]
277
+ if not len(chunk_data):
278
+ continue
279
+ self.logger.debug(f"new hidden {len(chunk_data)=}")
280
+ mel_spec = [
281
+ self.pretrain_models[docoder_name](i[None].permute(0, 2, 1))
282
+ for i in [chunk_data]
283
+ ]
284
+ wav = vocos_decode(mel_spec)
285
+ self.logger.debug(f"yield wav chunk {len(wav[0])=} {len(wav[0][0])=}")
286
+ yield wav
287
+ return
288
+ mel_spec = [
289
+ self.pretrain_models[docoder_name](i[None].permute(0, 2, 1))
290
+ for i in next(result_gen)[field]
291
+ ]
292
+ yield vocos_decode(mel_spec)
293
 
294
+ def infer(
295
  self,
296
  text,
297
+ skip_refine_text=False,
298
+ refine_text_only=False,
299
  params_refine_text={},
300
+ params_infer_code={"prompt": "[speed_5]"},
301
+ use_decoder=True,
302
+ stream=False,
303
+ ):
304
+ res_gen = self._infer(
305
+ text=text,
306
+ skip_refine_text=skip_refine_text,
307
+ refine_text_only=refine_text_only,
308
+ params_refine_text=params_refine_text,
309
+ params_infer_code=params_infer_code,
310
+ use_decoder=use_decoder,
311
+ stream=stream,
312
+ )
313
+ if stream:
314
+ return res_gen
315
+ else:
316
+ return next(res_gen)
317
+
318
+ def refiner_prompt(self, text, params_refine_text={}, stream=False):
319
+ return self.infer(
320
+ text=text,
321
+ skip_refine_text=False,
322
+ refine_text_only=True,
323
+ params_refine_text=params_refine_text,
324
+ stream=stream,
325
+ )
 
 
326
 
327
  def generate_audio(
328
  self,
329
  prompt,
330
  params_infer_code={"prompt": "[speed_5]"},
331
  use_decoder=True,
332
+ stream=False,
333
+ ):
334
+ return self.infer(
 
 
 
 
 
 
 
 
335
  prompt,
336
+ skip_refine_text=True,
337
+ params_infer_code=params_infer_code,
338
+ use_decoder=use_decoder,
339
+ stream=stream,
340
  )
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  def sample_random_speaker(
343
  self,
344
  ) -> torch.Tensor:
modules/ChatTTS/ChatTTS/infer/api.py CHANGED
@@ -17,6 +17,7 @@ def infer_code(
17
  prompt1="",
18
  prompt2="",
19
  prefix="",
 
20
  **kwargs,
21
  ):
22
 
@@ -83,6 +84,7 @@ def infer_code(
83
  eos_token=num_code,
84
  max_new_token=max_new_token,
85
  infer_text=False,
 
86
  **kwargs,
87
  )
88
 
@@ -98,6 +100,7 @@ def refine_text(
98
  repetition_penalty=1.0,
99
  max_new_token=384,
100
  prompt="",
 
101
  **kwargs,
102
  ):
103
  device = next(models["gpt"].parameters()).device
@@ -152,6 +155,7 @@ def refine_text(
152
  )[None],
153
  max_new_token=max_new_token,
154
  infer_text=True,
 
155
  **kwargs,
156
  )
157
  return result
 
17
  prompt1="",
18
  prompt2="",
19
  prefix="",
20
+ stream=False,
21
  **kwargs,
22
  ):
23
 
 
84
  eos_token=num_code,
85
  max_new_token=max_new_token,
86
  infer_text=False,
87
+ stream=stream,
88
  **kwargs,
89
  )
90
 
 
100
  repetition_penalty=1.0,
101
  max_new_token=384,
102
  prompt="",
103
+ stream=False,
104
  **kwargs,
105
  ):
106
  device = next(models["gpt"].parameters()).device
 
155
  )[None],
156
  max_new_token=max_new_token,
157
  infer_text=True,
158
+ stream=stream,
159
  **kwargs,
160
  )
161
  return result
modules/ChatTTS/ChatTTS/model/gpt.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
 
5
  import logging
 
6
 
7
  import torch
8
  import torch.nn as nn
@@ -37,7 +38,6 @@ class GPT_warpper(nn.Module):
37
  num_audio_tokens,
38
  num_text_tokens,
39
  num_vq=4,
40
- **kwargs,
41
  ):
42
  super().__init__()
43
 
@@ -211,12 +211,13 @@ class GPT_warpper(nn.Module):
211
  infer_text=False,
212
  return_attn=False,
213
  return_hidden=False,
 
214
  disable_tqdm=False,
215
  ):
 
 
216
  if disable_tqdm:
217
- tqdm = lambda x: x
218
- else:
219
- from tqdm import tqdm
220
 
221
  with torch.no_grad():
222
 
@@ -242,90 +243,136 @@ class GPT_warpper(nn.Module):
242
  if attention_mask is not None:
243
  attention_mask_cache[:, : attention_mask.shape[1]] = attention_mask
244
 
245
- for i in tqdm(range(max_new_token)):
246
- if finish.all():
247
- continue
248
 
249
- model_input = self.prepare_inputs_for_generation(
250
- inputs_ids,
251
- outputs.past_key_values if i != 0 else None,
252
- attention_mask_cache[:, : inputs_ids.shape[1]],
253
- use_cache=True,
254
- )
255
 
256
- if i == 0:
257
- model_input["inputs_embeds"] = emb
258
- else:
259
- if infer_text:
260
- model_input["inputs_embeds"] = self.emb_text(
261
- model_input["input_ids"][:, :, 0]
262
- )
263
- else:
264
- code_emb = [
265
- self.emb_code[i](model_input["input_ids"][:, :, i])
266
- for i in range(self.num_vq)
267
- ]
268
- model_input["inputs_embeds"] = torch.stack(code_emb, 3).sum(3)
269
-
270
- model_input["input_ids"] = None
271
- outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
272
- attentions.append(outputs.attentions)
273
- hidden_states = outputs[0] # 🐻
274
- if return_hidden:
275
- hiddens.append(hidden_states[:, -1])
276
-
277
- with P.cached():
278
- if infer_text:
279
- logits = self.head_text(hidden_states)
280
  else:
281
- logits = torch.stack(
282
- [
283
- self.head_code[i](hidden_states)
 
 
 
 
284
  for i in range(self.num_vq)
285
- ],
286
- 3,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  )
 
 
288
 
289
- logits = logits[:, -1].float()
290
 
291
- if not infer_text:
292
- logits = rearrange(logits, "b c n -> (b n) c")
293
- logits_token = rearrange(
294
- inputs_ids[:, start_idx:], "b c n -> (b n) c"
295
- )
296
- else:
297
- logits_token = inputs_ids[:, start_idx:, 0]
298
 
299
- logits = logits / temperature
 
300
 
301
- for logitsProcessors in LogitsProcessors:
302
- logits = logitsProcessors(logits_token, logits)
303
 
304
- for logitsWarpers in LogitsWarpers:
305
- logits = logitsWarpers(logits_token, logits)
306
 
307
- if i < min_new_token:
308
- logits[:, eos_token] = -torch.inf
309
 
310
- scores = F.softmax(logits, dim=-1)
311
 
312
- idx_next = torch.multinomial(scores, num_samples=1)
313
 
314
- if not infer_text:
315
- idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
316
- finish = finish | (idx_next == eos_token).any(1)
317
- inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
318
- else:
319
- finish = finish | (idx_next == eos_token).any(1)
320
- inputs_ids = torch.cat(
321
- [
322
- inputs_ids,
323
- idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq),
324
- ],
325
- 1,
326
- )
 
 
 
 
327
 
328
- end_idx = end_idx + (~finish).int()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  inputs_ids = [
331
  inputs_ids[idx, start_idx : start_idx + i]
@@ -342,7 +389,9 @@ class GPT_warpper(nn.Module):
342
  f"Incomplete result. hit max_new_token: {max_new_token}"
343
  )
344
 
345
- return {
 
 
346
  "ids": inputs_ids,
347
  "attentions": attentions,
348
  "hiddens": hiddens,
 
3
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
 
5
  import logging
6
+ from functools import partial
7
 
8
  import torch
9
  import torch.nn as nn
 
38
  num_audio_tokens,
39
  num_text_tokens,
40
  num_vq=4,
 
41
  ):
42
  super().__init__()
43
 
 
211
  infer_text=False,
212
  return_attn=False,
213
  return_hidden=False,
214
+ stream=False,
215
  disable_tqdm=False,
216
  ):
217
+ from tqdm import tqdm
218
+
219
  if disable_tqdm:
220
+ tqdm = partial(tqdm, disable=True)
 
 
221
 
222
  with torch.no_grad():
223
 
 
243
  if attention_mask is not None:
244
  attention_mask_cache[:, : attention_mask.shape[1]] = attention_mask
245
 
246
+ with tqdm(total=max_new_token) as pbar:
 
 
247
 
248
+ past_key_values = None
 
 
 
 
 
249
 
250
+ for i in range(max_new_token):
251
+ pbar.update(1)
252
+ model_input = self.prepare_inputs_for_generation(
253
+ inputs_ids,
254
+ past_key_values,
255
+ attention_mask_cache[:, : inputs_ids.shape[1]],
256
+ use_cache=True,
257
+ )
258
+
259
+ if i == 0:
260
+ model_input["inputs_embeds"] = emb
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  else:
262
+ if infer_text:
263
+ model_input["inputs_embeds"] = self.emb_text(
264
+ model_input["input_ids"][:, :, 0]
265
+ )
266
+ else:
267
+ code_emb = [
268
+ self.emb_code[i](model_input["input_ids"][:, :, i])
269
  for i in range(self.num_vq)
270
+ ]
271
+ model_input["inputs_embeds"] = torch.stack(code_emb, 3).sum(
272
+ 3
273
+ )
274
+
275
+ model_input["input_ids"] = None
276
+ outputs = self.gpt.forward(
277
+ **model_input, output_attentions=return_attn
278
+ )
279
+ del model_input
280
+ attentions.append(outputs.attentions)
281
+ hidden_states = outputs[0] # 🐻
282
+ past_key_values = outputs.past_key_values
283
+ del outputs
284
+ if return_hidden:
285
+ hiddens.append(hidden_states[:, -1])
286
+
287
+ with P.cached():
288
+ if infer_text:
289
+ logits = self.head_text(hidden_states)
290
+ else:
291
+ logits = torch.stack(
292
+ [
293
+ self.head_code[i](hidden_states)
294
+ for i in range(self.num_vq)
295
+ ],
296
+ 3,
297
+ )
298
+
299
+ logits = logits[:, -1].float()
300
+
301
+ if not infer_text:
302
+ logits = rearrange(logits, "b c n -> (b n) c")
303
+ logits_token = rearrange(
304
+ inputs_ids[:, start_idx:], "b c n -> (b n) c"
305
  )
306
+ else:
307
+ logits_token = inputs_ids[:, start_idx:, 0]
308
 
309
+ logits = logits / temperature
310
 
311
+ for logitsProcessors in LogitsProcessors:
312
+ logits = logitsProcessors(logits_token, logits)
 
 
 
 
 
313
 
314
+ for logitsWarpers in LogitsWarpers:
315
+ logits = logitsWarpers(logits_token, logits)
316
 
317
+ del logits_token
 
318
 
319
+ if i < min_new_token:
320
+ logits[:, eos_token] = -torch.inf
321
 
322
+ scores = F.softmax(logits, dim=-1)
 
323
 
324
+ del logits
325
 
326
+ idx_next = torch.multinomial(scores, num_samples=1)
327
 
328
+ if not infer_text:
329
+ idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
330
+ finish_or = (idx_next == eos_token).any(1)
331
+ finish |= finish_or
332
+ del finish_or
333
+ inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
334
+ else:
335
+ finish_or = (idx_next == eos_token).any(1)
336
+ finish |= finish_or
337
+ del finish_or
338
+ inputs_ids = torch.cat(
339
+ [
340
+ inputs_ids,
341
+ idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq),
342
+ ],
343
+ 1,
344
+ )
345
 
346
+ del idx_next
347
+
348
+ end_idx += (~finish).int().to(end_idx.device)
349
+ if stream:
350
+ if end_idx % 24 and not finish.all():
351
+ continue
352
+ y_inputs_ids = [
353
+ inputs_ids[idx, start_idx : start_idx + i]
354
+ for idx, i in enumerate(end_idx.int())
355
+ ]
356
+ y_inputs_ids = (
357
+ [i[:, 0] for i in y_inputs_ids]
358
+ if infer_text
359
+ else y_inputs_ids
360
+ )
361
+ y_hiddens = [[]]
362
+ if return_hidden:
363
+ y_hiddens = torch.stack(hiddens, 1)
364
+ y_hiddens = [
365
+ y_hiddens[idx, :i]
366
+ for idx, i in enumerate(end_idx.int())
367
+ ]
368
+ yield {
369
+ "ids": y_inputs_ids,
370
+ "attentions": attentions,
371
+ "hiddens": y_hiddens,
372
+ }
373
+ if finish.all():
374
+ pbar.update(max_new_token - i - 1)
375
+ break
376
 
377
  inputs_ids = [
378
  inputs_ids[idx, start_idx : start_idx + i]
 
389
  f"Incomplete result. hit max_new_token: {max_new_token}"
390
  )
391
 
392
+ del finish
393
+
394
+ yield {
395
  "ids": inputs_ids,
396
  "attentions": attentions,
397
  "hiddens": hiddens,
modules/ChatTTS/ChatTTS/utils/infer_utils.py CHANGED
@@ -24,6 +24,7 @@ class CustomRepetitionPenaltyLogitsProcessorRepeat:
24
  freq = F.one_hot(input_ids, scores.size(1)).sum(1)
25
  freq[self.max_input_ids :] = 0
26
  alpha = self.penalty**freq
 
27
  scores = torch.where(scores < 0, scores * alpha, scores / alpha)
28
 
29
  return scores
@@ -145,11 +146,35 @@ halfwidth_2_fullwidth_map = {
145
  }
146
 
147
 
148
- def apply_half2full_map(text):
149
- translation_table = str.maketrans(halfwidth_2_fullwidth_map)
150
- return text.translate(translation_table)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
 
153
- def apply_character_map(text):
154
- translation_table = str.maketrans(character_map)
155
- return text.translate(translation_table)
 
 
24
  freq = F.one_hot(input_ids, scores.size(1)).sum(1)
25
  freq[self.max_input_ids :] = 0
26
  alpha = self.penalty**freq
27
+ scores = scores.contiguous()
28
  scores = torch.where(scores < 0, scores * alpha, scores / alpha)
29
 
30
  return scores
 
146
  }
147
 
148
 
149
+ def replace_unsupported_chars(text, replace_dict, reserved_tokens: list = []):
150
+ escaped_tokens = [re.escape(token) for token in reserved_tokens]
151
+ special_tokens_pattern = "|".join(escaped_tokens)
152
+ tokens = re.split(f"({special_tokens_pattern})", text)
153
+
154
+ def replace_chars(segment):
155
+ for old_char, new_char in replace_dict.items():
156
+ segment = segment.replace(old_char, new_char)
157
+ return segment
158
+
159
+ result = "".join(
160
+ (
161
+ replace_chars(segment)
162
+ if not re.match(special_tokens_pattern, segment)
163
+ else segment
164
+ )
165
+ for segment in tokens
166
+ )
167
+
168
+ return result
169
+
170
+
171
+ def apply_half2full_map(text, reserved_tokens: list = []):
172
+ return replace_unsupported_chars(
173
+ text, halfwidth_2_fullwidth_map, reserved_tokens=reserved_tokens
174
+ )
175
 
176
 
177
+ def apply_character_map(text, reserved_tokens: list = []):
178
+ return replace_unsupported_chars(
179
+ text, character_map, reserved_tokens=reserved_tokens
180
+ )
modules/Enhancer/ResembleEnhance.py CHANGED
@@ -85,7 +85,7 @@ def load_enhancer() -> ResembleEnhance:
85
  if resemble_enhance is None:
86
  logger.info("Loading ResembleEnhance model")
87
  resemble_enhance = ResembleEnhance(
88
- device=devices.device, dtype=devices.dtype
89
  )
90
  resemble_enhance.load_model()
91
  logger.info("ResembleEnhance model loaded")
 
85
  if resemble_enhance is None:
86
  logger.info("Loading ResembleEnhance model")
87
  resemble_enhance = ResembleEnhance(
88
+ device=devices.get_device_for("enhancer"), dtype=devices.dtype
89
  )
90
  resemble_enhance.load_model()
91
  logger.info("ResembleEnhance model loaded")
modules/SentenceSplitter.py CHANGED
@@ -2,6 +2,8 @@ import re
2
 
3
  import zhon
4
 
 
 
5
 
6
  def split_zhon_sentence(text):
7
  result = []
@@ -21,6 +23,35 @@ def split_zhon_sentence(text):
21
  return result
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # 解析文本 并根据停止符号分割成句子
25
  # 可以设置最大阈值,即如果分割片段小于这个阈值会与下一段合并
26
  class SentenceSplitter:
@@ -28,7 +59,7 @@ class SentenceSplitter:
28
  self.sentence_threshold = threshold
29
 
30
  def parse(self, text):
31
- sentences = split_zhon_sentence(text)
32
 
33
  # 合并小于最大阈值的片段
34
  merged_sentences = []
 
2
 
3
  import zhon
4
 
5
+ from modules.utils.detect_lang import guess_lang
6
+
7
 
8
  def split_zhon_sentence(text):
9
  result = []
 
23
  return result
24
 
25
 
26
+ def split_en_sentence(text):
27
+ """
28
+ Split English text into sentences.
29
+ """
30
+ # Define a regex pattern for English sentence splitting
31
+ pattern = re.compile(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s")
32
+ result = pattern.split(text)
33
+
34
+ # Filter out any empty strings or strings that are just whitespace
35
+ result = [sentence.strip() for sentence in result if sentence.strip()]
36
+
37
+ return result
38
+
39
+
40
+ def is_eng_sentence(text):
41
+ return guess_lang(text) == "en"
42
+
43
+
44
+ def split_zhon_paragraph(text):
45
+ lines = text.split("\n")
46
+ result = []
47
+ for line in lines:
48
+ if is_eng_sentence(line):
49
+ result.extend(split_en_sentence(line))
50
+ else:
51
+ result.extend(split_zhon_sentence(line))
52
+ return result
53
+
54
+
55
  # 解析文本 并根据停止符号分割成句子
56
  # 可以设置最大阈值,即如果分割片段小于这个阈值会与下一段合并
57
  class SentenceSplitter:
 
59
  self.sentence_threshold = threshold
60
 
61
  def parse(self, text):
62
+ sentences = split_zhon_paragraph(text)
63
 
64
  # 合并小于最大阈值的片段
65
  merged_sentences = []
modules/SynthesizeSegments.py CHANGED
@@ -1,4 +1,5 @@
1
  import copy
 
2
  import json
3
  import logging
4
  import re
@@ -7,6 +8,7 @@ from typing import List, Union
7
  import numpy as np
8
  from box import Box
9
  from pydub import AudioSegment
 
10
 
11
  from modules import generate_audio
12
  from modules.api.utils import calc_spk_style
@@ -15,15 +17,39 @@ from modules.SentenceSplitter import SentenceSplitter
15
  from modules.speaker import Speaker
16
  from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLContext, SSMLSegment
17
  from modules.utils import rng
18
- from modules.utils.audio import pitch_shift, time_stretch
19
 
20
  logger = logging.getLogger(__name__)
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def audio_data_to_segment(audio_data: np.ndarray, sr: int):
24
  """
25
  optimize: https://github.com/lenML/ChatTTS-Forge/issues/57
26
  """
 
 
 
 
27
  audio_data = (audio_data * 32767).astype(np.int16)
28
  audio_segment = AudioSegment(
29
  audio_data.tobytes(),
@@ -41,21 +67,6 @@ def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment:
41
  return combined_audio
42
 
43
 
44
- def apply_prosody(
45
- audio_segment: AudioSegment, rate: float, volume: float, pitch: float
46
- ) -> AudioSegment:
47
- if rate != 1:
48
- audio_segment = time_stretch(audio_segment, rate)
49
-
50
- if volume != 0:
51
- audio_segment += volume
52
-
53
- if pitch != 0:
54
- audio_segment = pitch_shift(audio_segment, pitch)
55
-
56
- return audio_segment
57
-
58
-
59
  def to_number(value, t, default=0):
60
  try:
61
  number = t(value)
@@ -202,7 +213,9 @@ class SynthesizeSegments:
202
  pitch = float(segment.get("pitch", "0"))
203
 
204
  audio_segment = audio_data_to_segment(audio_data, sr)
205
- audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
 
 
206
  # compare by Box object
207
  original_index = src_segments.index(segment)
208
  audio_segments[original_index] = audio_segment
 
1
  import copy
2
+ import io
3
  import json
4
  import logging
5
  import re
 
8
  import numpy as np
9
  from box import Box
10
  from pydub import AudioSegment
11
+ from scipy.io import wavfile
12
 
13
  from modules import generate_audio
14
  from modules.api.utils import calc_spk_style
 
17
  from modules.speaker import Speaker
18
  from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLContext, SSMLSegment
19
  from modules.utils import rng
20
+ from modules.utils.audio import apply_prosody_to_audio_segment
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
 
25
+ def audio_data_to_segment_slow(audio_data, sr):
26
+ byte_io = io.BytesIO()
27
+ wavfile.write(byte_io, rate=sr, data=audio_data)
28
+ byte_io.seek(0)
29
+
30
+ return AudioSegment.from_file(byte_io, format="wav")
31
+
32
+
33
+ def clip_audio(audio_data: np.ndarray, threshold: float = 0.99):
34
+ audio_data = np.clip(audio_data, -threshold, threshold)
35
+ return audio_data
36
+
37
+
38
+ def normalize_audio(audio_data: np.ndarray, norm_factor: float = 0.8):
39
+ max_amplitude = np.max(np.abs(audio_data))
40
+ if max_amplitude > 0:
41
+ audio_data = audio_data / max_amplitude * norm_factor
42
+ return audio_data
43
+
44
+
45
  def audio_data_to_segment(audio_data: np.ndarray, sr: int):
46
  """
47
  optimize: https://github.com/lenML/ChatTTS-Forge/issues/57
48
  """
49
+
50
+ audio_data = normalize_audio(audio_data)
51
+ audio_data = clip_audio(audio_data)
52
+
53
  audio_data = (audio_data * 32767).astype(np.int16)
54
  audio_segment = AudioSegment(
55
  audio_data.tobytes(),
 
67
  return combined_audio
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def to_number(value, t, default=0):
71
  try:
72
  number = t(value)
 
213
  pitch = float(segment.get("pitch", "0"))
214
 
215
  audio_segment = audio_data_to_segment(audio_data, sr)
216
+ audio_segment = apply_prosody_to_audio_segment(
217
+ audio_segment, rate=rate, volume=volume, pitch=pitch
218
+ )
219
  # compare by Box object
220
  original_index = src_segments.index(segment)
221
  audio_segments[original_index] = audio_segment
modules/api/api_setup.py CHANGED
@@ -1,7 +1,9 @@
1
  import argparse
2
  import logging
3
 
4
- from modules import config, generate_audio
 
 
5
  from modules.api.Api import APIManager
6
  from modules.api.impl import (
7
  google_api,
@@ -15,15 +17,12 @@ from modules.api.impl import (
15
  tts_api,
16
  xtts_v2_api,
17
  )
18
- from modules.devices import devices
19
- from modules.Enhancer.ResembleEnhance import load_enhancer
20
- from modules.models import load_chat_tts
21
  from modules.utils import env
22
 
23
  logger = logging.getLogger(__name__)
24
 
25
 
26
- def create_api(app, exclude=[]):
27
  app_mgr = APIManager(app=app, exclude_patterns=exclude)
28
 
29
  ping_api.setup(app_mgr)
@@ -40,98 +39,6 @@ def create_api(app, exclude=[]):
40
  return app_mgr
41
 
42
 
43
- def setup_model_args(parser: argparse.ArgumentParser):
44
- parser.add_argument("--compile", action="store_true", help="Enable model compile")
45
- parser.add_argument(
46
- "--no_half",
47
- action="store_true",
48
- help="Disalbe half precision for model inference",
49
- )
50
- parser.add_argument(
51
- "--off_tqdm",
52
- action="store_true",
53
- help="Disable tqdm progress bar",
54
- )
55
- parser.add_argument(
56
- "--device_id",
57
- type=str,
58
- help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
59
- default=None,
60
- )
61
- parser.add_argument(
62
- "--use_cpu",
63
- nargs="+",
64
- help="use CPU as torch device for specified modules",
65
- default=[],
66
- type=str.lower,
67
- )
68
- parser.add_argument(
69
- "--lru_size",
70
- type=int,
71
- default=64,
72
- help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
73
- )
74
- parser.add_argument(
75
- "--debug_generate",
76
- action="store_true",
77
- help="Enable debug mode for audio generation",
78
- )
79
- parser.add_argument(
80
- "--preload_models",
81
- action="store_true",
82
- help="Preload all models at startup",
83
- )
84
-
85
-
86
- def process_model_args(args):
87
- lru_size = env.get_and_update_env(args, "lru_size", 64, int)
88
- compile = env.get_and_update_env(args, "compile", False, bool)
89
- device_id = env.get_and_update_env(args, "device_id", None, str)
90
- use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
91
- no_half = env.get_and_update_env(args, "no_half", False, bool)
92
- off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
93
- debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
94
- preload_models = env.get_and_update_env(args, "preload_models", False, bool)
95
-
96
- generate_audio.setup_lru_cache()
97
- devices.reset_device()
98
- devices.first_time_calculation()
99
-
100
- if debug_generate:
101
- generate_audio.logger.setLevel(logging.DEBUG)
102
-
103
- if preload_models:
104
- load_chat_tts()
105
- load_enhancer()
106
-
107
-
108
- def setup_uvicon_args(parser: argparse.ArgumentParser):
109
- parser.add_argument("--host", type=str, help="Host to run the server on")
110
- parser.add_argument("--port", type=int, help="Port to run the server on")
111
- parser.add_argument(
112
- "--reload", action="store_true", help="Enable auto-reload for development"
113
- )
114
- parser.add_argument("--workers", type=int, help="Number of worker processes")
115
- parser.add_argument("--log_level", type=str, help="Log level")
116
- parser.add_argument("--access_log", action="store_true", help="Enable access log")
117
- parser.add_argument(
118
- "--proxy_headers", action="store_true", help="Enable proxy headers"
119
- )
120
- parser.add_argument(
121
- "--timeout_keep_alive", type=int, help="Keep-alive timeout duration"
122
- )
123
- parser.add_argument(
124
- "--timeout_graceful_shutdown",
125
- type=int,
126
- help="Graceful shutdown timeout duration",
127
- )
128
- parser.add_argument("--ssl_keyfile", type=str, help="SSL key file path")
129
- parser.add_argument("--ssl_certfile", type=str, help="SSL certificate file path")
130
- parser.add_argument(
131
- "--ssl_keyfile_password", type=str, help="SSL key file password"
132
- )
133
-
134
-
135
  def setup_api_args(parser: argparse.ArgumentParser):
136
  parser.add_argument(
137
  "--cors_origin",
@@ -156,7 +63,7 @@ def setup_api_args(parser: argparse.ArgumentParser):
156
  )
157
 
158
 
159
- def process_api_args(args, app):
160
  cors_origin = env.get_and_update_env(args, "cors_origin", "*", str)
161
  no_playground = env.get_and_update_env(args, "no_playground", False, bool)
162
  no_docs = env.get_and_update_env(args, "no_docs", False, bool)
 
1
  import argparse
2
  import logging
3
 
4
+ from fastapi import FastAPI
5
+
6
+ from modules import config
7
  from modules.api.Api import APIManager
8
  from modules.api.impl import (
9
  google_api,
 
17
  tts_api,
18
  xtts_v2_api,
19
  )
 
 
 
20
  from modules.utils import env
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
 
25
+ def create_api(app: FastAPI, exclude=[]):
26
  app_mgr = APIManager(app=app, exclude_patterns=exclude)
27
 
28
  ping_api.setup(app_mgr)
 
39
  return app_mgr
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def setup_api_args(parser: argparse.ArgumentParser):
43
  parser.add_argument(
44
  "--cors_origin",
 
63
  )
64
 
65
 
66
+ def process_api_args(args: argparse.Namespace, app: FastAPI):
67
  cors_origin = env.get_and_update_env(args, "cors_origin", "*", str)
68
  no_playground = env.get_and_update_env(args, "no_playground", False, bool)
69
  no_docs = env.get_and_update_env(args, "no_docs", False, bool)
modules/api/impl/handler/AudioHandler.py CHANGED
@@ -1,5 +1,6 @@
1
  import base64
2
  import io
 
3
 
4
  import numpy as np
5
  import soundfile as sf
@@ -10,7 +11,24 @@ from modules.api.impl.model.audio_model import AudioFormat
10
 
11
  class AudioHandler:
12
  def enqueue(self) -> tuple[np.ndarray, int]:
13
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def enqueue_to_buffer(self, format: AudioFormat) -> io.BytesIO:
16
  audio_data, sample_rate = self.enqueue()
 
1
  import base64
2
  import io
3
+ from typing import Generator
4
 
5
  import numpy as np
6
  import soundfile as sf
 
11
 
12
  class AudioHandler:
13
  def enqueue(self) -> tuple[np.ndarray, int]:
14
+ raise NotImplementedError("Method 'enqueue' must be implemented by subclass")
15
+
16
+ def enqueue_stream(self) -> Generator[tuple[np.ndarray, int], None, None]:
17
+ raise NotImplementedError(
18
+ "Method 'enqueue_stream' must be implemented by subclass"
19
+ )
20
+
21
+ def enqueue_to_stream(self, format: AudioFormat) -> Generator[bytes, None, None]:
22
+ for audio_data, sample_rate in self.enqueue_stream():
23
+ buffer = io.BytesIO()
24
+ sf.write(buffer, audio_data, sample_rate, format="wav")
25
+ buffer.seek(0)
26
+
27
+ if format == AudioFormat.mp3:
28
+ buffer = api_utils.wav_to_mp3(buffer)
29
+
30
+ binary = buffer.read()
31
+ yield binary
32
 
33
  def enqueue_to_buffer(self, format: AudioFormat) -> io.BytesIO:
34
  audio_data, sample_rate = self.enqueue()
modules/api/impl/handler/SSMLHandler.py CHANGED
@@ -91,4 +91,9 @@ class SSMLHandler(AudioHandler):
91
  sr=sample_rate,
92
  )
93
 
 
 
 
 
 
94
  return audio_data, sample_rate
 
91
  sr=sample_rate,
92
  )
93
 
94
+ if adjust_config.normalize:
95
+ sample_rate, audio_data = audio.apply_normalize(
96
+ audio_data=audio_data, headroom=adjust_config.headroom, sr=sample_rate
97
+ )
98
+
99
  return audio_data, sample_rate
modules/api/impl/handler/TTSHandler.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import numpy as np
2
 
3
  from modules.api.impl.handler.AudioHandler import AudioHandler
@@ -8,7 +11,10 @@ from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
8
  from modules.normalization import text_normalize
9
  from modules.speaker import Speaker
10
  from modules.synthesize_audio import synthesize_audio
11
- from modules.utils.audio import apply_prosody_to_audio_data
 
 
 
12
 
13
 
14
  class TTSHandler(AudioHandler):
@@ -94,4 +100,57 @@ class TTSHandler(AudioHandler):
94
  sr=sample_rate,
95
  )
96
 
 
 
 
 
 
 
 
97
  return audio_data, sample_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Generator
3
+
4
  import numpy as np
5
 
6
  from modules.api.impl.handler.AudioHandler import AudioHandler
 
11
  from modules.normalization import text_normalize
12
  from modules.speaker import Speaker
13
  from modules.synthesize_audio import synthesize_audio
14
+ from modules.synthesize_stream import synthesize_stream
15
+ from modules.utils.audio import apply_normalize, apply_prosody_to_audio_data
16
+
17
+ logger = logging.getLogger(__name__)
18
 
19
 
20
  class TTSHandler(AudioHandler):
 
100
  sr=sample_rate,
101
  )
102
 
103
+ if adjust_config.normalize:
104
+ sample_rate, audio_data = apply_normalize(
105
+ audio_data=audio_data,
106
+ headroom=adjust_config.headroom,
107
+ sr=sample_rate,
108
+ )
109
+
110
  return audio_data, sample_rate
111
+
112
+ def enqueue_stream(self) -> Generator[tuple[np.ndarray, int], None, None]:
113
+ text = text_normalize(self.text_content)
114
+ tts_config = self.tts_config
115
+ infer_config = self.infer_config
116
+ adjust_config = self.adjest_config
117
+ enhancer_config = self.enhancer_config
118
+
119
+ if enhancer_config.enabled:
120
+ logger.warning(
121
+ "enhancer_config is enabled, but it is not supported in stream mode"
122
+ )
123
+
124
+ gen = synthesize_stream(
125
+ text,
126
+ spk=self.spk,
127
+ temperature=tts_config.temperature,
128
+ top_P=tts_config.top_p,
129
+ top_K=tts_config.top_k,
130
+ prompt1=tts_config.prompt1,
131
+ prompt2=tts_config.prompt2,
132
+ prefix=tts_config.prefix,
133
+ infer_seed=infer_config.seed,
134
+ spliter_threshold=infer_config.spliter_threshold,
135
+ end_of_sentence=infer_config.eos,
136
+ )
137
+
138
+ # FIXME: 很奇怪,合并出来的音频每个 chunk 之前会有一段异常,暂时没有查出来是哪里的问题,可能是解码时候切割漏了?或者多了?
139
+ for sr, wav in gen:
140
+
141
+ wav = apply_prosody_to_audio_data(
142
+ audio_data=wav,
143
+ rate=adjust_config.speed_rate,
144
+ pitch=adjust_config.pitch,
145
+ volume=adjust_config.volume_gain_db,
146
+ sr=sr,
147
+ )
148
+
149
+ if adjust_config.normalize:
150
+ sr, wav = apply_normalize(
151
+ audio_data=wav,
152
+ headroom=adjust_config.headroom,
153
+ sr=sr,
154
+ )
155
+
156
+ yield wav, sr
modules/api/impl/model/audio_model.py CHANGED
@@ -12,3 +12,7 @@ class AdjustConfig(BaseModel):
12
  pitch: float = 0
13
  speed_rate: float = 1
14
  volume_gain_db: float = 0
 
 
 
 
 
12
  pitch: float = 0
13
  speed_rate: float = 1
14
  volume_gain_db: float = 0
15
+
16
+ # 响度均衡
17
+ normalize: bool = True
18
+ headroom: float = 1
modules/api/impl/tts_api.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from fastapi import Depends, HTTPException, Query
2
  from fastapi.responses import FileResponse, StreamingResponse
3
  from pydantic import BaseModel
@@ -10,6 +12,8 @@ from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
10
  from modules.api.impl.model.enhancer_model import EnhancerConfig
11
  from modules.speaker import Speaker
12
 
 
 
13
 
14
  class TTSParams(BaseModel):
15
  text: str = Query(..., description="Text to synthesize")
@@ -44,6 +48,8 @@ class TTSParams(BaseModel):
44
  pitch: float = Query(0, description="Pitch of the audio")
45
  volume_gain: float = Query(0, description="Volume gain of the audio")
46
 
 
 
47
 
48
  async def synthesize_tts(params: TTSParams = Depends()):
49
  try:
@@ -132,14 +138,22 @@ async def synthesize_tts(params: TTSParams = Depends()):
132
  adjust_config=adjust_config,
133
  enhancer_config=enhancer_config,
134
  )
135
-
136
- buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))
137
-
138
  media_type = f"audio/{params.format}"
139
  if params.format == "mp3":
140
  media_type = "audio/mpeg"
141
- return StreamingResponse(buffer, media_type=media_type)
142
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  except Exception as e:
144
  import logging
145
 
 
1
+ import logging
2
+
3
  from fastapi import Depends, HTTPException, Query
4
  from fastapi.responses import FileResponse, StreamingResponse
5
  from pydantic import BaseModel
 
12
  from modules.api.impl.model.enhancer_model import EnhancerConfig
13
  from modules.speaker import Speaker
14
 
15
+ logger = logging.getLogger(__name__)
16
+
17
 
18
  class TTSParams(BaseModel):
19
  text: str = Query(..., description="Text to synthesize")
 
48
  pitch: float = Query(0, description="Pitch of the audio")
49
  volume_gain: float = Query(0, description="Volume gain of the audio")
50
 
51
+ stream: bool = Query(False, description="Stream the audio")
52
+
53
 
54
  async def synthesize_tts(params: TTSParams = Depends()):
55
  try:
 
138
  adjust_config=adjust_config,
139
  enhancer_config=enhancer_config,
140
  )
 
 
 
141
  media_type = f"audio/{params.format}"
142
  if params.format == "mp3":
143
  media_type = "audio/mpeg"
 
144
 
145
+ if params.stream:
146
+ if infer_config.batch_size != 1:
147
+ # 流式生成下仅支持 batch size 为 1,当前请求参数将被忽略
148
+ logger.warning(
149
+ f"Batch size {infer_config.batch_size} is not supported in streaming mode, will set to 1"
150
+ )
151
+
152
+ buffer_gen = handler.enqueue_to_stream(format=AudioFormat(params.format))
153
+ return StreamingResponse(buffer_gen, media_type=media_type)
154
+ else:
155
+ buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))
156
+ return StreamingResponse(buffer, media_type=media_type)
157
  except Exception as e:
158
  import logging
159
 
modules/api/impl/xtts_v2_api.py CHANGED
@@ -1,18 +1,15 @@
1
- import io
2
  import logging
3
 
4
- import soundfile as sf
5
- from fastapi import HTTPException
6
  from fastapi.responses import StreamingResponse
7
  from pydantic import BaseModel
8
 
9
- from modules import config
10
- from modules.api import utils as api_utils
11
  from modules.api.Api import APIManager
12
- from modules.normalization import text_normalize
 
 
 
13
  from modules.speaker import speaker_mgr
14
- from modules.synthesize_audio import synthesize_audio
15
- from modules.utils.audio import apply_prosody_to_audio_data
16
 
17
  logger = logging.getLogger(__name__)
18
 
@@ -22,8 +19,11 @@ class XTTS_V2_Settings:
22
  self.stream_chunk_size = 100
23
  self.temperature = 0.3
24
  self.speed = 1
 
 
25
  self.length_penalty = 0.5
26
  self.repetition_penalty = 1.0
 
27
  self.top_p = 0.7
28
  self.top_k = 20
29
  self.enable_text_splitting = True
@@ -37,6 +37,7 @@ class XTTS_V2_Settings:
37
  self.prompt2 = ""
38
  self.prefix = ""
39
  self.spliter_threshold = 100
 
40
 
41
 
42
  class TTSSettingsRequest(BaseModel):
@@ -58,6 +59,7 @@ class TTSSettingsRequest(BaseModel):
58
  prompt2: str = None
59
  prefix: str = None
60
  spliter_threshold: int = None
 
61
 
62
 
63
  class SynthesisRequest(BaseModel):
@@ -95,45 +97,101 @@ def setup(app: APIManager):
95
  if spk is None:
96
  raise HTTPException(status_code=400, detail="Invalid speaker id")
97
 
98
- text = text_normalize(text, is_end=True)
99
- sample_rate, audio_data = synthesize_audio(
100
- # TODO: 这两个参数现在用不着...但是其实gpt是可以用的
101
- # length_penalty=XTTSV2.length_penalty,
102
- # repetition_penalty=XTTSV2.repetition_penalty,
103
- text=text,
104
  temperature=XTTSV2.temperature,
105
- top_P=XTTSV2.top_p,
106
- top_K=XTTSV2.top_k,
107
- spk=spk,
108
- spliter_threshold=XTTSV2.spliter_threshold,
109
- batch_size=XTTSV2.batch_size,
110
- end_of_sentence=XTTSV2.eos,
111
- infer_seed=XTTSV2.infer_seed,
112
- use_decoder=XTTSV2.use_decoder,
113
  prompt1=XTTSV2.prompt1,
114
  prompt2=XTTSV2.prompt2,
115
- prefix=XTTSV2.prefix,
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  )
117
 
118
- if XTTSV2.speed:
119
- audio_data = apply_prosody_to_audio_data(
120
- audio_data,
121
- rate=XTTSV2.speed,
122
- sr=sample_rate,
123
- )
124
-
125
- # to mp3
126
- buffer = io.BytesIO()
127
- sf.write(buffer, audio_data, sample_rate, format="wav")
128
- buffer.seek(0)
129
 
130
- buffer = api_utils.wav_to_mp3(buffer)
131
 
132
  return StreamingResponse(buffer, media_type="audio/mpeg")
133
 
134
  @app.get("/v1/xtts_v2/tts_stream")
135
- async def tts_stream():
136
- raise HTTPException(status_code=501, detail="Not implemented")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  @app.post("/v1/xtts_v2/set_tts_settings")
139
  async def set_tts_settings(request: TTSSettingsRequest):
@@ -195,6 +253,8 @@ def setup(app: APIManager):
195
  XTTSV2.prefix = request.prefix
196
  if request.spliter_threshold:
197
  XTTSV2.spliter_threshold = request.spliter_threshold
 
 
198
 
199
  return {"message": "Settings successfully applied"}
200
  except Exception as e:
 
 
1
  import logging
2
 
3
+ from fastapi import HTTPException, Query, Request
 
4
  from fastapi.responses import StreamingResponse
5
  from pydantic import BaseModel
6
 
 
 
7
  from modules.api.Api import APIManager
8
+ from modules.api.impl.handler.TTSHandler import TTSHandler
9
+ from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
10
+ from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
11
+ from modules.api.impl.model.enhancer_model import EnhancerConfig
12
  from modules.speaker import speaker_mgr
 
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
 
19
  self.stream_chunk_size = 100
20
  self.temperature = 0.3
21
  self.speed = 1
22
+
23
+ # TODO: 这两个参数现在用不着...但是其实gpt是可以用的可以考虑增加
24
  self.length_penalty = 0.5
25
  self.repetition_penalty = 1.0
26
+
27
  self.top_p = 0.7
28
  self.top_k = 20
29
  self.enable_text_splitting = True
 
37
  self.prompt2 = ""
38
  self.prefix = ""
39
  self.spliter_threshold = 100
40
+ self.style = ""
41
 
42
 
43
  class TTSSettingsRequest(BaseModel):
 
59
  prompt2: str = None
60
  prefix: str = None
61
  spliter_threshold: int = None
62
+ style: str = None
63
 
64
 
65
  class SynthesisRequest(BaseModel):
 
97
  if spk is None:
98
  raise HTTPException(status_code=400, detail="Invalid speaker id")
99
 
100
+ tts_config = ChatTTSConfig(
101
+ style=XTTSV2.style,
 
 
 
 
102
  temperature=XTTSV2.temperature,
103
+ top_k=XTTSV2.top_k,
104
+ top_p=XTTSV2.top_p,
105
+ prefix=XTTSV2.prefix,
 
 
 
 
 
106
  prompt1=XTTSV2.prompt1,
107
  prompt2=XTTSV2.prompt2,
108
+ )
109
+ infer_config = InferConfig(
110
+ batch_size=XTTSV2.batch_size,
111
+ spliter_threshold=XTTSV2.spliter_threshold,
112
+ eos=XTTSV2.eos,
113
+ seed=XTTSV2.infer_seed,
114
+ )
115
+ adjust_config = AdjustConfig(
116
+ speed_rate=XTTSV2.speed,
117
+ )
118
+ # TODO: support enhancer
119
+ enhancer_config = EnhancerConfig(
120
+ # enabled=params.enhance or params.denoise or False,
121
+ # lambd=0.9 if params.denoise else 0.1,
122
  )
123
 
124
+ handler = TTSHandler(
125
+ text_content=text,
126
+ spk=spk,
127
+ tts_config=tts_config,
128
+ infer_config=infer_config,
129
+ adjust_config=adjust_config,
130
+ enhancer_config=enhancer_config,
131
+ )
 
 
 
132
 
133
+ buffer = handler.enqueue_to_buffer(AudioFormat.mp3)
134
 
135
  return StreamingResponse(buffer, media_type="audio/mpeg")
136
 
137
  @app.get("/v1/xtts_v2/tts_stream")
138
+ async def tts_stream(
139
+ request: Request,
140
+ text: str = Query(),
141
+ speaker_wav: str = Query(),
142
+ language: str = Query(),
143
+ ):
144
+ # speaker_wav 就是 speaker id 。。。
145
+ voice_id = speaker_wav
146
+
147
+ spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker(
148
+ voice_id
149
+ )
150
+ if spk is None:
151
+ raise HTTPException(status_code=400, detail="Invalid speaker id")
152
+
153
+ tts_config = ChatTTSConfig(
154
+ style=XTTSV2.style,
155
+ temperature=XTTSV2.temperature,
156
+ top_k=XTTSV2.top_k,
157
+ top_p=XTTSV2.top_p,
158
+ prefix=XTTSV2.prefix,
159
+ prompt1=XTTSV2.prompt1,
160
+ prompt2=XTTSV2.prompt2,
161
+ )
162
+ infer_config = InferConfig(
163
+ batch_size=XTTSV2.batch_size,
164
+ spliter_threshold=XTTSV2.spliter_threshold,
165
+ eos=XTTSV2.eos,
166
+ seed=XTTSV2.infer_seed,
167
+ )
168
+ adjust_config = AdjustConfig(
169
+ speed_rate=XTTSV2.speed,
170
+ )
171
+ # TODO: support enhancer
172
+ enhancer_config = EnhancerConfig(
173
+ # enabled=params.enhance or params.denoise or False,
174
+ # lambd=0.9 if params.denoise else 0.1,
175
+ )
176
+
177
+ handler = TTSHandler(
178
+ text_content=text,
179
+ spk=spk,
180
+ tts_config=tts_config,
181
+ infer_config=infer_config,
182
+ adjust_config=adjust_config,
183
+ enhancer_config=enhancer_config,
184
+ )
185
+
186
+ async def generator():
187
+ for chunk in handler.enqueue_to_stream(AudioFormat.mp3):
188
+ disconnected = await request.is_disconnected()
189
+ if disconnected:
190
+ break
191
+
192
+ yield chunk
193
+
194
+ return StreamingResponse(generator(), media_type="audio/mpeg")
195
 
196
  @app.post("/v1/xtts_v2/set_tts_settings")
197
  async def set_tts_settings(request: TTSSettingsRequest):
 
253
  XTTSV2.prefix = request.prefix
254
  if request.spliter_threshold:
255
  XTTSV2.spliter_threshold = request.spliter_threshold
256
+ if request.style:
257
+ XTTSV2.style = request.style
258
 
259
  return {"message": "Settings successfully applied"}
260
  except Exception as e:
modules/api/worker.py CHANGED
@@ -5,7 +5,9 @@ import os
5
  import dotenv
6
  from fastapi import FastAPI
7
 
 
8
  from modules.ffmpeg_env import setup_ffmpeg_path
 
9
 
10
  setup_ffmpeg_path()
11
  logging.basicConfig(
@@ -14,13 +16,7 @@ logging.basicConfig(
14
  )
15
 
16
  from modules import config
17
- from modules.api.api_setup import (
18
- process_api_args,
19
- process_model_args,
20
- setup_api_args,
21
- setup_model_args,
22
- setup_uvicon_args,
23
- )
24
  from modules.api.app_config import app_description, app_title, app_version
25
  from modules.utils.torch_opt import configure_torch_optimizations
26
 
 
5
  import dotenv
6
  from fastapi import FastAPI
7
 
8
+ from launch import setup_uvicon_args
9
  from modules.ffmpeg_env import setup_ffmpeg_path
10
+ from modules.models_setup import process_model_args, setup_model_args
11
 
12
  setup_ffmpeg_path()
13
  logging.basicConfig(
 
16
  )
17
 
18
  from modules import config
19
+ from modules.api.api_setup import process_api_args, setup_api_args
 
 
 
 
 
 
20
  from modules.api.app_config import app_description, app_title, app_version
21
  from modules.utils.torch_opt import configure_torch_optimizations
22
 
modules/devices/devices.py CHANGED
@@ -92,7 +92,10 @@ def get_optimal_device():
92
 
93
 
94
  def get_device_for(task):
95
- if task in config.cmd_opts.use_cpu or "all" in config.cmd_opts.use_cpu:
 
 
 
96
  return cpu
97
 
98
  return get_optimal_device()
@@ -128,6 +131,9 @@ def reset_device():
128
  global dtype_gpt
129
  global dtype_decoder
130
 
 
 
 
131
  if "all" in config.runtime_env_vars.use_cpu and not config.runtime_env_vars.no_half:
132
  logger.warning(
133
  "Cannot use half precision with CPU, using full precision instead"
 
92
 
93
 
94
  def get_device_for(task):
95
+ if (
96
+ task in config.runtime_env_vars.use_cpu
97
+ or "all" in config.runtime_env_vars.use_cpu
98
+ ):
99
  return cpu
100
 
101
  return get_optimal_device()
 
131
  global dtype_gpt
132
  global dtype_decoder
133
 
134
+ if config.runtime_env_vars.use_cpu is None:
135
+ config.runtime_env_vars.use_cpu = []
136
+
137
  if "all" in config.runtime_env_vars.use_cpu and not config.runtime_env_vars.no_half:
138
  logger.warning(
139
  "Cannot use half precision with CPU, using full precision instead"
modules/finetune/train_speaker.py CHANGED
@@ -255,7 +255,7 @@ if __name__ == "__main__":
255
  vocos_model=chat.pretrain_models["vocos"],
256
  tar_path=tar_path,
257
  tar_in_memory=tar_in_memory,
258
- device=devices.device,
259
  # speakers=None, # set(['speaker_A', 'speaker_B'])
260
  )
261
 
@@ -267,7 +267,7 @@ if __name__ == "__main__":
267
  speaker_embeds = {
268
  speaker: torch.tensor(
269
  spk.emb,
270
- device=devices.device,
271
  requires_grad=True,
272
  )
273
  for speaker in dataset.speakers
 
255
  vocos_model=chat.pretrain_models["vocos"],
256
  tar_path=tar_path,
257
  tar_in_memory=tar_in_memory,
258
+ device=devices.get_device_for("trainer"),
259
  # speakers=None, # set(['speaker_A', 'speaker_B'])
260
  )
261
 
 
267
  speaker_embeds = {
268
  speaker: torch.tensor(
269
  spk.emb,
270
+ device=devices.get_device_for("trainer"),
271
  requires_grad=True,
272
  )
273
  for speaker in dataset.speakers
modules/generate_audio.py CHANGED
@@ -1,11 +1,12 @@
1
  import gc
2
  import logging
3
- from typing import Union
4
 
5
  import numpy as np
6
  import torch
7
 
8
  from modules import config, models
 
9
  from modules.devices import devices
10
  from modules.speaker import Speaker
11
  from modules.utils.cache import conditional_cache
@@ -13,6 +14,8 @@ from modules.utils.SeedContext import SeedContext
13
 
14
  logger = logging.getLogger(__name__)
15
 
 
 
16
 
17
  def generate_audio(
18
  text: str,
@@ -42,20 +45,18 @@ def generate_audio(
42
  return (sample_rate, wav)
43
 
44
 
45
- @torch.inference_mode()
46
- def generate_audio_batch(
47
  texts: list[str],
 
48
  temperature: float = 0.3,
49
  top_P: float = 0.7,
50
  top_K: float = 20,
51
  spk: Union[int, Speaker] = -1,
52
  infer_seed: int = -1,
53
- use_decoder: bool = True,
54
  prompt1: str = "",
55
  prompt2: str = "",
56
  prefix: str = "",
57
  ):
58
- chat_tts = models.load_chat_tts()
59
  params_infer_code = {
60
  "spk_emb": None,
61
  "temperature": temperature,
@@ -97,18 +98,93 @@ def generate_audio_batch(
97
  }
98
  )
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  with SeedContext(infer_seed, True):
101
  wavs = chat_tts.generate_audio(
102
- texts, params_infer_code, use_decoder=use_decoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
 
105
- sample_rate = 24000
 
106
 
107
  if config.auto_gc:
108
  devices.torch_gc()
109
  gc.collect()
110
 
111
- return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
112
 
113
 
114
  lru_cache_enabled = False
 
1
  import gc
2
  import logging
3
+ from typing import Generator, Union
4
 
5
  import numpy as np
6
  import torch
7
 
8
  from modules import config, models
9
+ from modules.ChatTTS import ChatTTS
10
  from modules.devices import devices
11
  from modules.speaker import Speaker
12
  from modules.utils.cache import conditional_cache
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
17
+ SAMPLE_RATE = 24000
18
+
19
 
20
  def generate_audio(
21
  text: str,
 
45
  return (sample_rate, wav)
46
 
47
 
48
+ def parse_infer_params(
 
49
  texts: list[str],
50
+ chat_tts: ChatTTS.Chat,
51
  temperature: float = 0.3,
52
  top_P: float = 0.7,
53
  top_K: float = 20,
54
  spk: Union[int, Speaker] = -1,
55
  infer_seed: int = -1,
 
56
  prompt1: str = "",
57
  prompt2: str = "",
58
  prefix: str = "",
59
  ):
 
60
  params_infer_code = {
61
  "spk_emb": None,
62
  "temperature": temperature,
 
98
  }
99
  )
100
 
101
+ return params_infer_code
102
+
103
+
104
+ @torch.inference_mode()
105
+ def generate_audio_batch(
106
+ texts: list[str],
107
+ temperature: float = 0.3,
108
+ top_P: float = 0.7,
109
+ top_K: float = 20,
110
+ spk: Union[int, Speaker] = -1,
111
+ infer_seed: int = -1,
112
+ use_decoder: bool = True,
113
+ prompt1: str = "",
114
+ prompt2: str = "",
115
+ prefix: str = "",
116
+ ):
117
+ chat_tts = models.load_chat_tts()
118
+ params_infer_code = parse_infer_params(
119
+ texts=texts,
120
+ chat_tts=chat_tts,
121
+ temperature=temperature,
122
+ top_P=top_P,
123
+ top_K=top_K,
124
+ spk=spk,
125
+ infer_seed=infer_seed,
126
+ prompt1=prompt1,
127
+ prompt2=prompt2,
128
+ prefix=prefix,
129
+ )
130
+
131
  with SeedContext(infer_seed, True):
132
  wavs = chat_tts.generate_audio(
133
+ prompt=texts, params_infer_code=params_infer_code, use_decoder=use_decoder
134
+ )
135
+
136
+ if config.auto_gc:
137
+ devices.torch_gc()
138
+ gc.collect()
139
+
140
+ return [(SAMPLE_RATE, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
141
+
142
+
143
+ # TODO: generate_audio_stream 也应该支持 lru cache
144
+ @torch.inference_mode()
145
+ def generate_audio_stream(
146
+ text: str,
147
+ temperature: float = 0.3,
148
+ top_P: float = 0.7,
149
+ top_K: float = 20,
150
+ spk: Union[int, Speaker] = -1,
151
+ infer_seed: int = -1,
152
+ use_decoder: bool = True,
153
+ prompt1: str = "",
154
+ prompt2: str = "",
155
+ prefix: str = "",
156
+ ) -> Generator[tuple[int, np.ndarray], None, None]:
157
+ chat_tts = models.load_chat_tts()
158
+ texts = [text]
159
+ params_infer_code = parse_infer_params(
160
+ texts=texts,
161
+ chat_tts=chat_tts,
162
+ temperature=temperature,
163
+ top_P=top_P,
164
+ top_K=top_K,
165
+ spk=spk,
166
+ infer_seed=infer_seed,
167
+ prompt1=prompt1,
168
+ prompt2=prompt2,
169
+ prefix=prefix,
170
+ )
171
+
172
+ with SeedContext(infer_seed, True):
173
+ wavs_gen = chat_tts.generate_audio(
174
+ prompt=texts,
175
+ params_infer_code=params_infer_code,
176
+ use_decoder=use_decoder,
177
+ stream=True,
178
  )
179
 
180
+ for wav in wavs_gen:
181
+ yield [SAMPLE_RATE, np.array(wav).flatten().astype(np.float32)]
182
 
183
  if config.auto_gc:
184
  devices.torch_gc()
185
  gc.collect()
186
 
187
+ return
188
 
189
 
190
  lru_cache_enabled = False
modules/models.py CHANGED
@@ -21,18 +21,27 @@ def load_chat_tts_in_thread():
21
 
22
  logger.info("Loading ChatTTS models")
23
  chat_tts = ChatTTS.Chat()
 
 
24
  chat_tts.load_models(
25
  compile=config.runtime_env_vars.compile,
26
  source="local",
27
  local_path="./models/ChatTTS",
28
- device=devices.device,
29
- dtype=devices.dtype,
30
  dtype_vocos=devices.dtype_vocos,
31
  dtype_dvae=devices.dtype_dvae,
32
  dtype_gpt=devices.dtype_gpt,
33
  dtype_decoder=devices.dtype_decoder,
34
  )
35
 
 
 
 
 
 
 
 
36
  devices.torch_gc()
37
  logger.info("ChatTTS models loaded")
38
 
 
21
 
22
  logger.info("Loading ChatTTS models")
23
  chat_tts = ChatTTS.Chat()
24
+ device = devices.get_device_for("chattts")
25
+ dtype = devices.dtype
26
  chat_tts.load_models(
27
  compile=config.runtime_env_vars.compile,
28
  source="local",
29
  local_path="./models/ChatTTS",
30
+ device=device,
31
+ dtype=dtype,
32
  dtype_vocos=devices.dtype_vocos,
33
  dtype_dvae=devices.dtype_dvae,
34
  dtype_gpt=devices.dtype_gpt,
35
  dtype_decoder=devices.dtype_decoder,
36
  )
37
 
38
+ # 如果 device 为 cpu 同时,又是 dtype == float16 那么报 warn
39
+ # 提示可能无法正常运行,建议使用 float32 即开启 `--no_half` 参数
40
+ if device == devices.cpu and dtype == torch.float16:
41
+ logger.warning(
42
+ "The device is CPU and dtype is float16, which may not work properly. It is recommended to use float32 by enabling the `--no_half` parameter."
43
+ )
44
+
45
  devices.torch_gc()
46
  logger.info("ChatTTS models loaded")
47
 
modules/models_setup.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+
4
+ from modules import generate_audio
5
+ from modules.devices import devices
6
+ from modules.Enhancer.ResembleEnhance import load_enhancer
7
+ from modules.models import load_chat_tts
8
+ from modules.utils import env
9
+
10
+
11
+ def setup_model_args(parser: argparse.ArgumentParser):
12
+ parser.add_argument("--compile", action="store_true", help="Enable model compile")
13
+ parser.add_argument(
14
+ "--no_half",
15
+ action="store_true",
16
+ help="Disalbe half precision for model inference",
17
+ )
18
+ parser.add_argument(
19
+ "--off_tqdm",
20
+ action="store_true",
21
+ help="Disable tqdm progress bar",
22
+ )
23
+ parser.add_argument(
24
+ "--device_id",
25
+ type=str,
26
+ help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
27
+ default=None,
28
+ )
29
+ parser.add_argument(
30
+ "--use_cpu",
31
+ nargs="+",
32
+ help="use CPU as torch device for specified modules",
33
+ default=[],
34
+ type=str.lower,
35
+ choices=["all", "chattts", "enhancer", "trainer"],
36
+ )
37
+ parser.add_argument(
38
+ "--lru_size",
39
+ type=int,
40
+ default=64,
41
+ help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
42
+ )
43
+ parser.add_argument(
44
+ "--debug_generate",
45
+ action="store_true",
46
+ help="Enable debug mode for audio generation",
47
+ )
48
+ parser.add_argument(
49
+ "--preload_models",
50
+ action="store_true",
51
+ help="Preload all models at startup",
52
+ )
53
+
54
+
55
+ def process_model_args(args: argparse.Namespace):
56
+ lru_size = env.get_and_update_env(args, "lru_size", 64, int)
57
+ compile = env.get_and_update_env(args, "compile", False, bool)
58
+ device_id = env.get_and_update_env(args, "device_id", None, str)
59
+ use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
60
+ no_half = env.get_and_update_env(args, "no_half", False, bool)
61
+ off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
62
+ debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
63
+ preload_models = env.get_and_update_env(args, "preload_models", False, bool)
64
+
65
+ generate_audio.setup_lru_cache()
66
+ devices.reset_device()
67
+ devices.first_time_calculation()
68
+
69
+ if debug_generate:
70
+ generate_audio.logger.setLevel(logging.DEBUG)
71
+
72
+ if preload_models:
73
+ load_chat_tts()
74
+ load_enhancer()
modules/normalization.py CHANGED
@@ -1,39 +1,21 @@
 
1
  import re
2
- from functools import lru_cache
3
 
4
  import emojiswitch
 
5
 
6
  from modules import models
 
 
 
7
  from modules.utils.markdown import markdown_to_text
8
- from modules.utils.zh_normalization.text_normlization import *
9
 
10
  # 是否关闭 unk token 检查
11
  # NOTE: 单测的时候用于跳过模型加载
12
  DISABLE_UNK_TOKEN_CHECK = False
13
 
14
 
15
- @lru_cache(maxsize=64)
16
- def is_chinese(text):
17
- # 中文字符的 Unicode 范围是 \u4e00-\u9fff
18
- chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
19
- return bool(chinese_pattern.search(text))
20
-
21
-
22
- @lru_cache(maxsize=64)
23
- def is_eng(text):
24
- eng_pattern = re.compile(r"[a-zA-Z]")
25
- return bool(eng_pattern.search(text))
26
-
27
-
28
- @lru_cache(maxsize=64)
29
- def guess_lang(text):
30
- if is_chinese(text):
31
- return "zh"
32
- if is_eng(text):
33
- return "en"
34
- return "zh"
35
-
36
-
37
  post_normalize_pipeline = []
38
  pre_normalize_pipeline = []
39
 
@@ -184,9 +166,32 @@ def replace_unk_tokens(text):
184
  return output_text
185
 
186
 
 
 
 
 
 
 
 
 
 
 
 
187
  ## ---------- pre normalize ----------
188
 
189
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  @pre_normalize()
191
  def apply_markdown_to_text(text):
192
  if is_markdown(text):
@@ -194,6 +199,11 @@ def apply_markdown_to_text(text):
194
  return text
195
 
196
 
 
 
 
 
 
197
  # 将 "xxx" => \nxxx\n
198
  # 将 'xxx' => \nxxx\n
199
  @pre_normalize()
@@ -293,6 +303,7 @@ if __name__ == "__main__":
293
  " [oral_9] [laugh_0] [break_0] 电 [speed_0] 影 [speed_0] 中 梁朝伟 [speed_9] 扮演的陈永仁的编号27149",
294
  " 明天有62%的概率降雨",
295
  "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
 
296
  """
297
  # 你好,世界
298
  ```js
 
1
+ import html
2
  import re
 
3
 
4
  import emojiswitch
5
+ import ftfy
6
 
7
  from modules import models
8
+ from modules.utils.detect_lang import guess_lang
9
+ from modules.utils.HomophonesReplacer import HomophonesReplacer
10
+ from modules.utils.html import remove_html_tags as _remove_html_tags
11
  from modules.utils.markdown import markdown_to_text
12
+ from modules.utils.zh_normalization.text_normlization import TextNormalizer
13
 
14
  # 是否关闭 unk token 检查
15
  # NOTE: 单测的时候用于跳过模型加载
16
  DISABLE_UNK_TOKEN_CHECK = False
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  post_normalize_pipeline = []
20
  pre_normalize_pipeline = []
21
 
 
166
  return output_text
167
 
168
 
169
+ homo_replacer = HomophonesReplacer(map_file_path="./data/homophones_map.json")
170
+
171
+
172
+ @post_normalize()
173
+ def replace_homophones(text):
174
+ lang = guess_lang(text)
175
+ if lang == "zh":
176
+ text = homo_replacer.replace(text)
177
+ return text
178
+
179
+
180
  ## ---------- pre normalize ----------
181
 
182
 
183
+ @pre_normalize()
184
+ def html_unescape(text):
185
+ text = html.unescape(text)
186
+ text = html.unescape(text)
187
+ return text
188
+
189
+
190
+ @pre_normalize()
191
+ def fix_text(text):
192
+ return ftfy.fix_text(text=text)
193
+
194
+
195
  @pre_normalize()
196
  def apply_markdown_to_text(text):
197
  if is_markdown(text):
 
199
  return text
200
 
201
 
202
+ @pre_normalize()
203
+ def remove_html_tags(text):
204
+ return _remove_html_tags(text)
205
+
206
+
207
  # 将 "xxx" => \nxxx\n
208
  # 将 'xxx' => \nxxx\n
209
  @pre_normalize()
 
303
  " [oral_9] [laugh_0] [break_0] 电 [speed_0] 影 [speed_0] 中 梁朝伟 [speed_9] 扮演的陈永仁的编号27149",
304
  " 明天有62%的概率降雨",
305
  "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
306
+ "I like eating 🍏",
307
  """
308
  # 你好,世界
309
  ```js
modules/refiner.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import numpy as np
2
  import torch
3
 
@@ -31,4 +33,10 @@ def refine_text(
31
  "disable_tqdm": config.runtime_env_vars.off_tqdm,
32
  },
33
  )
 
 
 
 
 
 
34
  return refined_text
 
1
+ from typing import Generator
2
+
3
  import numpy as np
4
  import torch
5
 
 
33
  "disable_tqdm": config.runtime_env_vars.off_tqdm,
34
  },
35
  )
36
+ if isinstance(refined_text, Generator):
37
+ raise NotImplementedError(
38
+ "Refiner is not yet implemented for generator output"
39
+ )
40
+ if isinstance(refined_text, list):
41
+ refined_text = "\n".join(refined_text)
42
  return refined_text
modules/repos_static/resemble_enhance/inference.py CHANGED
@@ -1,12 +1,12 @@
1
  import logging
2
  import time
 
3
 
4
  import torch
5
  import torch.nn.functional as F
6
  from torch.nn.utils.parametrize import remove_parametrizations
7
  from torchaudio.functional import resample
8
  from torchaudio.transforms import MelSpectrogram
9
- from tqdm import trange
10
 
11
  from modules import config
12
  from modules.devices import devices
@@ -142,10 +142,10 @@ def inference(
142
  chunk_seconds: float = 30.0,
143
  overlap_seconds: float = 1.0,
144
  ):
 
 
145
  if config.runtime_env_vars.off_tqdm:
146
- trange = range
147
- else:
148
- from tqdm import trange
149
 
150
  remove_weight_norm_recursively(model)
151
 
@@ -188,7 +188,7 @@ def inference(
188
  torch.cuda.synchronize()
189
 
190
  elapsed_time = time.perf_counter() - start_time
191
- logger.info(
192
  f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz"
193
  )
194
  devices.torch_gc()
 
1
  import logging
2
  import time
3
+ from functools import partial
4
 
5
  import torch
6
  import torch.nn.functional as F
7
  from torch.nn.utils.parametrize import remove_parametrizations
8
  from torchaudio.functional import resample
9
  from torchaudio.transforms import MelSpectrogram
 
10
 
11
  from modules import config
12
  from modules.devices import devices
 
142
  chunk_seconds: float = 30.0,
143
  overlap_seconds: float = 1.0,
144
  ):
145
+ from tqdm import trange
146
+
147
  if config.runtime_env_vars.off_tqdm:
148
+ trange = partial(trange, disable=True)
 
 
149
 
150
  remove_weight_norm_recursively(model)
151
 
 
188
  torch.cuda.synchronize()
189
 
190
  elapsed_time = time.perf_counter() - start_time
191
+ logger.debug(
192
  f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz"
193
  )
194
  devices.torch_gc()
modules/speaker.py CHANGED
@@ -29,6 +29,12 @@ class Speaker:
29
  speaker.emb = tensor
30
  return speaker
31
 
 
 
 
 
 
 
32
  def __init__(
33
  self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe=""
34
  ):
 
29
  speaker.emb = tensor
30
  return speaker
31
 
32
+ @staticmethod
33
+ def from_seed(seed: int):
34
+ speaker = Speaker(seed_or_tensor=seed)
35
+ speaker.emb = create_speaker_from_seed(seed)
36
+ return speaker
37
+
38
  def __init__(
39
  self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe=""
40
  ):
modules/synthesize_audio.py CHANGED
@@ -1,7 +1,5 @@
1
- import io
2
  from typing import Union
3
 
4
- from modules import generate_audio as generate
5
  from modules.SentenceSplitter import SentenceSplitter
6
  from modules.speaker import Speaker
7
  from modules.ssml_parser.SSMLParser import SSMLSegment
 
 
1
  from typing import Union
2
 
 
3
  from modules.SentenceSplitter import SentenceSplitter
4
  from modules.speaker import Speaker
5
  from modules.ssml_parser.SSMLParser import SSMLSegment
modules/synthesize_stream.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import Generator, Union
3
+
4
+ import numpy as np
5
+
6
+ from modules import generate_audio as generate
7
+ from modules.SentenceSplitter import SentenceSplitter
8
+ from modules.speaker import Speaker
9
+
10
+
11
+ def synthesize_stream(
12
+ text: str,
13
+ temperature: float = 0.3,
14
+ top_P: float = 0.7,
15
+ top_K: float = 20,
16
+ spk: Union[int, Speaker] = -1,
17
+ infer_seed: int = -1,
18
+ use_decoder: bool = True,
19
+ prompt1: str = "",
20
+ prompt2: str = "",
21
+ prefix: str = "",
22
+ spliter_threshold: int = 100,
23
+ end_of_sentence="",
24
+ ) -> Generator[tuple[int, np.ndarray], None, None]:
25
+ spliter = SentenceSplitter(spliter_threshold)
26
+ sentences = spliter.parse(text)
27
+
28
+ for sentence in sentences:
29
+ wav_gen = generate.generate_audio_stream(
30
+ text=sentence + end_of_sentence,
31
+ temperature=temperature,
32
+ top_P=top_P,
33
+ top_K=top_K,
34
+ spk=spk,
35
+ infer_seed=infer_seed,
36
+ use_decoder=use_decoder,
37
+ prompt1=prompt1,
38
+ prompt2=prompt2,
39
+ prefix=prefix,
40
+ )
41
+ for sr, wav in wav_gen:
42
+ yield sr, wav
modules/utils/HomophonesReplacer.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+
4
+ # ref: https://github.com/2noise/ChatTTS/commit/ce1c962b6235bd7d0c637fbdcda5e2dccdbac80d
5
+ class HomophonesReplacer:
6
+ """
7
+ Homophones Replacer
8
+
9
+ Replace the mispronounced characters with correctly pronounced ones.
10
+
11
+ Creation process of homophones_map.json:
12
+
13
+ 1. Establish a word corpus using the [Tencent AI Lab Embedding Corpora v0.2.0 large] with 12 million entries. After cleaning, approximately 1.8 million entries remain. Use ChatTTS to infer the text.
14
+ 2. Record discrepancies between the inferred and input text, identifying about 180,000 misread words.
15
+ 3. Create a pinyin to common characters mapping using correctly read characters by ChatTTS.
16
+ 4. For each discrepancy, extract the correct pinyin using [python-pinyin] and find homophones with the correct pronunciation from the mapping.
17
+
18
+ Thanks to:
19
+ [Tencent AI Lab Embedding Corpora for Chinese and English Words and Phrases](https://ai.tencent.com/ailab/nlp/en/embedding.html)
20
+ [python-pinyin](https://github.com/mozillazg/python-pinyin)
21
+
22
+ """
23
+
24
+ def __init__(self, map_file_path):
25
+ self.homophones_map = self.load_homophones_map(map_file_path)
26
+
27
+ def load_homophones_map(self, map_file_path):
28
+ with open(map_file_path, "r", encoding="utf-8") as f:
29
+ homophones_map = json.load(f)
30
+ return homophones_map
31
+
32
+ def replace(self, text):
33
+ result = []
34
+ for char in text:
35
+ if char in self.homophones_map:
36
+ result.append(self.homophones_map[char])
37
+ else:
38
+ result.append(char)
39
+ return "".join(result)
modules/utils/audio.py CHANGED
@@ -2,14 +2,14 @@ import sys
2
  from io import BytesIO
3
 
4
  import numpy as np
5
- import pyrubberband as pyrb
6
  import soundfile as sf
7
- from pydub import AudioSegment
 
8
 
9
  INT16_MAX = np.iinfo(np.int16).max
10
 
11
 
12
- def audio_to_int16(audio_data):
13
  if (
14
  audio_data.dtype == np.float32
15
  or audio_data.dtype == np.float64
@@ -20,6 +20,23 @@ def audio_to_int16(audio_data):
20
  return audio_data
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def audiosegment_to_librosawav(audiosegment: AudioSegment) -> np.ndarray:
24
  """
25
  Converts pydub audio segment into np.float32 of shape [duration_in_seconds*sample_rate, channels],
@@ -35,64 +52,42 @@ def audiosegment_to_librosawav(audiosegment: AudioSegment) -> np.ndarray:
35
  return fp_arr
36
 
37
 
38
- def pydub_to_np(audio: AudioSegment) -> tuple[int, np.ndarray]:
39
- """
40
- Converts pydub audio segment into np.float32 of shape [duration_in_seconds*sample_rate, channels],
41
- where each value is in range [-1.0, 1.0].
42
- Returns tuple (audio_np_array, sample_rate).
43
- """
44
- return (
45
- audio.frame_rate,
46
- np.array(audio.get_array_of_samples(), dtype=np.float32).reshape(
47
- (-1, audio.channels)
48
- )
49
- / (1 << (8 * audio.sample_width - 1)),
50
- )
51
-
52
-
53
- def ndarray_to_segment(ndarray, frame_rate):
54
  buffer = BytesIO()
55
- sf.write(buffer, ndarray, frame_rate, format="wav")
56
  buffer.seek(0)
57
- sound = AudioSegment.from_wav(
58
- buffer,
59
- )
60
- return sound
61
 
 
 
 
 
62
 
63
- def time_stretch(input_segment: AudioSegment, time_factor: float) -> AudioSegment:
64
- """
65
- factor range -> [0.2,10]
66
- """
67
- time_factor = np.clip(time_factor, 0.2, 10)
68
- sr = input_segment.frame_rate
69
- y = audiosegment_to_librosawav(input_segment)
70
- y_stretch = pyrb.time_stretch(y, sr, time_factor)
71
-
72
- sound = ndarray_to_segment(
73
- y_stretch,
74
- frame_rate=sr,
75
  )
76
- return sound
77
 
78
 
79
- def pitch_shift(
80
- input_segment: AudioSegment,
81
- pitch_shift_factor: float,
 
 
 
82
  ) -> AudioSegment:
83
- """
84
- factor range -> [-12,12]
85
- """
86
- pitch_shift_factor = np.clip(pitch_shift_factor, -12, 12)
87
- sr = input_segment.frame_rate
88
- y = audiosegment_to_librosawav(input_segment)
89
- y_shift = pyrb.pitch_shift(y, sr, pitch_shift_factor)
90
-
91
- sound = ndarray_to_segment(
92
- y_shift,
93
- frame_rate=sr,
94
  )
95
- return sound
 
96
 
97
 
98
  def apply_prosody_to_audio_data(
@@ -114,6 +109,17 @@ def apply_prosody_to_audio_data(
114
  return audio_data
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
117
  if __name__ == "__main__":
118
  input_file = sys.argv[1]
119
 
@@ -123,11 +129,11 @@ if __name__ == "__main__":
123
  input_sound = AudioSegment.from_mp3(input_file)
124
 
125
  for time_factor in time_stretch_factors:
126
- output_wav = f"time_stretched_{int(time_factor * 100)}.wav"
127
- sound = time_stretch(input_sound, time_factor)
128
- sound.export(output_wav, format="wav")
129
 
130
  for pitch_factor in pitch_shift_factors:
131
- output_wav = f"pitch_shifted_{int(pitch_factor * 100)}.wav"
132
- sound = pitch_shift(input_sound, pitch_factor)
133
- sound.export(output_wav, format="wav")
 
2
  from io import BytesIO
3
 
4
  import numpy as np
 
5
  import soundfile as sf
6
+ from pydub import AudioSegment, effects
7
+ import pyrubberband as pyrb
8
 
9
  INT16_MAX = np.iinfo(np.int16).max
10
 
11
 
12
+ def audio_to_int16(audio_data: np.ndarray) -> np.ndarray:
13
  if (
14
  audio_data.dtype == np.float32
15
  or audio_data.dtype == np.float64
 
20
  return audio_data
21
 
22
 
23
+ def pydub_to_np(audio: AudioSegment) -> tuple[int, np.ndarray]:
24
+ """
25
+ Converts pydub audio segment into np.float32 of shape [duration_in_seconds*sample_rate, channels],
26
+ where each value is in range [-1.0, 1.0].
27
+ Returns tuple (audio_np_array, sample_rate).
28
+ """
29
+ nd_array = np.array(audio.get_array_of_samples(), dtype=np.float32)
30
+ if audio.channels != 1:
31
+ nd_array = nd_array.reshape((-1, audio.channels))
32
+ nd_array = nd_array / (1 << (8 * audio.sample_width - 1))
33
+
34
+ return (
35
+ audio.frame_rate,
36
+ nd_array,
37
+ )
38
+
39
+
40
  def audiosegment_to_librosawav(audiosegment: AudioSegment) -> np.ndarray:
41
  """
42
  Converts pydub audio segment into np.float32 of shape [duration_in_seconds*sample_rate, channels],
 
52
  return fp_arr
53
 
54
 
55
+ def ndarray_to_segment(
56
+ ndarray: np.ndarray, frame_rate: int, sample_width: int = None, channels: int = None
57
+ ) -> AudioSegment:
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  buffer = BytesIO()
59
+ sf.write(buffer, ndarray, frame_rate, format="wav", subtype="PCM_16")
60
  buffer.seek(0)
61
+ sound: AudioSegment = AudioSegment.from_wav(buffer)
 
 
 
62
 
63
+ if sample_width is None:
64
+ sample_width = sound.sample_width
65
+ if channels is None:
66
+ channels = sound.channels
67
 
68
+ return (
69
+ sound.set_frame_rate(frame_rate)
70
+ .set_sample_width(sample_width)
71
+ .set_channels(channels)
 
 
 
 
 
 
 
 
72
  )
 
73
 
74
 
75
+ def apply_prosody_to_audio_segment(
76
+ audio_segment: AudioSegment,
77
+ rate: float = 1,
78
+ volume: float = 0,
79
+ pitch: int = 0,
80
+ sr: int = 24000,
81
  ) -> AudioSegment:
82
+ audio_data = audiosegment_to_librosawav(audio_segment)
83
+
84
+ audio_data = apply_prosody_to_audio_data(audio_data, rate, volume, pitch, sr)
85
+
86
+ audio_segment = ndarray_to_segment(
87
+ audio_data, sr, audio_segment.sample_width, audio_segment.channels
 
 
 
 
 
88
  )
89
+
90
+ return audio_segment
91
 
92
 
93
  def apply_prosody_to_audio_data(
 
109
  return audio_data
110
 
111
 
112
+ def apply_normalize(
113
+ audio_data: np.ndarray,
114
+ headroom: float = 1,
115
+ sr: int = 24000,
116
+ ):
117
+ segment = ndarray_to_segment(audio_data, sr)
118
+ segment = effects.normalize(seg=segment, headroom=headroom)
119
+
120
+ return pydub_to_np(segment)
121
+
122
+
123
  if __name__ == "__main__":
124
  input_file = sys.argv[1]
125
 
 
129
  input_sound = AudioSegment.from_mp3(input_file)
130
 
131
  for time_factor in time_stretch_factors:
132
+ output_wav = f"{input_file}_time_{time_factor}.wav"
133
+ output_sound = apply_prosody_to_audio_segment(input_sound, rate=time_factor)
134
+ output_sound.export(output_wav, format="wav")
135
 
136
  for pitch_factor in pitch_shift_factors:
137
+ output_wav = f"{input_file}_pitch_{pitch_factor}.wav"
138
+ output_sound = apply_prosody_to_audio_segment(input_sound, pitch=pitch_factor)
139
+ output_sound.export(output_wav, format="wav")
modules/utils/detect_lang.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from typing import Literal
3
+
4
+
5
+ @lru_cache(maxsize=64)
6
+ def is_chinese(text):
7
+ for char in text:
8
+ if "\u4e00" <= char <= "\u9fff":
9
+ return True
10
+ return False
11
+
12
+
13
+ @lru_cache(maxsize=64)
14
+ def is_eng(text):
15
+ for char in text:
16
+ if "a" <= char.lower() <= "z":
17
+ return True
18
+ return False
19
+
20
+
21
+ @lru_cache(maxsize=64)
22
+ def guess_lang(text) -> Literal["zh", "en"]:
23
+ if is_chinese(text):
24
+ return "zh"
25
+ if is_eng(text):
26
+ return "en"
27
+ return "zh"
modules/utils/html.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from html.parser import HTMLParser
2
+
3
+
4
+ class HTMLTagRemover(HTMLParser):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.reset()
8
+ self.fed = []
9
+
10
+ def handle_data(self, data):
11
+ self.fed.append(data)
12
+
13
+ def get_data(self):
14
+ return "\n".join(self.fed)
15
+
16
+
17
+ def remove_html_tags(text):
18
+ parser = HTMLTagRemover()
19
+ parser.feed(text)
20
+ return parser.get_data()
21
+
22
+
23
+ if __name__ == "__main__":
24
+ input_text = "<h1>一个标题</h1> 这是一段包含<code>标签</code>的文本。"
25
+ output_text = remove_html_tags(input_text)
26
+ print(output_text) # 输出: 一个标题 这是一段包含标签的文本。
modules/utils/ignore_warn.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+
4
+ def ignore_useless_warnings():
5
+
6
+ # NOTE: 因为触发位置在 `vocos/heads.py:60` 改不动...所以忽略
7
+ warnings.filterwarnings(
8
+ "ignore", category=UserWarning, message="ComplexHalf support is experimental"
9
+ )
modules/utils/markdown.py CHANGED
@@ -36,6 +36,7 @@ class PlainTextRenderer(mistune.HTMLRenderer):
36
  return html + "\n" + text + "\n"
37
  return "\n" + text + "\n"
38
 
 
39
  def list_item(self, text):
40
  return "" + text + "\n"
41
 
 
36
  return html + "\n" + text + "\n"
37
  return "\n" + text + "\n"
38
 
39
+ # FIXME: 现在的 list 转换没法保留序号
40
  def list_item(self, text):
41
  return "" + text + "\n"
42
 
modules/webui/localization_runtime.py CHANGED
@@ -90,6 +90,28 @@ class ZHLocalizationVars(LocalizationVars):
90
  ]
91
 
92
  self.tts_examples = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  {
94
  "text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
95
  },
 
90
  ]
91
 
92
  self.tts_examples = [
93
+ {
94
+ "text": """
95
+ Fear is the path to the dark side. Fear leads to anger. Anger leads to hate. Hate leads to suffering.
96
+ 恐惧是通向黑暗之路。恐惧导致愤怒。愤怒引发仇恨。仇恨造成痛苦。 [lbreak]
97
+ Do. Or do not. There is no try.
98
+ 要么做,要么不做,没有试试看。[lbreak]
99
+ Peace is a lie, there is only passion.
100
+ 安宁即是谎言,激情方为王道。[lbreak]
101
+ Through passion, I gain strength.
102
+ 我以激情换取力量。[lbreak]
103
+ Through strength, I gain power.
104
+ 以力量赚取权力。[lbreak]
105
+ Through power, I gain victory.
106
+ 以权力赢取胜利。[lbreak]
107
+ Through victory, my chains are broken.
108
+ 于胜利中超越自我。[lbreak]
109
+ The Force shall free me.
110
+ 原力任我逍遥。[lbreak]
111
+ May the force be with you!
112
+ 愿原力与你同在![lbreak]
113
+ """.strip()
114
+ },
115
  {
116
  "text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
117
  },
modules/webui/speaker/speaker_creator.py CHANGED
@@ -62,11 +62,10 @@ def create_spk_from_seed(
62
  gender: str,
63
  desc: str,
64
  ):
65
- chat_tts = load_chat_tts()
66
- with SeedContext(seed, True):
67
- emb = chat_tts.sample_random_speaker()
68
- spk = Speaker(seed_or_tensor=-2, name=name, gender=gender, describe=desc)
69
- spk.emb = emb
70
 
71
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
72
  torch.save(spk, tmp_file)
@@ -82,7 +81,8 @@ def test_spk_voice(
82
  text: str,
83
  progress=gr.Progress(track_tqdm=True),
84
  ):
85
- return tts_generate(spk=seed, text=text, progress=progress)
 
86
 
87
 
88
  def random_speaker():
 
62
  gender: str,
63
  desc: str,
64
  ):
65
+ spk = Speaker.from_seed(seed)
66
+ spk.name = name
67
+ spk.gender = gender
68
+ spk.describe = desc
 
69
 
70
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
71
  torch.save(spk, tmp_file)
 
81
  text: str,
82
  progress=gr.Progress(track_tqdm=True),
83
  ):
84
+ spk = Speaker.from_seed(seed)
85
+ return tts_generate(spk=spk, text=text, progress=progress)
86
 
87
 
88
  def random_speaker():
modules/webui/ssml/podcast_tab.py CHANGED
@@ -124,7 +124,7 @@ def create_ssml_podcast_tab(ssml_input: gr.Textbox, tabs1: gr.Tabs, tabs2: gr.Ta
124
 
125
  def send_to_ssml(msg, spk, style, sheet: pd.DataFrame):
126
  if sheet.empty:
127
- return gr.Error("Please add some text to the script table.")
128
  msg, spk, style, ssml = merge_dataframe_to_ssml(msg, spk, style, sheet)
129
  return [
130
  msg,
 
124
 
125
  def send_to_ssml(msg, spk, style, sheet: pd.DataFrame):
126
  if sheet.empty:
127
+ raise gr.Error("Please add some text to the script table.")
128
  msg, spk, style, ssml = merge_dataframe_to_ssml(msg, spk, style, sheet)
129
  return [
130
  msg,
modules/webui/ssml/ssml_tab.py CHANGED
@@ -6,19 +6,6 @@ from modules.webui.webui_utils import synthesize_ssml
6
 
7
  def create_ssml_interface():
8
  with gr.Row():
9
- with gr.Column(scale=3):
10
- with gr.Group():
11
- gr.Markdown("📝SSML Input")
12
- gr.Markdown("SSML_TEXT_GUIDE")
13
- ssml_input = gr.Textbox(
14
- label="SSML Input",
15
- lines=10,
16
- value=webui_config.localization.DEFAULT_SSML_TEXT,
17
- placeholder="输入 SSML 或选择示例",
18
- elem_id="ssml_input",
19
- show_label=False,
20
- )
21
- ssml_button = gr.Button("🔊Synthesize SSML", variant="primary")
22
  with gr.Column(scale=1):
23
  with gr.Group():
24
  gr.Markdown("🎛️Parameters")
@@ -44,11 +31,64 @@ def create_ssml_interface():
44
  step=1,
45
  )
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  with gr.Group():
48
  gr.Markdown("💪🏼Enhance")
49
  enable_enhance = gr.Checkbox(value=True, label="Enable Enhance")
50
  enable_de_noise = gr.Checkbox(value=False, label="Enable De-noise")
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  with gr.Group():
53
  gr.Markdown("🎄Examples")
54
  gr.Examples(
@@ -56,7 +96,9 @@ def create_ssml_interface():
56
  inputs=[ssml_input],
57
  )
58
 
59
- ssml_output = gr.Audio(label="Generated Audio", format="mp3")
 
 
60
 
61
  ssml_button.click(
62
  synthesize_ssml,
@@ -67,6 +109,11 @@ def create_ssml_interface():
67
  enable_de_noise,
68
  eos_input,
69
  spliter_thr_input,
 
 
 
 
 
70
  ],
71
  outputs=ssml_output,
72
  )
 
6
 
7
  def create_ssml_interface():
8
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  with gr.Column(scale=1):
10
  with gr.Group():
11
  gr.Markdown("🎛️Parameters")
 
31
  step=1,
32
  )
33
 
34
+ with gr.Group():
35
+ gr.Markdown("🎛️Adjuster")
36
+ # 调节 speed pitch volume
37
+ # 可以选择开启 响度均衡
38
+
39
+ speed_input = gr.Slider(
40
+ label="Speed",
41
+ value=1.0,
42
+ minimum=0.5,
43
+ maximum=2.0,
44
+ step=0.1,
45
+ )
46
+ pitch_input = gr.Slider(
47
+ label="Pitch",
48
+ value=0,
49
+ minimum=-12,
50
+ maximum=12,
51
+ step=0.1,
52
+ )
53
+ volume_up_input = gr.Slider(
54
+ label="Volume Gain",
55
+ value=0,
56
+ minimum=-12,
57
+ maximum=12,
58
+ step=0.1,
59
+ )
60
+
61
+ enable_loudness_normalization = gr.Checkbox(
62
+ value=True,
63
+ label="Enable Loudness EQ",
64
+ )
65
+ headroom_input = gr.Slider(
66
+ label="Headroom",
67
+ value=1,
68
+ minimum=0,
69
+ maximum=12,
70
+ step=0.1,
71
+ )
72
+
73
  with gr.Group():
74
  gr.Markdown("💪🏼Enhance")
75
  enable_enhance = gr.Checkbox(value=True, label="Enable Enhance")
76
  enable_de_noise = gr.Checkbox(value=False, label="Enable De-noise")
77
 
78
+ with gr.Column(scale=3):
79
+ with gr.Group():
80
+ gr.Markdown("📝SSML Input")
81
+ gr.Markdown("SSML_TEXT_GUIDE")
82
+ ssml_input = gr.Textbox(
83
+ label="SSML Input",
84
+ lines=10,
85
+ value=webui_config.localization.DEFAULT_SSML_TEXT,
86
+ placeholder="输入 SSML 或选择示例",
87
+ elem_id="ssml_input",
88
+ show_label=False,
89
+ )
90
+ ssml_button = gr.Button("🔊Synthesize SSML", variant="primary")
91
+
92
  with gr.Group():
93
  gr.Markdown("🎄Examples")
94
  gr.Examples(
 
96
  inputs=[ssml_input],
97
  )
98
 
99
+ with gr.Group():
100
+ gr.Markdown("🎨Output")
101
+ ssml_output = gr.Audio(label="Generated Audio", format="mp3")
102
 
103
  ssml_button.click(
104
  synthesize_ssml,
 
109
  enable_de_noise,
110
  eos_input,
111
  spliter_thr_input,
112
+ pitch_input,
113
+ speed_input,
114
+ volume_up_input,
115
+ enable_loudness_normalization,
116
+ headroom_input,
117
  ],
118
  outputs=ssml_output,
119
  )
modules/webui/tts_tab.py CHANGED
@@ -228,14 +228,56 @@ def create_tts_interface():
228
  label="prompt_audio", visible=webui_config.experimental
229
  )
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  with gr.Group():
232
  gr.Markdown("🔊Generate")
233
  disable_normalize_input = gr.Checkbox(
234
- value=False, label="Disable Normalize"
 
 
 
235
  )
236
 
237
  with gr.Group():
238
- gr.Markdown("💪🏼Enhance")
239
  enable_enhance = gr.Checkbox(value=True, label="Enable Enhance")
240
  enable_de_noise = gr.Checkbox(value=False, label="Enable De-noise")
241
  tts_button = gr.Button(
@@ -271,6 +313,11 @@ def create_tts_interface():
271
  spk_file_upload,
272
  spliter_thr_input,
273
  eos_input,
 
 
 
 
 
274
  ],
275
  outputs=tts_output,
276
  )
 
228
  label="prompt_audio", visible=webui_config.experimental
229
  )
230
 
231
+ with gr.Group():
232
+ gr.Markdown("🎛️Adjuster")
233
+ # 调节 speed pitch volume
234
+ # 可以选择开启 响度均衡
235
+
236
+ speed_input = gr.Slider(
237
+ label="Speed",
238
+ value=1.0,
239
+ minimum=0.5,
240
+ maximum=2.0,
241
+ step=0.1,
242
+ )
243
+ pitch_input = gr.Slider(
244
+ label="Pitch",
245
+ value=0,
246
+ minimum=-12,
247
+ maximum=12,
248
+ step=0.1,
249
+ )
250
+ volume_up_input = gr.Slider(
251
+ label="Volume Gain",
252
+ value=0,
253
+ minimum=-12,
254
+ maximum=12,
255
+ step=0.1,
256
+ )
257
+
258
+ enable_loudness_normalization = gr.Checkbox(
259
+ value=True,
260
+ label="Enable Loudness EQ",
261
+ )
262
+ headroom_input = gr.Slider(
263
+ label="Headroom",
264
+ value=1,
265
+ minimum=0,
266
+ maximum=12,
267
+ step=0.1,
268
+ )
269
+
270
  with gr.Group():
271
  gr.Markdown("🔊Generate")
272
  disable_normalize_input = gr.Checkbox(
273
+ value=False,
274
+ label="Disable Normalize",
275
+ # 不需要了
276
+ visible=False,
277
  )
278
 
279
  with gr.Group():
280
+ # gr.Markdown("💪🏼Enhance")
281
  enable_enhance = gr.Checkbox(value=True, label="Enable Enhance")
282
  enable_de_noise = gr.Checkbox(value=False, label="Enable De-noise")
283
  tts_button = gr.Button(
 
313
  spk_file_upload,
314
  spliter_thr_input,
315
  eos_input,
316
+ pitch_input,
317
+ speed_input,
318
+ volume_up_input,
319
+ enable_loudness_normalization,
320
+ headroom_input,
321
  ],
322
  outputs=tts_output,
323
  )
modules/webui/webui_utils.py CHANGED
@@ -6,6 +6,11 @@ import torch
6
  import torch.profiler
7
 
8
  from modules import refiner
 
 
 
 
 
9
  from modules.api.utils import calc_spk_style
10
  from modules.data import styles_mgr
11
  from modules.Enhancer.ResembleEnhance import apply_audio_enhance as _apply_audio_enhance
@@ -13,8 +18,6 @@ from modules.normalization import text_normalize
13
  from modules.SentenceSplitter import SentenceSplitter
14
  from modules.speaker import Speaker, speaker_mgr
15
  from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLSegment, create_ssml_parser
16
- from modules.synthesize_audio import synthesize_audio
17
- from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
18
  from modules.utils import audio
19
  from modules.utils.hf import spaces
20
  from modules.webui import webui_config
@@ -89,6 +92,11 @@ def synthesize_ssml(
89
  enable_denoise=False,
90
  eos: str = "[uv_break]",
91
  spliter_thr: int = 100,
 
 
 
 
 
92
  progress=gr.Progress(track_tqdm=True),
93
  ):
94
  try:
@@ -99,7 +107,7 @@ def synthesize_ssml(
99
  ssml = ssml.strip()
100
 
101
  if ssml == "":
102
- return None
103
 
104
  parser = create_ssml_parser()
105
  segments = parser.parse(ssml)
@@ -107,22 +115,36 @@ def synthesize_ssml(
107
  segments = segments_length_limit(segments, max_len)
108
 
109
  if len(segments) == 0:
110
- return None
111
 
112
- synthesize = SynthesizeSegments(
113
- batch_size=batch_size, eos=eos, spliter_thr=spliter_thr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  )
115
- audio_segments = synthesize.synthesize_segments(segments)
116
- combined_audio = combine_audio_segments(audio_segments)
117
-
118
- sr = combined_audio.frame_rate
119
- audio_data, sr = apply_audio_enhance(
120
- audio.audiosegment_to_librosawav(combined_audio),
121
- sr,
122
- enable_denoise,
123
- enable_enhance,
124
  )
125
 
 
 
126
  # NOTE: 这里必须要加,不然 gradio 没法解析成 mp3 格式
127
  audio_data = audio.audio_to_int16(audio_data)
128
 
@@ -150,6 +172,11 @@ def tts_generate(
150
  spk_file=None,
151
  spliter_thr: int = 100,
152
  eos: str = "[uv_break]",
 
 
 
 
 
153
  progress=gr.Progress(track_tqdm=True),
154
  ):
155
  try:
@@ -161,10 +188,10 @@ def tts_generate(
161
  text = text.strip()[0:max_len]
162
 
163
  if text == "":
164
- return None
165
 
166
  if style == "*auto":
167
- style = None
168
 
169
  if isinstance(top_k, float):
170
  top_k = int(top_k)
@@ -181,31 +208,56 @@ def tts_generate(
181
  infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.float64)
182
  infer_seed = int(infer_seed)
183
 
184
- if not disable_normalize:
185
- text = text_normalize(text)
186
 
187
  if spk_file:
188
- spk = Speaker.from_file(spk_file)
 
 
 
189
 
190
- sample_rate, audio_data = synthesize_audio(
191
- text=text,
 
 
 
192
  temperature=temperature,
193
- top_P=top_p,
194
- top_K=top_k,
195
- spk=spk,
196
- infer_seed=infer_seed,
197
- use_decoder=use_decoder,
198
  prompt1=prompt1,
199
  prompt2=prompt2,
200
- prefix=prefix,
 
201
  batch_size=batch_size,
202
- end_of_sentence=eos,
203
  spliter_threshold=spliter_thr,
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  )
205
 
206
- audio_data, sample_rate = apply_audio_enhance(
207
- audio_data, sample_rate, enable_denoise, enable_enhance
 
 
 
 
 
208
  )
 
 
 
209
  # NOTE: 这里必须要加,不然 gradio 没法解析成 mp3 格式
210
  audio_data = audio.audio_to_int16(audio_data)
211
  return sample_rate, audio_data
 
6
  import torch.profiler
7
 
8
  from modules import refiner
9
+ from modules.api.impl.handler.SSMLHandler import SSMLHandler
10
+ from modules.api.impl.handler.TTSHandler import TTSHandler
11
+ from modules.api.impl.model.audio_model import AdjustConfig
12
+ from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
13
+ from modules.api.impl.model.enhancer_model import EnhancerConfig
14
  from modules.api.utils import calc_spk_style
15
  from modules.data import styles_mgr
16
  from modules.Enhancer.ResembleEnhance import apply_audio_enhance as _apply_audio_enhance
 
18
  from modules.SentenceSplitter import SentenceSplitter
19
  from modules.speaker import Speaker, speaker_mgr
20
  from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLSegment, create_ssml_parser
 
 
21
  from modules.utils import audio
22
  from modules.utils.hf import spaces
23
  from modules.webui import webui_config
 
92
  enable_denoise=False,
93
  eos: str = "[uv_break]",
94
  spliter_thr: int = 100,
95
+ pitch: float = 0,
96
+ speed_rate: float = 1,
97
+ volume_gain_db: float = 0,
98
+ normalize: bool = True,
99
+ headroom: float = 1,
100
  progress=gr.Progress(track_tqdm=True),
101
  ):
102
  try:
 
107
  ssml = ssml.strip()
108
 
109
  if ssml == "":
110
+ raise gr.Error("SSML is empty, please input some SSML")
111
 
112
  parser = create_ssml_parser()
113
  segments = parser.parse(ssml)
 
115
  segments = segments_length_limit(segments, max_len)
116
 
117
  if len(segments) == 0:
118
+ raise gr.Error("No valid segments in SSML")
119
 
120
+ infer_config = InferConfig(
121
+ batch_size=batch_size,
122
+ spliter_threshold=spliter_thr,
123
+ eos=eos,
124
+ # NOTE: SSML not support `infer_seed` contorl
125
+ # seed=42,
126
+ )
127
+ adjust_config = AdjustConfig(
128
+ pitch=pitch,
129
+ speed_rate=speed_rate,
130
+ volume_gain_db=volume_gain_db,
131
+ normalize=normalize,
132
+ headroom=headroom,
133
+ )
134
+ enhancer_config = EnhancerConfig(
135
+ enabled=enable_denoise or enable_enhance or False,
136
+ lambd=0.9 if enable_denoise else 0.1,
137
  )
138
+
139
+ handler = SSMLHandler(
140
+ ssml_content=ssml,
141
+ infer_config=infer_config,
142
+ adjust_config=adjust_config,
143
+ enhancer_config=enhancer_config,
 
 
 
144
  )
145
 
146
+ audio_data, sr = handler.enqueue()
147
+
148
  # NOTE: 这里必须要加,不然 gradio 没法解析成 mp3 格式
149
  audio_data = audio.audio_to_int16(audio_data)
150
 
 
172
  spk_file=None,
173
  spliter_thr: int = 100,
174
  eos: str = "[uv_break]",
175
+ pitch: float = 0,
176
+ speed_rate: float = 1,
177
+ volume_gain_db: float = 0,
178
+ normalize: bool = True,
179
+ headroom: float = 1,
180
  progress=gr.Progress(track_tqdm=True),
181
  ):
182
  try:
 
188
  text = text.strip()[0:max_len]
189
 
190
  if text == "":
191
+ raise gr.Error("Text is empty, please input some text")
192
 
193
  if style == "*auto":
194
+ style = ""
195
 
196
  if isinstance(top_k, float):
197
  top_k = int(top_k)
 
208
  infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.float64)
209
  infer_seed = int(infer_seed)
210
 
211
+ if isinstance(spk, int):
212
+ spk = Speaker.from_seed(spk)
213
 
214
  if spk_file:
215
+ try:
216
+ spk: Speaker = Speaker.from_file(spk_file)
217
+ except Exception:
218
+ raise gr.Error("Failed to load speaker file")
219
 
220
+ if not isinstance(spk.emb, torch.Tensor):
221
+ raise gr.Error("Speaker file is not supported")
222
+
223
+ tts_config = ChatTTSConfig(
224
+ style=style,
225
  temperature=temperature,
226
+ top_k=top_k,
227
+ top_p=top_p,
228
+ prefix=prefix,
 
 
229
  prompt1=prompt1,
230
  prompt2=prompt2,
231
+ )
232
+ infer_config = InferConfig(
233
  batch_size=batch_size,
 
234
  spliter_threshold=spliter_thr,
235
+ eos=eos,
236
+ seed=infer_seed,
237
+ )
238
+ adjust_config = AdjustConfig(
239
+ pitch=pitch,
240
+ speed_rate=speed_rate,
241
+ volume_gain_db=volume_gain_db,
242
+ normalize=normalize,
243
+ headroom=headroom,
244
+ )
245
+ enhancer_config = EnhancerConfig(
246
+ enabled=enable_denoise or enable_enhance or False,
247
+ lambd=0.9 if enable_denoise else 0.1,
248
  )
249
 
250
+ handler = TTSHandler(
251
+ text_content=text,
252
+ spk=spk,
253
+ tts_config=tts_config,
254
+ infer_config=infer_config,
255
+ adjust_config=adjust_config,
256
+ enhancer_config=enhancer_config,
257
  )
258
+
259
+ audio_data, sample_rate = handler.enqueue()
260
+
261
  # NOTE: 这里必须要加,不然 gradio 没法解析成 mp3 格式
262
  audio_data = audio.audio_to_int16(audio_data)
263
  return sample_rate, audio_data
requirements.txt CHANGED
@@ -1,27 +1,29 @@
1
- numpy
2
  scipy
3
  lxml
4
  pydub
5
  fastapi
6
  soundfile
7
- pyrubberband
8
  omegaconf
9
  pypinyin
 
10
  pandas
11
  vector_quantize_pytorch
12
  einops
 
13
  omegaconf~=2.3.0
14
  tqdm
15
- huggingface_hub>=0.22.2,<1.0
16
- vocos==0.0.1
17
- transformers==4.41.2
18
- torch
19
- torchvision
20
- torchaudio
21
  gradio
22
  emojiswitch
23
  python-dotenv
24
  zhon
25
  mistune==3.0.2
26
  cn2an
27
- python-box
 
 
 
 
 
1
+ numpy==1.26.4
2
  scipy
3
  lxml
4
  pydub
5
  fastapi
6
  soundfile
 
7
  omegaconf
8
  pypinyin
9
+ vocos
10
  pandas
11
  vector_quantize_pytorch
12
  einops
13
+ transformers~=4.41.1
14
  omegaconf~=2.3.0
15
  tqdm
16
+ # torch
17
+ # torchvision
18
+ # torchaudio
 
 
 
19
  gradio
20
  emojiswitch
21
  python-dotenv
22
  zhon
23
  mistune==3.0.2
24
  cn2an
25
+ # audio_denoiser
26
+ python-box
27
+ ftfy
28
+ librosa
29
+ pyrubberband
webui.py CHANGED
@@ -6,6 +6,7 @@ from modules.ffmpeg_env import setup_ffmpeg_path
6
 
7
  try:
8
  setup_ffmpeg_path()
 
9
  logging.basicConfig(
10
  level=os.getenv("LOG_LEVEL", "INFO"),
11
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
@@ -16,20 +17,18 @@ except BaseException:
16
  import argparse
17
 
18
  from modules import config
19
- from modules.api.api_setup import (
20
- process_api_args,
21
- process_model_args,
22
- setup_api_args,
23
- setup_model_args,
24
- )
25
  from modules.api.app_config import app_description, app_title, app_version
26
  from modules.gradio_dcls_fix import dcls_patch
 
27
  from modules.utils.env import get_and_update_env
 
28
  from modules.utils.torch_opt import configure_torch_optimizations
29
  from modules.webui import webui_config
30
  from modules.webui.app import create_interface, webui_init
31
 
32
  dcls_patch()
 
33
 
34
 
35
  def setup_webui_args(parser: argparse.ArgumentParser):
 
6
 
7
  try:
8
  setup_ffmpeg_path()
9
+ # NOTE: 因为 logger 都是在模块中初始化,所以这个 config 必须在最前面
10
  logging.basicConfig(
11
  level=os.getenv("LOG_LEVEL", "INFO"),
12
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
 
17
  import argparse
18
 
19
  from modules import config
20
+ from modules.api.api_setup import process_api_args, setup_api_args
 
 
 
 
 
21
  from modules.api.app_config import app_description, app_title, app_version
22
  from modules.gradio_dcls_fix import dcls_patch
23
+ from modules.models_setup import process_model_args, setup_model_args
24
  from modules.utils.env import get_and_update_env
25
+ from modules.utils.ignore_warn import ignore_useless_warnings
26
  from modules.utils.torch_opt import configure_torch_optimizations
27
  from modules.webui import webui_config
28
  from modules.webui.app import create_interface, webui_init
29
 
30
  dcls_patch()
31
+ ignore_useless_warnings()
32
 
33
 
34
  def setup_webui_args(parser: argparse.ArgumentParser):