|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import sys |
|
|
|
|
|
os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" |
|
|
from importlib.resources import files |
|
|
import matplotlib |
|
|
|
|
|
matplotlib.use("Agg") |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torchaudio |
|
|
import tqdm |
|
|
import logging |
|
|
|
|
|
|
|
|
from f5_tts.model.utils import ( |
|
|
get_tokenizer, |
|
|
convert_char_to_pinyin, |
|
|
) |
|
|
from f5_tts.model.modules import MelSpec |
|
|
|
|
|
|
|
|
device = ( |
|
|
"cuda" |
|
|
if torch.cuda.is_available() |
|
|
else ( |
|
|
"xpu" |
|
|
if torch.xpu.is_available() |
|
|
else "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
target_sample_rate = 24000 |
|
|
n_mel_channels = 100 |
|
|
hop_length = 256 |
|
|
win_length = 1024 |
|
|
n_fft = 1024 |
|
|
mel_spec_type = "vocos" |
|
|
target_rms = 0.1 |
|
|
cross_fade_duration = 0.15 |
|
|
ode_method = "euler" |
|
|
nfe_step = 32 |
|
|
cfg_strength = 2.0 |
|
|
sway_sampling_coef = -1.0 |
|
|
speed = 1.0 |
|
|
fix_duration = None |
|
|
seed = 3214 |
|
|
|
|
|
|
|
|
def chunk_infer_batch_process( |
|
|
ref_audio, |
|
|
ref_text, |
|
|
gen_text_batches, |
|
|
model_obj, |
|
|
vocoder, |
|
|
mel_spec_type="vocos", |
|
|
progress=tqdm, |
|
|
target_rms=0.1, |
|
|
cross_fade_duration=0.15, |
|
|
nfe_step=32, |
|
|
cfg_strength=2.0, |
|
|
sway_sampling_coef=-1.0, |
|
|
speed=1.0, |
|
|
fix_duration=None, |
|
|
device=None, |
|
|
chunk_cond_proportion=0.5, |
|
|
chunk_look_ahead=0, |
|
|
max_ref_duration=4.5, |
|
|
ref_head_cut=False, |
|
|
): |
|
|
audio, sr = ref_audio |
|
|
if audio.shape[0] > 1: |
|
|
audio = torch.mean(audio, dim=0, keepdim=True) |
|
|
|
|
|
rms = torch.sqrt(torch.mean(torch.square(audio))) |
|
|
if rms < target_rms: |
|
|
audio = audio * target_rms / rms |
|
|
if sr != target_sample_rate: |
|
|
resampler = torchaudio.transforms.Resample(sr, target_sample_rate) |
|
|
audio = resampler(audio) |
|
|
|
|
|
logging.info( |
|
|
"audio shape:" + str(audio.shape) + "; ref_text shape:" + str(len(ref_text)) |
|
|
) |
|
|
ref_duration = audio.shape[1] / target_sample_rate |
|
|
if ref_duration > max_ref_duration: |
|
|
reserved_ref_audio_len = round(max_ref_duration * target_sample_rate) |
|
|
if ref_head_cut: |
|
|
logging.info(f"Using the first {max_ref_duration} seconds as ref audio") |
|
|
audio = audio[:, :reserved_ref_audio_len] |
|
|
ref_text = ref_text[ |
|
|
: round(max_ref_duration * len(ref_text) / ref_duration) |
|
|
] |
|
|
else: |
|
|
logging.info(f"Using the last {max_ref_duration} seconds as ref audio") |
|
|
audio = audio[:, -reserved_ref_audio_len:] |
|
|
ref_text = ref_text[ |
|
|
-round(max_ref_duration * len(ref_text) / ref_duration) : |
|
|
] |
|
|
logging.info( |
|
|
"audio shape:" + str(audio.shape) + "; ref_text shape:" + str(len(ref_text)) |
|
|
) |
|
|
audio = audio.to(device) |
|
|
|
|
|
generated_waves = [] |
|
|
spectrograms = [] |
|
|
|
|
|
|
|
|
mel_spec_module = MelSpec(mel_spec_type=mel_spec_type) |
|
|
fixed_ref_audio_mel_spec = mel_spec_module(audio) |
|
|
|
|
|
fixed_ref_audio_mel_cond = fixed_ref_audio_mel_spec.permute(0, 2, 1) |
|
|
fixed_ref_audio_len = fixed_ref_audio_mel_cond.shape[1] |
|
|
|
|
|
assert isinstance(ref_text, list) is True |
|
|
fixed_ref_text = ref_text[:] |
|
|
fixed_ref_text_len = len(fixed_ref_text) |
|
|
|
|
|
mel_cond = fixed_ref_audio_mel_cond.clone() |
|
|
|
|
|
prev_chunk_audio_len = 0 |
|
|
|
|
|
for i, gen_text in enumerate(progress.tqdm(gen_text_batches)): |
|
|
|
|
|
final_text_list = [ref_text + gen_text] |
|
|
logging.info(f"final_text_list: {final_text_list}") |
|
|
|
|
|
if fix_duration is not None: |
|
|
duration = int(fix_duration * target_sample_rate / hop_length) |
|
|
else: |
|
|
|
|
|
assert isinstance(gen_text, list) is True |
|
|
gen_text_len = len(gen_text) |
|
|
duration = ( |
|
|
fixed_ref_audio_len |
|
|
+ prev_chunk_audio_len |
|
|
+ int(fixed_ref_audio_len / fixed_ref_text_len * gen_text_len / speed) |
|
|
) |
|
|
logging.info(f"Duration: {duration}") |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
logging.info(f"generate with nfe_step:{nfe_step}, cfg_strength:{cfg_strength}, sway_sampling_coef:{sway_sampling_coef}") |
|
|
|
|
|
generated, _ = model_obj.sample( |
|
|
cond=mel_cond, |
|
|
text=final_text_list, |
|
|
duration=duration, |
|
|
steps=nfe_step, |
|
|
cfg_strength=cfg_strength, |
|
|
sway_sampling_coef=sway_sampling_coef, |
|
|
seed=seed, |
|
|
) |
|
|
generated = generated.to(torch.float32) |
|
|
logging.info("gen mel shape: " + str(generated.shape)) |
|
|
|
|
|
|
|
|
stripped_generated = generated[ |
|
|
:, (fixed_ref_audio_len + prev_chunk_audio_len) :, : |
|
|
] |
|
|
|
|
|
|
|
|
look_ahead_mel_len = round( |
|
|
(duration - fixed_ref_audio_len - prev_chunk_audio_len) |
|
|
* chunk_look_ahead |
|
|
/ len(gen_text) |
|
|
) |
|
|
if look_ahead_mel_len > 0 and i < len(gen_text_batches) - 1: |
|
|
stripped_generated_without_look_ahead = stripped_generated[ |
|
|
:, |
|
|
:(-look_ahead_mel_len), |
|
|
:, |
|
|
] |
|
|
|
|
|
gen_text = gen_text[:-chunk_look_ahead] |
|
|
else: |
|
|
stripped_generated_without_look_ahead = stripped_generated |
|
|
logging.info("gen mel shape: %s, gen text len: %d" % (str(stripped_generated_without_look_ahead.shape), len(gen_text))) |
|
|
|
|
|
|
|
|
|
|
|
prev_chunk_audio_len = stripped_generated_without_look_ahead.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
generated_mel_spec = stripped_generated_without_look_ahead.permute(0, 2, 1) |
|
|
|
|
|
|
|
|
if mel_spec_type == "vocos": |
|
|
generated_wave = vocoder.decode(generated_mel_spec) |
|
|
elif mel_spec_type == "bigvgan": |
|
|
generated_wave = vocoder(generated_mel_spec) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if rms < target_rms: |
|
|
generated_wave = generated_wave * rms / target_rms |
|
|
|
|
|
logging.info("gen wav shape: " + str(generated_wave.shape)) |
|
|
|
|
|
|
|
|
|
|
|
generated_wave = generated_wave.squeeze().cpu().numpy() |
|
|
|
|
|
generated_waves.append(generated_wave) |
|
|
spectrograms.append(generated_mel_spec[0].cpu().numpy()) |
|
|
|
|
|
prev_chunk_cond_audio_len = round(chunk_cond_proportion * prev_chunk_audio_len) |
|
|
if prev_chunk_audio_len > prev_chunk_cond_audio_len: |
|
|
gen_text_cond = gen_text[-round(chunk_cond_proportion * len(gen_text)):] |
|
|
prev_chunk_audio_len = prev_chunk_cond_audio_len |
|
|
generated_cond = stripped_generated_without_look_ahead[:, (-prev_chunk_audio_len):, :] |
|
|
else: |
|
|
generated_cond = stripped_generated_without_look_ahead |
|
|
gen_text_cond = gen_text |
|
|
|
|
|
logging.info("gen text cond len: %d, gen mel cond len: %d" % (len(gen_text_cond), len(generated_cond))) |
|
|
|
|
|
ref_text = fixed_ref_text + gen_text_cond |
|
|
mel_cond = torch.cat([fixed_ref_audio_mel_cond, generated_cond], dim=1) |
|
|
|
|
|
|
|
|
if cross_fade_duration <= 0: |
|
|
|
|
|
logging.info("simply concatenate") |
|
|
final_wave = np.concatenate(generated_waves) |
|
|
else: |
|
|
final_wave = generated_waves[0] |
|
|
for i in range(1, len(generated_waves)): |
|
|
prev_wave = final_wave |
|
|
next_wave = generated_waves[i] |
|
|
|
|
|
|
|
|
cross_fade_samples = int(cross_fade_duration * target_sample_rate) |
|
|
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) |
|
|
|
|
|
if cross_fade_samples <= 0: |
|
|
|
|
|
final_wave = np.concatenate([prev_wave, next_wave]) |
|
|
continue |
|
|
|
|
|
|
|
|
prev_overlap = prev_wave[-cross_fade_samples:] |
|
|
next_overlap = next_wave[:cross_fade_samples] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wave_window = np.hamming(2 * cross_fade_samples) |
|
|
fade_out = wave_window[cross_fade_samples:] |
|
|
fade_in = wave_window[:cross_fade_samples] |
|
|
|
|
|
|
|
|
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in |
|
|
|
|
|
|
|
|
new_wave = np.concatenate( |
|
|
[ |
|
|
prev_wave[:-cross_fade_samples], |
|
|
cross_faded_overlap, |
|
|
next_wave[cross_fade_samples:], |
|
|
] |
|
|
) |
|
|
|
|
|
final_wave = new_wave |
|
|
|
|
|
|
|
|
combined_spectrogram = np.concatenate(spectrograms, axis=1) |
|
|
|
|
|
return final_wave, target_sample_rate, combined_spectrogram |
|
|
|