DSTK / semantic_detokenizer /chunk_infer.py
gooorillax's picture
add example codes in README, refine README, add DSTK to whitelist
12da045
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Dehua Tao,
# Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from datetime import datetime
from importlib.resources import files
from pathlib import Path
import sys
import tqdm
import soundfile as sf
import time
from omegaconf import OmegaConf
import torchaudio
import torch.multiprocessing as mp
from f5_tts.infer.utils_infer import (
load_model,
load_vocoder,
remove_silence_for_generated_wav,
)
sys.path.append(str(Path(__file__).parent))
from utils_infer import (
mel_spec_type,
target_rms,
nfe_step,
cfg_strength,
sway_sampling_coef,
speed,
fix_duration,
chunk_infer_batch_process
)
from model.cadit import CADiT
import logging
console_format = logging.Formatter(
"[%(asctime)s][%(filename)s:%(levelname)s][%(process)d:%(threadName)s]%(message)s"
)
console_handler = logging.StreamHandler()
console_handler.setFormatter(console_format)
console_handler.setLevel(logging.INFO)
if len(logging.root.handlers) > 0:
for handler in logging.root.handlers:
logging.root.removeHandler(handler)
logging.root.addHandler(console_handler)
logging.root.setLevel(logging.INFO)
TOKENIZE_ON_NPU = os.environ.get("TOKENIZE_ON_NPU")
if TOKENIZE_ON_NPU is not None and TOKENIZE_ON_NPU == "1":
import torch_npu
import f5tts_npu_patch
from torch_npu.contrib import transfer_to_npu
logging.info("Applying Patches for NPU!!!")
f5tts_npu_patch.patch_for_npu()
class SpeechDetokenizer:
def __init__(self,
vocoder_path:str,
model_cfg:str = str((Path(__file__).parent / "ckpt/model.yaml").absolute()),
ckpt_file:str = str((Path(__file__).parent / "ckpt/model.pt").absolute()),
vocab_file:str = str((Path(__file__).parent / "ckpt/vocab_4096.txt").absolute()),
device="cuda:0"):
self.model_cfg = model_cfg
self.ckpt_file = ckpt_file
self.vocab_file = vocab_file
self.vocoder_path = vocoder_path
self.device = device
self.cross_fade_duration = 0
self.initialize()
def initialize(self):
self.model = "CADiT" # "F5TTS"
load_vocoder_from_local = True
self.vocoder_name = mel_spec_type
# Vocoder
vocoder_local_path = self.vocoder_path # "/home/ma-user/work/chenxiao/workspace/model/vocos_mel_24khz/"
# TTS model
model_cls = CADiT
model_cfg = OmegaConf.load(self.model_cfg).model.arch
logging.info(f"Using {self.model}...")
# Load vocoder
self.vocoder = load_vocoder(
vocoder_name=self.vocoder_name,
is_local=load_vocoder_from_local,
local_path=vocoder_local_path,
device=self.device,
)
# Load TTS model
self.ema_model = load_model(
model_cls,
model_cfg,
self.ckpt_file,
mel_spec_type=self.vocoder_name,
vocab_file=self.vocab_file,
device=self.device,
)
# def chunk_text(self, text_list, chunk_len=135, merge_short_last=False):
# """
# Splits the input text into chunks, each with a maximum number of characters.
# Args:
# text (str): The text to be split.
# max_chars (int): The maximum number of characters per chunk.
# Returns:
# List[str]: A list of text chunks.
# """
# chunks = []
# # if isinstance(text, list):
# for i in range(0, len(text_list), chunk_len):
# chunks.append(text_list[i : i + chunk_len])
# if merge_short_last and len(chunks) >= 2 and len(chunks[-1]) < chunk_len:
# # Merge the last two chunks
# last = chunks.pop()
# second_last = chunks.pop()
# chunks.append(second_last + last)
# return chunks
def chunk_text_with_look_ahead(
self, text_list, chunk_look_ahead_len, chunk_len=135, merge_short_last=False
):
chunks = []
stride = chunk_len - chunk_look_ahead_len
for i in range(0, len(text_list), stride):
chk = text_list[i : i + chunk_len]
chunks.append(chk)
if i + chunk_len >= len(text_list):
break
if (
merge_short_last
and len(chunks) >= 2
and len(chunks[-1]) < stride # chunk_len * 4 / 5
):
# Merge the last two chunks
last = chunks.pop()
second_last = chunks.pop()
if chunk_look_ahead_len <= 0:
chunks.append(second_last + last)
else:
chunks.append(second_last[:-chunk_look_ahead_len] + last)
actual_chunks = []
for idx in range(len(chunks)):
chk = chunks[idx]
if chunk_look_ahead_len <= 0:
actual_chunks.extend(chk)
else:
if idx < len(chunks) - 1:
actual_chunks.extend(chk[:-chunk_look_ahead_len])
else:
actual_chunks.extend(chk)
assert(len(actual_chunks) == len(text_list))
assert(actual_chunks == text_list)
return chunks
def chunk_generate(
self,
ref_audio,
ref_text_list,
gen_text_list,
token_chunk_len,
chunk_cond_proportion,
chunk_look_ahead_len=0,
max_ref_duration=4.5,
ref_head_cut=False,
):
gen_text_batches = self.chunk_text_with_look_ahead(
gen_text_list,
chunk_look_ahead_len,
chunk_len=token_chunk_len,
merge_short_last=True,
)
if len(gen_text_batches) == 0:
return None, None
for i, gen_text in enumerate(gen_text_batches):
logging.info(f"gen_text {i} with {len(gen_text)} tokens : {gen_text}")
audio, sr = torchaudio.load(ref_audio)
logging.info(f"Generating audio in {len(gen_text_batches)} batches...")
target_wave, target_sample_rate, combined_spectrogram = chunk_infer_batch_process(
(audio, sr),
ref_text_list,
gen_text_batches,
self.ema_model,
self.vocoder,
mel_spec_type=mel_spec_type,
progress=tqdm,
target_rms=target_rms,
cross_fade_duration=self.cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=self.device,
chunk_cond_proportion=chunk_cond_proportion,
chunk_look_ahead=chunk_look_ahead_len,
max_ref_duration=max_ref_duration,
ref_head_cut=ref_head_cut,
)
return target_wave, target_sample_rate
def get_audio_duration(audio_path):
audio, sample_rate = torchaudio.load(audio_path)
return audio.shape[1] / sample_rate
def get_test_list(testset_path):
testset_file_path = testset_path
testset_list = [] # list of [ref_text, ref_audio, gen_text, gen_audio(optional)]
with open(testset_file_path, "r") as f:
for line in f:
content = line.strip().split("|")
if len(content) == 2 or len(content) == 3:
testset_list.append([content[1], content[0], content[1]])
elif len(content) == 4 or len(content) == 5:
testset_list.append([content[1], content[0], content[3], content[2]])
return testset_list
def infer(args, task_queue, rank=0):
device_spec = f"cuda:{rank}"
if args.model_cfg is None or args.ckpt is None or args.vocab is None:
detoker = SpeechDetokenizer(
vocoder_path=args.vocoder,
device=device_spec,
)
else:
detoker = SpeechDetokenizer(
vocoder_path=args.vocoder,
model_cfg=args.model_cfg,
ckpt_file=args.ckpt,
vocab_file=args.vocab,
device=device_spec,
)
token_chunk_len = args.chunk_token
chunk_cond_proportion = args.chunk_cond_portion
if chunk_cond_proportion > 1 or chunk_cond_proportion <= 0:
chunk_cond_proportion = 0.5 # set default
chunk_look_ahead = args.chunk_look_ahead
if chunk_look_ahead >= token_chunk_len:
chunk_look_ahead = 0
remove_silence = False
output_dir = args.output
if not os.path.exists(Path(output_dir)):
os.makedirs(Path(output_dir))
# logging.info(f"Using {model}...")
logging.info(f"infer with chunk of {token_chunk_len} tokens")
logging.info(f"the last {chunk_cond_proportion} of each chunk added into condition")
logging.info(f"Using the last {chunk_look_ahead} tokens as look ahead")
gen_nums = 0
while True:
try:
_tst = task_queue.get()
if _tst is None:
logging.info("FINISH processing all inputs")
break
ref_text_list = _tst[0].split()
ref_audio = _tst[1]
gen_text_list = _tst[2].split()
if len(_tst) == 4:
gen_audio = _tst[3]
else:
gen_audio = None
ref_wave_path = (
Path(output_dir) / f"{ref_audio.split('/')[-1].split('.')[0]}_ref.wav"
)
if gen_audio is None:
gen_wave_path = (
Path(output_dir)
/ f"{ref_audio.split('/')[-1].split('.')[0]}_gen.wav"
)
orig_wave_path = None
else:
gen_wave_path = (
Path(output_dir)
/ f"{gen_audio.split('/')[-1].split('.')[0]}_gen.wav"
)
orig_wave_path = (
Path(output_dir)
/ f"{gen_audio.split('/')[-1].split('.')[0]}_orig.wav"
)
if os.path.exists(gen_wave_path):
logging.info(f"{gen_wave_path} already exist, skip")
continue
if not os.path.exists(ref_wave_path):
os.system(f"cp {ref_audio} {ref_wave_path}")
if gen_audio is not None and os.path.exists(gen_audio) and orig_wave_path:
os.system(f"cp {gen_audio} {orig_wave_path}")
generated_wave, target_sample_rate = detoker.chunk_generate(
ref_audio,
ref_text_list,
gen_text_list,
token_chunk_len,
chunk_cond_proportion,
chunk_look_ahead,
args.max_ref_duration,
args.ref_audio_cut_from_head,
)
if generated_wave is None:
continue
with open(gen_wave_path, "wb") as f:
sf.write(f.name, generated_wave, target_sample_rate)
# Remove silence
if remove_silence:
remove_silence_for_generated_wav(f.name)
logging.info(f"write output to: {f.name}")
gen_nums += 1
# if gen_nums >= 10:
# break
except:
logging.info(f"Fail to get new task")
def run_infer_mp(args):
device_list = [0]
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_list = [int(x.strip()) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]
elif "ASCEND_RT_VISIBLE_DEVICES" in os.environ:
device_list = [int(x.strip()) for x in os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")]
logging.info(f"Using devices: {device_list}")
n_procs = len(device_list)
# load testset
testset_list = get_test_list(args.testset_path)
ctx = mp.get_context("spawn")
with ctx.Manager() as manager:
task_queue = manager.Queue()
for task in testset_list:
task_queue.put(task)
processes = []
for idx in range(n_procs):
task_queue.put(None)
rank = idx # device_list[idx]
p = mp.Process(target=infer, args=(args, task_queue, rank))
p.start()
processes.append(p)
for p in processes:
p.join()
os.system(f"cp {args.testset_path} {args.output}")
logging.info(f"Finish processing of {n_procs}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
required=False,
help="path to ckpt",
)
parser.add_argument(
"--model-cfg",
required=False,
help="path to model_cfg",
)
parser.add_argument(
"--vocab",
required=False,
help="path to vocab",
)
parser.add_argument(
"--vocoder",
required=True,
help="path to vocoder",
)
parser.add_argument(
"--testset",
dest="testset_path",
required=True,
help="path of testset file",
)
parser.add_argument(
"--output",
required=True,
help="path to output generated audio",
)
parser.add_argument(
"--chunk-token",
required=True,
type=int,
default=25,
help="max number of tokens in a chunk",
)
parser.add_argument(
"--chunk-look-ahead",
required=False,
type=int,
default=0,
help="number of tokens in a chunk as look ahead",
)
parser.add_argument(
"--chunk-cond-portion",
required=True,
type=float,
default=25,
help="the portion at the tail of the prev chunk as condition",
)
parser.add_argument(
"--max-ref-duration",
required=False,
type=float,
default=4.5,
help="the max duration of ref audio in seconds",
)
parser.add_argument(
"--ref-audio-cut-from-head",
default=False,
action="store_true",
help="cut ref audio from head, if not set, from tail by default",
)
args = parser.parse_args()
start_time = time.perf_counter()
run_infer_mp(args)
end_time = time.perf_counter()
logging.info("processig time: %f sec\n" % (end_time - start_time))
logging.info(f"Finished! output to : {args.output}")