chenjgtea
commited on
Commit
·
214ea91
1
Parent(s):
394c436
提交代码
Browse files- .gitignore +5 -0
- .idea/.gitignore +10 -0
- README.md +1 -1
- requirements.txt +29 -0
- test/__init__.py +0 -0
- test/api.py +72 -0
- test/common_test.py +24 -0
- tool/__init__.py +4 -0
- tool/av.py +79 -0
- tool/ctx.py +14 -0
- tool/func.py +35 -0
- tool/logger/__init__.py +1 -0
- tool/logger/log.py +73 -0
- tool/np.py +11 -0
- tool/pcm.py +21 -0
- web/__init__.py +0 -0
- web/app.py +246 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/.idea/misc.xml
|
2 |
+
/.idea/modules.xml
|
3 |
+
/.idea/inspectionProfiles/profiles_settings.xml
|
4 |
+
/.idea/inspectionProfiles/Project_Default.xml
|
5 |
+
/.idea/vcs.xml
|
.idea/.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Datasource local storage ignored files
|
7 |
+
/dataSources/
|
8 |
+
/dataSources.local.xml
|
9 |
+
/.idea/
|
10 |
+
/chat-tts.iml
|
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: blue
|
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.41.0
|
8 |
-
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.41.0
|
8 |
+
app_file: web\app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
requirements.txt
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PyTorch and related libraries
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
torchaudio
|
5 |
+
|
6 |
+
# Hugging Face transformers library
|
7 |
+
transformers
|
8 |
+
|
9 |
+
# Configuration management with OmegaConf
|
10 |
+
omegaconf
|
11 |
+
|
12 |
+
# Interactive widgets for Jupyter Notebooks
|
13 |
+
ipywidgets
|
14 |
+
|
15 |
+
# Gradio for creating web UIs
|
16 |
+
gradio
|
17 |
+
|
18 |
+
# Vector quantization for PyTorch
|
19 |
+
vector_quantize_pytorch
|
20 |
+
# Hugging Face Hub client
|
21 |
+
huggingface_hub
|
22 |
+
|
23 |
+
vocos
|
24 |
+
|
25 |
+
spaces
|
26 |
+
|
27 |
+
ChatTTS
|
28 |
+
|
29 |
+
av
|
test/__init__.py
ADDED
File without changes
|
test/api.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import necessary libraries and configure settings
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import ChatTTS
|
5 |
+
import os,sys
|
6 |
+
from common_test import *
|
7 |
+
|
8 |
+
now_dir = os.getcwd()
|
9 |
+
sys.path.append(now_dir)
|
10 |
+
from tool.logger import get_logger
|
11 |
+
|
12 |
+
|
13 |
+
torch._dynamo.config.cache_size_limit = 64
|
14 |
+
torch._dynamo.config.suppress_errors = True
|
15 |
+
torch.set_float32_matmul_precision('high')
|
16 |
+
|
17 |
+
logger= get_logger("api")
|
18 |
+
# Initialize and load the model:
|
19 |
+
chat = ChatTTS.Chat()
|
20 |
+
if chat.load(source="custom", custom_path="D:\\chenjgspace\\ai-model\\chattts",coef=None):
|
21 |
+
print("Models loaded successfully.")
|
22 |
+
else:
|
23 |
+
print("Models load failed.")
|
24 |
+
sys.exit(1)
|
25 |
+
|
26 |
+
# Define the text input for inference (Support Batching)
|
27 |
+
texts = [
|
28 |
+
"我真的不敢相信,他那么年轻武功居然这么好",
|
29 |
+
]
|
30 |
+
|
31 |
+
|
32 |
+
#使用随机种子数,会导致每次生成的音频文件都是随机的音色
|
33 |
+
rand_spk = chat.sample_random_speaker()
|
34 |
+
print(rand_spk) # save it for later timbre recovery
|
35 |
+
|
36 |
+
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
37 |
+
spk_emb = rand_spk, # add sampled speaker
|
38 |
+
temperature = .3, # using custom temperature
|
39 |
+
top_P = 0.7, # top P decode
|
40 |
+
top_K = 20, # top K decode
|
41 |
+
)
|
42 |
+
|
43 |
+
###################################
|
44 |
+
# For sentence level manual control.
|
45 |
+
|
46 |
+
# use oral_(0-9), laugh_(0-2), break_(0-7)
|
47 |
+
# to generate special token in text to synthesize.
|
48 |
+
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
49 |
+
prompt='[oral_2][laugh_0][break_6]',
|
50 |
+
)
|
51 |
+
|
52 |
+
wavs = chat.infer(
|
53 |
+
texts,
|
54 |
+
params_refine_text=params_refine_text,
|
55 |
+
params_infer_code=params_infer_code,
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
# Perform inference and play the generated audio
|
60 |
+
#wavs = chat.infer(texts)
|
61 |
+
#Audio(wavs[0], rate=24_000, autoplay=True)
|
62 |
+
|
63 |
+
# Save the generated audio
|
64 |
+
#torchaudio.save("D:\\Download\\output.wav", torch.from_numpy(wavs[0]), 24000)
|
65 |
+
prefix_name = "D:\\Download\\" + get_date_time()
|
66 |
+
|
67 |
+
for index, wav in enumerate(wavs):
|
68 |
+
save_mp3_file(wav, index, prefix_name)
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
test/common_test.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import time
|
3 |
+
import os,sys
|
4 |
+
|
5 |
+
now_dir = os.getcwd()
|
6 |
+
sys.path.append(now_dir)
|
7 |
+
from tool.logger import get_logger
|
8 |
+
|
9 |
+
logger=get_logger("common-test")
|
10 |
+
def save_mp3_file(wav, index, prefix_name):
|
11 |
+
from tool.pcm import pcm_arr_to_mp3_view
|
12 |
+
data = pcm_arr_to_mp3_view(wav)
|
13 |
+
mp3_filename = prefix_name + "_" + str(index) + ".mp3"
|
14 |
+
with open(mp3_filename, "wb") as f:
|
15 |
+
f.write(data)
|
16 |
+
logger.info(f"Audio saved to {mp3_filename}")
|
17 |
+
|
18 |
+
|
19 |
+
def get_date_time():
|
20 |
+
# 获取当前时间戳(秒级别)
|
21 |
+
current_timestamp = int(time.time())
|
22 |
+
# 将时间戳转换为datetime对象
|
23 |
+
current_datetime = datetime.datetime.fromtimestamp(current_timestamp)
|
24 |
+
return current_datetime.strftime("%Y-%m-%d-%H-%M-%S")
|
tool/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .av import load_audio
|
2 |
+
from .pcm import pcm_arr_to_mp3_view
|
3 |
+
from .np import float_to_int16
|
4 |
+
from .ctx import TorchSeedContext
|
tool/av.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BufferedWriter, BytesIO
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
import av
|
6 |
+
from av.audio.resampler import AudioResampler
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
video_format_dict: Dict[str, str] = {
|
11 |
+
"m4a": "mp4",
|
12 |
+
}
|
13 |
+
|
14 |
+
audio_format_dict: Dict[str, str] = {
|
15 |
+
"ogg": "libvorbis",
|
16 |
+
"mp4": "aac",
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def wav2(i: BytesIO, o: BufferedWriter, format: str):
|
21 |
+
"""
|
22 |
+
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L20
|
23 |
+
"""
|
24 |
+
inp = av.open(i, "r")
|
25 |
+
format = video_format_dict.get(format, format)
|
26 |
+
out = av.open(o, "w", format=format)
|
27 |
+
format = audio_format_dict.get(format, format)
|
28 |
+
|
29 |
+
ostream = out.add_stream(format)
|
30 |
+
|
31 |
+
for frame in inp.decode(audio=0):
|
32 |
+
for p in ostream.encode(frame):
|
33 |
+
out.mux(p)
|
34 |
+
|
35 |
+
for p in ostream.encode(None):
|
36 |
+
out.mux(p)
|
37 |
+
|
38 |
+
out.close()
|
39 |
+
inp.close()
|
40 |
+
|
41 |
+
|
42 |
+
def load_audio(file: str, sr: int) -> np.ndarray:
|
43 |
+
"""
|
44 |
+
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L39
|
45 |
+
"""
|
46 |
+
|
47 |
+
if not Path(file).exists():
|
48 |
+
raise FileNotFoundError(f"File not found: {file}")
|
49 |
+
|
50 |
+
try:
|
51 |
+
container = av.open(file)
|
52 |
+
resampler = AudioResampler(format="fltp", layout="mono", rate=sr)
|
53 |
+
|
54 |
+
# Estimated maximum total number of samples to pre-allocate the array
|
55 |
+
# AV stores length in microseconds by default
|
56 |
+
estimated_total_samples = int(container.duration * sr // 1_000_000)
|
57 |
+
decoded_audio = np.zeros(estimated_total_samples + 1, dtype=np.float32)
|
58 |
+
|
59 |
+
offset = 0
|
60 |
+
for frame in container.decode(audio=0):
|
61 |
+
frame.pts = None # Clear presentation timestamp to avoid resampling issues
|
62 |
+
resampled_frames = resampler.resample(frame)
|
63 |
+
for resampled_frame in resampled_frames:
|
64 |
+
frame_data = resampled_frame.to_ndarray()[0]
|
65 |
+
end_index = offset + len(frame_data)
|
66 |
+
|
67 |
+
# Check if decoded_audio has enough space, and resize if necessary
|
68 |
+
if end_index > decoded_audio.shape[0]:
|
69 |
+
decoded_audio = np.resize(decoded_audio, end_index + 1)
|
70 |
+
|
71 |
+
decoded_audio[offset:end_index] = frame_data
|
72 |
+
offset += len(frame_data)
|
73 |
+
|
74 |
+
# Truncate the array to the actual size
|
75 |
+
decoded_audio = decoded_audio[:offset]
|
76 |
+
except Exception as e:
|
77 |
+
raise RuntimeError(f"Failed to load audio: {e}")
|
78 |
+
|
79 |
+
return decoded_audio
|
tool/ctx.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class TorchSeedContext:
|
5 |
+
def __init__(self, seed):
|
6 |
+
self.seed = seed
|
7 |
+
self.state = None
|
8 |
+
|
9 |
+
def __enter__(self):
|
10 |
+
self.state = torch.random.get_rng_state()
|
11 |
+
torch.manual_seed(self.seed)
|
12 |
+
|
13 |
+
def __exit__(self, type, value, traceback):
|
14 |
+
torch.random.set_rng_state(self.state)
|
tool/func.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import random
|
4 |
+
|
5 |
+
seed_min = 1
|
6 |
+
seed_max = 4294967295
|
7 |
+
|
8 |
+
seeds = {
|
9 |
+
"旁白": {"seed": 4444},
|
10 |
+
"中年女性": {"seed": 7869},
|
11 |
+
"年轻女性": {"seed": 6615},
|
12 |
+
"中年男性": {"seed": 4099},
|
13 |
+
"年轻男性": {"seed": 6653},
|
14 |
+
}
|
15 |
+
|
16 |
+
# 音色选项:用于预置合适的音色
|
17 |
+
voices = {
|
18 |
+
"旁白": {"seed": 2},
|
19 |
+
"Timbre1": {"seed": 1111},
|
20 |
+
"Timbre2": {"seed": 2222},
|
21 |
+
"Timbre3": {"seed": 3333},
|
22 |
+
"Timbre4": {"seed": 4444},
|
23 |
+
"Timbre5": {"seed": 5555},
|
24 |
+
"Timbre6": {"seed": 6666},
|
25 |
+
"Timbre7": {"seed": 7777},
|
26 |
+
"Timbre8": {"seed": 8888},
|
27 |
+
"Timbre9": {"seed": 9999},
|
28 |
+
}
|
29 |
+
|
30 |
+
def on_voice_change(vocie_selection):
|
31 |
+
return voices.get(vocie_selection)["seed"]
|
32 |
+
|
33 |
+
|
34 |
+
def generate_seed():
|
35 |
+
return gr.update(value=random.randint(seed_min, seed_max))
|
tool/logger/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .log import get_logger
|
tool/logger/log.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform, sys
|
2 |
+
import logging
|
3 |
+
from datetime import datetime, timezone
|
4 |
+
|
5 |
+
logging.getLogger("numba").setLevel(logging.WARNING)
|
6 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
7 |
+
logging.getLogger("wetext-zh_normalizer").setLevel(logging.WARNING)
|
8 |
+
logging.getLogger("NeMo-text-processing").setLevel(logging.WARNING)
|
9 |
+
|
10 |
+
# from https://github.com/FloatTech/ZeroBot-Plugin/blob/c70766a989698452e60e5e48fb2f802a2444330d/console/console_windows.go#L89-L96
|
11 |
+
colorCodePanic = "\x1b[1;31m"
|
12 |
+
colorCodeFatal = "\x1b[1;31m"
|
13 |
+
colorCodeError = "\x1b[31m"
|
14 |
+
colorCodeWarn = "\x1b[33m"
|
15 |
+
colorCodeInfo = "\x1b[37m"
|
16 |
+
colorCodeDebug = "\x1b[32m"
|
17 |
+
colorCodeTrace = "\x1b[36m"
|
18 |
+
colorReset = "\x1b[0m"
|
19 |
+
|
20 |
+
log_level_color_code = {
|
21 |
+
logging.DEBUG: colorCodeDebug,
|
22 |
+
logging.INFO: colorCodeInfo,
|
23 |
+
logging.WARN: colorCodeWarn,
|
24 |
+
logging.ERROR: colorCodeError,
|
25 |
+
logging.FATAL: colorCodeFatal,
|
26 |
+
}
|
27 |
+
|
28 |
+
log_level_msg_str = {
|
29 |
+
logging.DEBUG: "DEBU",
|
30 |
+
logging.INFO: "INFO",
|
31 |
+
logging.WARN: "WARN",
|
32 |
+
logging.ERROR: "ERRO",
|
33 |
+
logging.FATAL: "FATL",
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
class Formatter(logging.Formatter):
|
38 |
+
def __init__(self, color=platform.system().lower() != "windows"):
|
39 |
+
# https://stackoverflow.com/questions/2720319/python-figure-out-local-timezone
|
40 |
+
self.tz = datetime.now(timezone.utc).astimezone().tzinfo
|
41 |
+
self.color = color
|
42 |
+
|
43 |
+
def format(self, record: logging.LogRecord):
|
44 |
+
logstr = "[" + datetime.now(self.tz).strftime("%z %Y%m%d %H:%M:%S") + "] ["
|
45 |
+
if self.color:
|
46 |
+
logstr += log_level_color_code.get(record.levelno, colorCodeInfo)
|
47 |
+
logstr += log_level_msg_str.get(record.levelno, record.levelname)
|
48 |
+
if self.color:
|
49 |
+
logstr += colorReset
|
50 |
+
if sys.version_info >= (3, 9):
|
51 |
+
fn = record.filename.removesuffix(".py")
|
52 |
+
elif record.filename.endswith(".py"):
|
53 |
+
fn = record.filename[:-3]
|
54 |
+
logstr += f"] {str(record.name)} | {fn} | {str(record.msg)%record.args}"
|
55 |
+
return logstr
|
56 |
+
|
57 |
+
|
58 |
+
def get_logger(name: str, lv=logging.INFO, remove_exist=False, format_root=False):
|
59 |
+
logger = logging.getLogger(name)
|
60 |
+
logger.setLevel(lv)
|
61 |
+
if remove_exist and logger.hasHandlers():
|
62 |
+
logger.handlers.clear()
|
63 |
+
if not logger.hasHandlers():
|
64 |
+
syslog = logging.StreamHandler()
|
65 |
+
syslog.setFormatter(Formatter())
|
66 |
+
logger.addHandler(syslog)
|
67 |
+
else:
|
68 |
+
for h in logger.handlers:
|
69 |
+
h.setFormatter(Formatter())
|
70 |
+
if format_root:
|
71 |
+
for h in logger.root.handlers:
|
72 |
+
h.setFormatter(Formatter())
|
73 |
+
return logger
|
tool/np.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from numba import jit
|
5 |
+
|
6 |
+
|
7 |
+
@jit
|
8 |
+
def float_to_int16(audio: np.ndarray) -> np.ndarray:
|
9 |
+
am = int(math.ceil(float(np.abs(audio).max())) * 32768)
|
10 |
+
am = 32767 * 32768 // am
|
11 |
+
return np.multiply(audio, am).astype(np.int16)
|
tool/pcm.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wave
|
2 |
+
from io import BytesIO
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from .np import float_to_int16
|
7 |
+
from .av import wav2
|
8 |
+
|
9 |
+
|
10 |
+
def pcm_arr_to_mp3_view(wav: np.ndarray):
|
11 |
+
buf = BytesIO()
|
12 |
+
with wave.open(buf, "wb") as wf:
|
13 |
+
wf.setnchannels(1) # Mono channel
|
14 |
+
wf.setsampwidth(2) # Sample width in bytes
|
15 |
+
wf.setframerate(24000) # Sample rate in Hz
|
16 |
+
wf.writeframes(float_to_int16(wav))
|
17 |
+
buf.seek(0, 0)
|
18 |
+
buf2 = BytesIO()
|
19 |
+
wav2(buf, buf2, "mp3")
|
20 |
+
buf.seek(0, 0)
|
21 |
+
return buf2.getbuffer()
|
web/__init__.py
ADDED
File without changes
|
web/app.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
if sys.platform == "darwin":
|
4 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
5 |
+
|
6 |
+
now_dir = os.getcwd()
|
7 |
+
sys.path.append(now_dir)
|
8 |
+
|
9 |
+
from tool.logger import get_logger
|
10 |
+
import ChatTTS
|
11 |
+
import argparse
|
12 |
+
import gradio as gr
|
13 |
+
from tool.func import *
|
14 |
+
from tool.ctx import TorchSeedContext
|
15 |
+
from tool.np import *
|
16 |
+
|
17 |
+
logger = get_logger("app")
|
18 |
+
|
19 |
+
# Initialize and load the model:
|
20 |
+
chat = ChatTTS.Chat()
|
21 |
+
|
22 |
+
|
23 |
+
def init_chat(args):
|
24 |
+
global chat
|
25 |
+
# 获取启动模式
|
26 |
+
MODEL = os.getenv('MODEL')
|
27 |
+
logger.info("loading ChatTTS model..., start MODEL:" + str(MODEL))
|
28 |
+
source = "custom"
|
29 |
+
# huggingface 部署模式下,模型则直接使用hf的模型数据
|
30 |
+
if MODEL == "HF":
|
31 |
+
source = "huggingface"
|
32 |
+
|
33 |
+
if chat.load(source=source, custom_path="D:\\chenjgspace\\ai-model\\chattts", coef=None):
|
34 |
+
print("Models loaded successfully.")
|
35 |
+
else:
|
36 |
+
print("Models load failed.")
|
37 |
+
sys.exit(1)
|
38 |
+
|
39 |
+
|
40 |
+
def main(args):
|
41 |
+
with gr.Blocks() as demo:
|
42 |
+
gr.Markdown("# ChatTTS demo")
|
43 |
+
with gr.Row():
|
44 |
+
with gr.Column(scale=1):
|
45 |
+
text_input = gr.Textbox(
|
46 |
+
label="转换内容",
|
47 |
+
lines=4,
|
48 |
+
max_lines=4,
|
49 |
+
placeholder="Please Input Text...",
|
50 |
+
value="柔柔的,浓浓的,痴痴的风,牵引起心底灵动的思潮;情愫悠悠,思情绵绵,风里默坐,红尘中的浅醉,诗词中的优柔,任那自在飞花轻似梦的情怀,裁一束霓衣,织就清浅淡薄的安寂。",
|
51 |
+
interactive=True,
|
52 |
+
)
|
53 |
+
with gr.Row():
|
54 |
+
refine_text_checkBox = gr.Checkbox(
|
55 |
+
label="是否优化文本,如是则先对文本内容做优化分词",
|
56 |
+
interactive=True,
|
57 |
+
value=True
|
58 |
+
)
|
59 |
+
temperature_slider = gr.Slider(
|
60 |
+
minimum=0.00001,
|
61 |
+
maximum=1.0,
|
62 |
+
step=0.00001,
|
63 |
+
value=0.3,
|
64 |
+
interactive=True,
|
65 |
+
label="模型 Temperature 参数设置"
|
66 |
+
)
|
67 |
+
top_p_slider = gr.Slider(
|
68 |
+
minimum=0.1,
|
69 |
+
maximum=0.9,
|
70 |
+
step=0.05,
|
71 |
+
value=0.7,
|
72 |
+
label="模型 top_P 参数设置",
|
73 |
+
interactive=True,
|
74 |
+
)
|
75 |
+
top_k_slider = gr.Slider(
|
76 |
+
minimum=1,
|
77 |
+
maximum=20,
|
78 |
+
step=1,
|
79 |
+
value=20,
|
80 |
+
label="模型 top_K 参数设置",
|
81 |
+
interactive=True,
|
82 |
+
)
|
83 |
+
with gr.Row():
|
84 |
+
voice_selection = gr.Dropdown(
|
85 |
+
label="Timbre",
|
86 |
+
choices=voices.keys(),
|
87 |
+
value="旁白",
|
88 |
+
interactive=True,
|
89 |
+
show_label=True
|
90 |
+
)
|
91 |
+
audio_seed_input = gr.Number(
|
92 |
+
value=2,
|
93 |
+
label="音色种子",
|
94 |
+
interactive=True,
|
95 |
+
minimum=seed_min,
|
96 |
+
maximum=seed_max,
|
97 |
+
)
|
98 |
+
generate_audio_seed = gr.Button("随机生成音色种子", interactive=True)
|
99 |
+
text_seed_input = gr.Number(
|
100 |
+
value=42,
|
101 |
+
label="文本种子",
|
102 |
+
interactive=True,
|
103 |
+
minimum=seed_min,
|
104 |
+
maximum=seed_max,
|
105 |
+
)
|
106 |
+
generate_text_seed = gr.Button("随机生成文本种子", interactive=True)
|
107 |
+
|
108 |
+
with gr.Row():
|
109 |
+
spk_emb_text = gr.Textbox(
|
110 |
+
label="Speaker Embedding",
|
111 |
+
max_lines=3,
|
112 |
+
show_copy_button=True,
|
113 |
+
interactive=False,
|
114 |
+
scale=2,
|
115 |
+
|
116 |
+
)
|
117 |
+
reload_chat_button = gr.Button("Reload", scale=1, interactive=True)
|
118 |
+
|
119 |
+
with gr.Row():
|
120 |
+
generate_button = gr.Button("生成音频文件", scale=1, interactive=True)
|
121 |
+
|
122 |
+
with gr.Row():
|
123 |
+
text_output = gr.Textbox(
|
124 |
+
label="输出文本",
|
125 |
+
interactive=False,
|
126 |
+
show_copy_button=True,
|
127 |
+
)
|
128 |
+
|
129 |
+
audio_output = gr.Audio(
|
130 |
+
label="输出音频",
|
131 |
+
value=None,
|
132 |
+
format="wav",
|
133 |
+
autoplay=False,
|
134 |
+
streaming=False,
|
135 |
+
interactive=False,
|
136 |
+
show_label=True,
|
137 |
+
waveform_options=gr.WaveformOptions(
|
138 |
+
sample_rate=24000,
|
139 |
+
),
|
140 |
+
)
|
141 |
+
# 针对页面元素新增 监听事件
|
142 |
+
voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input)
|
143 |
+
|
144 |
+
audio_seed_input.change(fn=on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text)
|
145 |
+
|
146 |
+
generate_audio_seed.click(fn=generate_seed, outputs=audio_seed_input)
|
147 |
+
|
148 |
+
generate_text_seed.click(fn=generate_seed,outputs=text_seed_input)
|
149 |
+
|
150 |
+
# reload_chat_button.click()
|
151 |
+
|
152 |
+
generate_button.click(fn=get_chat_infer_text,
|
153 |
+
inputs=[text_input,
|
154 |
+
text_seed_input,
|
155 |
+
refine_text_checkBox
|
156 |
+
],
|
157 |
+
outputs=[text_output]
|
158 |
+
).then(fn=get_chat_infer_audio,
|
159 |
+
inputs=[text_output,
|
160 |
+
temperature_slider,
|
161 |
+
top_p_slider,
|
162 |
+
top_k_slider,
|
163 |
+
audio_seed_input,
|
164 |
+
spk_emb_text
|
165 |
+
],
|
166 |
+
outputs=[audio_output])
|
167 |
+
# 初始化 spk_emb_text 数值
|
168 |
+
spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
|
169 |
+
logger.info("元素初始化完成,启动gradio服务=======")
|
170 |
+
|
171 |
+
# 运行gradio服务
|
172 |
+
demo.launch(
|
173 |
+
server_name=args.server_name,
|
174 |
+
server_port=args.server_port,
|
175 |
+
inbrowser=True,
|
176 |
+
show_api=False)
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
def get_chat_infer_audio(chat_txt,
|
181 |
+
temperature_slider,
|
182 |
+
top_p_slider,
|
183 |
+
top_k_slider,
|
184 |
+
audio_seed_input,
|
185 |
+
spk_emb_text):
|
186 |
+
logger.info("========开始生成音频文件=====")
|
187 |
+
#音频参数设置
|
188 |
+
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
189 |
+
spk_emb=spk_emb_text, # add sampled speaker
|
190 |
+
temperature=temperature_slider, # using custom temperature
|
191 |
+
top_P=top_p_slider, # top P decode
|
192 |
+
top_K=top_k_slider, # top K decode
|
193 |
+
)
|
194 |
+
|
195 |
+
with TorchSeedContext(audio_seed_input):
|
196 |
+
wav = chat.infer(
|
197 |
+
text=chat_txt,
|
198 |
+
skip_refine_text=True, #跳过文本优化
|
199 |
+
params_infer_code=params_infer_code,
|
200 |
+
)
|
201 |
+
yield 24000, float_to_int16(wav[0]).T
|
202 |
+
|
203 |
+
def get_chat_infer_text(text,seed,refine_text_checkBox):
|
204 |
+
|
205 |
+
logger.info("========开始优化文本内容=====")
|
206 |
+
global chat
|
207 |
+
if not refine_text_checkBox:
|
208 |
+
logger.info("========文本内容无需优化=====")
|
209 |
+
return text
|
210 |
+
|
211 |
+
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
212 |
+
prompt='[oral_2][laugh_0][break_6]',
|
213 |
+
)
|
214 |
+
|
215 |
+
with TorchSeedContext(seed):
|
216 |
+
chat_text = chat.infer(
|
217 |
+
text=text,
|
218 |
+
skip_refine_text=False,
|
219 |
+
refine_text_only=True, #仅返回优化后文本内容
|
220 |
+
params_refine_text=params_refine_text,
|
221 |
+
)
|
222 |
+
|
223 |
+
return chat_text[0] if isinstance(chat_text, list) else chat_text
|
224 |
+
|
225 |
+
def on_audio_seed_change(audio_seed_input):
|
226 |
+
global chat
|
227 |
+
with TorchSeedContext(audio_seed_input):
|
228 |
+
rand_spk = chat.sample_random_speaker()
|
229 |
+
return rand_spk
|
230 |
+
|
231 |
+
|
232 |
+
if __name__ == "__main__":
|
233 |
+
parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
|
234 |
+
parser.add_argument(
|
235 |
+
"--server_name", type=str, default="0.0.0.0", help="server name"
|
236 |
+
)
|
237 |
+
parser.add_argument("--server_port", type=int, default=8080, help="server port")
|
238 |
+
parser.add_argument(
|
239 |
+
"--custom_path", type=str, default="D:\\chenjgspace\\ai-model\\chattts", help="custom model path"
|
240 |
+
)
|
241 |
+
parser.add_argument(
|
242 |
+
"--coef", type=str, default=None, help="custom dvae coefficient"
|
243 |
+
)
|
244 |
+
args = parser.parse_args()
|
245 |
+
init_chat(args)
|
246 |
+
main(args)
|