LittleMouse commited on
Commit
d054f6c
·
1 Parent(s): 6912cd9

Upload file

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.wav filter=lfs diff=lfs merge=lfs -text
37
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
CosyVoice-BlankEN/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
CosyVoice-BlankEN/tokenizer_config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "151643": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151644": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151645": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "additional_special_tokens": ["<|im_start|>", "<|im_end|>"],
30
+ "bos_token": null,
31
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "<|im_end|>",
34
+ "errors": "replace",
35
+ "model_max_length": 32768,
36
+ "pad_token": "<|endoftext|>",
37
+ "split_special_tokens": false,
38
+ "tokenizer_class": "Qwen2Tokenizer",
39
+ "unk_token": null
40
+ }
CosyVoice-BlankEN/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
asset/en_man1.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:461dd4cc9cf5bf6b774a9978cc9b7ca96033b214714b12413ecfe9eb1bf03ab9
3
+ size 15309
asset/en_man1.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Because he has zero capacity to respond to the two and a half hour
asset/en_woman1.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:872ff69b74b37763cfc4a49bdd39d8a2acf51f428e42e1ab9fa3dfc0c4a2e3d4
3
+ size 16941
asset/en_woman1.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ But many of these southern girls have the same trouble, said Holly.
asset/zh_man1.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 南方高温卷土重来,全国秋老虎地图出炉。
asset/zh_man1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da1153fca1303cd20470317a4ba93027cc5e172214b777747215add36f41109e
3
+ size 1536044
asset/zh_man2.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd012ac30fe1ffb5bc3e356a84f4f668a25a62c72f810ffae218f83cbcfdf53e
3
+ size 31761
asset/zh_man2.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 所以呢目标是非常有威力的,它是创造原则的全部。
asset/zh_woman1.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 希望你以后能够做的比我还好呦。
asset/zh_woman1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd199eb7109fd6ce9943cb297e3cf350c1073af014063dfadbdc100230526243
3
+ size 111496
frontend-onnx/campplus.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6ac6a63997761ae2997373e2ee1c47040854b4b759ea41ec48e4e42df0f4d73
3
+ size 28303423
frontend-onnx/speech_tokenizer_v2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d43342aa12163a80bf07bffb94c9de2e120a8df2f9917cd2f642e7f4219c6f71
3
+ size 496082973
pengzhendong/wetext ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 8e93692beb2e7f7d0aab4807819abfff0c3dbe6d
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ soundfile
4
+ numpy
5
+ onnxruntime
6
+ openai-whisper
7
+ inflect
8
+ transformers
9
+ librosa
10
+ wetext==0.0.4
scripts/audio.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
46
+ if torch.min(y) < -1.0:
47
+ print("min value is ", torch.min(y))
48
+ if torch.max(y) > 1.0:
49
+ print("max value is ", torch.max(y))
50
+
51
+ global mel_basis, hann_window # pylint: disable=global-statement
52
+ print("fmax",fmax)
53
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
54
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
55
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
56
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
57
+
58
+ y = torch.nn.functional.pad(
59
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
60
+ )
61
+ y = y.squeeze(1)
62
+
63
+ spec = torch.view_as_real(
64
+ torch.stft(
65
+ y,
66
+ n_fft,
67
+ hop_length=hop_size,
68
+ win_length=win_size,
69
+ window=hann_window[str(y.device)],
70
+ center=center,
71
+ pad_mode="reflect",
72
+ normalized=False,
73
+ onesided=True,
74
+ return_complex=True,
75
+ )
76
+ )
77
+
78
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
79
+
80
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
81
+ spec = spectral_normalize_torch(spec)
82
+
83
+ return spec
scripts/frontend.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ from functools import lru_cache
16
+ from typing import Generator
17
+ import json
18
+ import onnxruntime
19
+ import torch
20
+ import numpy as np
21
+ import whisper
22
+ from typing import Callable
23
+ import torchaudio.compliance.kaldi as kaldi
24
+ import torchaudio
25
+ import os
26
+ import re
27
+ import inflect
28
+ from tokenizer.tokenizer import get_qwen_tokenizer
29
+ from audio import mel_spectrogram
30
+
31
+ try:
32
+ import ttsfrd
33
+ use_ttsfrd = True
34
+ except ImportError:
35
+
36
+ from wetext import Normalizer as ZhNormalizer
37
+ from wetext import Normalizer as EnNormalizer
38
+ use_ttsfrd = False
39
+
40
+ import logging
41
+ logging.getLogger('frontend').setLevel(logging.WARNING)
42
+ logging.basicConfig(level=logging.DEBUG,
43
+ format='%(asctime)s %(levelname)s %(message)s')
44
+
45
+ class CosyVoiceFrontEnd:
46
+
47
+ def __init__(self,
48
+ pretrained_path: str,
49
+ wetext_dir: str,
50
+ campplus_model: str,
51
+ speech_tokenizer_model: str,
52
+ spk2info: str = '',
53
+ allowed_special: str = 'all'):
54
+ self.tokenizer = get_qwen_tokenizer(pretrained_path, True)
55
+ self.feat_extractor = partial(
56
+ mel_spectrogram,
57
+ n_fft=1920,
58
+ num_mels=80,
59
+ sampling_rate=24000,
60
+ hop_size=480,
61
+ win_size=1920,
62
+ fmin=0,
63
+ fmax=8000,
64
+ center=False)
65
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
66
+ option = onnxruntime.SessionOptions()
67
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
68
+ option.intra_op_num_threads = 1
69
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
70
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
71
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
72
+ "CPUExecutionProvider"])
73
+ if os.path.exists(spk2info):
74
+ self.spk2info = torch.load(spk2info, map_location=self.device)
75
+ else:
76
+ self.spk2info = {}
77
+ self.allowed_special = allowed_special
78
+ self.use_ttsfrd = use_ttsfrd
79
+ if self.use_ttsfrd:
80
+ self.frd = ttsfrd.TtsFrontendEngine()
81
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
82
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
83
+ 'failed to initialize ttsfrd resource'
84
+ self.frd.set_lang_type('pinyinvg')
85
+ else:
86
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, lang="zh", tagger_path=f"{wetext_dir}/zh/tn/tagger.fst", verbalizer_path=f"{wetext_dir}/zh/tn/tagger.fst")
87
+ self.en_tn_model = EnNormalizer(lang="zh", tagger_path=f"{wetext_dir}/zh/tn/tagger.fst", verbalizer_path=f"{wetext_dir}/zh/tn/tagger.fst")
88
+ self.inflect_parser = inflect.engine()
89
+
90
+ def _extract_text_token(self, text):
91
+ if isinstance(text, Generator):
92
+ logging.info('get tts_text generator, will return _extract_text_token_generator!')
93
+ # NOTE add a dummy text_token_len for compatibility
94
+ return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
95
+ else:
96
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
97
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
98
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
99
+ return text_token, text_token_len
100
+
101
+ def _extract_text_token_generator(self, text_generator):
102
+ for text in text_generator:
103
+ text_token, _ = self._extract_text_token(text)
104
+ for i in range(text_token.shape[1]):
105
+ yield text_token[:, i: i + 1]
106
+
107
+ def _extract_speech_token(self, speech):
108
+ assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
109
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
110
+ speech_token = self.speech_tokenizer_session.run(None,
111
+ {self.speech_tokenizer_session.get_inputs()[0].name:
112
+ feat.detach().cpu().numpy(),
113
+ self.speech_tokenizer_session.get_inputs()[1].name:
114
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
115
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
116
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
117
+ return speech_token, speech_token_len
118
+
119
+ def _extract_spk_embedding(self, speech):
120
+ feat = kaldi.fbank(speech,
121
+ num_mel_bins=80,
122
+ dither=0,
123
+ sample_frequency=16000)
124
+ feat = feat - feat.mean(dim=0, keepdim=True)
125
+ embedding = self.campplus_session.run(None,
126
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
127
+ embedding = torch.tensor([embedding]).to(self.device)
128
+ return embedding
129
+
130
+ def _extract_speech_feat(self, speech):
131
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
132
+ speech_feat = speech_feat.unsqueeze(dim=0)
133
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
134
+ return speech_feat, speech_feat_len
135
+
136
+ def text_normalize(self, text, split=True, text_frontend=True):
137
+ if isinstance(text, Generator):
138
+ logging.info('get tts_text generator, will skip text_normalize!')
139
+ return [text]
140
+ if text_frontend is False or text == '':
141
+ return [text] if split is True else text
142
+ text = text.strip()
143
+ if self.use_ttsfrd:
144
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
145
+ text = ''.join(texts)
146
+ else:
147
+ if contains_chinese(text):
148
+ text = self.zh_tn_model.normalize(text)
149
+ text = text.replace("\n", "")
150
+ text = replace_blank(text)
151
+ text = replace_corner_mark(text)
152
+ text = text.replace(".", "。")
153
+ text = text.replace(" - ", ",")
154
+ text = remove_bracket(text)
155
+ text = re.sub(r'[,,、]+$', '。', text)
156
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
157
+ token_min_n=60, merge_len=20, comma_split=False))
158
+ else:
159
+ text = self.en_tn_model.normalize(text)
160
+ text = spell_out_number(text, self.inflect_parser)
161
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
162
+ token_min_n=60, merge_len=20, comma_split=False))
163
+ texts = [i for i in texts if not is_only_punctuation(i)]
164
+ return texts if split is True else text
165
+
166
+ def frontend_sft(self, tts_text, spk_id):
167
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
168
+ embedding = self.spk2info[spk_id]['embedding']
169
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
170
+ return model_input
171
+
172
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
173
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
174
+ if zero_shot_spk_id == '':
175
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
176
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
177
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
178
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
179
+ if resample_rate == 24000:
180
+ # cosyvoice2, force speech_feat % speech_token = 2
181
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
182
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
183
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
184
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
185
+ model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
186
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
187
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
188
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
189
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
190
+ else:
191
+ model_input = self.spk2info[zero_shot_spk_id]
192
+ model_input['text'] = tts_text_token
193
+ model_input['text_len'] = tts_text_token_len
194
+ return model_input
195
+
196
+ def process_prompt(self, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
197
+ if zero_shot_spk_id == '':
198
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
199
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
200
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
201
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
202
+ if resample_rate == 24000:
203
+ # cosyvoice2, force speech_feat % speech_token = 2
204
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
205
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
206
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
207
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
208
+ model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
209
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
210
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
211
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
212
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
213
+ else:
214
+ model_input = self.spk2info[zero_shot_spk_id]
215
+ return model_input
216
+
217
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
218
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
219
+ # in cross lingual mode, we remove prompt in llm
220
+ del model_input['prompt_text']
221
+ del model_input['prompt_text_len']
222
+ del model_input['llm_prompt_speech_token']
223
+ del model_input['llm_prompt_speech_token_len']
224
+ return model_input
225
+
226
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
227
+ model_input = self.frontend_sft(tts_text, spk_id)
228
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
229
+ del model_input['llm_embedding']
230
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
231
+ model_input['prompt_text'] = instruct_text_token
232
+ model_input['prompt_text_len'] = instruct_text_token_len
233
+ return model_input
234
+
235
+ def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
236
+ model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
237
+ del model_input['llm_prompt_speech_token']
238
+ del model_input['llm_prompt_speech_token_len']
239
+ return model_input
240
+
241
+ def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
242
+ prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
243
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
244
+ prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
245
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
246
+ source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
247
+ model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
248
+ 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
249
+ 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
250
+ 'flow_embedding': embedding}
251
+ return model_input
scripts/process_prompt.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import torchaudio
5
+ import numpy as np
6
+ from frontend import CosyVoiceFrontEnd
7
+ import soundfile as sf
8
+
9
+ def load_wav(wav, target_sr):
10
+ speech, sample_rate = sf.read(wav, dtype='float32')
11
+
12
+ if speech.ndim == 1:
13
+ speech = torch.from_numpy(speech).unsqueeze(0) # (1, T)
14
+ else:
15
+ speech = torch.from_numpy(speech).transpose(0, 1) # (C, T)
16
+
17
+ speech = speech.mean(dim=0, keepdim=True) # (1, T)
18
+
19
+ if sample_rate != target_sr:
20
+ assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
21
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
22
+ return speech
23
+
24
+ if __name__ == "__main__":
25
+
26
+ args = argparse.ArgumentParser()
27
+ args.add_argument('--model_dir', type=str, default="CosyVoice-BlankEN", help="tokenizer configuration directionary")
28
+ args.add_argument('--wetext_dir', type=str, default="pengzhendong/wetext", help="path to wetext")
29
+ args.add_argument('--sample_rate', type=int, default=24000, help="Sampling rate for prompt audio")
30
+ args.add_argument('--prompt_text', type=str, default="希望你以后能够做的比我还好呦。", help="The text content of the prompt(reference) audio. Text or file path.")
31
+ args.add_argument('--prompt_speech', type=str, default="asset/zero_shot_prompt.wav", help="The path to prompt(reference) audio.")
32
+ args.add_argument('--output', type=str, default="prompt_files", help="Output data storage directory")
33
+ args = args.parse_args()
34
+
35
+ os.makedirs(args.output, exist_ok=True)
36
+
37
+ frontend = CosyVoiceFrontEnd(f"{args.model_dir}",
38
+ args.wetext_dir,
39
+ "frontend-onnx/campplus.onnx",
40
+ "frontend-onnx/speech_tokenizer_v2.onnx",
41
+ f"{args.model_dir}/spk2info.pt",
42
+ "all")
43
+
44
+ prompt_speech_16k = load_wav(args.prompt_speech, 16000)
45
+ zero_shot_spk_id = ""
46
+
47
+ if os.path.isfile(args.prompt_text):
48
+ with open(args.prompt_text, "r") as f:
49
+ prompt_text = f.read()
50
+ else:
51
+ prompt_text = args.prompt_text
52
+ print("prompt_text",prompt_text)
53
+ model_input = frontend.process_prompt( prompt_text, prompt_speech_16k, args.sample_rate, zero_shot_spk_id)
54
+
55
+ # model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
56
+ # 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
57
+ # 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
58
+ # 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
59
+ # 'llm_embedding': embedding, 'flow_embedding': embedding}
60
+ print("prompt speech token size:", model_input["flow_prompt_speech_token"].shape)
61
+ assert model_input["flow_prompt_speech_token"].shape[1] >=75, f"speech_token length should >= 75, bug get {model_input['flow_prompt_speech_token'].shape[1]}"
62
+ for k, v in model_input.items():
63
+ if "_len" in k:
64
+ continue
65
+ shapes = [str(s) for s in v.shape]
66
+ shape_str = "_".join(shapes)
67
+ if v.dtype in (torch.int32, torch.int64):
68
+ np.savetxt(f"{args.output}/{k}.txt", v.detach().cpu().numpy().reshape(-1), fmt="%d", delimiter=",")
69
+ else:
70
+ np.savetxt(f"{args.output}/{k}.txt", v.detach().cpu().numpy().reshape(-1), delimiter=",")
scripts/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
scripts/tokenizer/tokenizer.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Optional
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ import tiktoken
8
+
9
+ LANGUAGES = {
10
+ "en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian",
11
+ "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish",
12
+ "pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian",
13
+ "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "he": "hebrew",
14
+ "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish",
15
+ "hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian",
16
+ "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh",
17
+ "sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian",
18
+ "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian",
19
+ "br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian",
20
+ "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi",
21
+ "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali",
22
+ "af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik",
23
+ "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek",
24
+ "fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk",
25
+ "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan",
26
+ "tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian",
27
+ "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
28
+ "yue": "cantonese", "minnan": "minnan", "wuyu": "wuyu", "dialect": "dialect", "zh/en": "zh/en", "en/zh": "en/zh"
29
+ }
30
+
31
+ TO_LANGUAGE_CODE = {
32
+ **{language: code for code, language in LANGUAGES.items()},
33
+ "burmese": "my", "valencian": "ca", "flemish": "nl", "haitian": "ht", "letzeburgesch": "lb",
34
+ "pushto": "ps", "panjabi": "pa", "moldavian": "ro", "moldovan": "ro", "sinhalese": "si",
35
+ "castilian": "es", "mandarin": "zh",
36
+ }
37
+
38
+ AUDIO_EVENT = {
39
+ "ASR": "ASR", "AED": "AED", "SER": "SER", "Speech": "Speech", "/Speech": "/Speech",
40
+ "BGM": "BGM", "/BGM": "/BGM", "Laughter": "Laughter", "/Laughter": "/Laughter",
41
+ "Applause": "Applause", "/Applause": "/Applause",
42
+ }
43
+
44
+ EMOTION = {
45
+ "HAPPY": "HAPPY", "SAD": "SAD", "ANGRY": "ANGRY", "NEUTRAL": "NEUTRAL",
46
+ }
47
+
48
+ TTS_Vocal_Token = {
49
+ "TTS/B": "TTS/B", "TTS/O": "TTS/O", "TTS/Q": "TTS/Q", "TTS/A": "TTS/A", "TTS/CO": "TTS/CO",
50
+ "TTS/CL": "TTS/CL", "TTS/H": "TTS/H", **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
51
+ }
52
+
53
+ # ===== 构造 Encoding =====
54
+ @lru_cache(maxsize=None)
55
+ def get_encoding(name: str = "gpt2", num_languages: int = 99):
56
+ vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
57
+ ranks = {
58
+ base64.b64decode(token): int(rank)
59
+ for token, rank in (line.split() for line in open(vocab_path) if line)
60
+ }
61
+ n_vocab = len(ranks)
62
+ special_tokens = {}
63
+ specials = [
64
+ "<|endoftext|>", "<|startoftranscript|>",
65
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
66
+ *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
67
+ *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
68
+ "<|translate|>", "<|transcribe|>", "<|startoflm|>", "<|startofprev|>",
69
+ "<|nospeech|>", "<|notimestamps|>",
70
+ *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)],
71
+ *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())],
72
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
73
+ ]
74
+ for token in specials:
75
+ special_tokens[token] = n_vocab
76
+ n_vocab += 1
77
+ return tiktoken.Encoding(
78
+ name=os.path.basename(vocab_path),
79
+ explicit_n_vocab=n_vocab,
80
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
81
+ mergeable_ranks=ranks,
82
+ special_tokens=special_tokens,
83
+ )
84
+
85
+ class SimpleTokenizer:
86
+ def __init__(self, encoding, num_languages: int = 99, language: Optional[str] = None, task: Optional[str] = None):
87
+ self.encoding = encoding
88
+ self.num_languages = num_languages
89
+ self.language = language
90
+ self.task = task
91
+ def encode(self, text: str):
92
+ return self.encoding.encode(text)
93
+ def decode(self, tokens: list):
94
+ return self.encoding.decode(tokens)
95
+
96
+ @lru_cache(maxsize=None)
97
+ def get_tokenizer(
98
+ multilingual: bool,
99
+ *,
100
+ num_languages: int = 99,
101
+ language: Optional[str] = None,
102
+ task: Optional[str] = None,
103
+ ) -> SimpleTokenizer:
104
+ if language is not None:
105
+ language = language.lower()
106
+ if language not in LANGUAGES:
107
+ if language in TO_LANGUAGE_CODE:
108
+ language = TO_LANGUAGE_CODE[language]
109
+ else:
110
+ raise ValueError(f"Unsupported language: {language}")
111
+ if multilingual:
112
+ encoding_name = "multilingual_zh_ja_yue_char_del"
113
+ language = language or "en"
114
+ task = task or "transcribe"
115
+ else:
116
+ encoding_name = "gpt2"
117
+ language = None
118
+ task = None
119
+ encoding = get_encoding(name=encoding_name, num_languages=num_languages)
120
+ return SimpleTokenizer(encoding=encoding, num_languages=num_languages, language=language, task=task)
121
+
122
+ class QwenTokenizer():
123
+ def __init__(self, token_path, skip_special_tokens=True):
124
+ super().__init__()
125
+ special_tokens = {
126
+ 'eos_token': '<|endoftext|>',
127
+ 'pad_token': '<|endoftext|>',
128
+ 'additional_special_tokens': [
129
+ '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
130
+ '[breath]', '<strong>', '</strong>', '[noise]',
131
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
132
+ '[quick_breath]',
133
+ "<laughter>", "</laughter>",
134
+ "[hissing]", "[sigh]", "[vocalized-noise]",
135
+ "[lipsmack]", "[mn]"
136
+ ]
137
+ }
138
+ self.special_tokens = special_tokens
139
+ self.tokenizer = AutoTokenizer.from_pretrained(token_path)
140
+ self.tokenizer.add_special_tokens(special_tokens)
141
+ self.skip_special_tokens = skip_special_tokens
142
+ def encode(self, text, **kwargs):
143
+ tokens = self.tokenizer([text], return_tensors="pt")
144
+ return tokens["input_ids"][0].cpu().tolist()
145
+ def decode(self, tokens):
146
+ tokens = torch.tensor(tokens, dtype=torch.int64)
147
+ return self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
148
+
149
+ @lru_cache(maxsize=None)
150
+ def get_qwen_tokenizer(token_path: str, skip_special_tokens: bool) -> QwenTokenizer:
151
+ return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)