|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
from datetime import datetime |
|
|
from importlib.resources import files |
|
|
from pathlib import Path |
|
|
import sys |
|
|
import tqdm |
|
|
|
|
|
import soundfile as sf |
|
|
import time |
|
|
|
|
|
from omegaconf import OmegaConf |
|
|
import torchaudio |
|
|
import torch.multiprocessing as mp |
|
|
|
|
|
from f5_tts.infer.utils_infer import ( |
|
|
load_model, |
|
|
load_vocoder, |
|
|
remove_silence_for_generated_wav, |
|
|
) |
|
|
|
|
|
sys.path.append(str(Path(__file__).parent)) |
|
|
from utils_infer import ( |
|
|
mel_spec_type, |
|
|
target_rms, |
|
|
nfe_step, |
|
|
cfg_strength, |
|
|
sway_sampling_coef, |
|
|
speed, |
|
|
fix_duration, |
|
|
chunk_infer_batch_process |
|
|
) |
|
|
from model.cadit import CADiT |
|
|
|
|
|
import logging |
|
|
console_format = logging.Formatter( |
|
|
"[%(asctime)s][%(filename)s:%(levelname)s][%(process)d:%(threadName)s]%(message)s" |
|
|
) |
|
|
console_handler = logging.StreamHandler() |
|
|
console_handler.setFormatter(console_format) |
|
|
console_handler.setLevel(logging.INFO) |
|
|
if len(logging.root.handlers) > 0: |
|
|
for handler in logging.root.handlers: |
|
|
logging.root.removeHandler(handler) |
|
|
logging.root.addHandler(console_handler) |
|
|
logging.root.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
TOKENIZE_ON_NPU = os.environ.get("TOKENIZE_ON_NPU") |
|
|
if TOKENIZE_ON_NPU is not None and TOKENIZE_ON_NPU == "1": |
|
|
import torch_npu |
|
|
import f5tts_npu_patch |
|
|
from torch_npu.contrib import transfer_to_npu |
|
|
|
|
|
logging.info("Applying Patches for NPU!!!") |
|
|
f5tts_npu_patch.patch_for_npu() |
|
|
|
|
|
|
|
|
class SpeechDetokenizer: |
|
|
def __init__(self, |
|
|
vocoder_path:str, |
|
|
model_cfg:str = str((Path(__file__).parent / "ckpt/model.yaml").absolute()), |
|
|
ckpt_file:str = str((Path(__file__).parent / "ckpt/model.pt").absolute()), |
|
|
vocab_file:str = str((Path(__file__).parent / "ckpt/vocab_4096.txt").absolute()), |
|
|
device="cuda:0"): |
|
|
self.model_cfg = model_cfg |
|
|
self.ckpt_file = ckpt_file |
|
|
self.vocab_file = vocab_file |
|
|
self.vocoder_path = vocoder_path |
|
|
self.device = device |
|
|
|
|
|
self.cross_fade_duration = 0 |
|
|
self.initialize() |
|
|
|
|
|
def initialize(self): |
|
|
self.model = "CADiT" |
|
|
load_vocoder_from_local = True |
|
|
|
|
|
self.vocoder_name = mel_spec_type |
|
|
|
|
|
vocoder_local_path = self.vocoder_path |
|
|
|
|
|
|
|
|
model_cls = CADiT |
|
|
model_cfg = OmegaConf.load(self.model_cfg).model.arch |
|
|
logging.info(f"Using {self.model}...") |
|
|
|
|
|
|
|
|
self.vocoder = load_vocoder( |
|
|
vocoder_name=self.vocoder_name, |
|
|
is_local=load_vocoder_from_local, |
|
|
local_path=vocoder_local_path, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
|
|
|
self.ema_model = load_model( |
|
|
model_cls, |
|
|
model_cfg, |
|
|
self.ckpt_file, |
|
|
mel_spec_type=self.vocoder_name, |
|
|
vocab_file=self.vocab_file, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chunk_text_with_look_ahead( |
|
|
self, text_list, chunk_look_ahead_len, chunk_len=135, merge_short_last=False |
|
|
): |
|
|
chunks = [] |
|
|
|
|
|
stride = chunk_len - chunk_look_ahead_len |
|
|
for i in range(0, len(text_list), stride): |
|
|
chk = text_list[i : i + chunk_len] |
|
|
chunks.append(chk) |
|
|
if i + chunk_len >= len(text_list): |
|
|
break |
|
|
|
|
|
if ( |
|
|
merge_short_last |
|
|
and len(chunks) >= 2 |
|
|
and len(chunks[-1]) < stride |
|
|
): |
|
|
|
|
|
last = chunks.pop() |
|
|
second_last = chunks.pop() |
|
|
if chunk_look_ahead_len <= 0: |
|
|
chunks.append(second_last + last) |
|
|
else: |
|
|
chunks.append(second_last[:-chunk_look_ahead_len] + last) |
|
|
|
|
|
actual_chunks = [] |
|
|
for idx in range(len(chunks)): |
|
|
chk = chunks[idx] |
|
|
if chunk_look_ahead_len <= 0: |
|
|
actual_chunks.extend(chk) |
|
|
else: |
|
|
if idx < len(chunks) - 1: |
|
|
actual_chunks.extend(chk[:-chunk_look_ahead_len]) |
|
|
else: |
|
|
actual_chunks.extend(chk) |
|
|
|
|
|
assert(len(actual_chunks) == len(text_list)) |
|
|
assert(actual_chunks == text_list) |
|
|
return chunks |
|
|
|
|
|
def chunk_generate( |
|
|
self, |
|
|
ref_audio, |
|
|
ref_text_list, |
|
|
gen_text_list, |
|
|
token_chunk_len, |
|
|
chunk_cond_proportion, |
|
|
chunk_look_ahead_len=0, |
|
|
max_ref_duration=4.5, |
|
|
ref_head_cut=False, |
|
|
): |
|
|
|
|
|
gen_text_batches = self.chunk_text_with_look_ahead( |
|
|
gen_text_list, |
|
|
chunk_look_ahead_len, |
|
|
chunk_len=token_chunk_len, |
|
|
merge_short_last=True, |
|
|
) |
|
|
|
|
|
if len(gen_text_batches) == 0: |
|
|
return None, None |
|
|
|
|
|
for i, gen_text in enumerate(gen_text_batches): |
|
|
logging.info(f"gen_text {i} with {len(gen_text)} tokens : {gen_text}") |
|
|
|
|
|
audio, sr = torchaudio.load(ref_audio) |
|
|
logging.info(f"Generating audio in {len(gen_text_batches)} batches...") |
|
|
|
|
|
target_wave, target_sample_rate, combined_spectrogram = chunk_infer_batch_process( |
|
|
(audio, sr), |
|
|
ref_text_list, |
|
|
gen_text_batches, |
|
|
self.ema_model, |
|
|
self.vocoder, |
|
|
mel_spec_type=mel_spec_type, |
|
|
progress=tqdm, |
|
|
target_rms=target_rms, |
|
|
cross_fade_duration=self.cross_fade_duration, |
|
|
nfe_step=nfe_step, |
|
|
cfg_strength=cfg_strength, |
|
|
sway_sampling_coef=sway_sampling_coef, |
|
|
speed=speed, |
|
|
fix_duration=fix_duration, |
|
|
device=self.device, |
|
|
chunk_cond_proportion=chunk_cond_proportion, |
|
|
chunk_look_ahead=chunk_look_ahead_len, |
|
|
max_ref_duration=max_ref_duration, |
|
|
ref_head_cut=ref_head_cut, |
|
|
) |
|
|
return target_wave, target_sample_rate |
|
|
|
|
|
|
|
|
def get_audio_duration(audio_path): |
|
|
audio, sample_rate = torchaudio.load(audio_path) |
|
|
return audio.shape[1] / sample_rate |
|
|
|
|
|
|
|
|
def get_test_list(testset_path): |
|
|
testset_file_path = testset_path |
|
|
testset_list = [] |
|
|
|
|
|
with open(testset_file_path, "r") as f: |
|
|
for line in f: |
|
|
content = line.strip().split("|") |
|
|
if len(content) == 2 or len(content) == 3: |
|
|
testset_list.append([content[1], content[0], content[1]]) |
|
|
elif len(content) == 4 or len(content) == 5: |
|
|
testset_list.append([content[1], content[0], content[3], content[2]]) |
|
|
return testset_list |
|
|
|
|
|
|
|
|
def infer(args, task_queue, rank=0): |
|
|
device_spec = f"cuda:{rank}" |
|
|
|
|
|
if args.model_cfg is None or args.ckpt is None or args.vocab is None: |
|
|
detoker = SpeechDetokenizer( |
|
|
vocoder_path=args.vocoder, |
|
|
device=device_spec, |
|
|
) |
|
|
else: |
|
|
detoker = SpeechDetokenizer( |
|
|
vocoder_path=args.vocoder, |
|
|
model_cfg=args.model_cfg, |
|
|
ckpt_file=args.ckpt, |
|
|
vocab_file=args.vocab, |
|
|
device=device_spec, |
|
|
) |
|
|
|
|
|
token_chunk_len = args.chunk_token |
|
|
chunk_cond_proportion = args.chunk_cond_portion |
|
|
if chunk_cond_proportion > 1 or chunk_cond_proportion <= 0: |
|
|
chunk_cond_proportion = 0.5 |
|
|
|
|
|
chunk_look_ahead = args.chunk_look_ahead |
|
|
if chunk_look_ahead >= token_chunk_len: |
|
|
chunk_look_ahead = 0 |
|
|
|
|
|
remove_silence = False |
|
|
|
|
|
output_dir = args.output |
|
|
if not os.path.exists(Path(output_dir)): |
|
|
os.makedirs(Path(output_dir)) |
|
|
|
|
|
|
|
|
logging.info(f"infer with chunk of {token_chunk_len} tokens") |
|
|
logging.info(f"the last {chunk_cond_proportion} of each chunk added into condition") |
|
|
logging.info(f"Using the last {chunk_look_ahead} tokens as look ahead") |
|
|
|
|
|
gen_nums = 0 |
|
|
while True: |
|
|
try: |
|
|
_tst = task_queue.get() |
|
|
if _tst is None: |
|
|
logging.info("FINISH processing all inputs") |
|
|
break |
|
|
|
|
|
ref_text_list = _tst[0].split() |
|
|
ref_audio = _tst[1] |
|
|
gen_text_list = _tst[2].split() |
|
|
|
|
|
if len(_tst) == 4: |
|
|
gen_audio = _tst[3] |
|
|
else: |
|
|
gen_audio = None |
|
|
|
|
|
ref_wave_path = ( |
|
|
Path(output_dir) / f"{ref_audio.split('/')[-1].split('.')[0]}_ref.wav" |
|
|
) |
|
|
if gen_audio is None: |
|
|
gen_wave_path = ( |
|
|
Path(output_dir) |
|
|
/ f"{ref_audio.split('/')[-1].split('.')[0]}_gen.wav" |
|
|
) |
|
|
orig_wave_path = None |
|
|
else: |
|
|
gen_wave_path = ( |
|
|
Path(output_dir) |
|
|
/ f"{gen_audio.split('/')[-1].split('.')[0]}_gen.wav" |
|
|
) |
|
|
orig_wave_path = ( |
|
|
Path(output_dir) |
|
|
/ f"{gen_audio.split('/')[-1].split('.')[0]}_orig.wav" |
|
|
) |
|
|
|
|
|
if os.path.exists(gen_wave_path): |
|
|
logging.info(f"{gen_wave_path} already exist, skip") |
|
|
continue |
|
|
|
|
|
if not os.path.exists(ref_wave_path): |
|
|
os.system(f"cp {ref_audio} {ref_wave_path}") |
|
|
|
|
|
if gen_audio is not None and os.path.exists(gen_audio) and orig_wave_path: |
|
|
os.system(f"cp {gen_audio} {orig_wave_path}") |
|
|
|
|
|
generated_wave, target_sample_rate = detoker.chunk_generate( |
|
|
ref_audio, |
|
|
ref_text_list, |
|
|
gen_text_list, |
|
|
token_chunk_len, |
|
|
chunk_cond_proportion, |
|
|
chunk_look_ahead, |
|
|
args.max_ref_duration, |
|
|
args.ref_audio_cut_from_head, |
|
|
) |
|
|
|
|
|
if generated_wave is None: |
|
|
continue |
|
|
|
|
|
with open(gen_wave_path, "wb") as f: |
|
|
sf.write(f.name, generated_wave, target_sample_rate) |
|
|
|
|
|
if remove_silence: |
|
|
remove_silence_for_generated_wav(f.name) |
|
|
logging.info(f"write output to: {f.name}") |
|
|
|
|
|
gen_nums += 1 |
|
|
|
|
|
|
|
|
except: |
|
|
logging.info(f"Fail to get new task") |
|
|
|
|
|
|
|
|
def run_infer_mp(args): |
|
|
|
|
|
device_list = [0] |
|
|
if "CUDA_VISIBLE_DEVICES" in os.environ: |
|
|
device_list = [int(x.strip()) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] |
|
|
elif "ASCEND_RT_VISIBLE_DEVICES" in os.environ: |
|
|
device_list = [int(x.strip()) for x in os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")] |
|
|
|
|
|
logging.info(f"Using devices: {device_list}") |
|
|
n_procs = len(device_list) |
|
|
|
|
|
|
|
|
testset_list = get_test_list(args.testset_path) |
|
|
|
|
|
ctx = mp.get_context("spawn") |
|
|
with ctx.Manager() as manager: |
|
|
task_queue = manager.Queue() |
|
|
for task in testset_list: |
|
|
task_queue.put(task) |
|
|
|
|
|
processes = [] |
|
|
for idx in range(n_procs): |
|
|
task_queue.put(None) |
|
|
rank = idx |
|
|
p = mp.Process(target=infer, args=(args, task_queue, rank)) |
|
|
p.start() |
|
|
processes.append(p) |
|
|
|
|
|
for p in processes: |
|
|
p.join() |
|
|
|
|
|
os.system(f"cp {args.testset_path} {args.output}") |
|
|
logging.info(f"Finish processing of {n_procs}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--ckpt", |
|
|
required=False, |
|
|
help="path to ckpt", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model-cfg", |
|
|
required=False, |
|
|
help="path to model_cfg", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--vocab", |
|
|
required=False, |
|
|
help="path to vocab", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--vocoder", |
|
|
required=True, |
|
|
help="path to vocoder", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--testset", |
|
|
dest="testset_path", |
|
|
required=True, |
|
|
help="path of testset file", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
required=True, |
|
|
help="path to output generated audio", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--chunk-token", |
|
|
required=True, |
|
|
type=int, |
|
|
default=25, |
|
|
help="max number of tokens in a chunk", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--chunk-look-ahead", |
|
|
required=False, |
|
|
type=int, |
|
|
default=0, |
|
|
help="number of tokens in a chunk as look ahead", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--chunk-cond-portion", |
|
|
required=True, |
|
|
type=float, |
|
|
default=25, |
|
|
help="the portion at the tail of the prev chunk as condition", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-ref-duration", |
|
|
required=False, |
|
|
type=float, |
|
|
default=4.5, |
|
|
help="the max duration of ref audio in seconds", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ref-audio-cut-from-head", |
|
|
default=False, |
|
|
action="store_true", |
|
|
help="cut ref audio from head, if not set, from tail by default", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
start_time = time.perf_counter() |
|
|
|
|
|
run_infer_mp(args) |
|
|
|
|
|
end_time = time.perf_counter() |
|
|
logging.info("processig time: %f sec\n" % (end_time - start_time)) |
|
|
logging.info(f"Finished! output to : {args.output}") |
|
|
|