File size: 6,759 Bytes
e79b770
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# text to semantic
import argparse
import os
import re
import time
from pathlib import Path

import librosa
import numpy as np
import torch
import whisper
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.text_processing.phonemizer import GruutPhonemizer
from AR.utils.io import load_yaml_config


def get_batch(text, phonemizer):
    # phoneme_ids 和 phoneme_ids_len 是需要的
    phoneme = phonemizer.phonemize(text, espeak=False)
    phoneme_ids = phonemizer.transform(phoneme)
    phoneme_ids_len = len(phoneme_ids)
    phoneme_ids = np.array(phoneme_ids)
    # add batch axis here
    phoneme_ids = torch.tensor(phoneme_ids).unsqueeze(0)
    phoneme_ids_len = torch.tensor([phoneme_ids_len])
    print("phoneme:", phoneme)
    batch = {
        # torch.Tensor (B, max_phoneme_length) 
        "phoneme_ids": phoneme_ids,
        # torch.Tensor (B)
        "phoneme_ids_len": phoneme_ids_len
    }
    return batch


def get_prompt(prompt_wav_path, asr_model, phonemizer, semantic_tokenizer):
    sample_rate = 16000
    # to get prompt
    prompt_name = os.path.basename(prompt_wav_path).split('.')[0]
    wav, _ = librosa.load(prompt_wav_path, sr=sample_rate)
    # 取末尾 3s, 但是不包含最后 0.1s 防止 AR S1 infer 提前停止
    wav = wav[-sample_rate * 3:-int(sample_rate * 0.1)]
    # wav 需要挪出末尾的静音否则也可能提前停住
    prompt_text = asr_model.transcribe(wav)["text"]
    # 移除最后的句点, 防止 AR S1 infer 提前停止, 加了句点可能会有停顿
    prompt_text = prompt_text.replace(".", "")
    prompt_phoneme = phonemizer.phonemize(prompt_text, espeak=False)
    prompt_phoneme_ids = phonemizer.transform(prompt_phoneme)
    prompt_phoneme_ids_len = len(prompt_phoneme_ids)
    # get prompt_semantic
    # (T) -> (1, T)
    wav = torch.tensor(wav).unsqueeze(0)
    wav = wav.cuda()
    # (1, T)
    prompt_semantic_tokens = semantic_tokenizer.tokenize(wav).to(torch.int32)
    prompt_phoneme_ids = torch.tensor(prompt_phoneme_ids).unsqueeze(0)
    prompt_phoneme_ids_len = torch.tensor([prompt_phoneme_ids_len])

    result = {
        'prompt_name': prompt_name,
        'prompt_phoneme_ids': prompt_phoneme_ids,
        'prompt_semantic_tokens': prompt_semantic_tokens,
        'prompt_phoneme_ids_len': prompt_phoneme_ids_len
    }

    return result


def parse_args():
    # parse args and config
    parser = argparse.ArgumentParser(
        description="Run SoundStorm AR S1 model for input text file")

    parser.add_argument(
        '--config_file',
        type=str,
        default='conf/default.yaml',
        help='path of config file')

    parser.add_argument(
        "--text_file",
        type=str,
        help="text file to be convert to semantic tokens, a 'utt_id sentence' pair per line."
    )

    parser.add_argument(
        '--ckpt_path',
        type=str,
        default='exp/default/ckpt/epoch=99-step=49000.ckpt',
        help='Checkpoint file of SoundStorm AR S1 model.')

    parser.add_argument(
        '--prompt_wav_path',
        type=str,
        default=None,
        help='extract prompt semantic and prompt phonemes from prompt wav')

    # to get semantic tokens from prompt_wav
    parser.add_argument("--hubert_path", type=str, default=None)
    parser.add_argument("--quantizer_path", type=str, default=None)

    parser.add_argument("--output_dir", type=str, help="output dir.")

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    config = load_yaml_config(args.config_file)

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    hz = 50
    max_sec = config['data']['max_sec']

    # get models
    t2s_model = Text2SemanticLightningModule.load_from_checkpoint(
        checkpoint_path=args.ckpt_path, config=config)
    t2s_model.cuda()
    t2s_model.eval()

    phonemizer: GruutPhonemizer = GruutPhonemizer(language='en-us')

    # models for prompt
    asr_model = whisper.load_model("tiny.en")

    semantic_tokenizer = SemanticTokenizer(
        hubert_path=args.hubert_path,
        quantizer_path=args.quantizer_path,
        duplicate=True)

    prompt_result = get_prompt(
        prompt_wav_path=args.prompt_wav_path,
        asr_model=asr_model,
        phonemizer=phonemizer,
        semantic_tokenizer=semantic_tokenizer)

    # zero prompt => 输出的 semantic 包含的内容是对的但是音色是乱的
    # (B, 1)
    # prompt = torch.ones(
    #     batch['phoneme_ids'].size(0), 1, dtype=torch.int32) * 0

    prompt = prompt_result['prompt_semantic_tokens']
    prompt_phoneme_ids_len = prompt_result['prompt_phoneme_ids_len']
    prompt_phoneme_ids = prompt_result['prompt_phoneme_ids']

    sentences = []
    with open(args.text_file, 'rt', encoding='utf-8') as f:
        for line in f:
            if line.strip() != "":
                items = re.split(r"\s+", line.strip(), 1)
                utt_id = items[0]
                sentence = " ".join(items[1:])
            sentences.append((utt_id, sentence))
    semantic_data = [['item_name', 'semantic_audio']]
    for utt_id, sentence in sentences[1:]:
        # 需要自己构造伪 batch 输入给模型
        batch = get_batch(sentence, phonemizer)
        # prompt 和真正的输入拼接
        all_phoneme_ids = torch.cat(
            [prompt_phoneme_ids, batch['phoneme_ids']], dim=1)
        # 或者可以直接求 all_phoneme_ids 的 shape[-1]
        all_phoneme_len = prompt_phoneme_ids_len + batch['phoneme_ids_len']
        st = time.time()
        with torch.no_grad():
            pred_semantic = t2s_model.model.infer(
                all_phoneme_ids.cuda(),
                all_phoneme_len.cuda(),
                prompt.cuda(),
                top_k=config['inference']['top_k'],
                early_stop_num=hz * max_sec)
        print(f'{time.time() - st} sec used in T2S')

        # 删除 prompt 对应的部分
        prompt_len = prompt.shape[-1]
        pred_semantic = pred_semantic[:, prompt_len:]

        # bs = 1
        pred_semantic = pred_semantic[0]
        semantic_token = pred_semantic.detach().cpu().numpy().tolist()
        semantic_token_str = ' '.join(str(x) for x in semantic_token)
        semantic_data.append([utt_id, semantic_token_str])

        delimiter = '\t'
        filename = output_dir / f'{utt_id}_p_{prompt_result["prompt_name"]}_semantic_token.tsv'
        with open(filename, 'w', encoding='utf-8') as writer:
            for row in semantic_data:
                line = delimiter.join(row)
                writer.write(line + '\n')
        # clean semantic token for next setence
        semantic_data = [['item_name', 'semantic_audio']]


if __name__ == "__main__":
    main()