Lemonfoot_GPTSoVITS / GPT_SoVITS /inference_webui.py
Kit-Lemonfoot's picture
Added Ame, Mint and Tenma, did some code changes
d848e9a verified
raw
history blame contribute delete
No virus
11.8 kB
# Based on GPT-SoVITS-fast-inference by ChasonJiang
import random
import os
import torch
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
import re, logging
logging.getLogger("markdown_it").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpcore").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("asyncio").setLevel(logging.ERROR)
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb
import json
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui)
is_share = os.environ.get("is_share", "False")
is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
gpt_path=None
sovits_path=None
#gpt_path = os.environ.get("gpt_path", None)
#sovits_path = os.environ.get("sovits_path", None)
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
bert_path = os.environ.get("bert_path", None)
import gradio as gr
from TTS_infer_pack.TTS import TTS, TTS_Config
from TTS_infer_pack.text_segmentation_method import get_method
dict_language = {
"ZH": "all_zh",#全部按中文识别
"EN": "en",#全部按英文识别#######不变
"JP": "all_ja",#全部按日文识别
"ZH/EN": "zh",#按中英混合识别####不变
"JP/EN": "ja",#按日英混合识别####不变
"Automatic": "auto",#多语种启动切分识别语种
}
cut_method = {
"None":"cut0",
"4 Sentences": "cut1",
"50 Characters": "cut2",
"ZH/JP Punctuation": "cut3",
"EN Punctuation": "cut4",
"All Punctuation": "cut5",
}
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
tts_config.device = device
tts_config.is_half = is_half
if gpt_path is not None:
tts_config.t2s_weights_path = gpt_path
if sovits_path is not None:
tts_config.vits_weights_path = sovits_path
if cnhubert_base_path is not None:
tts_config.cnhuhbert_base_path = cnhubert_base_path
if bert_path is not None:
tts_config.bert_base_path = bert_path
print(tts_config)
tts_pipeline = TTS(tts_config)
gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path
clm= ""
def inference(name, gptmp, svmp, sty, text, text_lang,
ref_audio_path, prompt_text,
prompt_lang, top_k,
top_p, temperature,
text_split_method, batch_size,
speed_factor,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
repetition_penalty
):
global clm
#Live switching
if(not ref_audio_path):
ref_audio_path=f"referenceaudio/{name}/"+referencedata[name][0][sty]
prompt_text=referencedata[name][1][sty]
if clm!=name:
print(f"Switching to model {name}")
clm=name
tts_pipeline.init_t2s_weights(gptmp)
tts_pipeline.init_vits_weights(svmp)
seed = -1 if keep_random else seed
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
print(f"TMP: {temperature} | SPDFCT: {speed_factor} | STY: {sty} | LANG: {text_lang}")
inputs={
"text": text,
"text_lang": dict_language[text_lang],
"ref_audio_path": ref_audio_path,
"prompt_text": prompt_text,
"prompt_lang": dict_language[prompt_lang],
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"text_split_method": cut_method[text_split_method],
"batch_size":int(batch_size),
"speed_factor":float(speed_factor),
"split_bucket":split_bucket,
"return_fragment":False,
"fragment_interval":fragment_interval,
"seed":actual_seed,
"parallel_infer": parallel_infer,
"repetition_penalty": repetition_penalty,
}
for item in tts_pipeline.run(inputs):
yield item, actual_seed
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分
parts = re.split('(\d+)', s)
# 将数字部分转换为整数,非数字部分保持不变
parts = [int(part) if part.isdigit() else part for part in parts]
return parts
def change_choices():
SoVITS_names, GPT_names = get_weights_names()
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
pretrained_sovits_name = "GPT_SoVITS/pretrained_models/s2G488k.pth"
pretrained_gpt_name = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
SoVITS_weight_root = "GPT_SoVITS/SoVITS_weights/"
GPT_weight_root = "GPT_SoVITS/GPT_weights/"
def get_weights_names():
SoVITS_names = [pretrained_sovits_name]
for name in os.listdir(SoVITS_weight_root):
if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (SoVITS_weight_root, name))
GPT_names = [pretrained_gpt_name]
for name in os.listdir(GPT_weight_root):
if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (GPT_weight_root, name))
return SoVITS_names, GPT_names
def load_models():
print("Loading models...")
voices=[]
ustyles={}
with open("voicelist.json", "r", encoding="utf-8") as f:
voc_info = json.load(f)
for name, info in voc_info.items():
if not info['enable']:
continue
title= info['title']
gptmodelpath= "%s/%s" % (GPT_weight_root, info['gpt_model_path'])
sovitsmodelpath= "%s/%s" % (SoVITS_weight_root, info['sovits_model_path'])
author= info['modelauthor']
image = info['cover']
styles = info['styles']
#check that all styles properly exist
for s in styles.values():
if(not os.path.exists(f"referenceaudio/{name}/{s}")):
print(f"WARNING : Some defined preset styles do not exist for model {name}, skipping")
styles=None
break
styletrans = info['styletrans']
st=[styles, styletrans]
voices.append((name, title, gptmodelpath, sovitsmodelpath, author, image))
ustyles[name]=st
print(f"Indexed model {title}")
return voices, ustyles
modeldata, referencedata = load_models()
#Gradio preload
text = gr.TextArea(label="Input Text", value="Hello there! This is test audio of a new text to speech tool.")
text_language = gr.Dropdown(label="Language", choices=["EN", "JP", "ZH", "ZH/EN", "JP/EN", "Automatic"], value="EN")
how_to_cut = gr.Dropdown(label="Slicing Method",
choices=["None", "4 Sentences", "50 Characters", "ZH/JP Punctuation", "EN Punctuation", "All Punctuation" ],
value="4 Sentences",
interactive=True,
)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label="Top_k",value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label="Top_p",value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label="Temperature",value=0.7,interactive=True)
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label="Batch Size",value=20,interactive=True)
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label="Fragment Interval",value=0.3,interactive=True)
speed_factor = gr.Slider(minimum=0.50,maximum=2,step=0.05,label="Speed Factor",value=1.0,interactive=True)
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label="Repetition Penalty",value=1.35,interactive=True)
parallel_infer = gr.Checkbox(label="Parallel Infer", value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(label="Split Bucket", value=True, interactive=True, show_label=True)
seed = gr.Number(label="Random Seed",value=-1, interactive=True, show_label=True)
keep_random = gr.Checkbox(label="Use Randomized Seed", value=True, interactive=True, show_label=True)
#Main gradio
with gr.Blocks(title="Lemonfoot GPT-SoVITS") as app:
gr.Markdown(
"# Lemonfoot GPT-SoVITS 🚀🍋\n"
"### Space by Kit Lemonfoot / Noel Shirogane's High Flying Birds\n"
"Based on code originally by RVC_Boss and ChasonJiang\n\n"
"Do no evil.\n\n"
)
for (name, title, gptmodelpath, sovitsmodelpath, author, image) in modeldata:
with gr.TabItem(name):
with gr.Row():
with gr.Column():
n = gr.Textbox(value=name, visible=False, interactive=False)
gptmp = gr.Textbox(value=gptmodelpath, visible=False, interactive=False)
svmp = gr.Textbox(value=sovitsmodelpath, visible=False, interactive=False)
gr.Markdown(f"**{title}**\n\n Dataset author: {author}")
gr.Image(f"images/{image}", label=None, show_label=False, width=300, show_download_button=False, container=False, show_share_button=False)
with gr.Column():
#if there isn't any styles don't bother rendering the style window
if(not referencedata[name][0]==None):
rd = list(referencedata[name][0].keys())
with gr.TabItem("Style using a preset"):
sty = gr.Dropdown(
label="Current style",
choices=rd,
value=rd[0],
interactive=True
)
else:
sty=gr.Textbox(value="none", visible=False, interactive=False)
with gr.TabItem("Style using a different audio"):
with gr.Column():
ref_audio_path = gr.Audio(label="Reference Audio", type="filepath")
prompt_text = gr.Textbox(label="Reference Audio Text", interactive=True, placeholder="Leave blank to use no-text reference mode.")
prompt_language = gr.Dropdown(label="Reference Audio Language", choices=["EN", "JP", "ZH", "ZH/EN", "JP/EN", "Automatic"], value="EN")
with gr.Column():
inference_button = gr.Button("Synthesize", variant="primary")
output = gr.Audio(label="Output")
inference_button.click(
inference,
inputs=[n, gptmp, svmp, sty, text, text_language, ref_audio_path, prompt_text, prompt_language, top_k, top_p, temperature, how_to_cut, batch_size, speed_factor, split_bucket, fragment_interval, seed, keep_random, parallel_infer, repetition_penalty],
outputs=[output, seed]
)
#bottom info
with gr.Row():
with gr.Column():
text.render()
text_language.render()
how_to_cut.render()
with gr.Column():
temperature.render()
speed_factor.render()
with gr.Accordion("Advanced Inference Parameters", open=False):
top_k.render()
top_p.render()
batch_size.render()
fragment_interval.render()
repetition_penalty.render()
parallel_infer.render()
split_bucket.render()
seed.render()
keep_random.render()
app.queue().launch()