Spaces:
Running
Running
import argparse | |
import os | |
import sys | |
import tempfile | |
import gradio as gr | |
import librosa.display | |
import numpy as np | |
import os | |
import torch | |
import torchaudio | |
import traceback | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
import spaces | |
def clear_gpu_cache(): | |
# clear the GPU cache | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
XTTS_MODEL = None | |
def load_model(choice): | |
global XTTS_MODEL | |
clear_gpu_cache() | |
if choice == "dingzhen": | |
xtts_checkpoint="./finetune_models/run/training/GPT_XTTS_FT-July-04-2024_01+29PM-44c61c9/best_model.pth" | |
xtts_config="./finetune_models/run/training/XTTS_v2.0_original_model_files/config.json" | |
xtts_vocab="./finetune_models/run/training/XTTS_v2.0_original_model_files/vocab.json" | |
elif choice == "kobe": | |
xtts_checkpoint="./finetune_models_kobe/run/training/GPT_XTTS_FT-July-05-2024_09+09AM-44c61c9/best_model.pth" | |
xtts_config="./finetune_models_kobe/run/training/XTTS_v2.0_original_model_files/config.json" | |
xtts_vocab="./finetune_models_kobe/run/training/XTTS_v2.0_original_model_files/vocab.json" | |
if not xtts_checkpoint or not xtts_config or not xtts_vocab: | |
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!" | |
config = XttsConfig() | |
config.load_json(xtts_config) | |
XTTS_MODEL = Xtts.init_from_config(config) | |
print("Loading XTTS model! ") | |
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, speaker_file_path="./speakers_xtts.pth", use_deepspeed=False) | |
if torch.cuda.is_available(): | |
XTTS_MODEL.cuda() | |
print("模型已成功加载!") | |
return "模型已成功加载!" | |
def run_tts(lang, tts_text, speaker_audio_file): | |
#print(XTTS_MODEL) | |
#print(speaker_audio_file) | |
if XTTS_MODEL is None or not speaker_audio_file: | |
return "您需要先执行第1步 - 加载模型", None, None | |
speaker_audio_file = "".join([item for item in speaker_audio_file.strip().split("\n") if item != ""]) | |
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) | |
out = XTTS_MODEL.inference( | |
text=tts_text.strip(), | |
language=lang, | |
gpt_cond_latent=gpt_cond_latent, | |
speaker_embedding=speaker_embedding, | |
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here | |
length_penalty=XTTS_MODEL.config.length_penalty, | |
repetition_penalty=XTTS_MODEL.config.repetition_penalty, | |
top_k=XTTS_MODEL.config.top_k, | |
top_p=XTTS_MODEL.config.top_p, | |
) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: | |
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0) | |
out_path = fp.name | |
torchaudio.save(out_path, out["wav"], 24000) | |
return "推理成功,快来听听吧!", out_path, speaker_audio_file | |
# define a logger to redirect | |
class Logger: | |
def __init__(self, filename="log.out"): | |
self.log_file = filename | |
self.terminal = sys.stdout | |
self.log = open(self.log_file, "w") | |
def write(self, message): | |
self.terminal.write(message) | |
self.log.write(message) | |
def flush(self): | |
self.terminal.flush() | |
self.log.flush() | |
def isatty(self): | |
return False | |
# redirect stdout and stderr to a file | |
sys.stdout = Logger() | |
sys.stderr = sys.stdout | |
# logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
import logging | |
logging.basicConfig( | |
level=logging.WARNING, | |
format="%(asctime)s [%(levelname)s] %(message)s", | |
handlers=[ | |
logging.StreamHandler(sys.stdout) | |
] | |
) | |
def read_logs(): | |
sys.stdout.flush() | |
with open(sys.stdout.log_file, "r") as f: | |
return f.read() | |
with gr.Blocks(title="GPT-SoVITS WebUI") as app: | |
gr.Markdown("# <center>🌊💕🎶 XTTS 微调:2分钟语音,开启中日英16种语言真实拟声</center>") | |
gr.Markdown("## <center>🌟 只需2分钟的语音,一键在线微调 最强多语种模型</center>") | |
gr.Markdown("### <center>🤗 更多精彩,尽在[滔滔AI](https://www.talktalkai.com/);滔滔AI,为爱滔滔!💕</center>") | |
with gr.Row(): | |
with gr.Column() as col1: | |
choice = gr.Dropdown(label="请选择您喜欢的模型", value="dingzhen", choices=["dingzhen", "kobe"]) | |
progress_load = gr.Label( | |
label="模型加载进程" | |
) | |
load_btn = gr.Button(value="1. 加载已训练好的模型", variant="primary") | |
with gr.Column() as col2: | |
speaker_reference_audio = gr.Dropdown( | |
label="请选择一个参考音频", | |
info="不同参考音频对应的合成效果不同。您可以尝试多次,每次选择一个音频路径", | |
value="dingzhen1.wav", | |
choices=["dingzhen1.wav", "dingzhen2.wav", "dingzhen3.wav", "dingzhen4.wav", "dingzhen5.wav", "dingzhen6.wav"] | |
) | |
tts_text = gr.Textbox( | |
label="请填写语音合成的文本🍻", | |
placeholder="想说却还没说的,还很多", | |
) | |
tts_language = gr.Dropdown( | |
label="请选择文本对应的语言", | |
value="zh", | |
choices=[ | |
"en", | |
"es", | |
"fr", | |
"de", | |
"it", | |
"pt", | |
"pl", | |
"tr", | |
"ru", | |
"nl", | |
"cs", | |
"ar", | |
"zh", | |
"hu", | |
"ko", | |
"ja", | |
] | |
) | |
tts_btn = gr.Button(value="2. 开启AI语音之旅吧💕", variant="primary") | |
with gr.Column() as col3: | |
progress_gen = gr.Label( | |
label="语音合成进程" | |
) | |
tts_output_audio = gr.Audio(label="为您合成的专属音频🎶") | |
reference_audio = gr.Audio(label="您使用的参考音频") | |
load_btn.click( | |
fn=load_model, | |
inputs=[ | |
choice | |
], | |
outputs=[progress_load], | |
) | |
tts_btn.click( | |
fn=run_tts, | |
inputs=[ | |
tts_language, | |
tts_text, | |
speaker_reference_audio, | |
], | |
outputs=[progress_gen, tts_output_audio, reference_audio], | |
) | |
gr.Markdown("### <center>注意❗:请不要生成会对个人以及组织造成侵害的内容,此程序仅供科研、学习及个人娱乐使用。请自觉合规使用此程序,程序开发者不负有任何责任。</center>") | |
gr.HTML(''' | |
<div class="footer"> | |
<p>🌊🏞️🎶 - 江水东流急,滔滔无尽声。 明·顾璘 | |
</p> | |
</div> | |
''') | |
app.queue().launch( | |
share=True, | |
show_error=True, | |
) |