File size: 5,088 Bytes
cd8454d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Daxin Tan,
# Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from simple_tokenizer_infer import SpeechTokenizer
import argparse
import librosa
import logging
from pathlib import Path
def main(args):
ref_wav_file_list = []
line_info_list = []
reconstruct_wav_file_list = []
logging.info(f"loading eval file list")
base_path = Path(args.input_list).parent
with open(args.input_list, "r") as input_file:
for line in input_file:
fields = line.strip().split("|")
if args.input_type == "tts":
ref_wav_file_list.append(fields[2])
else:
reconstruct_wav_file_list.append(fields[4])
ref_wav_file_list.append(fields[2])
line_info_list.append([fields[2], fields[0], fields[3]]) # ref wav path, gen wav id, text
logging.info(f"loading ref audio")
raw_ref_wavs_list = [] # 用librosa 加载后的raw wave 波形数据
for file_path in ref_wav_file_list:
# 加载波形数据
raw_wav, sr = librosa.load(
(base_path / file_path), sr=16000
) # sr=None 保留原始采样率
raw_ref_wavs_list.append(raw_wav)
logging.info(f"extracting token for ref audio")
if args.ckpt is not None:
tokenizer = SpeechTokenizer(
ckpt_path=args.ckpt, cfg_path=args.cfg_path, cfg_name=args.cfg_name
)
else:
tokenizer = SpeechTokenizer()
ref_token_list, ref_token_info_list = tokenizer.extract(raw_ref_wavs_list)
if args.input_type == "reconstruct":
logging.info(f"loading reconstruct audio")
raw_reconstruct_wav_list = [] # 用librosa 加载后的raw wave 波形数据
for file_path in reconstruct_wav_file_list:
# 加载波形数据
raw_wav, sr = librosa.load(
(base_path / file_path), sr=16000
) # sr=None 保留原始采样率
raw_reconstruct_wav_list.append(raw_wav)
logging.info(f"extracting token for reconstruct audio")
recon_token_list, recon_token_info_list = tokenizer.extract(raw_reconstruct_wav_list)
assert(len(ref_token_info_list) == len(recon_token_info_list))
assert(len(ref_token_info_list) == len(line_info_list))
with open(args.output_file, "w") as output_file:
logging.info(f"writing output file")
if args.input_type == "tts":
for ref, line_info in zip(ref_token_info_list, line_info_list):
ref_units = ref["reduced_unit_sequence"]
# logging.info(ref_units)
ref_path = str((base_path / line_info[0]))
output_file.write(f"{ref_path}|{ref_units}|{line_info[1]}|{line_info[2]}\n")
else:
for ref, recon, line_info in zip(ref_token_info_list, recon_token_info_list, line_info_list):
ref_units = ref["reduced_unit_sequence"]
recon_units = recon["reduced_unit_sequence"]
# logging.info(ref_units)
ref_path = str((base_path / line_info[0]))
output_file.write(f"{ref_path}|{ref_units}|{line_info[1]}|{recon_units}|{line_info[2]}\n")
output_file.close()
logging.info("Finished")
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
dest="ckpt",
required=False,
help="path to ckpt",
)
parser.add_argument(
"--cfg-path",
dest="cfg_path",
required=False,
default="config",
help="path to config",
)
parser.add_argument(
"--cfg-name",
dest="cfg_name",
required=False,
default="hubert_config",
help="name of config",
)
parser.add_argument(
"--input-list",
dest="input_list",
required=True,
help="list of input wavform",
)
parser.add_argument(
"--output-file",
dest="output_file",
required=True,
help="file to output speech tokens",
)
parser.add_argument(
"--input-type",
default="tts",
type=str,
required=True,
help=f"test fil list type: tts or reconstruct, seedtts format",
)
args = parser.parse_args()
if args.input_type not in {"tts", "reconstruct"}:
logging.info(f"Input type must be tts or reconstruct")
exit()
main(args)
|