DSTK / semantic_detokenizer /utils_infer.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Dehua Tao,
# Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
from importlib.resources import files
import matplotlib
matplotlib.use("Agg")
import numpy as np
import torch
import torchaudio
import tqdm
import logging
# torch.set_printoptions(profile="full")
# from f5_tts.model import CFM
from f5_tts.model.utils import (
get_tokenizer,
convert_char_to_pinyin,
)
from f5_tts.model.modules import MelSpec
device = (
"cuda"
if torch.cuda.is_available()
else (
"xpu"
if torch.xpu.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
)
# -----------------------------------------
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos"
target_rms = 0.1
cross_fade_duration = 0.15
ode_method = "euler"
nfe_step = 32 # 16, 32
cfg_strength = 2.0
sway_sampling_coef = -1.0
speed = 1.0
fix_duration = None
seed = 3214
# -----------------------------------------
def chunk_infer_batch_process(
ref_audio,
ref_text,
gen_text_batches,
model_obj,
vocoder,
mel_spec_type="vocos",
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
nfe_step=32,
cfg_strength=2.0,
sway_sampling_coef=-1.0,
speed=1.0,
fix_duration=None,
device=None,
chunk_cond_proportion=0.5,
chunk_look_ahead=0,
max_ref_duration=4.5,
ref_head_cut=False,
):
audio, sr = ref_audio
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
audio = audio * target_rms / rms
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
logging.info(
"audio shape:" + str(audio.shape) + "; ref_text shape:" + str(len(ref_text))
)
ref_duration = audio.shape[1] / target_sample_rate
if ref_duration > max_ref_duration:
reserved_ref_audio_len = round(max_ref_duration * target_sample_rate)
if ref_head_cut:
logging.info(f"Using the first {max_ref_duration} seconds as ref audio")
audio = audio[:, :reserved_ref_audio_len]
ref_text = ref_text[
: round(max_ref_duration * len(ref_text) / ref_duration)
]
else:
logging.info(f"Using the last {max_ref_duration} seconds as ref audio")
audio = audio[:, -reserved_ref_audio_len:]
ref_text = ref_text[
-round(max_ref_duration * len(ref_text) / ref_duration) :
]
logging.info(
"audio shape:" + str(audio.shape) + "; ref_text shape:" + str(len(ref_text))
)
audio = audio.to(device)
generated_waves = []
spectrograms = []
# fixed_ref_audio_len = audio.shape[-1] // hop_length
mel_spec_module = MelSpec(mel_spec_type=mel_spec_type)
fixed_ref_audio_mel_spec = mel_spec_module(audio)
# The last dim should be num_channels
fixed_ref_audio_mel_cond = fixed_ref_audio_mel_spec.permute(0, 2, 1)
fixed_ref_audio_len = fixed_ref_audio_mel_cond.shape[1]
assert isinstance(ref_text, list) is True
fixed_ref_text = ref_text[:]
fixed_ref_text_len = len(fixed_ref_text)
mel_cond = fixed_ref_audio_mel_cond.clone()
prev_chunk_audio_len = 0
for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
# Prepare the text
final_text_list = [ref_text + gen_text]
logging.info(f"final_text_list: {final_text_list}")
if fix_duration is not None:
duration = int(fix_duration * target_sample_rate / hop_length)
else:
# Calculate duration
assert isinstance(gen_text, list) is True
gen_text_len = len(gen_text)
duration = (
fixed_ref_audio_len
+ prev_chunk_audio_len
+ int(fixed_ref_audio_len / fixed_ref_text_len * gen_text_len / speed)
)
logging.info(f"Duration: {duration}")
# inference
with torch.inference_mode():
logging.info(f"generate with nfe_step:{nfe_step}, cfg_strength:{cfg_strength}, sway_sampling_coef:{sway_sampling_coef}")
# logging.info("mel_cond: " + str(mel_cond))
generated, _ = model_obj.sample(
cond=mel_cond,
text=final_text_list,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
seed=seed,
)
generated = generated.to(torch.float32)
logging.info("gen mel shape: " + str(generated.shape))
# try to remove condition mel
stripped_generated = generated[
:, (fixed_ref_audio_len + prev_chunk_audio_len) :, :
]
# remove chunk_look_ahead from the tail of each generated mel
look_ahead_mel_len = round(
(duration - fixed_ref_audio_len - prev_chunk_audio_len)
* chunk_look_ahead
/ len(gen_text)
)
if look_ahead_mel_len > 0 and i < len(gen_text_batches) - 1:
stripped_generated_without_look_ahead = stripped_generated[
:,
:(-look_ahead_mel_len),
:,
]
# try to remove the chunk_look_ahead from the tail of gen_text
gen_text = gen_text[:-chunk_look_ahead]
else:
stripped_generated_without_look_ahead = stripped_generated
logging.info("gen mel shape: %s, gen text len: %d" % (str(stripped_generated_without_look_ahead.shape), len(gen_text)))
# logging.info("generated mel: " + str(generated))
# prev chunk audio len is the length without fixed condition and chunk look ahead
prev_chunk_audio_len = stripped_generated_without_look_ahead.shape[1]
# prev_chunk_audio_len_with_look_ahead = stripped_generated.shape[1]
# generate wav with look ahead
generated_mel_spec = stripped_generated_without_look_ahead.permute(0, 2, 1)
# generated_mel_spec = stripped_generated.permute(0, 2, 1)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(generated_mel_spec)
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(generated_mel_spec)
# strip look ahead wav from generated wav
# if look_ahead_mel_len > 0 and i < len(gen_text_batches) - 1:
# look_ahead_wav_len = round(
# look_ahead_mel_len
# * generated_wave.shape[1]
# / prev_chunk_audio_len_with_look_ahead
# )
# generated_wave = generated_wave[:, :-look_ahead_wav_len]
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
logging.info("gen wav shape: " + str(generated_wave.shape))
# logging.info("generated wav: " + str(generated_wave))
# wav -> numpy
generated_wave = generated_wave.squeeze().cpu().numpy()
generated_waves.append(generated_wave)
spectrograms.append(generated_mel_spec[0].cpu().numpy())
prev_chunk_cond_audio_len = round(chunk_cond_proportion * prev_chunk_audio_len)
if prev_chunk_audio_len > prev_chunk_cond_audio_len:
gen_text_cond = gen_text[-round(chunk_cond_proportion * len(gen_text)):]
prev_chunk_audio_len = prev_chunk_cond_audio_len
generated_cond = stripped_generated_without_look_ahead[:, (-prev_chunk_audio_len):, :]
else:
generated_cond = stripped_generated_without_look_ahead
gen_text_cond = gen_text
logging.info("gen text cond len: %d, gen mel cond len: %d" % (len(gen_text_cond), len(generated_cond)))
ref_text = fixed_ref_text + gen_text_cond
mel_cond = torch.cat([fixed_ref_audio_mel_cond, generated_cond], dim=1)
# Combine all generated waves with cross-fading
if cross_fade_duration <= 0:
# Simply concatenate
logging.info("simply concatenate")
final_wave = np.concatenate(generated_waves)
else:
final_wave = generated_waves[0]
for i in range(1, len(generated_waves)):
prev_wave = final_wave
next_wave = generated_waves[i]
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
if cross_fade_samples <= 0:
# No overlap possible, concatenate
final_wave = np.concatenate([prev_wave, next_wave])
continue
# Overlapping parts
prev_overlap = prev_wave[-cross_fade_samples:]
next_overlap = next_wave[:cross_fade_samples]
# Fade out and fade in
# fade_out = np.linspace(1, 0, cross_fade_samples)
# fade_in = np.linspace(0, 1, cross_fade_samples)
wave_window = np.hamming(2 * cross_fade_samples)
fade_out = wave_window[cross_fade_samples:]
fade_in = wave_window[:cross_fade_samples]
# Cross-faded overlap
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
# Combine
new_wave = np.concatenate(
[
prev_wave[:-cross_fade_samples],
cross_faded_overlap,
next_wave[cross_fade_samples:],
]
)
final_wave = new_wave
# Create a combined spectrogram
combined_spectrogram = np.concatenate(spectrograms, axis=1)
return final_wave, target_sample_rate, combined_spectrogram