DSTK / reconstuction_example.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import librosa
import logging
import soundfile as sf
import sys
from pathlib import Path
sub_modules = ["", "semantic_tokenizer/f40ms", "semantic_detokenizer"]
for sub in sub_modules:
sys.path.append(str((Path(__file__).parent / sub).absolute()))
from semantic_tokenizer.f40ms.simple_tokenizer_infer import SpeechTokenizer, TOKENIZER_CFG_NAME
from semantic_detokenizer.chunk_infer import SpeechDetokenizer
class ReconstructionPipeline:
def __init__(
self,
detok_vocoder: str,
tokenizer_cfg_name: str = TOKENIZER_CFG_NAME,
tokenizer_cfg_path: str = str(
(Path(__file__).parent / "semantic_tokenizer/f40ms/config").absolute()
),
tokenizer_ckpt: str = str(
(
Path(__file__).parent / "semantic_tokenizer/f40ms/ckpt/model.pt"
).absolute()
),
detok_model_cfg: str = str(
(Path(__file__).parent / "semantic_detokenizer/ckpt/model.yaml").absolute()
),
detok_ckpt: str = str(
(Path(__file__).parent / "semantic_detokenizer/ckpt/model.pt").absolute()
),
detok_vocab: str = str(
(
Path(__file__).parent / "semantic_detokenizer/ckpt/vocab_4096.txt"
).absolute()
),
):
self.tokenizer_cfg_name = tokenizer_cfg_name
self.tokenizer = SpeechTokenizer(
ckpt_path=tokenizer_ckpt,
cfg_path=tokenizer_cfg_path,
cfg_name=self.tokenizer_cfg_name,
)
self.device = "cuda:0"
self.detoker = SpeechDetokenizer(
vocoder_path=detok_vocoder,
model_cfg=detok_model_cfg,
ckpt_file=detok_ckpt,
vocab_file=detok_vocab,
device=self.device,
)
self.token_chunk_len = 75
self.chunk_cond_proportion = 0.3
self.chunk_look_ahead = 10
self.max_ref_duration = 4.5
self.ref_audio_cut_from_head = False
def reconstruct(self, ref_wav, input_wav):
ref_wavs_list = []
raw_ref_wav, sr = librosa.load(ref_wav, sr=16000)
ref_wavs_list.append(raw_ref_wav)
raw_input_wav, sr = librosa.load(input_wav, sr=16000)
ref_wavs_list.append(raw_input_wav)
token_list, token_info_list = self.tokenizer.extract(
ref_wavs_list
)
ref_tokens = token_info_list[0]["reduced_unit_sequence"]
input_tokens = token_info_list[1]["reduced_unit_sequence"]
logging.info("tokens for ref wav: %s are [%s]" % (ref_wav, ref_tokens))
logging.info("tokens for input wav: %s are [%s]" % (input_wav, input_tokens))
generated_wave, target_sample_rate = self.detoker.chunk_generate(
ref_wav,
ref_tokens.split(),
input_tokens.split(),
self.token_chunk_len,
self.chunk_cond_proportion,
self.chunk_look_ahead,
self.max_ref_duration,
self.ref_audio_cut_from_head,
)
if generated_wave is None:
logging.info("generation FAILED")
return None, None
return generated_wave, target_sample_rate
def main(args):
# initialize
reconsturctor = ReconstructionPipeline(
detok_vocoder=args.detok_vocoder,
)
generated_wave, target_sample_rate = reconsturctor.reconstruct(args.ref_wav, args.input_wav)
with open(args.output_wav, "wb") as f:
sf.write(f.name, generated_wave, target_sample_rate)
logging.info(f"write output to: {f.name}")
logging.info("Finished")
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tokenizer-ckpt",
required=False,
help="path to ckpt",
)
parser.add_argument(
"--tokenizer-cfg-path",
required=False,
default="semantic_tokenizer/f40ms/config",
help="path to config",
)
parser.add_argument(
"--detok-ckpt",
required=False,
help="path to ckpt",
)
parser.add_argument(
"--detok-model-cfg",
required=False,
help="path to model_cfg",
)
parser.add_argument(
"--detok-vocab",
required=False,
help="path to vocab",
)
parser.add_argument(
"--detok-vocoder",
required=True,
help="path to vocoder",
)
parser.add_argument(
"--ref-wav",
required=True,
help="path to ref wav",
)
parser.add_argument(
"--output-wav",
required=True,
help="path to output reconstructed wav",
)
parser.add_argument(
"--input-wav",
required=True,
help="input wav to reconstruction",
)
args = parser.parse_args()
main(args)