zhzluke96
commited on
Commit
•
627d3d7
1
Parent(s):
72bb5b8
update
Browse files- language/zh-CN.json +2 -2
- launch.py +36 -168
- modules/Enhancer/ResembleEnhance.py +94 -16
- modules/api/Api.py +1 -1
- modules/api/api_setup.py +164 -0
- modules/api/app_config.py +14 -0
- modules/api/impl/google_api.py +47 -17
- modules/api/impl/models_api.py +9 -2
- modules/api/worker.py +49 -0
- modules/config.py +2 -1
- modules/ffmpeg_env.py +16 -0
- modules/generate_audio.py +3 -3
- modules/gradio_dcls_fix.py +1 -0
- modules/models.py +2 -3
- modules/repos_static/resemble_enhance/denoiser/denoiser.py +7 -2
- modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py +10 -4
- modules/repos_static/resemble_enhance/inference.py +27 -6
- modules/speaker.py +1 -1
- modules/utils/env.py +9 -0
- modules/utils/ffmpeg.py +20 -0
- modules/utils/git.py +4 -12
- modules/utils/hf.py +17 -0
- modules/utils/torch_opt.py +7 -0
- modules/webui/app.py +4 -7
- modules/webui/gradio_extensions.py +1 -1
- modules/webui/js/index.js +1 -0
- modules/webui/speaker/speaker_creator.py +1 -1
- modules/webui/speaker/speaker_editor.py +1 -1
- modules/webui/speaker/speaker_merger.py +1 -1
- modules/webui/ssml/podcast_tab.py +11 -11
- modules/webui/ssml/spliter_tab.py +1 -1
- modules/webui/webui_utils.py +3 -18
- webui.py +26 -11
language/zh-CN.json
CHANGED
@@ -31,8 +31,8 @@
|
|
31 |
"🔊Generate": "🔊生成",
|
32 |
"Disable Normalize": "禁用文本预处理",
|
33 |
"💪🏼Enhance": "💪🏼增强",
|
34 |
-
"Enable Enhance": "
|
35 |
-
"Enable De-noise": "
|
36 |
"🔊Generate Audio": "🔊生成音频",
|
37 |
"SSML": "SSML",
|
38 |
"Editor": "编辑器",
|
|
|
31 |
"🔊Generate": "🔊生成",
|
32 |
"Disable Normalize": "禁用文本预处理",
|
33 |
"💪🏼Enhance": "💪🏼增强",
|
34 |
+
"Enable Enhance": "启用人声增强",
|
35 |
+
"Enable De-noise": "启用背景降噪",
|
36 |
"🔊Generate Audio": "🔊生成音频",
|
37 |
"SSML": "SSML",
|
38 |
"Editor": "编辑器",
|
launch.py
CHANGED
@@ -1,201 +1,69 @@
|
|
1 |
import os
|
2 |
import logging
|
3 |
|
|
|
|
|
|
|
|
|
4 |
logging.basicConfig(
|
5 |
level=os.getenv("LOG_LEVEL", "INFO"),
|
6 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
7 |
)
|
8 |
|
9 |
-
from modules.devices import devices
|
10 |
import argparse
|
11 |
import uvicorn
|
12 |
|
13 |
-
import torch
|
14 |
from modules import config
|
15 |
from modules.utils import env
|
16 |
-
from modules import generate_audio as generate
|
17 |
-
from modules.api.Api import APIManager
|
18 |
|
19 |
-
from
|
20 |
-
style_api,
|
21 |
-
tts_api,
|
22 |
-
ssml_api,
|
23 |
-
google_api,
|
24 |
-
openai_api,
|
25 |
-
refiner_api,
|
26 |
-
speaker_api,
|
27 |
-
ping_api,
|
28 |
-
models_api,
|
29 |
-
)
|
30 |
|
31 |
logger = logging.getLogger(__name__)
|
32 |
|
33 |
-
torch._dynamo.config.cache_size_limit = 64
|
34 |
-
torch._dynamo.config.suppress_errors = True
|
35 |
-
torch.set_float32_matmul_precision("high")
|
36 |
-
|
37 |
-
|
38 |
-
def create_api(app, no_docs=False, exclude=[]):
|
39 |
-
app_mgr = APIManager(app=app, no_docs=no_docs, exclude_patterns=exclude)
|
40 |
-
|
41 |
-
ping_api.setup(app_mgr)
|
42 |
-
models_api.setup(app_mgr)
|
43 |
-
style_api.setup(app_mgr)
|
44 |
-
speaker_api.setup(app_mgr)
|
45 |
-
tts_api.setup(app_mgr)
|
46 |
-
ssml_api.setup(app_mgr)
|
47 |
-
google_api.setup(app_mgr)
|
48 |
-
openai_api.setup(app_mgr)
|
49 |
-
refiner_api.setup(app_mgr)
|
50 |
-
|
51 |
-
return app_mgr
|
52 |
-
|
53 |
-
|
54 |
-
def get_and_update_env(*args):
|
55 |
-
val = env.get_env_or_arg(*args)
|
56 |
-
key = args[1]
|
57 |
-
config.runtime_env_vars[key] = val
|
58 |
-
return val
|
59 |
-
|
60 |
-
|
61 |
-
def setup_model_args(parser: argparse.ArgumentParser):
|
62 |
-
parser.add_argument("--compile", action="store_true", help="Enable model compile")
|
63 |
-
parser.add_argument(
|
64 |
-
"--half",
|
65 |
-
action="store_true",
|
66 |
-
help="Enable half precision for model inference",
|
67 |
-
)
|
68 |
-
parser.add_argument(
|
69 |
-
"--off_tqdm",
|
70 |
-
action="store_true",
|
71 |
-
help="Disable tqdm progress bar",
|
72 |
-
)
|
73 |
-
parser.add_argument(
|
74 |
-
"--device_id",
|
75 |
-
type=str,
|
76 |
-
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
|
77 |
-
default=None,
|
78 |
-
)
|
79 |
-
parser.add_argument(
|
80 |
-
"--use_cpu",
|
81 |
-
nargs="+",
|
82 |
-
help="use CPU as torch device for specified modules",
|
83 |
-
default=[],
|
84 |
-
type=str.lower,
|
85 |
-
)
|
86 |
-
parser.add_argument(
|
87 |
-
"--lru_size",
|
88 |
-
type=int,
|
89 |
-
default=64,
|
90 |
-
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
|
91 |
-
)
|
92 |
-
|
93 |
-
|
94 |
-
def setup_api_args(parser: argparse.ArgumentParser):
|
95 |
-
parser.add_argument("--api_host", type=str, help="Host to run the server on")
|
96 |
-
parser.add_argument("--api_port", type=int, help="Port to run the server on")
|
97 |
-
parser.add_argument(
|
98 |
-
"--reload", action="store_true", help="Enable auto-reload for development"
|
99 |
-
)
|
100 |
-
parser.add_argument(
|
101 |
-
"--cors_origin",
|
102 |
-
type=str,
|
103 |
-
help="Allowed CORS origins. Use '*' to allow all origins.",
|
104 |
-
)
|
105 |
-
parser.add_argument(
|
106 |
-
"--no_playground",
|
107 |
-
action="store_true",
|
108 |
-
help="Disable the playground entry",
|
109 |
-
)
|
110 |
-
parser.add_argument(
|
111 |
-
"--no_docs",
|
112 |
-
action="store_true",
|
113 |
-
help="Disable the documentation entry",
|
114 |
-
)
|
115 |
-
# 配置哪些api要跳过 比如 exclude="/v1/speakers/*,/v1/tts/*"
|
116 |
-
parser.add_argument(
|
117 |
-
"--exclude",
|
118 |
-
type=str,
|
119 |
-
help="Exclude the specified API from the server",
|
120 |
-
)
|
121 |
-
|
122 |
-
|
123 |
-
def process_model_args(args):
|
124 |
-
lru_size = get_and_update_env(args, "lru_size", 64, int)
|
125 |
-
compile = get_and_update_env(args, "compile", False, bool)
|
126 |
-
device_id = get_and_update_env(args, "device_id", None, str)
|
127 |
-
use_cpu = get_and_update_env(args, "use_cpu", [], list)
|
128 |
-
half = get_and_update_env(args, "half", False, bool)
|
129 |
-
off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
|
130 |
-
|
131 |
-
generate.setup_lru_cache()
|
132 |
-
devices.reset_device()
|
133 |
-
devices.first_time_calculation()
|
134 |
-
|
135 |
-
|
136 |
-
def process_api_args(args, app):
|
137 |
-
cors_origin = get_and_update_env(args, "cors_origin", "*", str)
|
138 |
-
no_playground = get_and_update_env(args, "no_playground", False, bool)
|
139 |
-
no_docs = get_and_update_env(args, "no_docs", False, bool)
|
140 |
-
exclude = get_and_update_env(args, "exclude", "", str)
|
141 |
-
|
142 |
-
api = create_api(app=app, no_docs=no_docs, exclude=exclude.split(","))
|
143 |
-
config.api = api
|
144 |
-
|
145 |
-
if cors_origin:
|
146 |
-
api.set_cors(allow_origins=[cors_origin])
|
147 |
-
|
148 |
-
if not no_playground:
|
149 |
-
api.setup_playground()
|
150 |
-
|
151 |
-
if compile:
|
152 |
-
logger.info("Model compile is enabled")
|
153 |
-
|
154 |
-
|
155 |
-
app_description = """
|
156 |
-
ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
|
157 |
-
ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
|
158 |
-
|
159 |
-
项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
|
160 |
-
|
161 |
-
> 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
|
162 |
-
> All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging
|
163 |
-
|
164 |
-
> 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
|
165 |
-
> [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
|
166 |
-
"""
|
167 |
-
app_title = "ChatTTS Forge API"
|
168 |
-
app_version = "0.1.0"
|
169 |
-
|
170 |
if __name__ == "__main__":
|
171 |
import dotenv
|
172 |
-
from fastapi import FastAPI
|
173 |
|
174 |
dotenv.load_dotenv(
|
175 |
dotenv_path=os.getenv("ENV_FILE", ".env.api"),
|
176 |
)
|
177 |
-
|
178 |
parser = argparse.ArgumentParser(
|
179 |
description="Start the FastAPI server with command line arguments"
|
180 |
)
|
181 |
setup_api_args(parser)
|
182 |
setup_model_args(parser)
|
|
|
183 |
|
184 |
args = parser.parse_args()
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
)
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import logging
|
3 |
|
4 |
+
from modules.api.api_setup import setup_api_args, setup_model_args, setup_uvicon_args
|
5 |
+
from modules.ffmpeg_env import setup_ffmpeg_path
|
6 |
+
|
7 |
+
setup_ffmpeg_path()
|
8 |
logging.basicConfig(
|
9 |
level=os.getenv("LOG_LEVEL", "INFO"),
|
10 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
11 |
)
|
12 |
|
|
|
13 |
import argparse
|
14 |
import uvicorn
|
15 |
|
|
|
16 |
from modules import config
|
17 |
from modules.utils import env
|
|
|
|
|
18 |
|
19 |
+
from fastapi import FastAPI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
logger = logging.getLogger(__name__)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
if __name__ == "__main__":
|
24 |
import dotenv
|
|
|
25 |
|
26 |
dotenv.load_dotenv(
|
27 |
dotenv_path=os.getenv("ENV_FILE", ".env.api"),
|
28 |
)
|
|
|
29 |
parser = argparse.ArgumentParser(
|
30 |
description="Start the FastAPI server with command line arguments"
|
31 |
)
|
32 |
setup_api_args(parser)
|
33 |
setup_model_args(parser)
|
34 |
+
setup_uvicon_args(parser=parser)
|
35 |
|
36 |
args = parser.parse_args()
|
37 |
|
38 |
+
host = env.get_and_update_env(args, "host", "0.0.0.0", str)
|
39 |
+
port = env.get_and_update_env(args, "port", 7870, int)
|
40 |
+
reload = env.get_and_update_env(args, "reload", False, bool)
|
41 |
+
workers = env.get_and_update_env(args, "workers", 1, int)
|
42 |
+
log_level = env.get_and_update_env(args, "log_level", "info", str)
|
43 |
+
access_log = env.get_and_update_env(args, "access_log", True, bool)
|
44 |
+
proxy_headers = env.get_and_update_env(args, "proxy_headers", True, bool)
|
45 |
+
timeout_keep_alive = env.get_and_update_env(args, "timeout_keep_alive", 5, int)
|
46 |
+
timeout_graceful_shutdown = env.get_and_update_env(
|
47 |
+
args, "timeout_graceful_shutdown", 0, int
|
48 |
+
)
|
49 |
+
ssl_keyfile = env.get_and_update_env(args, "ssl_keyfile", None, str)
|
50 |
+
ssl_certfile = env.get_and_update_env(args, "ssl_certfile", None, str)
|
51 |
+
ssl_keyfile_password = env.get_and_update_env(
|
52 |
+
args, "ssl_keyfile_password", None, str
|
53 |
)
|
54 |
|
55 |
+
uvicorn.run(
|
56 |
+
"modules.api.worker:app",
|
57 |
+
host=host,
|
58 |
+
port=port,
|
59 |
+
reload=reload,
|
60 |
+
workers=workers,
|
61 |
+
log_level=log_level,
|
62 |
+
access_log=access_log,
|
63 |
+
proxy_headers=proxy_headers,
|
64 |
+
timeout_keep_alive=timeout_keep_alive,
|
65 |
+
timeout_graceful_shutdown=timeout_graceful_shutdown,
|
66 |
+
ssl_keyfile=ssl_keyfile,
|
67 |
+
ssl_certfile=ssl_certfile,
|
68 |
+
ssl_keyfile_password=ssl_keyfile_password,
|
69 |
+
)
|
modules/Enhancer/ResembleEnhance.py
CHANGED
@@ -1,5 +1,8 @@
|
|
|
|
1 |
import os
|
2 |
from typing import List, Literal
|
|
|
|
|
3 |
from modules.devices import devices
|
4 |
from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
|
5 |
from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
|
@@ -11,53 +14,54 @@ from modules.utils.constants import MODELS_DIR
|
|
11 |
from pathlib import Path
|
12 |
|
13 |
from threading import Lock
|
|
|
14 |
|
15 |
-
|
16 |
-
lock = Lock()
|
17 |
|
|
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
with lock:
|
22 |
-
if resemble_enhance is None:
|
23 |
-
resemble_enhance = ResembleEnhance(device)
|
24 |
-
resemble_enhance.load_model()
|
25 |
-
return resemble_enhance
|
26 |
|
27 |
|
28 |
class ResembleEnhance:
|
29 |
-
def __init__(self, device: torch.device):
|
30 |
self.device = device
|
|
|
31 |
|
32 |
self.enhancer: HParams = None
|
33 |
self.hparams: Enhancer = None
|
34 |
|
35 |
def load_model(self):
|
36 |
hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
|
37 |
-
enhancer = Enhancer(hparams)
|
38 |
state_dict = torch.load(
|
39 |
Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
|
40 |
map_location=self.device,
|
41 |
)["module"]
|
42 |
enhancer.load_state_dict(state_dict)
|
43 |
-
enhancer.to(self.device).eval()
|
44 |
|
45 |
self.hparams = hparams
|
46 |
self.enhancer = enhancer
|
47 |
|
48 |
@torch.inference_mode()
|
49 |
-
def denoise(self, dwav, sr
|
50 |
assert self.enhancer is not None, "Model not loaded"
|
51 |
assert self.enhancer.denoiser is not None, "Denoiser not loaded"
|
52 |
enhancer = self.enhancer
|
53 |
-
return inference(
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
@torch.inference_mode()
|
56 |
def enhance(
|
57 |
self,
|
58 |
dwav,
|
59 |
sr,
|
60 |
-
device,
|
61 |
nfe=32,
|
62 |
solver: Literal["midpoint", "rk4", "euler"] = "midpoint",
|
63 |
lambd=0.5,
|
@@ -74,7 +78,81 @@ class ResembleEnhance:
|
|
74 |
assert self.enhancer is not None, "Model not loaded"
|
75 |
enhancer = self.enhancer
|
76 |
enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
|
77 |
-
return inference(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
|
80 |
if __name__ == "__main__":
|
|
|
1 |
+
import gc
|
2 |
import os
|
3 |
from typing import List, Literal
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
from modules.devices import devices
|
7 |
from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
|
8 |
from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
|
|
|
14 |
from pathlib import Path
|
15 |
|
16 |
from threading import Lock
|
17 |
+
from modules import config
|
18 |
|
19 |
+
import logging
|
|
|
20 |
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
|
23 |
+
resemble_enhance = None
|
24 |
+
lock = Lock()
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
class ResembleEnhance:
|
28 |
+
def __init__(self, device: torch.device, dtype=torch.float32):
|
29 |
self.device = device
|
30 |
+
self.dtype = dtype
|
31 |
|
32 |
self.enhancer: HParams = None
|
33 |
self.hparams: Enhancer = None
|
34 |
|
35 |
def load_model(self):
|
36 |
hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
|
37 |
+
enhancer = Enhancer(hparams).to(device=self.device, dtype=self.dtype).eval()
|
38 |
state_dict = torch.load(
|
39 |
Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
|
40 |
map_location=self.device,
|
41 |
)["module"]
|
42 |
enhancer.load_state_dict(state_dict)
|
|
|
43 |
|
44 |
self.hparams = hparams
|
45 |
self.enhancer = enhancer
|
46 |
|
47 |
@torch.inference_mode()
|
48 |
+
def denoise(self, dwav, sr) -> tuple[torch.Tensor, int]:
|
49 |
assert self.enhancer is not None, "Model not loaded"
|
50 |
assert self.enhancer.denoiser is not None, "Denoiser not loaded"
|
51 |
enhancer = self.enhancer
|
52 |
+
return inference(
|
53 |
+
model=enhancer.denoiser,
|
54 |
+
dwav=dwav,
|
55 |
+
sr=sr,
|
56 |
+
device=self.devicem,
|
57 |
+
dtype=self.dtype,
|
58 |
+
)
|
59 |
|
60 |
@torch.inference_mode()
|
61 |
def enhance(
|
62 |
self,
|
63 |
dwav,
|
64 |
sr,
|
|
|
65 |
nfe=32,
|
66 |
solver: Literal["midpoint", "rk4", "euler"] = "midpoint",
|
67 |
lambd=0.5,
|
|
|
78 |
assert self.enhancer is not None, "Model not loaded"
|
79 |
enhancer = self.enhancer
|
80 |
enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
|
81 |
+
return inference(
|
82 |
+
model=enhancer, dwav=dwav, sr=sr, device=self.device, dtype=self.dtype
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
def load_enhancer() -> ResembleEnhance:
|
87 |
+
global resemble_enhance
|
88 |
+
with lock:
|
89 |
+
if resemble_enhance is None:
|
90 |
+
logger.info("Loading ResembleEnhance model")
|
91 |
+
resemble_enhance = ResembleEnhance(
|
92 |
+
device=devices.device, dtype=devices.dtype
|
93 |
+
)
|
94 |
+
resemble_enhance.load_model()
|
95 |
+
logger.info("ResembleEnhance model loaded")
|
96 |
+
return resemble_enhance
|
97 |
+
|
98 |
+
|
99 |
+
def unload_enhancer():
|
100 |
+
global resemble_enhance
|
101 |
+
with lock:
|
102 |
+
if resemble_enhance is not None:
|
103 |
+
logger.info("Unloading ResembleEnhance model")
|
104 |
+
del resemble_enhance
|
105 |
+
resemble_enhance = None
|
106 |
+
devices.torch_gc()
|
107 |
+
gc.collect()
|
108 |
+
logger.info("ResembleEnhance model unloaded")
|
109 |
+
|
110 |
+
|
111 |
+
def reload_enhancer():
|
112 |
+
logger.info("Reloading ResembleEnhance model")
|
113 |
+
unload_enhancer()
|
114 |
+
load_enhancer()
|
115 |
+
logger.info("ResembleEnhance model reloaded")
|
116 |
+
|
117 |
+
|
118 |
+
def apply_audio_enhance_full(
|
119 |
+
audio_data: np.ndarray,
|
120 |
+
sr: int,
|
121 |
+
nfe=32,
|
122 |
+
solver: Literal["midpoint", "rk4", "euler"] = "midpoint",
|
123 |
+
lambd=0.5,
|
124 |
+
tau=0.5,
|
125 |
+
):
|
126 |
+
# FIXME: 这里可能改成 to(device) 会优化一点?
|
127 |
+
tensor = torch.from_numpy(audio_data).float().squeeze().cpu()
|
128 |
+
enhancer = load_enhancer()
|
129 |
+
|
130 |
+
tensor, sr = enhancer.enhance(
|
131 |
+
tensor, sr, tau=tau, nfe=nfe, solver=solver, lambd=lambd
|
132 |
+
)
|
133 |
+
|
134 |
+
audio_data = tensor.cpu().numpy()
|
135 |
+
return audio_data, int(sr)
|
136 |
+
|
137 |
+
|
138 |
+
def apply_audio_enhance(
|
139 |
+
audio_data: np.ndarray, sr: int, enable_denoise: bool, enable_enhance: bool
|
140 |
+
):
|
141 |
+
if not enable_denoise and not enable_enhance:
|
142 |
+
return audio_data, sr
|
143 |
+
|
144 |
+
# FIXME: 这里可能改成 to(device) 会优化一点?
|
145 |
+
tensor = torch.from_numpy(audio_data).float().squeeze().cpu()
|
146 |
+
enhancer = load_enhancer()
|
147 |
+
|
148 |
+
if enable_enhance or enable_denoise:
|
149 |
+
lambd = 0.9 if enable_denoise else 0.1
|
150 |
+
tensor, sr = enhancer.enhance(
|
151 |
+
tensor, sr, tau=0.5, nfe=64, solver="rk4", lambd=lambd
|
152 |
+
)
|
153 |
+
|
154 |
+
audio_data = tensor.cpu().numpy()
|
155 |
+
return audio_data, int(sr)
|
156 |
|
157 |
|
158 |
if __name__ == "__main__":
|
modules/api/Api.py
CHANGED
@@ -24,7 +24,7 @@ def is_excluded(path, exclude_patterns):
|
|
24 |
|
25 |
|
26 |
class APIManager:
|
27 |
-
def __init__(self, app: FastAPI,
|
28 |
self.app = app
|
29 |
self.registered_apis = {}
|
30 |
self.logger = logging.getLogger(__name__)
|
|
|
24 |
|
25 |
|
26 |
class APIManager:
|
27 |
+
def __init__(self, app: FastAPI, exclude_patterns=[]):
|
28 |
self.app = app
|
29 |
self.registered_apis = {}
|
30 |
self.logger = logging.getLogger(__name__)
|
modules/api/api_setup.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from modules.devices import devices
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from modules import config
|
7 |
+
from modules.utils import env
|
8 |
+
from modules import generate_audio
|
9 |
+
from modules.api.Api import APIManager
|
10 |
+
|
11 |
+
from modules.api.impl import (
|
12 |
+
style_api,
|
13 |
+
tts_api,
|
14 |
+
ssml_api,
|
15 |
+
google_api,
|
16 |
+
openai_api,
|
17 |
+
refiner_api,
|
18 |
+
speaker_api,
|
19 |
+
ping_api,
|
20 |
+
models_api,
|
21 |
+
)
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
def create_api(app, exclude=[]):
|
27 |
+
app_mgr = APIManager(app=app, exclude_patterns=exclude)
|
28 |
+
|
29 |
+
ping_api.setup(app_mgr)
|
30 |
+
models_api.setup(app_mgr)
|
31 |
+
style_api.setup(app_mgr)
|
32 |
+
speaker_api.setup(app_mgr)
|
33 |
+
tts_api.setup(app_mgr)
|
34 |
+
ssml_api.setup(app_mgr)
|
35 |
+
google_api.setup(app_mgr)
|
36 |
+
openai_api.setup(app_mgr)
|
37 |
+
refiner_api.setup(app_mgr)
|
38 |
+
|
39 |
+
return app_mgr
|
40 |
+
|
41 |
+
|
42 |
+
def setup_model_args(parser: argparse.ArgumentParser):
|
43 |
+
parser.add_argument("--compile", action="store_true", help="Enable model compile")
|
44 |
+
parser.add_argument(
|
45 |
+
"--half",
|
46 |
+
action="store_true",
|
47 |
+
help="Enable half precision for model inference",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--off_tqdm",
|
51 |
+
action="store_true",
|
52 |
+
help="Disable tqdm progress bar",
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--device_id",
|
56 |
+
type=str,
|
57 |
+
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
|
58 |
+
default=None,
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--use_cpu",
|
62 |
+
nargs="+",
|
63 |
+
help="use CPU as torch device for specified modules",
|
64 |
+
default=[],
|
65 |
+
type=str.lower,
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--lru_size",
|
69 |
+
type=int,
|
70 |
+
default=64,
|
71 |
+
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--debug_generate",
|
75 |
+
action="store_true",
|
76 |
+
help="Enable debug mode for audio generation",
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
def process_model_args(args):
|
81 |
+
lru_size = env.get_and_update_env(args, "lru_size", 64, int)
|
82 |
+
compile = env.get_and_update_env(args, "compile", False, bool)
|
83 |
+
device_id = env.get_and_update_env(args, "device_id", None, str)
|
84 |
+
use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
|
85 |
+
half = env.get_and_update_env(args, "half", False, bool)
|
86 |
+
off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
|
87 |
+
debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
|
88 |
+
|
89 |
+
generate_audio.setup_lru_cache()
|
90 |
+
devices.reset_device()
|
91 |
+
devices.first_time_calculation()
|
92 |
+
|
93 |
+
if debug_generate:
|
94 |
+
generate_audio.logger.setLevel(logging.DEBUG)
|
95 |
+
|
96 |
+
|
97 |
+
def setup_uvicon_args(parser: argparse.ArgumentParser):
|
98 |
+
parser.add_argument("--host", type=str, help="Host to run the server on")
|
99 |
+
parser.add_argument("--port", type=int, help="Port to run the server on")
|
100 |
+
parser.add_argument(
|
101 |
+
"--reload", action="store_true", help="Enable auto-reload for development"
|
102 |
+
)
|
103 |
+
parser.add_argument("--workers", type=int, help="Number of worker processes")
|
104 |
+
parser.add_argument("--log_level", type=str, help="Log level")
|
105 |
+
parser.add_argument("--access_log", action="store_true", help="Enable access log")
|
106 |
+
parser.add_argument(
|
107 |
+
"--proxy_headers", action="store_true", help="Enable proxy headers"
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--timeout_keep_alive", type=int, help="Keep-alive timeout duration"
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--timeout_graceful_shutdown",
|
114 |
+
type=int,
|
115 |
+
help="Graceful shutdown timeout duration",
|
116 |
+
)
|
117 |
+
parser.add_argument("--ssl_keyfile", type=str, help="SSL key file path")
|
118 |
+
parser.add_argument("--ssl_certfile", type=str, help="SSL certificate file path")
|
119 |
+
parser.add_argument(
|
120 |
+
"--ssl_keyfile_password", type=str, help="SSL key file password"
|
121 |
+
)
|
122 |
+
|
123 |
+
|
124 |
+
def setup_api_args(parser: argparse.ArgumentParser):
|
125 |
+
parser.add_argument(
|
126 |
+
"--cors_origin",
|
127 |
+
type=str,
|
128 |
+
help="Allowed CORS origins. Use '*' to allow all origins.",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--no_playground",
|
132 |
+
action="store_true",
|
133 |
+
help="Disable the playground entry",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--no_docs",
|
137 |
+
action="store_true",
|
138 |
+
help="Disable the documentation entry",
|
139 |
+
)
|
140 |
+
# 配置哪些api要跳过 比如 exclude="/v1/speakers/*,/v1/tts/*"
|
141 |
+
parser.add_argument(
|
142 |
+
"--exclude",
|
143 |
+
type=str,
|
144 |
+
help="Exclude the specified API from the server",
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
def process_api_args(args, app):
|
149 |
+
cors_origin = env.get_and_update_env(args, "cors_origin", "*", str)
|
150 |
+
no_playground = env.get_and_update_env(args, "no_playground", False, bool)
|
151 |
+
no_docs = env.get_and_update_env(args, "no_docs", False, bool)
|
152 |
+
exclude = env.get_and_update_env(args, "exclude", "", str)
|
153 |
+
|
154 |
+
api = create_api(app=app, exclude=exclude.split(","))
|
155 |
+
config.api = api
|
156 |
+
|
157 |
+
if cors_origin:
|
158 |
+
api.set_cors(allow_origins=[cors_origin])
|
159 |
+
|
160 |
+
if not no_playground:
|
161 |
+
api.setup_playground()
|
162 |
+
|
163 |
+
if compile:
|
164 |
+
logger.info("Model compile is enabled")
|
modules/api/app_config.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
app_description = """
|
2 |
+
ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
|
3 |
+
ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
|
4 |
+
|
5 |
+
项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
|
6 |
+
|
7 |
+
> 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
|
8 |
+
> All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging
|
9 |
+
|
10 |
+
> 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
|
11 |
+
> [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
|
12 |
+
"""
|
13 |
+
app_title = "ChatTTS Forge API"
|
14 |
+
app_version = "0.1.0"
|
modules/api/impl/google_api.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import base64
|
|
|
2 |
from fastapi import HTTPException
|
3 |
|
4 |
import io
|
@@ -6,7 +7,12 @@ import soundfile as sf
|
|
6 |
from pydantic import BaseModel
|
7 |
|
8 |
|
|
|
|
|
|
|
|
|
9 |
from modules.api.Api import APIManager
|
|
|
10 |
from modules.utils.audio import apply_prosody_to_audio_data
|
11 |
from modules.normalization import text_normalize
|
12 |
|
@@ -44,15 +50,25 @@ class AudioConfig(BaseModel):
|
|
44 |
speakingRate: float = 1
|
45 |
pitch: float = 0
|
46 |
volumeGainDb: float = 0
|
47 |
-
sampleRateHertz: int
|
48 |
batchSize: int = 1
|
49 |
spliterThreshold: int = 100
|
50 |
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
class GoogleTextSynthesizeRequest(BaseModel):
|
53 |
input: SynthesisInput
|
54 |
voice: VoiceSelectionParams
|
55 |
-
audioConfig:
|
|
|
56 |
|
57 |
|
58 |
class GoogleTextSynthesizeResponse(BaseModel):
|
@@ -63,6 +79,7 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
63 |
input = request.input
|
64 |
voice = request.voice
|
65 |
audioConfig = request.audioConfig
|
|
|
66 |
|
67 |
# 提取参数
|
68 |
|
@@ -70,40 +87,41 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
70 |
language_code = voice.languageCode
|
71 |
voice_name = voice.name
|
72 |
infer_seed = voice.seed or 42
|
73 |
-
audio_format = audioConfig.
|
74 |
-
speaking_rate = audioConfig.
|
75 |
-
pitch = audioConfig.
|
76 |
-
volume_gain_db = audioConfig.
|
77 |
|
78 |
-
batch_size = audioConfig.
|
79 |
|
80 |
# TODO spliter_threshold
|
81 |
-
spliter_threshold = audioConfig.
|
82 |
|
83 |
# TODO sample_rate
|
84 |
-
sample_rate_hertz = audioConfig.
|
85 |
|
86 |
params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
|
87 |
|
88 |
-
# TODO maybe need to change the sample rate
|
89 |
-
sample_rate = 24000
|
90 |
-
|
91 |
# 虽然 calc_spk_style 可以解析 seed 形式,但是这个接口只准备支持 speakers list 中存在的 speaker
|
92 |
if speaker_mgr.get_speaker(voice_name) is None:
|
93 |
raise HTTPException(
|
94 |
-
status_code=
|
95 |
)
|
96 |
|
97 |
if audio_format != "mp3" and audio_format != "wav":
|
98 |
raise HTTPException(
|
99 |
-
status_code=
|
100 |
)
|
101 |
|
|
|
|
|
|
|
|
|
102 |
try:
|
103 |
if input.text:
|
104 |
# 处理文本合成逻辑
|
105 |
text = text_normalize(input.text, is_end=True)
|
106 |
-
sample_rate, audio_data =
|
107 |
text,
|
108 |
temperature=(
|
109 |
voice.temperature
|
@@ -117,6 +135,8 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
117 |
prompt1=params.get("prompt1", ""),
|
118 |
prompt2=params.get("prompt2", ""),
|
119 |
prefix=params.get("prefix", ""),
|
|
|
|
|
120 |
)
|
121 |
|
122 |
elif input.ssml:
|
@@ -128,7 +148,7 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
128 |
|
129 |
if len(segments) == 0:
|
130 |
raise HTTPException(
|
131 |
-
status_code=
|
132 |
)
|
133 |
|
134 |
synthesize = SynthesizeSegments(batch_size=batch_size)
|
@@ -144,7 +164,17 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
144 |
|
145 |
else:
|
146 |
raise HTTPException(
|
147 |
-
status_code=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
)
|
149 |
|
150 |
audio_data = apply_prosody_to_audio_data(
|
|
|
1 |
import base64
|
2 |
+
from typing import Literal
|
3 |
from fastapi import HTTPException
|
4 |
|
5 |
import io
|
|
|
7 |
from pydantic import BaseModel
|
8 |
|
9 |
|
10 |
+
from modules.Enhancer.ResembleEnhance import (
|
11 |
+
apply_audio_enhance,
|
12 |
+
apply_audio_enhance_full,
|
13 |
+
)
|
14 |
from modules.api.Api import APIManager
|
15 |
+
from modules.synthesize_audio import synthesize_audio
|
16 |
from modules.utils.audio import apply_prosody_to_audio_data
|
17 |
from modules.normalization import text_normalize
|
18 |
|
|
|
50 |
speakingRate: float = 1
|
51 |
pitch: float = 0
|
52 |
volumeGainDb: float = 0
|
53 |
+
sampleRateHertz: int = 24000
|
54 |
batchSize: int = 1
|
55 |
spliterThreshold: int = 100
|
56 |
|
57 |
|
58 |
+
class EnhancerConfig(BaseModel):
|
59 |
+
enabled: bool = False
|
60 |
+
model: str = "resemble-enhance"
|
61 |
+
nfe: int = 32
|
62 |
+
solver: Literal["midpoint", "rk4", "euler"] = "midpoint"
|
63 |
+
lambd: float = 0.5
|
64 |
+
tau: float = 0.5
|
65 |
+
|
66 |
+
|
67 |
class GoogleTextSynthesizeRequest(BaseModel):
|
68 |
input: SynthesisInput
|
69 |
voice: VoiceSelectionParams
|
70 |
+
audioConfig: AudioConfig
|
71 |
+
enhancerConfig: EnhancerConfig = None
|
72 |
|
73 |
|
74 |
class GoogleTextSynthesizeResponse(BaseModel):
|
|
|
79 |
input = request.input
|
80 |
voice = request.voice
|
81 |
audioConfig = request.audioConfig
|
82 |
+
enhancerConfig = request.enhancerConfig
|
83 |
|
84 |
# 提取参数
|
85 |
|
|
|
87 |
language_code = voice.languageCode
|
88 |
voice_name = voice.name
|
89 |
infer_seed = voice.seed or 42
|
90 |
+
audio_format = audioConfig.audioEncoding or "mp3"
|
91 |
+
speaking_rate = audioConfig.speakingRate or 1
|
92 |
+
pitch = audioConfig.pitch or 0
|
93 |
+
volume_gain_db = audioConfig.volumeGainDb or 0
|
94 |
|
95 |
+
batch_size = audioConfig.batchSize or 1
|
96 |
|
97 |
# TODO spliter_threshold
|
98 |
+
spliter_threshold = audioConfig.spliterThreshold or 100
|
99 |
|
100 |
# TODO sample_rate
|
101 |
+
sample_rate_hertz = audioConfig.sampleRateHertz or 24000
|
102 |
|
103 |
params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
|
104 |
|
|
|
|
|
|
|
105 |
# 虽然 calc_spk_style 可以解析 seed 形式,但是这个接口只准备支持 speakers list 中存在的 speaker
|
106 |
if speaker_mgr.get_speaker(voice_name) is None:
|
107 |
raise HTTPException(
|
108 |
+
status_code=422, detail="The specified voice name is not supported."
|
109 |
)
|
110 |
|
111 |
if audio_format != "mp3" and audio_format != "wav":
|
112 |
raise HTTPException(
|
113 |
+
status_code=422, detail="Invalid audio encoding format specified."
|
114 |
)
|
115 |
|
116 |
+
if enhancerConfig.enabled:
|
117 |
+
# TODO enhancer params checker
|
118 |
+
pass
|
119 |
+
|
120 |
try:
|
121 |
if input.text:
|
122 |
# 处理文本合成逻辑
|
123 |
text = text_normalize(input.text, is_end=True)
|
124 |
+
sample_rate, audio_data = synthesize_audio(
|
125 |
text,
|
126 |
temperature=(
|
127 |
voice.temperature
|
|
|
135 |
prompt1=params.get("prompt1", ""),
|
136 |
prompt2=params.get("prompt2", ""),
|
137 |
prefix=params.get("prefix", ""),
|
138 |
+
batch_size=batch_size,
|
139 |
+
spliter_threshold=spliter_threshold,
|
140 |
)
|
141 |
|
142 |
elif input.ssml:
|
|
|
148 |
|
149 |
if len(segments) == 0:
|
150 |
raise HTTPException(
|
151 |
+
status_code=422, detail="The SSML text is empty or parsing failed."
|
152 |
)
|
153 |
|
154 |
synthesize = SynthesizeSegments(batch_size=batch_size)
|
|
|
164 |
|
165 |
else:
|
166 |
raise HTTPException(
|
167 |
+
status_code=422, detail="Either text or SSML input must be provided."
|
168 |
+
)
|
169 |
+
|
170 |
+
if enhancerConfig.enabled:
|
171 |
+
audio_data, sample_rate = apply_audio_enhance_full(
|
172 |
+
audio_data=audio_data,
|
173 |
+
sr=sample_rate,
|
174 |
+
nfe=enhancerConfig.nfe,
|
175 |
+
solver=enhancerConfig.solver,
|
176 |
+
lambd=enhancerConfig.lambd,
|
177 |
+
tau=enhancerConfig.tau,
|
178 |
)
|
179 |
|
180 |
audio_data = apply_prosody_to_audio_data(
|
modules/api/impl/models_api.py
CHANGED
@@ -1,11 +1,18 @@
|
|
|
|
1 |
from modules.api import utils as api_utils
|
2 |
from modules.api.Api import APIManager
|
3 |
-
from modules.models import reload_chat_tts
|
4 |
|
5 |
|
6 |
def setup(app: APIManager):
|
7 |
@app.get("/v1/models/reload", response_model=api_utils.BaseResponse)
|
8 |
async def reload_models():
|
9 |
-
# Reload models
|
10 |
reload_chat_tts()
|
|
|
11 |
return api_utils.success_response("Models reloaded")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.Enhancer.ResembleEnhance import reload_enhancer, unload_enhancer
|
2 |
from modules.api import utils as api_utils
|
3 |
from modules.api.Api import APIManager
|
4 |
+
from modules.models import reload_chat_tts, unload_chat_tts
|
5 |
|
6 |
|
7 |
def setup(app: APIManager):
|
8 |
@app.get("/v1/models/reload", response_model=api_utils.BaseResponse)
|
9 |
async def reload_models():
|
|
|
10 |
reload_chat_tts()
|
11 |
+
reload_enhancer()
|
12 |
return api_utils.success_response("Models reloaded")
|
13 |
+
|
14 |
+
@app.get("/v1/models/unload", response_model=api_utils.BaseResponse)
|
15 |
+
async def reload_models():
|
16 |
+
unload_chat_tts()
|
17 |
+
unload_enhancer()
|
18 |
+
return api_utils.success_response("Models unloaded")
|
modules/api/worker.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import dotenv
|
5 |
+
from fastapi import FastAPI
|
6 |
+
|
7 |
+
from modules.ffmpeg_env import setup_ffmpeg_path
|
8 |
+
|
9 |
+
setup_ffmpeg_path()
|
10 |
+
logging.basicConfig(
|
11 |
+
level=os.getenv("LOG_LEVEL", "INFO"),
|
12 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
13 |
+
)
|
14 |
+
|
15 |
+
from modules.api.api_setup import (
|
16 |
+
process_api_args,
|
17 |
+
process_model_args,
|
18 |
+
setup_api_args,
|
19 |
+
setup_model_args,
|
20 |
+
setup_uvicon_args,
|
21 |
+
)
|
22 |
+
from modules.api.app_config import app_description, app_title, app_version
|
23 |
+
from modules import config
|
24 |
+
from modules.utils.torch_opt import configure_torch_optimizations
|
25 |
+
|
26 |
+
dotenv.load_dotenv(
|
27 |
+
dotenv_path=os.getenv("ENV_FILE", ".env.api"),
|
28 |
+
)
|
29 |
+
parser = argparse.ArgumentParser(
|
30 |
+
description="Start the FastAPI server with command line arguments"
|
31 |
+
)
|
32 |
+
setup_api_args(parser)
|
33 |
+
setup_model_args(parser)
|
34 |
+
setup_uvicon_args(parser)
|
35 |
+
|
36 |
+
args = parser.parse_args()
|
37 |
+
|
38 |
+
app = FastAPI(
|
39 |
+
title=app_title,
|
40 |
+
description=app_description,
|
41 |
+
version=app_version,
|
42 |
+
redoc_url=None if config.runtime_env_vars.no_docs else "/redoc",
|
43 |
+
docs_url=None if config.runtime_env_vars.no_docs else "/docs",
|
44 |
+
)
|
45 |
+
|
46 |
+
process_model_args(args)
|
47 |
+
process_api_args(args, app)
|
48 |
+
|
49 |
+
configure_torch_optimizations()
|
modules/config.py
CHANGED
@@ -3,7 +3,7 @@ import sys
|
|
3 |
import torch
|
4 |
from modules.utils.JsonObject import JsonObject
|
5 |
|
6 |
-
from modules.utils import git
|
7 |
|
8 |
# TODO impl RuntimeEnvVars() class
|
9 |
runtime_env_vars = JsonObject({})
|
@@ -20,5 +20,6 @@ versions = JsonObject(
|
|
20 |
"git_tag": git.git_tag(),
|
21 |
"git_branch": git.branch_name(),
|
22 |
"git_commit": git.commit_hash(),
|
|
|
23 |
}
|
24 |
)
|
|
|
3 |
import torch
|
4 |
from modules.utils.JsonObject import JsonObject
|
5 |
|
6 |
+
from modules.utils import git, ffmpeg
|
7 |
|
8 |
# TODO impl RuntimeEnvVars() class
|
9 |
runtime_env_vars = JsonObject({})
|
|
|
20 |
"git_tag": git.git_tag(),
|
21 |
"git_branch": git.branch_name(),
|
22 |
"git_commit": git.commit_hash(),
|
23 |
+
"ffmpeg_version": ffmpeg.ffmpeg_version(),
|
24 |
}
|
25 |
)
|
modules/ffmpeg_env.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from modules.utils.constants import ROOT_DIR
|
3 |
+
import logging
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
|
8 |
+
def setup_ffmpeg_path():
|
9 |
+
ffmpeg_path = os.path.join(ROOT_DIR, "ffmpeg")
|
10 |
+
os.environ["PATH"] = ffmpeg_path + os.pathsep + os.environ["PATH"]
|
11 |
+
|
12 |
+
import pydub.utils
|
13 |
+
|
14 |
+
if pydub.utils.which("ffmpeg") is None:
|
15 |
+
logger.error("ffmpeg not found in PATH")
|
16 |
+
raise Exception("ffmpeg not found in PATH")
|
modules/generate_audio.py
CHANGED
@@ -74,10 +74,10 @@ def generate_audio_batch(
|
|
74 |
if isinstance(spk, int):
|
75 |
with SeedContext(spk, True):
|
76 |
params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
|
77 |
-
logger.
|
78 |
elif isinstance(spk, Speaker):
|
79 |
params_infer_code["spk_emb"] = spk.emb
|
80 |
-
logger.
|
81 |
else:
|
82 |
logger.warn(
|
83 |
f"spk must be int or Speaker, but: <{type(spk)}> {spk}, wiil set to default voice"
|
@@ -85,7 +85,7 @@ def generate_audio_batch(
|
|
85 |
with SeedContext(2, True):
|
86 |
params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
|
87 |
|
88 |
-
logger.
|
89 |
{
|
90 |
"text": texts,
|
91 |
"infer_seed": infer_seed,
|
|
|
74 |
if isinstance(spk, int):
|
75 |
with SeedContext(spk, True):
|
76 |
params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
|
77 |
+
logger.debug(("spk", spk))
|
78 |
elif isinstance(spk, Speaker):
|
79 |
params_infer_code["spk_emb"] = spk.emb
|
80 |
+
logger.debug(("spk", spk.name))
|
81 |
else:
|
82 |
logger.warn(
|
83 |
f"spk must be int or Speaker, but: <{type(spk)}> {spk}, wiil set to default voice"
|
|
|
85 |
with SeedContext(2, True):
|
86 |
params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
|
87 |
|
88 |
+
logger.debug(
|
89 |
{
|
90 |
"text": texts,
|
91 |
"infer_seed": infer_seed,
|
modules/gradio_dcls_fix.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
def dcls_patch():
|
2 |
from gradio import data_classes
|
3 |
|
|
|
4 |
data_classes.PredictBody.__get_pydantic_json_schema__ = lambda x, y: {
|
5 |
"type": "object",
|
6 |
}
|
|
|
1 |
def dcls_patch():
|
2 |
from gradio import data_classes
|
3 |
|
4 |
+
# https://github.com/gradio-app/gradio/pull/8530
|
5 |
data_classes.PredictBody.__get_pydantic_json_schema__ = lambda x, y: {
|
6 |
"type": "object",
|
7 |
}
|
modules/models.py
CHANGED
@@ -55,10 +55,9 @@ def unload_chat_tts():
|
|
55 |
if isinstance(model, torch.nn.Module):
|
56 |
model.cpu()
|
57 |
del model
|
58 |
-
if torch.cuda.is_available():
|
59 |
-
torch.cuda.empty_cache()
|
60 |
-
gc.collect()
|
61 |
chat_tts = None
|
|
|
|
|
62 |
logger.info("ChatTTS models unloaded")
|
63 |
|
64 |
|
|
|
55 |
if isinstance(model, torch.nn.Module):
|
56 |
model.cpu()
|
57 |
del model
|
|
|
|
|
|
|
58 |
chat_tts = None
|
59 |
+
devices.torch_gc()
|
60 |
+
gc.collect()
|
61 |
logger.info("ChatTTS models unloaded")
|
62 |
|
63 |
|
modules/repos_static/resemble_enhance/denoiser/denoiser.py
CHANGED
@@ -65,7 +65,9 @@ class Denoiser(nn.Module):
|
|
65 |
x = x.cpu()
|
66 |
|
67 |
window = torch.hann_window(self.stft_cfg["win_length"], device=x.device)
|
68 |
-
s = torch.stft(
|
|
|
|
|
69 |
|
70 |
s = s[..., :-1] # (b f t)
|
71 |
|
@@ -106,6 +108,7 @@ class Denoiser(nn.Module):
|
|
106 |
if s.isnan().any():
|
107 |
logger.warning("NaN detected in ISTFT input.")
|
108 |
|
|
|
109 |
s = F.pad(s, (0, 1), "replicate") # (b f t+1)
|
110 |
|
111 |
window = torch.hann_window(self.stft_cfg["win_length"], device=s.device)
|
@@ -168,7 +171,9 @@ class Denoiser(nn.Module):
|
|
168 |
|
169 |
mag, cos, sin = self._stft(x) # (b 2f t)
|
170 |
mag_mask, sin_res, cos_res = self._predict(mag, cos, sin)
|
171 |
-
sep_mag, sep_cos, sep_sin = self._separate(
|
|
|
|
|
172 |
|
173 |
o = self._istft(sep_mag, sep_cos, sep_sin)
|
174 |
|
|
|
65 |
x = x.cpu()
|
66 |
|
67 |
window = torch.hann_window(self.stft_cfg["win_length"], device=x.device)
|
68 |
+
s = torch.stft(
|
69 |
+
x.float(), **self.stft_cfg, window=window, return_complex=True
|
70 |
+
) # (b f t+1)
|
71 |
|
72 |
s = s[..., :-1] # (b f t)
|
73 |
|
|
|
108 |
if s.isnan().any():
|
109 |
logger.warning("NaN detected in ISTFT input.")
|
110 |
|
111 |
+
s = s.to(torch.complex64)
|
112 |
s = F.pad(s, (0, 1), "replicate") # (b f t+1)
|
113 |
|
114 |
window = torch.hann_window(self.stft_cfg["win_length"], device=s.device)
|
|
|
171 |
|
172 |
mag, cos, sin = self._stft(x) # (b 2f t)
|
173 |
mag_mask, sin_res, cos_res = self._predict(mag, cos, sin)
|
174 |
+
sep_mag, sep_cos, sep_sin = self._separate(
|
175 |
+
mag, cos, sin, mag_mask, cos_res, sin_res
|
176 |
+
)
|
177 |
|
178 |
o = self._istft(sep_mag, sep_cos, sep_sin)
|
179 |
|
modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py
CHANGED
@@ -64,7 +64,12 @@ class IRMAE(nn.Module):
|
|
64 |
nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
|
65 |
*[ResBlock(hidden_dim) for _ in range(4)],
|
66 |
# Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
|
67 |
-
*[
|
|
|
|
|
|
|
|
|
|
|
68 |
nn.Tanh(),
|
69 |
)
|
70 |
|
@@ -92,9 +97,10 @@ class IRMAE(nn.Module):
|
|
92 |
self.stats = {}
|
93 |
self.stats["z_mean"] = z.mean().item()
|
94 |
self.stats["z_std"] = z.std().item()
|
95 |
-
|
96 |
-
self.stats["
|
97 |
-
self.stats["
|
|
|
98 |
return z
|
99 |
|
100 |
def decode(self, z):
|
|
|
64 |
nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
|
65 |
*[ResBlock(hidden_dim) for _ in range(4)],
|
66 |
# Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
|
67 |
+
*[
|
68 |
+
nn.Conv1d(
|
69 |
+
hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False
|
70 |
+
)
|
71 |
+
for i in range(num_irms)
|
72 |
+
],
|
73 |
nn.Tanh(),
|
74 |
)
|
75 |
|
|
|
97 |
self.stats = {}
|
98 |
self.stats["z_mean"] = z.mean().item()
|
99 |
self.stats["z_std"] = z.std().item()
|
100 |
+
z_float = z.float()
|
101 |
+
self.stats["z_abs_68"] = z_float.abs().quantile(0.6827).item()
|
102 |
+
self.stats["z_abs_95"] = z_float.abs().quantile(0.9545).item()
|
103 |
+
self.stats["z_abs_99"] = z_float.abs().quantile(0.9973).item()
|
104 |
return z
|
105 |
|
106 |
def decode(self, z):
|
modules/repos_static/resemble_enhance/inference.py
CHANGED
@@ -8,6 +8,8 @@ from torchaudio.functional import resample
|
|
8 |
from torchaudio.transforms import MelSpectrogram
|
9 |
from tqdm import trange
|
10 |
|
|
|
|
|
11 |
from .hparams import HParams
|
12 |
|
13 |
from modules import config
|
@@ -16,7 +18,14 @@ logger = logging.getLogger(__name__)
|
|
16 |
|
17 |
|
18 |
@torch.inference_mode()
|
19 |
-
def inference_chunk(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz"
|
21 |
del sr
|
22 |
|
@@ -24,10 +33,10 @@ def inference_chunk(model, dwav, sr, device, npad=441):
|
|
24 |
abs_max = dwav.abs().max().clamp(min=1e-7)
|
25 |
|
26 |
assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D"
|
27 |
-
dwav = dwav.to(device)
|
28 |
dwav = dwav / abs_max # Normalize
|
29 |
dwav = F.pad(dwav, (0, npad))
|
30 |
-
hwav = model(dwav[None])[0].cpu() # (T,)
|
31 |
hwav = hwav[:length] # Trim padding
|
32 |
hwav = hwav * abs_max # Unnormalize
|
33 |
|
@@ -60,6 +69,9 @@ def compute_offset(chunk1, chunk2, sr=44100):
|
|
60 |
f_max=sr // 2,
|
61 |
)
|
62 |
|
|
|
|
|
|
|
63 |
spec1 = mel_fn(chunk1).log1p()
|
64 |
spec2 = mel_fn(chunk2).log1p()
|
65 |
|
@@ -123,7 +135,13 @@ def remove_weight_norm_recursively(module):
|
|
123 |
|
124 |
|
125 |
def inference(
|
126 |
-
model,
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
):
|
128 |
if config.runtime_env_vars.off_tqdm:
|
129 |
trange = range
|
@@ -159,9 +177,11 @@ def inference(
|
|
159 |
|
160 |
chunks = []
|
161 |
for start in trange(0, dwav.shape[-1], hop_length):
|
162 |
-
|
163 |
-
|
164 |
)
|
|
|
|
|
165 |
|
166 |
hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1])
|
167 |
|
@@ -172,5 +192,6 @@ def inference(
|
|
172 |
logger.info(
|
173 |
f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz"
|
174 |
)
|
|
|
175 |
|
176 |
return hwav, sr
|
|
|
8 |
from torchaudio.transforms import MelSpectrogram
|
9 |
from tqdm import trange
|
10 |
|
11 |
+
from modules.devices import devices
|
12 |
+
|
13 |
from .hparams import HParams
|
14 |
|
15 |
from modules import config
|
|
|
18 |
|
19 |
|
20 |
@torch.inference_mode()
|
21 |
+
def inference_chunk(
|
22 |
+
model,
|
23 |
+
dwav: torch.Tensor,
|
24 |
+
sr: int,
|
25 |
+
device: torch.device,
|
26 |
+
dtype: torch.dtype,
|
27 |
+
npad=441,
|
28 |
+
) -> torch.Tensor:
|
29 |
assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz"
|
30 |
del sr
|
31 |
|
|
|
33 |
abs_max = dwav.abs().max().clamp(min=1e-7)
|
34 |
|
35 |
assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D"
|
36 |
+
dwav = dwav.to(device=device, dtype=dtype)
|
37 |
dwav = dwav / abs_max # Normalize
|
38 |
dwav = F.pad(dwav, (0, npad))
|
39 |
+
hwav: torch.Tensor = model(dwav[None])[0].cpu() # (T,)
|
40 |
hwav = hwav[:length] # Trim padding
|
41 |
hwav = hwav * abs_max # Unnormalize
|
42 |
|
|
|
69 |
f_max=sr // 2,
|
70 |
)
|
71 |
|
72 |
+
chunk1 = chunk1.float()
|
73 |
+
chunk2 = chunk2.float()
|
74 |
+
|
75 |
spec1 = mel_fn(chunk1).log1p()
|
76 |
spec2 = mel_fn(chunk2).log1p()
|
77 |
|
|
|
135 |
|
136 |
|
137 |
def inference(
|
138 |
+
model,
|
139 |
+
dwav,
|
140 |
+
sr,
|
141 |
+
device,
|
142 |
+
dtype,
|
143 |
+
chunk_seconds: float = 30.0,
|
144 |
+
overlap_seconds: float = 1.0,
|
145 |
):
|
146 |
if config.runtime_env_vars.off_tqdm:
|
147 |
trange = range
|
|
|
177 |
|
178 |
chunks = []
|
179 |
for start in trange(0, dwav.shape[-1], hop_length):
|
180 |
+
chunk_dwav = inference_chunk(
|
181 |
+
model, dwav[start : start + chunk_length], sr, device, dtype
|
182 |
)
|
183 |
+
chunks.append(chunk_dwav.cpu())
|
184 |
+
devices.torch_gc()
|
185 |
|
186 |
hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1])
|
187 |
|
|
|
192 |
logger.info(
|
193 |
f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz"
|
194 |
)
|
195 |
+
devices.torch_gc()
|
196 |
|
197 |
return hwav, sr
|
modules/speaker.py
CHANGED
@@ -104,7 +104,7 @@ class SpeakerManager:
|
|
104 |
if not os.path.exists(self.speaker_dir + fname):
|
105 |
del self.speakers[fname]
|
106 |
|
107 |
-
def list_speakers(self):
|
108 |
return list(self.speakers.values())
|
109 |
|
110 |
def create_speaker_from_seed(self, seed, name="", gender="", describe=""):
|
|
|
104 |
if not os.path.exists(self.speaker_dir + fname):
|
105 |
del self.speakers[fname]
|
106 |
|
107 |
+
def list_speakers(self) -> list[Speaker]:
|
108 |
return list(self.speakers.values())
|
109 |
|
110 |
def create_speaker_from_seed(self, seed, name="", gender="", describe=""):
|
modules/utils/env.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import os
|
2 |
|
|
|
|
|
3 |
|
4 |
def get_env_val(key, val_type):
|
5 |
env_val = os.getenv(key.upper())
|
@@ -27,3 +29,10 @@ def get_env_or_arg(args, arg_name, default, arg_type):
|
|
27 |
return env_val
|
28 |
|
29 |
return default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
|
3 |
+
from modules import config
|
4 |
+
|
5 |
|
6 |
def get_env_val(key, val_type):
|
7 |
env_val = os.getenv(key.upper())
|
|
|
29 |
return env_val
|
30 |
|
31 |
return default
|
32 |
+
|
33 |
+
|
34 |
+
def get_and_update_env(*args):
|
35 |
+
val = get_env_or_arg(*args)
|
36 |
+
key = args[1]
|
37 |
+
config.runtime_env_vars[key] = val
|
38 |
+
return val
|
modules/utils/ffmpeg.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
from functools import lru_cache
|
3 |
+
|
4 |
+
|
5 |
+
@lru_cache()
|
6 |
+
def ffmpeg_version():
|
7 |
+
try:
|
8 |
+
result = subprocess.check_output(
|
9 |
+
["ffmpeg", "-version"], shell=False, encoding="utf8"
|
10 |
+
)
|
11 |
+
version_info = result.split("\n")[0]
|
12 |
+
version_info = version_info.split("ffmpeg version")[1].strip()
|
13 |
+
version_info = version_info.split("Copyright")[0].strip()
|
14 |
+
return version_info
|
15 |
+
except Exception:
|
16 |
+
return "<none>"
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == "__main__":
|
20 |
+
print(ffmpeg_version())
|
modules/utils/git.py
CHANGED
@@ -3,23 +3,15 @@ import os
|
|
3 |
import subprocess
|
4 |
|
5 |
|
6 |
-
from modules.utils import constants
|
7 |
-
|
8 |
-
# 用于判断是否在hf spaces
|
9 |
-
try:
|
10 |
-
import spaces
|
11 |
-
except:
|
12 |
-
spaces = None
|
13 |
|
14 |
git = os.environ.get("GIT", "git")
|
15 |
|
16 |
-
in_hf_spaces = spaces is not None
|
17 |
-
|
18 |
|
19 |
@lru_cache()
|
20 |
def commit_hash():
|
21 |
try:
|
22 |
-
if
|
23 |
return "<hf>"
|
24 |
return subprocess.check_output(
|
25 |
[git, "-C", constants.ROOT_DIR, "rev-parse", "HEAD"],
|
@@ -33,7 +25,7 @@ def commit_hash():
|
|
33 |
@lru_cache()
|
34 |
def git_tag():
|
35 |
try:
|
36 |
-
if
|
37 |
return "<hf>"
|
38 |
return subprocess.check_output(
|
39 |
[git, "-C", constants.ROOT_DIR, "describe", "--tags"],
|
@@ -57,7 +49,7 @@ def git_tag():
|
|
57 |
@lru_cache()
|
58 |
def branch_name():
|
59 |
try:
|
60 |
-
if
|
61 |
return "<hf>"
|
62 |
return subprocess.check_output(
|
63 |
[git, "-C", constants.ROOT_DIR, "rev-parse", "--abbrev-ref", "HEAD"],
|
|
|
3 |
import subprocess
|
4 |
|
5 |
|
6 |
+
from modules.utils import constants, hf
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
git = os.environ.get("GIT", "git")
|
9 |
|
|
|
|
|
10 |
|
11 |
@lru_cache()
|
12 |
def commit_hash():
|
13 |
try:
|
14 |
+
if hf.is_spaces_env:
|
15 |
return "<hf>"
|
16 |
return subprocess.check_output(
|
17 |
[git, "-C", constants.ROOT_DIR, "rev-parse", "HEAD"],
|
|
|
25 |
@lru_cache()
|
26 |
def git_tag():
|
27 |
try:
|
28 |
+
if hf.is_spaces_env:
|
29 |
return "<hf>"
|
30 |
return subprocess.check_output(
|
31 |
[git, "-C", constants.ROOT_DIR, "describe", "--tags"],
|
|
|
49 |
@lru_cache()
|
50 |
def branch_name():
|
51 |
try:
|
52 |
+
if hf.is_spaces_env:
|
53 |
return "<hf>"
|
54 |
return subprocess.check_output(
|
55 |
[git, "-C", constants.ROOT_DIR, "rev-parse", "--abbrev-ref", "HEAD"],
|
modules/utils/hf.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 给huggingface space写的兼容代码
|
2 |
+
|
3 |
+
try:
|
4 |
+
import spaces
|
5 |
+
|
6 |
+
is_spaces_env = True
|
7 |
+
except:
|
8 |
+
|
9 |
+
class NoneSpaces:
|
10 |
+
def __init__(self):
|
11 |
+
pass
|
12 |
+
|
13 |
+
def GPU(self, fn):
|
14 |
+
return fn
|
15 |
+
|
16 |
+
spaces = NoneSpaces()
|
17 |
+
is_spaces_env = False
|
modules/utils/torch_opt.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def configure_torch_optimizations():
|
5 |
+
torch._dynamo.config.cache_size_limit = 64
|
6 |
+
torch._dynamo.config.suppress_errors = True
|
7 |
+
torch.set_float32_matmul_precision("high")
|
modules/webui/app.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
import logging
|
2 |
import os
|
3 |
|
4 |
-
import torch
|
5 |
import gradio as gr
|
6 |
|
7 |
from modules import config
|
8 |
-
from modules.webui import gradio_extensions,
|
9 |
|
10 |
from modules.webui.changelog_tab import create_changelog_tab
|
11 |
from modules.webui.localization_runtime import ENLocalizationVars, ZHLocalizationVars
|
@@ -24,10 +23,6 @@ def webui_init():
|
|
24 |
# fix: If the system proxy is enabled in the Windows system, you need to skip these
|
25 |
os.environ["NO_PROXY"] = "localhost,127.0.0.1,0.0.0.0"
|
26 |
|
27 |
-
torch._dynamo.config.cache_size_limit = 64
|
28 |
-
torch._dynamo.config.suppress_errors = True
|
29 |
-
torch.set_float32_matmul_precision("high")
|
30 |
-
|
31 |
if config.runtime_env_vars.language == "en":
|
32 |
webui_config.localization = ENLocalizationVars()
|
33 |
else:
|
@@ -43,6 +38,7 @@ def create_app_footer():
|
|
43 |
git_branch = os.environ.get("V_GIT_BRANCH") or config.versions.git_branch
|
44 |
python_version = config.versions.python_version
|
45 |
torch_version = config.versions.torch_version
|
|
|
46 |
|
47 |
config.versions.gradio_version = gradio_version
|
48 |
|
@@ -53,9 +49,10 @@ def create_app_footer():
|
|
53 |
footer_items.append(f"branch: `{git_branch}`")
|
54 |
footer_items.append(f"python: `{python_version}`")
|
55 |
footer_items.append(f"torch: `{torch_version}`")
|
|
|
56 |
|
57 |
if config.runtime_env_vars.api and not config.runtime_env_vars.no_docs:
|
58 |
-
footer_items.append(f"[
|
59 |
|
60 |
gr.Markdown(
|
61 |
" | ".join(footer_items),
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
|
|
|
4 |
import gradio as gr
|
5 |
|
6 |
from modules import config
|
7 |
+
from modules.webui import gradio_extensions, webui_config
|
8 |
|
9 |
from modules.webui.changelog_tab import create_changelog_tab
|
10 |
from modules.webui.localization_runtime import ENLocalizationVars, ZHLocalizationVars
|
|
|
23 |
# fix: If the system proxy is enabled in the Windows system, you need to skip these
|
24 |
os.environ["NO_PROXY"] = "localhost,127.0.0.1,0.0.0.0"
|
25 |
|
|
|
|
|
|
|
|
|
26 |
if config.runtime_env_vars.language == "en":
|
27 |
webui_config.localization = ENLocalizationVars()
|
28 |
else:
|
|
|
38 |
git_branch = os.environ.get("V_GIT_BRANCH") or config.versions.git_branch
|
39 |
python_version = config.versions.python_version
|
40 |
torch_version = config.versions.torch_version
|
41 |
+
ffmpeg_version = config.versions.ffmpeg_version
|
42 |
|
43 |
config.versions.gradio_version = gradio_version
|
44 |
|
|
|
49 |
footer_items.append(f"branch: `{git_branch}`")
|
50 |
footer_items.append(f"python: `{python_version}`")
|
51 |
footer_items.append(f"torch: `{torch_version}`")
|
52 |
+
footer_items.append(f"ffmpeg: `{ffmpeg_version}`")
|
53 |
|
54 |
if config.runtime_env_vars.api and not config.runtime_env_vars.no_docs:
|
55 |
+
footer_items.append(f"[api](/docs)")
|
56 |
|
57 |
gr.Markdown(
|
58 |
" | ".join(footer_items),
|
modules/webui/gradio_extensions.py
CHANGED
@@ -14,7 +14,7 @@ WEBUI_DIR_PATH = Path(os.path.dirname(os.path.realpath(__file__)))
|
|
14 |
|
15 |
|
16 |
def read_file(fp):
|
17 |
-
with open(WEBUI_DIR_PATH / fp, "r") as f:
|
18 |
return f.read()
|
19 |
|
20 |
|
|
|
14 |
|
15 |
|
16 |
def read_file(fp):
|
17 |
+
with open(WEBUI_DIR_PATH / fp, "r", encoding="utf-8") as f:
|
18 |
return f.read()
|
19 |
|
20 |
|
modules/webui/js/index.js
CHANGED
@@ -154,6 +154,7 @@ addObserverIfDesiredNodeAvailable(".toast-wrap", function (added) {
|
|
154 |
added.forEach(function (element) {
|
155 |
if (element.innerText.includes("Connection errored out.")) {
|
156 |
window.setTimeout(function () {
|
|
|
157 |
document.getElementById("reset_button")?.classList.remove("hidden");
|
158 |
document.getElementById("generate_button")?.classList.add("hidden");
|
159 |
document.getElementById("skip_button")?.classList.add("hidden");
|
|
|
154 |
added.forEach(function (element) {
|
155 |
if (element.innerText.includes("Connection errored out.")) {
|
156 |
window.setTimeout(function () {
|
157 |
+
// FIXME: 这几个button好像是没有...a1111里面的gradio版本和我们的不一样
|
158 |
document.getElementById("reset_button")?.classList.remove("hidden");
|
159 |
document.getElementById("generate_button")?.classList.add("hidden");
|
160 |
document.getElementById("skip_button")?.classList.add("hidden");
|
modules/webui/speaker/speaker_creator.py
CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
from modules.speaker import Speaker
|
4 |
from modules.utils.SeedContext import SeedContext
|
5 |
-
from modules.hf import spaces
|
6 |
from modules.models import load_chat_tts
|
7 |
from modules.utils.rng import np_rng
|
8 |
from modules.webui import webui_config
|
|
|
2 |
import torch
|
3 |
from modules.speaker import Speaker
|
4 |
from modules.utils.SeedContext import SeedContext
|
5 |
+
from modules.utils.hf import spaces
|
6 |
from modules.models import load_chat_tts
|
7 |
from modules.utils.rng import np_rng
|
8 |
from modules.webui import webui_config
|
modules/webui/speaker/speaker_editor.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from modules.speaker import Speaker
|
4 |
-
from modules.hf import spaces
|
5 |
from modules.webui import webui_config
|
6 |
from modules.webui.webui_utils import tts_generate
|
7 |
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from modules.speaker import Speaker
|
4 |
+
from modules.utils.hf import spaces
|
5 |
from modules.webui import webui_config
|
6 |
from modules.webui.webui_utils import tts_generate
|
7 |
|
modules/webui/speaker/speaker_merger.py
CHANGED
@@ -2,7 +2,7 @@ import io
|
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
|
5 |
-
from modules.hf import spaces
|
6 |
from modules.webui import webui_config, webui_utils
|
7 |
from modules.webui.webui_utils import get_speakers, tts_generate
|
8 |
from modules.speaker import speaker_mgr, Speaker
|
|
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
|
5 |
+
from modules.utils.hf import spaces
|
6 |
from modules.webui import webui_config, webui_utils
|
7 |
from modules.webui.webui_utils import get_speakers, tts_generate
|
8 |
from modules.speaker import speaker_mgr, Speaker
|
modules/webui/ssml/podcast_tab.py
CHANGED
@@ -4,68 +4,68 @@ import torch
|
|
4 |
|
5 |
from modules.normalization import text_normalize
|
6 |
from modules.webui import webui_utils
|
7 |
-
from modules.hf import spaces
|
8 |
|
9 |
podcast_default_case = [
|
10 |
[
|
11 |
1,
|
12 |
"female2",
|
13 |
"你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。 [lbreak]",
|
14 |
-
"
|
15 |
],
|
16 |
[
|
17 |
2,
|
18 |
"Alice",
|
19 |
"嗨,我特别期待这个话题!中华料理真的是博大精深。 [lbreak]",
|
20 |
-
"
|
21 |
],
|
22 |
[
|
23 |
3,
|
24 |
"Bob",
|
25 |
"没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。 [lbreak]",
|
26 |
-
"
|
27 |
],
|
28 |
[
|
29 |
4,
|
30 |
"female2",
|
31 |
"那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。 [lbreak]",
|
32 |
-
"
|
33 |
],
|
34 |
[
|
35 |
5,
|
36 |
"Alice",
|
37 |
"对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。 [lbreak]",
|
38 |
-
"
|
39 |
],
|
40 |
[
|
41 |
6,
|
42 |
"Bob",
|
43 |
"除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。 [lbreak]",
|
44 |
-
"
|
45 |
],
|
46 |
[
|
47 |
7,
|
48 |
"female2",
|
49 |
"对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。 [lbreak]",
|
50 |
-
"
|
51 |
],
|
52 |
[
|
53 |
8,
|
54 |
"Alice",
|
55 |
"还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。 [lbreak]",
|
56 |
-
"
|
57 |
],
|
58 |
[
|
59 |
9,
|
60 |
"Bob",
|
61 |
"不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。 [lbreak]",
|
62 |
-
"
|
63 |
],
|
64 |
[
|
65 |
10,
|
66 |
"female2",
|
67 |
"对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。 [lbreak]",
|
68 |
-
"
|
69 |
],
|
70 |
]
|
71 |
|
|
|
4 |
|
5 |
from modules.normalization import text_normalize
|
6 |
from modules.webui import webui_utils
|
7 |
+
from modules.utils.hf import spaces
|
8 |
|
9 |
podcast_default_case = [
|
10 |
[
|
11 |
1,
|
12 |
"female2",
|
13 |
"你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。 [lbreak]",
|
14 |
+
"podcast",
|
15 |
],
|
16 |
[
|
17 |
2,
|
18 |
"Alice",
|
19 |
"嗨,我特别期待这个话题!中华料理真的是博大精深。 [lbreak]",
|
20 |
+
"podcast",
|
21 |
],
|
22 |
[
|
23 |
3,
|
24 |
"Bob",
|
25 |
"没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。 [lbreak]",
|
26 |
+
"podcast",
|
27 |
],
|
28 |
[
|
29 |
4,
|
30 |
"female2",
|
31 |
"那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。 [lbreak]",
|
32 |
+
"podcast",
|
33 |
],
|
34 |
[
|
35 |
5,
|
36 |
"Alice",
|
37 |
"对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。 [lbreak]",
|
38 |
+
"podcast",
|
39 |
],
|
40 |
[
|
41 |
6,
|
42 |
"Bob",
|
43 |
"除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。 [lbreak]",
|
44 |
+
"podcast",
|
45 |
],
|
46 |
[
|
47 |
7,
|
48 |
"female2",
|
49 |
"对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。 [lbreak]",
|
50 |
+
"podcast",
|
51 |
],
|
52 |
[
|
53 |
8,
|
54 |
"Alice",
|
55 |
"还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。 [lbreak]",
|
56 |
+
"podcast",
|
57 |
],
|
58 |
[
|
59 |
9,
|
60 |
"Bob",
|
61 |
"不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。 [lbreak]",
|
62 |
+
"podcast",
|
63 |
],
|
64 |
[
|
65 |
10,
|
66 |
"female2",
|
67 |
"对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。 [lbreak]",
|
68 |
+
"podcast",
|
69 |
],
|
70 |
]
|
71 |
|
modules/webui/ssml/spliter_tab.py
CHANGED
@@ -7,7 +7,7 @@ from modules.webui.webui_utils import (
|
|
7 |
get_styles,
|
8 |
split_long_text,
|
9 |
)
|
10 |
-
from modules.hf import spaces
|
11 |
|
12 |
|
13 |
# NOTE: 因为 text_normalize 需要使用 tokenizer
|
|
|
7 |
get_styles,
|
8 |
split_long_text,
|
9 |
)
|
10 |
+
from modules.utils.hf import spaces
|
11 |
|
12 |
|
13 |
# NOTE: 因为 text_normalize 需要使用 tokenizer
|
modules/webui/webui_utils.py
CHANGED
@@ -2,10 +2,10 @@ import io
|
|
2 |
from typing import Union
|
3 |
import numpy as np
|
4 |
|
5 |
-
from modules.Enhancer.ResembleEnhance import
|
6 |
from modules.devices import devices
|
7 |
from modules.synthesize_audio import synthesize_audio
|
8 |
-
from modules.hf import spaces
|
9 |
from modules.webui import webui_config
|
10 |
|
11 |
import torch
|
@@ -85,22 +85,7 @@ def segments_length_limit(
|
|
85 |
@torch.inference_mode()
|
86 |
@spaces.GPU
|
87 |
def apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance):
|
88 |
-
|
89 |
-
return audio_data, sr
|
90 |
-
|
91 |
-
device = devices.device
|
92 |
-
# NOTE: 这里很奇怪按道理得放到 device 上,但是 enhancer 做 chunk 的时候会报错...所以得 cpu()
|
93 |
-
tensor = torch.from_numpy(audio_data).float().squeeze().cpu()
|
94 |
-
enhancer = load_enhancer(device)
|
95 |
-
|
96 |
-
if enable_enhance or enable_denoise:
|
97 |
-
lambd = 0.9 if enable_denoise else 0.1
|
98 |
-
tensor, sr = enhancer.enhance(
|
99 |
-
tensor, sr, tau=0.5, nfe=64, solver="rk4", lambd=lambd, device=device
|
100 |
-
)
|
101 |
-
|
102 |
-
audio_data = tensor.cpu().numpy()
|
103 |
-
return audio_data, int(sr)
|
104 |
|
105 |
|
106 |
@torch.inference_mode()
|
|
|
2 |
from typing import Union
|
3 |
import numpy as np
|
4 |
|
5 |
+
from modules.Enhancer.ResembleEnhance import apply_audio_enhance as _apply_audio_enhance
|
6 |
from modules.devices import devices
|
7 |
from modules.synthesize_audio import synthesize_audio
|
8 |
+
from modules.utils.hf import spaces
|
9 |
from modules.webui import webui_config
|
10 |
|
11 |
import torch
|
|
|
85 |
@torch.inference_mode()
|
86 |
@spaces.GPU
|
87 |
def apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance):
|
88 |
+
return _apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
|
91 |
@torch.inference_mode()
|
webui.py
CHANGED
@@ -1,21 +1,23 @@
|
|
1 |
import os
|
2 |
import logging
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
logging.basicConfig(
|
5 |
level=os.getenv("LOG_LEVEL", "INFO"),
|
6 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
7 |
)
|
8 |
|
9 |
-
from launch import (
|
10 |
-
get_and_update_env,
|
11 |
-
setup_api_args,
|
12 |
-
setup_model_args,
|
13 |
-
process_api_args,
|
14 |
-
process_model_args,
|
15 |
-
app_description,
|
16 |
-
app_title,
|
17 |
-
app_version,
|
18 |
-
)
|
19 |
from modules.webui import webui_config
|
20 |
from modules import config
|
21 |
from modules.webui.app import webui_init, create_interface
|
@@ -89,6 +91,7 @@ def process_webui_args(args):
|
|
89 |
webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
|
90 |
webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
|
91 |
|
|
|
92 |
webui_init()
|
93 |
demo = create_interface()
|
94 |
|
@@ -102,7 +105,7 @@ def process_webui_args(args):
|
|
102 |
debug=debug,
|
103 |
auth=auth,
|
104 |
show_api=False,
|
105 |
-
prevent_thread_lock=
|
106 |
app_kwargs={
|
107 |
"title": app_title,
|
108 |
"description": app_description,
|
@@ -119,6 +122,18 @@ def process_webui_args(args):
|
|
119 |
),
|
120 |
},
|
121 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
|
124 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
2 |
import logging
|
3 |
|
4 |
+
from modules.api.api_setup import (
|
5 |
+
process_api_args,
|
6 |
+
process_model_args,
|
7 |
+
setup_api_args,
|
8 |
+
setup_model_args,
|
9 |
+
)
|
10 |
+
from modules.ffmpeg_env import setup_ffmpeg_path
|
11 |
+
from modules.utils.env import get_and_update_env
|
12 |
+
from modules.api.app_config import app_description, app_title, app_version
|
13 |
+
from modules.utils.torch_opt import configure_torch_optimizations
|
14 |
+
|
15 |
+
setup_ffmpeg_path()
|
16 |
logging.basicConfig(
|
17 |
level=os.getenv("LOG_LEVEL", "INFO"),
|
18 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
19 |
)
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
from modules.webui import webui_config
|
22 |
from modules import config
|
23 |
from modules.webui.app import webui_init, create_interface
|
|
|
91 |
webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
|
92 |
webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
|
93 |
|
94 |
+
configure_torch_optimizations()
|
95 |
webui_init()
|
96 |
demo = create_interface()
|
97 |
|
|
|
105 |
debug=debug,
|
106 |
auth=auth,
|
107 |
show_api=False,
|
108 |
+
prevent_thread_lock=True,
|
109 |
app_kwargs={
|
110 |
"title": app_title,
|
111 |
"description": app_description,
|
|
|
122 |
),
|
123 |
},
|
124 |
)
|
125 |
+
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
|
126 |
+
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
|
127 |
+
# running web ui and do whatever the attacker wants, including installing an extension and
|
128 |
+
# running its code. We disable this here. Suggested by RyotaK.
|
129 |
+
app.user_middleware = [
|
130 |
+
x for x in app.user_middleware if x.cls.__name__ != "CustomCORSMiddleware"
|
131 |
+
]
|
132 |
+
|
133 |
+
if api:
|
134 |
+
process_api_args(args, app)
|
135 |
+
|
136 |
+
demo.block_thread()
|
137 |
|
138 |
|
139 |
if __name__ == "__main__":
|