Spaces:
Runtime error
Runtime error
File size: 5,887 Bytes
003d053 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import datetime
import json
import os
import re
import time
import numpy as np
import torch
from tqdm import tqdm
import ChatTTS
from config import DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_TOP_K
def load_chat_tts_model(source='huggingface', force_redownload=False, local_path=None):
"""
Load ChatTTS model
:param source:
:param force_redownload:
:param local_path:
:return:
"""
print("Loading ChatTTS model...")
chat = ChatTTS.Chat()
chat.load_models(source=source, force_redownload=force_redownload, custom_path=local_path, compile=False)
return chat
def clear_cuda_cache():
"""
Clear CUDA cache
:return:
"""
torch.cuda.empty_cache()
def deterministic(seed=0):
"""
Set random seed for reproducibility
:param seed:
:return:
"""
# ref: https://github.com/Jackiexiao/ChatTTS-api-ui-docker/blob/main/api.py#L27
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def generate_audio_for_seed(chat, seed, texts, batch_size, speed, refine_text_prompt, roleid=None,
temperature=DEFAULT_TEMPERATURE,
top_P=DEFAULT_TOP_P, top_K=DEFAULT_TOP_K, cur_tqdm=None, skip_save=False,
skip_refine_text=False, speaker_type="seed", pt_file=None):
from utils import combine_audio, save_audio, batch_split
print(f"speaker_type: {speaker_type}")
if speaker_type == "seed":
if seed in [None, -1, 0, "", "random"]:
seed = np.random.randint(0, 9999)
deterministic(seed)
rnd_spk_emb = chat.sample_random_speaker()
elif speaker_type == "role":
# 从 JSON 文件中读取数据
with open('./slct_voice_240605.json', 'r', encoding='utf-8') as json_file:
slct_idx_loaded = json.load(json_file)
# 将包含 Tensor 数据的部分转换回 Tensor 对象
for key in slct_idx_loaded:
tensor_list = slct_idx_loaded[key]["tensor"]
slct_idx_loaded[key]["tensor"] = torch.tensor(tensor_list)
# 将音色 tensor 打包进params_infer_code,固定使用此音色发音,调低temperature
rnd_spk_emb = slct_idx_loaded[roleid]["tensor"]
# temperature = 0.001
elif speaker_type == "pt":
print(pt_file)
rnd_spk_emb = torch.load(pt_file)
print(rnd_spk_emb.shape)
if rnd_spk_emb.shape != (768,):
raise ValueError("维度应为 768。")
else:
raise ValueError(f"Invalid speaker_type: {speaker_type}. ")
params_infer_code = {
'spk_emb': rnd_spk_emb,
'prompt': f'[speed_{speed}]',
'top_P': top_P,
'top_K': top_K,
'temperature': temperature
}
params_refine_text = {
'prompt': refine_text_prompt,
'top_P': top_P,
'top_K': top_K,
'temperature': temperature
}
all_wavs = []
start_time = time.time()
total = len(texts)
flag = 0
if not cur_tqdm:
cur_tqdm = tqdm
if re.search(r'\[uv_break\]|\[laugh\]', ''.join(texts)) is not None:
if not skip_refine_text:
print("Detected [uv_break] or [laugh] in text, skipping refine_text")
skip_refine_text = True
for batch in cur_tqdm(batch_split(texts, batch_size), desc=f"Inferring audio for seed={seed}"):
flag += len(batch)
_params_infer_code = {**params_infer_code}
wavs = chat.infer(batch, params_infer_code=_params_infer_code, params_refine_text=params_refine_text,
use_decoder=True, skip_refine_text=skip_refine_text)
all_wavs.extend(wavs)
clear_cuda_cache()
if skip_save:
return all_wavs
combined_audio = combine_audio(all_wavs)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Saving audio for seed {seed}, took {elapsed_time:.2f}s")
timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S')
wav_filename = f"chattts-[seed_{seed}][speed_{speed}]{refine_text_prompt}[{timestamp}].wav"
return save_audio(wav_filename, combined_audio)
def generate_refine_text(chat, seed, text, refine_text_prompt, temperature=DEFAULT_TEMPERATURE,
top_P=DEFAULT_TOP_P, top_K=DEFAULT_TOP_K):
if seed in [None, -1, 0, "", "random"]:
seed = np.random.randint(0, 9999)
deterministic(seed)
params_refine_text = {
'prompt': refine_text_prompt,
'top_P': top_P,
'top_K': top_K,
'temperature': temperature
}
print('params_refine_text:', text)
print('refine_text_prompt:', refine_text_prompt)
refine_text = chat.infer(text, params_refine_text=params_refine_text, refine_text_only=True, skip_refine_text=False)
print('refine_text:', refine_text)
return refine_text
def tts(chat, text_file, seed, speed, oral, laugh, bk, seg, batch, progres=None):
"""
Text-to-Speech
:param chat: ChatTTS model
:param text_file: Text file or string
:param seed: Seed
:param speed: Speed
:param oral: Oral
:param laugh: Laugh
:param bk:
:param seg:
:param batch:
:param progres:
:return:
"""
from utils import read_long_text, split_text
if os.path.isfile(text_file):
content = read_long_text(text_file)
elif isinstance(text_file, str):
content = text_file
texts = split_text(content, min_length=seg)
print(texts)
# exit()
if oral < 0 or oral > 9 or laugh < 0 or laugh > 2 or bk < 0 or bk > 7:
raise ValueError("oral_(0-9), laugh_(0-2), break_(0-7) out of range")
refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]"
return generate_audio_for_seed(chat, seed, texts, batch, speed, refine_text_prompt)
|