Spaces:
Runtime error
Runtime error
try: | |
import spaces | |
except: | |
class NoneSpaces: | |
def __init__(self): | |
pass | |
def GPU(self, fn): | |
return fn | |
spaces = NoneSpaces() | |
import os | |
import logging | |
import numpy as np | |
from modules.devices import devices | |
from modules.synthesize_audio import synthesize_audio | |
from modules.utils.cache import conditional_cache | |
logging.basicConfig( | |
level=os.getenv("LOG_LEVEL", "INFO"), | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
import gradio as gr | |
import torch | |
from modules.ssml import parse_ssml | |
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments | |
from modules.speaker import speaker_mgr | |
from modules.data import styles_mgr | |
from modules.api.utils import calc_spk_style | |
import modules.generate_audio as generate | |
from modules.normalization import text_normalize | |
from modules import refiner, config | |
from modules.utils import env, audio | |
from modules.SentenceSplitter import SentenceSplitter | |
torch._dynamo.config.cache_size_limit = 64 | |
torch._dynamo.config.suppress_errors = True | |
torch.set_float32_matmul_precision("high") | |
webui_config = { | |
"tts_max": 1000, | |
"ssml_max": 5000, | |
"spliter_threshold": 100, | |
"max_batch_size": 8, | |
} | |
def get_speakers(): | |
return speaker_mgr.list_speakers() | |
def get_styles(): | |
return styles_mgr.list_items() | |
def segments_length_limit(segments, total_max: int): | |
ret_segments = [] | |
total_len = 0 | |
for seg in segments: | |
if "text" not in seg: | |
continue | |
total_len += len(seg["text"]) | |
if total_len > total_max: | |
break | |
ret_segments.append(seg) | |
return ret_segments | |
def synthesize_ssml(ssml: str, batch_size=4): | |
try: | |
batch_size = int(batch_size) | |
except Exception: | |
batch_size = 8 | |
ssml = ssml.strip() | |
if ssml == "": | |
return None | |
segments = parse_ssml(ssml) | |
max_len = webui_config["ssml_max"] | |
segments = segments_length_limit(segments, max_len) | |
if len(segments) == 0: | |
return None | |
synthesize = SynthesizeSegments(batch_size=batch_size) | |
audio_segments = synthesize.synthesize_segments(segments) | |
combined_audio = combine_audio_segments(audio_segments) | |
return audio.pydub_to_np(combined_audio) | |
def tts_generate( | |
text, | |
temperature, | |
top_p, | |
top_k, | |
spk, | |
infer_seed, | |
use_decoder, | |
prompt1, | |
prompt2, | |
prefix, | |
style, | |
disable_normalize=False, | |
batch_size=4, | |
): | |
try: | |
batch_size = int(batch_size) | |
except Exception: | |
batch_size = 4 | |
max_len = webui_config["tts_max"] | |
text = text.strip()[0:max_len] | |
if text == "": | |
return None | |
if style == "*auto": | |
style = None | |
if isinstance(top_k, float): | |
top_k = int(top_k) | |
params = calc_spk_style(spk=spk, style=style) | |
spk = params.get("spk", spk) | |
infer_seed = infer_seed or params.get("seed", infer_seed) | |
temperature = temperature or params.get("temperature", temperature) | |
prefix = prefix or params.get("prefix", prefix) | |
prompt1 = prompt1 or params.get("prompt1", "") | |
prompt2 = prompt2 or params.get("prompt2", "") | |
infer_seed = np.clip(infer_seed, -1, 2**32 - 1) | |
infer_seed = int(infer_seed) | |
if not disable_normalize: | |
text = text_normalize(text) | |
sample_rate, audio_data = synthesize_audio( | |
text=text, | |
temperature=temperature, | |
top_P=top_p, | |
top_K=top_k, | |
spk=spk, | |
infer_seed=infer_seed, | |
use_decoder=use_decoder, | |
prompt1=prompt1, | |
prompt2=prompt2, | |
prefix=prefix, | |
batch_size=batch_size, | |
) | |
audio_data = audio.audio_to_int16(audio_data) | |
return sample_rate, audio_data | |
def refine_text(text: str, prompt: str): | |
text = text_normalize(text) | |
return refiner.refine_text(text, prompt=prompt) | |
def read_local_readme(): | |
with open("README.md", "r", encoding="utf-8") as file: | |
content = file.read() | |
content = content[content.index("# 🗣️ ChatTTS-Forge") :] | |
return content | |
# 演示示例文本 | |
sample_texts = [ | |
{ | |
"text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]", | |
}, | |
{ | |
"text": "天气预报显示,今天会有小雨,请大家出门时记得带伞。降温的天气也提醒我们要适时添衣保暖 [lbreak]", | |
}, | |
{ | |
"text": "公司的年度总结会议将在下周三举行,请各部门提前准备好相关材料,确保会议顺利进行 [lbreak]", | |
}, | |
{ | |
"text": "今天的午餐菜单包括烤鸡、沙拉和蔬菜汤,大家可以根据自己的口味选择适合的菜品 [lbreak]", | |
}, | |
{ | |
"text": "请注意,电梯将在下午两点进行例行维护,预计需要一个小时的时间,请大家在此期间使用楼梯 [lbreak]", | |
}, | |
{ | |
"text": "图书馆新到了一批书籍,涵盖了文学、科学和历史等多个领域,欢迎大家前来借阅 [lbreak]", | |
}, | |
{ | |
"text": "电影中梁朝伟扮演的陈永仁的编号27149 [lbreak]", | |
}, | |
{ | |
"text": "这块黄金重达324.75克 [lbreak]", | |
}, | |
{ | |
"text": "我们班的最高总分为583分 [lbreak]", | |
}, | |
{ | |
"text": "12~23 [lbreak]", | |
}, | |
{ | |
"text": "-1.5~2 [lbreak]", | |
}, | |
{ | |
"text": "她出生于86年8月18日,她弟弟出生于1995年3月1日 [lbreak]", | |
}, | |
{ | |
"text": "等会请在12:05请通知我 [lbreak]", | |
}, | |
{ | |
"text": "今天的最低气温达到-10°C [lbreak]", | |
}, | |
{ | |
"text": "现场有7/12的观众投出了赞成票 [lbreak]", | |
}, | |
{ | |
"text": "明天有62%的概率降雨 [lbreak]", | |
}, | |
{ | |
"text": "随便来几个价格12块5,34.5元,20.1万 [lbreak]", | |
}, | |
{ | |
"text": "这是固话0421-33441122 [lbreak]", | |
}, | |
{ | |
"text": "这是手机+86 18544139121 [lbreak]", | |
}, | |
] | |
ssml_example1 = """ | |
<speak version="0.1"> | |
<voice spk="Bob" seed="42" style="narration-relaxed"> | |
下面是一个 ChatTTS 用于合成多角色多情感的有声书示例[lbreak] | |
</voice> | |
<voice spk="Bob" seed="42" style="narration-relaxed"> | |
黛玉冷笑道:[lbreak] | |
</voice> | |
<voice spk="female2" seed="42" style="angry"> | |
我说呢 [uv_break] ,亏了绊住,不然,早就飞起来了[lbreak] | |
</voice> | |
<voice spk="Bob" seed="42" style="narration-relaxed"> | |
宝玉道:[lbreak] | |
</voice> | |
<voice spk="Alice" seed="42" style="unfriendly"> | |
“只许和你玩 [uv_break] ,替你解闷。不过偶然到他那里,就说这些闲话。”[lbreak] | |
</voice> | |
<voice spk="female2" seed="42" style="angry"> | |
“好没意思的话![uv_break] 去不去,关我什么事儿? 又没叫你替我解闷儿 [uv_break],还许你不理我呢” [lbreak] | |
</voice> | |
<voice spk="Bob" seed="42" style="narration-relaxed"> | |
说着,便赌气回房去了 [lbreak] | |
</voice> | |
</speak> | |
""" | |
ssml_example2 = """ | |
<speak version="0.1"> | |
<voice spk="Bob" seed="42" style="narration-relaxed"> | |
使用 prosody 控制生成文本的语速语调和音量,示例如下 [lbreak] | |
<prosody> | |
无任何限制将会继承父级voice配置进行生成 [lbreak] | |
</prosody> | |
<prosody rate="1.5"> | |
设置 rate 大于1表示加速,小于1为减速 [lbreak] | |
</prosody> | |
<prosody pitch="6"> | |
设置 pitch 调整音调,设置为6表示提高6个半音 [lbreak] | |
</prosody> | |
<prosody volume="2"> | |
设置 volume 调整音量,设置为2表示提高2个分贝 [lbreak] | |
</prosody> | |
在 voice 中无prosody包裹的文本即为默认生成状态下的语音 [lbreak] | |
</voice> | |
</speak> | |
""" | |
ssml_example3 = """ | |
<speak version="0.1"> | |
<voice spk="Bob" seed="42" style="narration-relaxed"> | |
使用 break 标签将会简单的 [lbreak] | |
<break time="500" /> | |
插入一段空白到生成结果中 [lbreak] | |
</voice> | |
</speak> | |
""" | |
ssml_example4 = """ | |
<speak version="0.1"> | |
<voice spk="Bob" seed="42" style="excited"> | |
temperature for sampling (may be overridden by style or speaker) [lbreak] | |
<break time="500" /> | |
温度值用于采样,这个值有可能被 style 或者 speaker 覆盖 [lbreak] | |
<break time="500" /> | |
temperature for sampling ,这个值有可能被 style 或者 speaker 覆盖 [lbreak] | |
<break time="500" /> | |
温度值用于采样,(may be overridden by style or speaker) [lbreak] | |
</voice> | |
</speak> | |
""" | |
default_ssml = """ | |
<speak version="0.1"> | |
<voice spk="Bob" seed="42" style="narration-relaxed"> | |
这里是一个简单的 SSML 示例 [lbreak] | |
</voice> | |
</speak> | |
""" | |
def create_tts_interface(): | |
speakers = get_speakers() | |
def get_speaker_show_name(spk): | |
if spk.gender == "*" or spk.gender == "": | |
return spk.name | |
return f"{spk.gender} : {spk.name}" | |
speaker_names = ["*random"] + [ | |
get_speaker_show_name(speaker) for speaker in speakers | |
] | |
styles = ["*auto"] + [s.get("name") for s in get_styles()] | |
history = [] | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown("🎛️Sampling") | |
temperature_input = gr.Slider( | |
0.01, 2.0, value=0.3, step=0.01, label="Temperature" | |
) | |
top_p_input = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Top P") | |
top_k_input = gr.Slider(1, 50, value=20, step=1, label="Top K") | |
batch_size_input = gr.Slider( | |
1, | |
webui_config["max_batch_size"], | |
value=4, | |
step=1, | |
label="Batch Size", | |
) | |
with gr.Row(): | |
with gr.Group(): | |
gr.Markdown("🎭Style") | |
gr.Markdown("- 后缀为 `_p` 表示带prompt,效果更强但是影响质量") | |
style_input_dropdown = gr.Dropdown( | |
choices=styles, | |
# label="Choose Style", | |
interactive=True, | |
show_label=False, | |
value="*auto", | |
) | |
with gr.Row(): | |
with gr.Group(): | |
gr.Markdown("🗣️Speaker (Name or Seed)") | |
spk_input_text = gr.Textbox( | |
label="Speaker (Text or Seed)", | |
value="female2", | |
show_label=False, | |
) | |
spk_input_dropdown = gr.Dropdown( | |
choices=speaker_names, | |
# label="Choose Speaker", | |
interactive=True, | |
value="female : female2", | |
show_label=False, | |
) | |
spk_rand_button = gr.Button( | |
value="🎲", | |
# tooltip="Random Seed", | |
variant="secondary", | |
) | |
spk_input_dropdown.change( | |
fn=lambda x: x.startswith("*") | |
and "-1" | |
or x.split(":")[-1].strip(), | |
inputs=[spk_input_dropdown], | |
outputs=[spk_input_text], | |
) | |
spk_rand_button.click( | |
lambda x: str(torch.randint(0, 2**32 - 1, (1,)).item()), | |
inputs=[spk_input_text], | |
outputs=[spk_input_text], | |
) | |
with gr.Group(): | |
gr.Markdown("💃Inference Seed") | |
infer_seed_input = gr.Number( | |
value=42, | |
label="Inference Seed", | |
show_label=False, | |
minimum=-1, | |
maximum=2**32 - 1, | |
) | |
infer_seed_rand_button = gr.Button( | |
value="🎲", | |
# tooltip="Random Seed", | |
variant="secondary", | |
) | |
use_decoder_input = gr.Checkbox( | |
value=True, label="Use Decoder", visible=False | |
) | |
with gr.Group(): | |
gr.Markdown("🔧Prompt engineering") | |
prompt1_input = gr.Textbox(label="Prompt 1") | |
prompt2_input = gr.Textbox(label="Prompt 2") | |
prefix_input = gr.Textbox(label="Prefix") | |
infer_seed_rand_button.click( | |
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()), | |
inputs=[infer_seed_input], | |
outputs=[infer_seed_input], | |
) | |
with gr.Column(scale=3): | |
with gr.Row(): | |
with gr.Column(scale=4): | |
with gr.Group(): | |
input_title = gr.Markdown( | |
"📝Text Input", | |
elem_id="input-title", | |
) | |
gr.Markdown( | |
f"- 字数限制{webui_config['tts_max']:,}字,超过部分截断" | |
) | |
gr.Markdown("- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`") | |
gr.Markdown( | |
"- If the input text is all in English, it is recommended to check disable_normalize" | |
) | |
text_input = gr.Textbox( | |
show_label=False, | |
label="Text to Speech", | |
lines=10, | |
placeholder="输入文本或选择示例", | |
elem_id="text-input", | |
) | |
# TODO 字数统计,其实实现很好写,但是就是会触发loading...并且还要和后端交互... | |
# text_input.change( | |
# fn=lambda x: ( | |
# f"📝Text Input ({len(x)} char)" | |
# if x | |
# else ( | |
# "📝Text Input (0 char)" | |
# if not x | |
# else "📝Text Input (0 char)" | |
# ) | |
# ), | |
# inputs=[text_input], | |
# outputs=[input_title], | |
# ) | |
with gr.Row(): | |
contorl_tokens = [ | |
"[laugh]", | |
"[uv_break]", | |
"[v_break]", | |
"[lbreak]", | |
] | |
for tk in contorl_tokens: | |
t_btn = gr.Button(tk) | |
t_btn.click( | |
lambda text, tk=tk: text + " " + tk, | |
inputs=[text_input], | |
outputs=[text_input], | |
) | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown("🎶Refiner") | |
refine_prompt_input = gr.Textbox( | |
label="Refine Prompt", | |
value="[oral_2][laugh_0][break_6]", | |
) | |
refine_button = gr.Button("✍️Refine Text") | |
# TODO 分割句子,使用当前配置拼接为SSML,然后发送到SSML tab | |
# send_button = gr.Button("📩Split and send to SSML") | |
with gr.Group(): | |
gr.Markdown("🔊Generate") | |
disable_normalize_input = gr.Checkbox( | |
value=False, label="Disable Normalize" | |
) | |
tts_button = gr.Button( | |
"🔊Generate Audio", | |
variant="primary", | |
elem_classes="big-button", | |
) | |
with gr.Group(): | |
gr.Markdown("🎄Examples") | |
sample_dropdown = gr.Dropdown( | |
choices=[sample["text"] for sample in sample_texts], | |
show_label=False, | |
value=None, | |
interactive=True, | |
) | |
sample_dropdown.change( | |
fn=lambda x: x, | |
inputs=[sample_dropdown], | |
outputs=[text_input], | |
) | |
with gr.Group(): | |
gr.Markdown("🎨Output") | |
tts_output = gr.Audio(label="Generated Audio") | |
refine_button.click( | |
refine_text, | |
inputs=[text_input, refine_prompt_input], | |
outputs=[text_input], | |
) | |
tts_button.click( | |
tts_generate, | |
inputs=[ | |
text_input, | |
temperature_input, | |
top_p_input, | |
top_k_input, | |
spk_input_text, | |
infer_seed_input, | |
use_decoder_input, | |
prompt1_input, | |
prompt2_input, | |
prefix_input, | |
style_input_dropdown, | |
disable_normalize_input, | |
batch_size_input, | |
], | |
outputs=tts_output, | |
) | |
def create_ssml_interface(): | |
examples = [ | |
ssml_example1, | |
ssml_example2, | |
ssml_example3, | |
ssml_example4, | |
] | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Group(): | |
gr.Markdown("📝SSML Input") | |
gr.Markdown(f"- 最长{webui_config['ssml_max']:,}字符,超过会被截断") | |
gr.Markdown("- 尽量保证使用相同的 seed") | |
gr.Markdown( | |
"- 关于SSML可以看这个 [文档](https://github.com/lenML/ChatTTS-Forge/blob/main/docs/SSML.md)" | |
) | |
ssml_input = gr.Textbox( | |
label="SSML Input", | |
lines=10, | |
value=default_ssml, | |
placeholder="输入 SSML 或选择示例", | |
elem_id="ssml_input", | |
show_label=False, | |
) | |
ssml_button = gr.Button("🔊Synthesize SSML", variant="primary") | |
with gr.Column(scale=1): | |
with gr.Group(): | |
# 参数 | |
gr.Markdown("🎛️Parameters") | |
# batch size | |
batch_size_input = gr.Slider( | |
label="Batch Size", | |
value=4, | |
minimum=1, | |
maximum=webui_config["max_batch_size"], | |
step=1, | |
) | |
with gr.Group(): | |
gr.Markdown("🎄Examples") | |
gr.Examples( | |
examples=examples, | |
inputs=[ssml_input], | |
) | |
ssml_output = gr.Audio(label="Generated Audio") | |
ssml_button.click( | |
synthesize_ssml, | |
inputs=[ssml_input, batch_size_input], | |
outputs=ssml_output, | |
) | |
return ssml_input | |
def split_long_text(long_text_input): | |
spliter = SentenceSplitter(webui_config["spliter_threshold"]) | |
sentences = spliter.parse(long_text_input) | |
sentences = [text_normalize(s) for s in sentences] | |
data = [] | |
for i, text in enumerate(sentences): | |
data.append([i, text, len(text)]) | |
return data | |
def merge_dataframe_to_ssml(dataframe, spk, style, seed): | |
if style == "*auto": | |
style = None | |
if spk == "-1" or spk == -1: | |
spk = None | |
if seed == -1 or seed == "-1": | |
seed = None | |
ssml = "" | |
indent = " " * 2 | |
for i, row in dataframe.iterrows(): | |
ssml += f"{indent}<voice" | |
if spk: | |
ssml += f' spk="{spk}"' | |
if style: | |
ssml += f' style="{style}"' | |
if seed: | |
ssml += f' seed="{seed}"' | |
ssml += ">\n" | |
ssml += f"{indent}{indent}{text_normalize(row[1])}\n" | |
ssml += f"{indent}</voice>\n" | |
return f"<speak version='0.1'>\n{ssml}</speak>" | |
# 长文本处理 | |
# 可以输入长文本,并选择切割方法,切割之后可以将拼接的SSML发送到SSML tab | |
# 根据 。 句号切割,切割之后显示到 data table | |
def create_long_content_tab(ssml_input, tabs): | |
speakers = get_speakers() | |
def get_speaker_show_name(spk): | |
if spk.gender == "*" or spk.gender == "": | |
return spk.name | |
return f"{spk.gender} : {spk.name}" | |
speaker_names = ["*random"] + [ | |
get_speaker_show_name(speaker) for speaker in speakers | |
] | |
styles = ["*auto"] + [s.get("name") for s in get_styles()] | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# 选择说话人 选择风格 选择seed | |
with gr.Group(): | |
gr.Markdown("🗣️Speaker") | |
spk_input_text = gr.Textbox( | |
label="Speaker (Text or Seed)", | |
value="female2", | |
show_label=False, | |
) | |
spk_input_dropdown = gr.Dropdown( | |
choices=speaker_names, | |
interactive=True, | |
value="female : female2", | |
show_label=False, | |
) | |
spk_rand_button = gr.Button( | |
value="🎲", | |
variant="secondary", | |
) | |
with gr.Group(): | |
gr.Markdown("🎭Style") | |
style_input_dropdown = gr.Dropdown( | |
choices=styles, | |
interactive=True, | |
show_label=False, | |
value="*auto", | |
) | |
with gr.Group(): | |
gr.Markdown("🗣️Seed") | |
infer_seed_input = gr.Number( | |
value=42, | |
label="Inference Seed", | |
show_label=False, | |
minimum=-1, | |
maximum=2**32 - 1, | |
) | |
infer_seed_rand_button = gr.Button( | |
value="🎲", | |
variant="secondary", | |
) | |
send_btn = gr.Button("📩Send to SSML", variant="primary") | |
with gr.Column(scale=3): | |
with gr.Group(): | |
gr.Markdown("📝Long Text Input") | |
gr.Markdown("- 此页面用于处理超长文本") | |
gr.Markdown("- 切割后,可以选择说话人、风格、seed,然后发送到SSML") | |
long_text_input = gr.Textbox( | |
label="Long Text Input", | |
lines=10, | |
placeholder="输入长文本", | |
elem_id="long-text-input", | |
show_label=False, | |
) | |
long_text_split_button = gr.Button("🔪Split Text") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Group(): | |
gr.Markdown("🎨Output") | |
long_text_output = gr.DataFrame( | |
headers=["index", "text", "length"], | |
datatype=["number", "str", "number"], | |
elem_id="long-text-output", | |
interactive=False, | |
wrap=True, | |
value=[], | |
) | |
spk_input_dropdown.change( | |
fn=lambda x: x.startswith("*") and "-1" or x.split(":")[-1].strip(), | |
inputs=[spk_input_dropdown], | |
outputs=[spk_input_text], | |
) | |
spk_rand_button.click( | |
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()), | |
inputs=[spk_input_text], | |
outputs=[spk_input_text], | |
) | |
infer_seed_rand_button.click( | |
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()), | |
inputs=[infer_seed_input], | |
outputs=[infer_seed_input], | |
) | |
long_text_split_button.click( | |
split_long_text, | |
inputs=[long_text_input], | |
outputs=[long_text_output], | |
) | |
infer_seed_rand_button.click( | |
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()), | |
inputs=[infer_seed_input], | |
outputs=[infer_seed_input], | |
) | |
send_btn.click( | |
merge_dataframe_to_ssml, | |
inputs=[ | |
long_text_output, | |
spk_input_text, | |
style_input_dropdown, | |
infer_seed_input, | |
], | |
outputs=[ssml_input], | |
) | |
def change_tab(): | |
return gr.Tabs(selected="ssml") | |
send_btn.click(change_tab, inputs=[], outputs=[tabs]) | |
def create_readme_tab(): | |
readme_content = read_local_readme() | |
gr.Markdown(readme_content) | |
def create_interface(): | |
js_func = """ | |
function refresh() { | |
const url = new URL(window.location); | |
if (url.searchParams.get('__theme') !== 'dark') { | |
url.searchParams.set('__theme', 'dark'); | |
window.location.href = url.href; | |
} | |
} | |
""" | |
head_js = """ | |
<script> | |
</script> | |
""" | |
with gr.Blocks(js=js_func, head=head_js, title="ChatTTS Forge WebUI") as demo: | |
css = """ | |
<style> | |
.big-button { | |
height: 80px; | |
} | |
#input_title div.eta-bar { | |
display: none !important; transform: none !important; | |
} | |
</style> | |
""" | |
gr.HTML(css) | |
with gr.Tabs() as tabs: | |
with gr.TabItem("TTS"): | |
create_tts_interface() | |
with gr.TabItem("SSML", id="ssml"): | |
ssml_input = create_ssml_interface() | |
with gr.TabItem("Long Text"): | |
create_long_content_tab(ssml_input, tabs=tabs) | |
with gr.TabItem("README"): | |
create_readme_tab() | |
gr.Markdown( | |
"此项目基于 [ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge) " | |
) | |
return demo | |
if __name__ == "__main__": | |
import argparse | |
import dotenv | |
dotenv.load_dotenv( | |
dotenv_path=os.getenv("ENV_FILE", ".env.webui"), | |
) | |
parser = argparse.ArgumentParser(description="Gradio App") | |
parser.add_argument("--server_name", type=str, help="server name") | |
parser.add_argument("--server_port", type=int, help="server port") | |
parser.add_argument( | |
"--share", action="store_true", help="share the gradio interface" | |
) | |
parser.add_argument("--debug", action="store_true", help="enable debug mode") | |
parser.add_argument("--auth", type=str, help="username:password for authentication") | |
parser.add_argument( | |
"--half", | |
action="store_true", | |
help="Enable half precision for model inference", | |
) | |
parser.add_argument( | |
"--off_tqdm", | |
action="store_true", | |
help="Disable tqdm progress bar", | |
) | |
parser.add_argument( | |
"--tts_max_len", | |
type=int, | |
help="Max length of text for TTS", | |
) | |
parser.add_argument( | |
"--ssml_max_len", | |
type=int, | |
help="Max length of text for SSML", | |
) | |
parser.add_argument( | |
"--max_batch_size", | |
type=int, | |
help="Max batch size for TTS", | |
) | |
parser.add_argument( | |
"--lru_size", | |
type=int, | |
default=64, | |
help="Set the size of the request cache pool, set it to 0 will disable lru_cache", | |
) | |
parser.add_argument( | |
"--device_id", | |
type=str, | |
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", | |
default=None, | |
) | |
parser.add_argument( | |
"--use_cpu", | |
nargs="+", | |
help="use CPU as torch device for specified modules", | |
default=[], | |
type=str.lower, | |
) | |
parser.add_argument("--compile", action="store_true", help="Enable model compile") | |
args = parser.parse_args() | |
def get_and_update_env(*args): | |
val = env.get_env_or_arg(*args) | |
key = args[1] | |
config.runtime_env_vars[key] = val | |
return val | |
server_name = get_and_update_env(args, "server_name", "0.0.0.0", str) | |
server_port = get_and_update_env(args, "server_port", 7860, int) | |
share = get_and_update_env(args, "share", False, bool) | |
debug = get_and_update_env(args, "debug", False, bool) | |
auth = get_and_update_env(args, "auth", None, str) | |
half = get_and_update_env(args, "half", False, bool) | |
off_tqdm = get_and_update_env(args, "off_tqdm", False, bool) | |
lru_size = get_and_update_env(args, "lru_size", 64, int) | |
device_id = get_and_update_env(args, "device_id", None, str) | |
use_cpu = get_and_update_env(args, "use_cpu", [], list) | |
compile = get_and_update_env(args, "compile", False, bool) | |
webui_config["tts_max"] = get_and_update_env(args, "tts_max_len", 1000, int) | |
webui_config["ssml_max"] = get_and_update_env(args, "ssml_max_len", 5000, int) | |
webui_config["max_batch_size"] = get_and_update_env(args, "max_batch_size", 8, int) | |
demo = create_interface() | |
if auth: | |
auth = tuple(auth.split(":")) | |
generate.setup_lru_cache() | |
devices.reset_device() | |
devices.first_time_calculation() | |
demo.queue().launch( | |
server_name=server_name, | |
server_port=server_port, | |
share=share, | |
debug=debug, | |
auth=auth, | |
show_api=False, | |
) | |