DSTK / semantic_tokenizer /f40ms /infer_for_eval.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Daxin Tan,
# Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from simple_tokenizer_infer import SpeechTokenizer
import argparse
import librosa
import logging
from pathlib import Path
def main(args):
ref_wav_file_list = []
line_info_list = []
reconstruct_wav_file_list = []
logging.info(f"loading eval file list")
base_path = Path(args.input_list).parent
with open(args.input_list, "r") as input_file:
for line in input_file:
fields = line.strip().split("|")
if args.input_type == "tts":
ref_wav_file_list.append(fields[2])
else:
reconstruct_wav_file_list.append(fields[4])
ref_wav_file_list.append(fields[2])
line_info_list.append([fields[2], fields[0], fields[3]]) # ref wav path, gen wav id, text
logging.info(f"loading ref audio")
raw_ref_wavs_list = [] # 用librosa 加载后的raw wave 波形数据
for file_path in ref_wav_file_list:
# 加载波形数据
raw_wav, sr = librosa.load(
(base_path / file_path), sr=16000
) # sr=None 保留原始采样率
raw_ref_wavs_list.append(raw_wav)
logging.info(f"extracting token for ref audio")
if args.ckpt is not None:
tokenizer = SpeechTokenizer(
ckpt_path=args.ckpt, cfg_path=args.cfg_path, cfg_name=args.cfg_name
)
else:
tokenizer = SpeechTokenizer()
ref_token_list, ref_token_info_list = tokenizer.extract(raw_ref_wavs_list)
if args.input_type == "reconstruct":
logging.info(f"loading reconstruct audio")
raw_reconstruct_wav_list = [] # 用librosa 加载后的raw wave 波形数据
for file_path in reconstruct_wav_file_list:
# 加载波形数据
raw_wav, sr = librosa.load(
(base_path / file_path), sr=16000
) # sr=None 保留原始采样率
raw_reconstruct_wav_list.append(raw_wav)
logging.info(f"extracting token for reconstruct audio")
recon_token_list, recon_token_info_list = tokenizer.extract(raw_reconstruct_wav_list)
assert(len(ref_token_info_list) == len(recon_token_info_list))
assert(len(ref_token_info_list) == len(line_info_list))
with open(args.output_file, "w") as output_file:
logging.info(f"writing output file")
if args.input_type == "tts":
for ref, line_info in zip(ref_token_info_list, line_info_list):
ref_units = ref["reduced_unit_sequence"]
# logging.info(ref_units)
ref_path = str((base_path / line_info[0]))
output_file.write(f"{ref_path}|{ref_units}|{line_info[1]}|{line_info[2]}\n")
else:
for ref, recon, line_info in zip(ref_token_info_list, recon_token_info_list, line_info_list):
ref_units = ref["reduced_unit_sequence"]
recon_units = recon["reduced_unit_sequence"]
# logging.info(ref_units)
ref_path = str((base_path / line_info[0]))
output_file.write(f"{ref_path}|{ref_units}|{line_info[1]}|{recon_units}|{line_info[2]}\n")
output_file.close()
logging.info("Finished")
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
dest="ckpt",
required=False,
help="path to ckpt",
)
parser.add_argument(
"--cfg-path",
dest="cfg_path",
required=False,
default="config",
help="path to config",
)
parser.add_argument(
"--cfg-name",
dest="cfg_name",
required=False,
default="hubert_config",
help="name of config",
)
parser.add_argument(
"--input-list",
dest="input_list",
required=True,
help="list of input wavform",
)
parser.add_argument(
"--output-file",
dest="output_file",
required=True,
help="file to output speech tokens",
)
parser.add_argument(
"--input-type",
default="tts",
type=str,
required=True,
help=f"test fil list type: tts or reconstruct, seedtts format",
)
args = parser.parse_args()
if args.input_type not in {"tts", "reconstruct"}:
logging.info(f"Input type must be tts or reconstruct")
exit()
main(args)