File size: 10,195 Bytes
130e301
 
 
fcf7ece
 
 
 
 
130e301
 
 
fcf7ece
 
 
 
 
130e301
fcf7ece
130e301
 
 
 
 
fcf7ece
bf80208
 
 
 
 
e503dc6
 
 
 
 
 
 
 
 
 
 
 
130e301
e559a2a
 
 
 
 
fcf7ece
130e301
fcf7ece
 
bf80208
 
 
 
 
 
fcf7ece
130e301
fcf7ece
772f068
 
e503dc6
772f068
 
6c1b4fd
 
 
 
 
 
 
 
 
e503dc6
130e301
162b5de
e503dc6
fcf7ece
e503dc6
fcf7ece
 
e503dc6
fcf7ece
 
 
 
 
 
 
 
 
e503dc6
130e301
fcf7ece
 
e503dc6
 
 
67e8e08
e503dc6
fcf7ece
 
e503dc6
fcf7ece
e503dc6
 
 
fcf7ece
e503dc6
fcf7ece
e559a2a
fcf7ece
e503dc6
fcf7ece
 
e559a2a
fcf7ece
 
 
 
 
 
e559a2a
 
fcf7ece
 
 
 
 
 
 
e559a2a
fcf7ece
 
 
 
 
 
 
 
 
 
 
 
 
 
e559a2a
fcf7ece
 
 
e559a2a
 
fcf7ece
 
 
bf80208
 
 
 
 
 
 
fcf7ece
 
 
 
 
3b2e69c
fcf7ece
130e301
e559a2a
fcf7ece
 
e503dc6
bf80208
 
e559a2a
bf80208
fcf7ece
 
bf80208
fcf7ece
bf80208
fcf7ece
bf80208
fcf7ece
 
 
bf80208
fcf7ece
 
 
e559a2a
 
 
 
fcf7ece
e559a2a
fcf7ece
 
 
e559a2a
e503dc6
 
e559a2a
162b5de
e559a2a
 
162b5de
e559a2a
e503dc6
162b5de
 
 
fcf7ece
e559a2a
fcf7ece
e503dc6
 
 
e559a2a
 
e503dc6
 
 
 
 
 
 
 
 
 
 
 
 
162b5de
e559a2a
 
 
e503dc6
 
 
 
fcf7ece
6c1b4fd
 
e503dc6
 
 
 
 
 
 
e559a2a
e503dc6
e559a2a
 
 
 
e503dc6
 
 
 
 
8202ce8
e503dc6
8202ce8
e503dc6
fcf7ece
6c1b4fd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
# Imports
import gradio as gr
import spaces
import re
import torch
import torchaudio
import numpy as np
import tempfile
import click
import soundfile as sf

from einops import rearrange
from vocos import Vocos
from pydub import AudioSegment, silence
from model import CFM, UNetT, DiT, MMDiT
from cached_path import cached_path
from model.utils import (load_checkpoint, get_tokenizer, convert_char_to_pinyin, save_spectrogram)

# Pre-Initialize
DEVICE = "auto"
if DEVICE == "auto":
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[SYSTEM] | Using {DEVICE} type compute device.")

target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
ode_method = "euler"

def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
    ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
    vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
    model = CFM(
        transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
        mel_spec_kwargs=dict(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length),
        odeint_kwargs=dict(method=ode_method),
        vocab_char_map=vocab_char_map,
    ).to(DEVICE)
    model = load_checkpoint(model, ckpt_path, DEVICE, use_ema = True)
    return model
    
# Variables
DEFAULT_MODEL = "F5"
DEFAULT_REMOVE_SILENCES = True
DEFAULT_SPEED = 1
DEFAULT_CROSS_FADE = 0.15

target_rms = 0.1
nfe_step = 32
cfg_strength = 2.0
sway_sampling_coef = -1.0

input_silence_offset = 14
input_silence_min_len = 500

silence_offset = 14
silence_min_len = 200

vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")

F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
F5TTS_ema_model = load_model("F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)

E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
E2TTS_ema_model = load_model("E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)

css = '''
.gradio-container{max-width: 560px !important}
h1{text-align:center}
footer {
    visibility: hidden
}
'''

# Functions
@spaces.GPU(duration=30)
def infer_batch(input_batches, reference_audio, reference_input, model_choice=DEFAULT_MODEL, remove_silences=DEFAULT_REMOVE_SILENCES, speed=DEFAULT_SPEED, cross_fade=DEFAULT_CROSS_FADE):
    if model_choice == "F5":
        ema_model = F5TTS_ema_model
    elif model_choice == "E2":
        ema_model = E2TTS_ema_model

    audio, sr = reference_audio
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)

    rms = torch.sqrt(torch.mean(torch.square(audio)))
    if rms < target_rms:
        audio = audio * target_rms / rms
    if sr != target_sample_rate:
        resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
        audio = resampler(audio)
        
    audio = audio.to(DEVICE)
    generated_waves = []

    if len(reference_input[-1].encode('utf-8')) == 1:
        reference_input = reference_input + " "
        
    for i, input in enumerate(input_batches):
        text_list = [reference_input + input]
        final_text_list = convert_char_to_pinyin(text_list)

        reference_audio_len = audio.shape[-1] // hop_length
        zh_pause_punc = r"。,、;:?!"
        reference_input_len = len(reference_input.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, reference_input))
        input_len = len(input.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, input))
        duration = reference_audio_len + int(reference_audio_len / reference_input_len * input_len / speed)

        # Inference
        with torch.inference_mode():
            generated, _ = ema_model.sample(cond=audio, text=final_text_list, duration=duration, steps=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef)

        generated = generated[:, reference_audio_len:, :]
        generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
        generated_wave = vocos.decode(generated_mel_spec.cpu())
        
        if rms < target_rms:
            generated_wave = generated_wave * rms / target_rms

        generated_wave = generated_wave.squeeze().cpu().numpy()
        generated_waves.append(generated_wave)

    # Handle combining generated waves with cross-fading
    if cross_fade <= 0:
        final_wave = np.concatenate(generated_waves)
    else:
        final_wave = generated_waves[0]
        for i in range(1, len(generated_waves)):
            prev_wave = final_wave
            next_wave = generated_waves[i]

            cross_fade_samples = int(cross_fade * target_sample_rate)
            cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))

            if cross_fade_samples <= 0:
                final_wave = np.concatenate([prev_wave, next_wave])
                continue

            prev_overlap = prev_wave[-cross_fade_samples:]
            next_overlap = next_wave[:cross_fade_samples]

            fade_out = np.linspace(1, 0, cross_fade_samples)
            fade_in = np.linspace(0, 1, cross_fade_samples)

            cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in

            new_wave = np.concatenate([prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]])

            final_wave = new_wave

    # Handle removing silences
    if remove_silences:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
            sf.write(f.name, final_wave, target_sample_rate)
            aseg = AudioSegment.from_file(f.name)
                
            non_silent_segs = silence.split_on_silence(aseg, min_silence_len=silence_min_len, silence_thresh=aseg.dBFS - silence_offset, keep_silence=250)
            
            non_silent_wave = AudioSegment.empty()
            for seg in non_silent_segs:
                non_silent_wave += seg
    
            aseg = non_silent_wave
            aseg.export(f.name, format="wav")
            final_wave, _ = torchaudio.load(f.name)
        final_wave = final_wave.squeeze().cpu().numpy()

    return (target_sample_rate, final_wave)

@spaces.GPU(duration=30)
def infer(input, reference_audio, reference_input, model_choice=DEFAULT_MODEL, remove_silences=DEFAULT_REMOVE_SILENCES, speed=DEFAULT_SPEED, cross_fade=DEFAULT_CROSS_FADE):

    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
        aseg = AudioSegment.from_file(reference_audio)
    
        non_silent_segs = silence.split_on_silence(aseg, min_silence_len=input_silence_min_len, silence_thresh=aseg.dBFS - input_silence_offset, keep_silence=250)
        
        non_silent_wave = AudioSegment.empty()
        for non_silent_seg in non_silent_segs:
            non_silent_wave += non_silent_seg
    
        aseg = non_silent_wave
    
        audio_duration = len(aseg)
    
        if audio_duration > 15000:
            gr.Warning("Audio is over 15s, clipping to only first 15s.")
            aseg = aseg[:15000]
    
        aseg.export(f.name, format="wav")
        ref_audio = f.name

    # Ensure it ends with period.
    if not reference_input.endswith(". "):
        if reference_input.endswith("."):
            reference_input += " "
        else:
            reference_input += ". "

    audio, sr = torchaudio.load(ref_audio)

    # Split input into chunks
    max_chars = int(len(reference_input.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
    input_batches = chunk_text(input, max_chars=max_chars)
    
    print("--------------------------------------------- INPUT")
    print(f"Input: {input}")
    print(f"Reference Input: {reference_input}")
    print(f"Batch Inputs:")
    
    for i, batch_text in enumerate(input_batches):
        print(f" {i}: ", batch_text)
        
    print("---------------------------------------------------")
    
    return infer_batch(input_batches, (audio, sr), reference_input, model_choice, remove_silences, speed, cross_fade)

def chunk_text(text, max_chars=135):
    chunks = []
    current_chunk = ""
    
    # Split input into sentences with punctuations
    sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)

    for sentence in sentences:
        if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
            current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
        else:
            if current_chunk:
                chunks.append(current_chunk.strip())
            current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence

    if current_chunk:
        chunks.append(current_chunk.strip())

    print("-------------------------------------------- CHUNKS")
    print(chunks)
    print("---------------------------------------------------")
    
    return chunks

def cloud():
    print("[CLOUD] | Space maintained.")

# Initialize
with gr.Blocks(css=css) as main:
    with gr.Column():
        gr.Markdown("🪄 Speak text to audio.")

    with gr.Column():
        input = gr.Textbox(lines=1, value="", label="Input")
        reference_audio = gr.Audio(sources="upload", type="filepath", label="Reference Audio")
        reference_input = gr.Textbox(lines=1, value="", label="Reference Text")
        model_choice = gr.Radio(["F5", "E2"], label="TTS Model", value=DEFAULT_MODEL)
    
        remove_silences = gr.Checkbox(value=DEFAULT_REMOVE_SILENCES, label="Remove Silences")
        
        speed = gr.Slider(minimum=0.3, maximum=2.0, value=DEFAULT_SPEED, step=0.1, label="Speed")
        cross_fade = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_CROSS_FADE, step=0.01, label="Audio Cross-Fade Duration Between Sentences")
                
        submit = gr.Button("▶")
        maintain = gr.Button("☁️")

    with gr.Column():
        output = gr.Audio(label="Output")

    submit.click(infer, inputs=[input, reference_audio, reference_input, model_choice, remove_silences, speed, cross_fade], outputs=output, queue=False)
    maintain.click(cloud, inputs=[], outputs=[], queue=False)

main.launch(show_api=True)