ChatVC / app.py
Hilley's picture
Update app.py
779412e verified
raw
history blame
10.2 kB
import spaces
import os
import random
import argparse
import torch
import gradio as gr
import numpy as np
import ChatTTS
import se_extractor
from api import BaseSpeakerTTS, ToneColorConverter
import soundfile
from tts_voice import tts_order_voice
import edge_tts
import tempfile
import anyio
print("loading ChatTTS model...")
chat = ChatTTS.Chat()
chat.load_models()
def generate_seed():
new_seed = random.randint(1, 100000000)
return {
"__type__": "update",
"value": new_seed
}
@spaces.GPU
def chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None):
torch.manual_seed(audio_seed_input)
rand_spk = torch.randn(768)
params_infer_code = {
'spk_emb': rand_spk,
'temperature': temperature,
'top_P': top_P,
'top_K': top_K,
}
params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
torch.manual_seed(text_seed_input)
if refine_text_flag:
if refine_text_input:
params_refine_text['prompt'] = refine_text_input
text = chat.infer(text,
skip_refine_text=False,
refine_text_only=True,
params_refine_text=params_refine_text,
params_infer_code=params_infer_code
)
print("Text has been refined!")
wav = chat.infer(text,
skip_refine_text=True,
params_refine_text=params_refine_text,
params_infer_code=params_infer_code
)
audio_data = np.array(wav[0]).flatten()
sample_rate = 24000
text_data = text[0] if isinstance(text, list) else text
if output_path is None:
return [(sample_rate, audio_data), text_data]
else:
soundfile.write(output_path, audio_data, sample_rate)
# OpenVoice
ckpt_base_en = 'checkpoints/base_speakers/EN'
ckpt_converter_en = 'checkpoints/converter'
device = 'cuda:0'
#device = "cpu"
base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base_en}/config.json', device=device)
base_speaker_tts.load_ckpt(f'{ckpt_base_en}/checkpoint.pth')
tone_color_converter = ToneColorConverter(f'{ckpt_converter_en}/config.json', device=device)
tone_color_converter.load_ckpt(f'{ckpt_converter_en}/checkpoint.pth')
def generate_audio(text, audio_ref, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input):
source_se = torch.load(f'{ckpt_base_en}/en_default_se.pth').to(device)
reference_speaker = audio_ref
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
save_path = "output.wav"
# Run the base speaker tts
src_path = "tmp.wav"
chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, src_path)
print("Ready for voice cloning!")
source_se, audio_name = se_extractor.get_se(src_path, tone_color_converter, target_dir='processed', vad=True)
print("Get source segment!")
# Run the tone color converter
encode_message = "@Hilley"
# convert from file
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message)
'''
# convert from data
src_path = None
sample_rate, audio = chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, src_path)[0]
print("Ready for voice cloning!")
tone_color_converter.convert_data(
audio=audio,
sample_rate=sample_rate,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message)
'''
print("Finished!")
return "output.wav"
def vc_en(text, audio_ref, style_mode):
if style_mode=="default":
source_se = torch.load(f'{ckpt_base_en}/en_default_se.pth').to(device)
reference_speaker = audio_ref
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
save_path = "output.wav"
# Run the base speaker tts
src_path = "tmp.wav"
base_speaker_tts.tts(text, src_path, speaker='default', language='English', speed=1.0)
# Run the tone color converter
encode_message = "@MyShell"
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message)
else:
source_se = torch.load(f'{ckpt_base_en}/en_style_se.pth').to(device)
reference_speaker = audio_ref
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
save_path = "output.wav"
# Run the base speaker tts
src_path = "tmp.wav"
base_speaker_tts.tts(text, src_path, speaker=style_mode, language='English', speed=0.9)
# Run the tone color converter
encode_message = "@MyShell"
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message)
return "output.wav"
language_dict = tts_order_voice
base_speaker = "base_audio.mp3"
source_se, audio_name = se_extractor.get_se(base_speaker, tone_color_converter, vad=True)
async def text_to_speech_edge(text, audio_ref, language_code):
voice = language_dict[language_code]
communicate = edge_tts.Communicate(text, voice)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
tmp_path = tmp_file.name
await communicate.save(tmp_path)
reference_speaker = audio_ref
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
save_path = "output.wav"
# Run the tone color converter
encode_message = "@MyShell"
tone_color_converter.convert(
audio_src_path=tmp_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message)
return "output.wav"
with gr.Blocks() as demo:
gr.Markdown("# Enjoy chatting with your ai friends on website, telegram and so on! (https://linkin.love)")
default_text = "Today a man knocked on my door and asked for a small donation toward the local swimming pool. I gave him a glass of water."
text_input = gr.Textbox(label="Input Text", lines=4, placeholder="Please Input Text...", value=default_text)
voice_ref = gr.Audio(label="Reference Audio", info="Click on the ✎ button to upload your own target speaker audio", type="filepath", value="base_audio.mp3")
with gr.Tab("💕Super Natural"):
default_refine_text = "[oral_2][laugh_0][break_6]"
refine_text_checkbox = gr.Checkbox(label="Refine text", info="'oral' means add filler words, 'laugh' means add laughter, and 'break' means add a pause. (0-10) ", value=True)
refine_text_input = gr.Textbox(label="Refine Prompt", lines=1, placeholder="Please Refine Prompt...", value=default_refine_text)
with gr.Row():
temperature_slider = gr.Slider(minimum=0.00001, maximum=1.0, step=0.00001, value=0.3, label="Audio temperature")
top_p_slider = gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.7, label="top_P")
top_k_slider = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_K")
with gr.Row():
audio_seed_input = gr.Number(value=42, label="Speaker Seed")
generate_audio_seed = gr.Button("\U0001F3B2")
text_seed_input = gr.Number(value=42, label="Text Seed")
generate_text_seed = gr.Button("\U0001F3B2")
generate_button = gr.Button("Generate!")
#text_output = gr.Textbox(label="Refined Text", interactive=False)
audio_output = gr.Audio(label="Output Audio")
generate_audio_seed.click(generate_seed,
inputs=[],
outputs=audio_seed_input)
generate_text_seed.click(generate_seed,
inputs=[],
outputs=text_seed_input)
generate_button.click(generate_audio,
inputs=[text_input, voice_ref, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input],
outputs=audio_output)
with gr.Tab("💕Emotion Control"):
emo_pick = gr.Dropdown(label="Emotion", info="🙂default😊friendly🤫whispering😄cheerful😱terrified😡angry😢sad", choices=["default", "friendly", "whispering", "cheerful", "terrified", "angry", "sad"], value="default")
generate_button_emo = gr.Button("Generate!", variant="primary")
audio_emo = gr.Audio(label="Output Audio", type="filepath")
generate_button_emo.click(vc_en, [text_input, voice_ref, emo_pick], audio_emo)
with gr.Tab("💕multilingual"):
language = gr.Dropdown(choices=list(language_dict.keys()), value=list(language_dict.keys())[15], label="请选择文本对应的语言及说话人")
generate_button_ml = gr.Button("开始语音情感真实复刻吧!", variant="primary")
audio_ml = gr.Audio(label="为您合成的专属语音", type="filepath")
generate_button_ml.click(text_to_speech_edge, [text_input, voice_ref, language], audio_ml)
parser = argparse.ArgumentParser(description='ChatTTS demo Launch')
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
parser.add_argument('--server_port', type=int, default=8080, help='Server port')
args = parser.parse_args()
# demo.launch(server_name=args.server_name, server_port=args.server_port, inbrowser=True)
if __name__ == '__main__':
demo.launch()