File size: 5,887 Bytes
003d053
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import datetime
import json
import os
import re
import time

import numpy as np
import torch
from tqdm import tqdm

import ChatTTS
from config import DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_TOP_K


def load_chat_tts_model(source='huggingface', force_redownload=False, local_path=None):
    """
    Load ChatTTS model
    :param source:
    :param force_redownload:
    :param local_path:
    :return:
    """
    print("Loading ChatTTS model...")
    chat = ChatTTS.Chat()
    chat.load_models(source=source, force_redownload=force_redownload, custom_path=local_path, compile=False)
    return chat


def clear_cuda_cache():
    """
    Clear CUDA cache
    :return:
    """
    torch.cuda.empty_cache()


def deterministic(seed=0):
    """
    Set random seed for reproducibility
    :param seed:
    :return:
    """
    # ref: https://github.com/Jackiexiao/ChatTTS-api-ui-docker/blob/main/api.py#L27
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def generate_audio_for_seed(chat, seed, texts, batch_size, speed, refine_text_prompt, roleid=None,
                            temperature=DEFAULT_TEMPERATURE,
                            top_P=DEFAULT_TOP_P, top_K=DEFAULT_TOP_K, cur_tqdm=None, skip_save=False,
                            skip_refine_text=False, speaker_type="seed", pt_file=None):
    from utils import combine_audio, save_audio, batch_split
    print(f"speaker_type: {speaker_type}")
    if speaker_type == "seed":
        if seed in [None, -1, 0, "", "random"]:
            seed = np.random.randint(0, 9999)
        deterministic(seed)
        rnd_spk_emb = chat.sample_random_speaker()
    elif speaker_type == "role":
        # 从 JSON 文件中读取数据
        with open('./slct_voice_240605.json', 'r', encoding='utf-8') as json_file:
            slct_idx_loaded = json.load(json_file)
        # 将包含 Tensor 数据的部分转换回 Tensor 对象
        for key in slct_idx_loaded:
            tensor_list = slct_idx_loaded[key]["tensor"]
            slct_idx_loaded[key]["tensor"] = torch.tensor(tensor_list)
        # 将音色 tensor 打包进params_infer_code,固定使用此音色发音,调低temperature
        rnd_spk_emb = slct_idx_loaded[roleid]["tensor"]
        # temperature = 0.001
    elif speaker_type == "pt":
        print(pt_file)
        rnd_spk_emb = torch.load(pt_file)
        print(rnd_spk_emb.shape)
        if rnd_spk_emb.shape != (768,):
            raise ValueError("维度应为 768。")
    else:
        raise ValueError(f"Invalid speaker_type: {speaker_type}. ")

    params_infer_code = {
        'spk_emb': rnd_spk_emb,
        'prompt': f'[speed_{speed}]',
        'top_P': top_P,
        'top_K': top_K,
        'temperature': temperature
    }
    params_refine_text = {
        'prompt': refine_text_prompt,
        'top_P': top_P,
        'top_K': top_K,
        'temperature': temperature
    }
    all_wavs = []
    start_time = time.time()
    total = len(texts)
    flag = 0
    if not cur_tqdm:
        cur_tqdm = tqdm

    if re.search(r'\[uv_break\]|\[laugh\]', ''.join(texts)) is not None:
        if not skip_refine_text:
            print("Detected [uv_break] or [laugh] in text, skipping refine_text")
        skip_refine_text = True

    for batch in cur_tqdm(batch_split(texts, batch_size), desc=f"Inferring audio for seed={seed}"):
        flag += len(batch)
        _params_infer_code = {**params_infer_code}
        wavs = chat.infer(batch, params_infer_code=_params_infer_code, params_refine_text=params_refine_text,
                          use_decoder=True, skip_refine_text=skip_refine_text)
        all_wavs.extend(wavs)
        clear_cuda_cache()
    if skip_save:
        return all_wavs
    combined_audio = combine_audio(all_wavs)
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Saving audio for seed {seed}, took {elapsed_time:.2f}s")
    timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S')
    wav_filename = f"chattts-[seed_{seed}][speed_{speed}]{refine_text_prompt}[{timestamp}].wav"
    return save_audio(wav_filename, combined_audio)


def generate_refine_text(chat, seed, text, refine_text_prompt, temperature=DEFAULT_TEMPERATURE,
                         top_P=DEFAULT_TOP_P, top_K=DEFAULT_TOP_K):
    if seed in [None, -1, 0, "", "random"]:
        seed = np.random.randint(0, 9999)

    deterministic(seed)

    params_refine_text = {
        'prompt': refine_text_prompt,
        'top_P': top_P,
        'top_K': top_K,
        'temperature': temperature
    }
    print('params_refine_text:', text)
    print('refine_text_prompt:', refine_text_prompt)
    refine_text = chat.infer(text, params_refine_text=params_refine_text, refine_text_only=True, skip_refine_text=False)
    print('refine_text:', refine_text)
    return refine_text


def tts(chat, text_file, seed, speed, oral, laugh, bk, seg, batch, progres=None):
    """
    Text-to-Speech
    :param chat:  ChatTTS model
    :param text_file:  Text file or string
    :param seed:  Seed
    :param speed:   Speed
    :param oral:  Oral
    :param laugh:  Laugh
    :param bk:
    :param seg:
    :param batch:
    :param progres:
    :return:
    """
    from utils import read_long_text, split_text

    if os.path.isfile(text_file):
        content = read_long_text(text_file)
    elif isinstance(text_file, str):
        content = text_file
    texts = split_text(content, min_length=seg)

    print(texts)
    # exit()

    if oral < 0 or oral > 9 or laugh < 0 or laugh > 2 or bk < 0 or bk > 7:
        raise ValueError("oral_(0-9), laugh_(0-2), break_(0-7) out of range")

    refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]"
    return generate_audio_for_seed(chat, seed, texts, batch, speed, refine_text_prompt)