xtts-ft / app.py
kevinwang676's picture
Update app.py
9e1609a verified
raw
history blame contribute delete
No virus
7.27 kB
import argparse
import os
import sys
import tempfile
import gradio as gr
import librosa.display
import numpy as np
import os
import torch
import torchaudio
import traceback
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
import spaces
def clear_gpu_cache():
# clear the GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
XTTS_MODEL = None
def load_model(choice):
global XTTS_MODEL
clear_gpu_cache()
if choice == "dingzhen":
xtts_checkpoint="./finetune_models/run/training/GPT_XTTS_FT-July-04-2024_01+29PM-44c61c9/best_model.pth"
xtts_config="./finetune_models/run/training/XTTS_v2.0_original_model_files/config.json"
xtts_vocab="./finetune_models/run/training/XTTS_v2.0_original_model_files/vocab.json"
elif choice == "kobe":
xtts_checkpoint="./finetune_models_kobe/run/training/GPT_XTTS_FT-July-05-2024_09+09AM-44c61c9/best_model.pth"
xtts_config="./finetune_models_kobe/run/training/XTTS_v2.0_original_model_files/config.json"
xtts_vocab="./finetune_models_kobe/run/training/XTTS_v2.0_original_model_files/vocab.json"
if not xtts_checkpoint or not xtts_config or not xtts_vocab:
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
config = XttsConfig()
config.load_json(xtts_config)
XTTS_MODEL = Xtts.init_from_config(config)
print("Loading XTTS model! ")
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, speaker_file_path="./speakers_xtts.pth", use_deepspeed=False)
if torch.cuda.is_available():
XTTS_MODEL.cuda()
print("模型已成功加载!")
return "模型已成功加载!"
@spaces.GPU
def run_tts(lang, tts_text, speaker_audio_file):
#print(XTTS_MODEL)
#print(speaker_audio_file)
if XTTS_MODEL is None or not speaker_audio_file:
return "您需要先执行第1步 - 加载模型", None, None
speaker_audio_file = "".join([item for item in speaker_audio_file.strip().split("\n") if item != ""])
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
out = XTTS_MODEL.inference(
text=tts_text.strip(),
language=lang,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
length_penalty=XTTS_MODEL.config.length_penalty,
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
top_k=XTTS_MODEL.config.top_k,
top_p=XTTS_MODEL.config.top_p,
)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
out_path = fp.name
torchaudio.save(out_path, out["wav"], 24000)
return "推理成功,快来听听吧!", out_path, speaker_audio_file
# define a logger to redirect
class Logger:
def __init__(self, filename="log.out"):
self.log_file = filename
self.terminal = sys.stdout
self.log = open(self.log_file, "w")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def isatty(self):
return False
# redirect stdout and stderr to a file
sys.stdout = Logger()
sys.stderr = sys.stdout
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
import logging
logging.basicConfig(
level=logging.WARNING,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.StreamHandler(sys.stdout)
]
)
def read_logs():
sys.stdout.flush()
with open(sys.stdout.log_file, "r") as f:
return f.read()
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown("# <center>🌊💕🎶 XTTS 微调:2分钟语音,开启中日英16种语言真实拟声</center>")
gr.Markdown("## <center>🌟 只需2分钟的语音,一键在线微调 最强多语种模型</center>")
gr.Markdown("### <center>🤗 更多精彩,尽在[滔滔AI](https://www.talktalkai.com/);滔滔AI,为爱滔滔!💕</center>")
with gr.Row():
with gr.Column() as col1:
choice = gr.Dropdown(label="请选择您喜欢的模型", value="dingzhen", choices=["dingzhen", "kobe"])
progress_load = gr.Label(
label="模型加载进程"
)
load_btn = gr.Button(value="1. 加载已训练好的模型", variant="primary")
with gr.Column() as col2:
speaker_reference_audio = gr.Dropdown(
label="请选择一个参考音频",
info="不同参考音频对应的合成效果不同。您可以尝试多次,每次选择一个音频路径",
value="dingzhen1.wav",
choices=["dingzhen1.wav", "dingzhen2.wav", "dingzhen3.wav", "dingzhen4.wav", "dingzhen5.wav", "dingzhen6.wav"]
)
tts_text = gr.Textbox(
label="请填写语音合成的文本🍻",
placeholder="想说却还没说的,还很多",
)
tts_language = gr.Dropdown(
label="请选择文本对应的语言",
value="zh",
choices=[
"en",
"es",
"fr",
"de",
"it",
"pt",
"pl",
"tr",
"ru",
"nl",
"cs",
"ar",
"zh",
"hu",
"ko",
"ja",
]
)
tts_btn = gr.Button(value="2. 开启AI语音之旅吧💕", variant="primary")
with gr.Column() as col3:
progress_gen = gr.Label(
label="语音合成进程"
)
tts_output_audio = gr.Audio(label="为您合成的专属音频🎶")
reference_audio = gr.Audio(label="您使用的参考音频")
load_btn.click(
fn=load_model,
inputs=[
choice
],
outputs=[progress_load],
)
tts_btn.click(
fn=run_tts,
inputs=[
tts_language,
tts_text,
speaker_reference_audio,
],
outputs=[progress_gen, tts_output_audio, reference_audio],
)
gr.Markdown("### <center>注意❗:请不要生成会对个人以及组织造成侵害的内容,此程序仅供科研、学习及个人娱乐使用。请自觉合规使用此程序,程序开发者不负有任何责任。</center>")
gr.HTML('''
<div class="footer">
<p>🌊🏞️🎶 - 江水东流急,滔滔无尽声。 明·顾璘
</p>
</div>
''')
app.queue().launch(
share=True,
show_error=True,
)