File size: 5,088 Bytes
cd8454d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Daxin Tan,
#                                                                                  Xiao Chen)

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from simple_tokenizer_infer import SpeechTokenizer
import argparse
import librosa
import logging
from pathlib import Path


def main(args):
    ref_wav_file_list = []
    line_info_list = []
    reconstruct_wav_file_list = []

    logging.info(f"loading eval file list")
    base_path = Path(args.input_list).parent
    with open(args.input_list, "r") as input_file:
        for line in input_file:
            fields = line.strip().split("|")
            if args.input_type == "tts":
                ref_wav_file_list.append(fields[2])
            else:
                reconstruct_wav_file_list.append(fields[4])
                ref_wav_file_list.append(fields[2])
            line_info_list.append([fields[2], fields[0], fields[3]]) # ref wav path, gen wav id, text

    logging.info(f"loading ref audio")
    raw_ref_wavs_list = []  # 用librosa 加载后的raw wave 波形数据
    for file_path in ref_wav_file_list:
        # 加载波形数据
        raw_wav, sr = librosa.load(
            (base_path / file_path), sr=16000
        )  # sr=None 保留原始采样率
        raw_ref_wavs_list.append(raw_wav)

    logging.info(f"extracting token for ref audio")
    if args.ckpt is not None:
        tokenizer = SpeechTokenizer(
            ckpt_path=args.ckpt, cfg_path=args.cfg_path, cfg_name=args.cfg_name
        )
    else:
        tokenizer = SpeechTokenizer()
    ref_token_list, ref_token_info_list = tokenizer.extract(raw_ref_wavs_list)

    if args.input_type == "reconstruct":
        logging.info(f"loading reconstruct audio")
        raw_reconstruct_wav_list = []  # 用librosa 加载后的raw wave 波形数据
        for file_path in reconstruct_wav_file_list:
            # 加载波形数据
            raw_wav, sr = librosa.load(
                (base_path / file_path), sr=16000
            )  # sr=None 保留原始采样率
            raw_reconstruct_wav_list.append(raw_wav)

        logging.info(f"extracting token for reconstruct audio")
        recon_token_list, recon_token_info_list = tokenizer.extract(raw_reconstruct_wav_list)
        assert(len(ref_token_info_list) == len(recon_token_info_list))

    assert(len(ref_token_info_list) == len(line_info_list))
    with open(args.output_file, "w") as output_file:
        logging.info(f"writing output file")
        if args.input_type == "tts":
            for ref, line_info in zip(ref_token_info_list, line_info_list):
                ref_units = ref["reduced_unit_sequence"]
                # logging.info(ref_units)
                ref_path = str((base_path / line_info[0]))
                output_file.write(f"{ref_path}|{ref_units}|{line_info[1]}|{line_info[2]}\n")
        else:
            for ref, recon, line_info in zip(ref_token_info_list, recon_token_info_list, line_info_list):
                ref_units = ref["reduced_unit_sequence"]
                recon_units = recon["reduced_unit_sequence"]
                # logging.info(ref_units)
                ref_path = str((base_path / line_info[0]))
                output_file.write(f"{ref_path}|{ref_units}|{line_info[1]}|{recon_units}|{line_info[2]}\n")
    output_file.close()
    logging.info("Finished")
    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--ckpt",
        dest="ckpt",
        required=False,
        help="path to ckpt",
    )
    parser.add_argument(
        "--cfg-path",
        dest="cfg_path",
        required=False,
        default="config",
        help="path to config",
    )
    parser.add_argument(
        "--cfg-name",
        dest="cfg_name",
        required=False,
        default="hubert_config",
        help="name of config",
    )
    parser.add_argument(
        "--input-list",
        dest="input_list",
        required=True,
        help="list of input wavform",
    )
    parser.add_argument(
        "--output-file",
        dest="output_file",
        required=True,
        help="file to output speech tokens",
    )
    parser.add_argument(
        "--input-type",
        default="tts",
        type=str,
        required=True,
        help=f"test fil list type: tts or reconstruct, seedtts format",
    )
    args = parser.parse_args()

    if args.input_type not in {"tts", "reconstruct"}:
        logging.info(f"Input type must be tts or reconstruct")
        exit()
    main(args)