fix bug
Browse files- config.json +5 -6
- configuration.json +3 -1
- modeling_qwen.py +8 -9
- qwen_generation_utils.py +0 -1
- tokenization_qwen.py +123 -138
- tokenizer_config.json +1 -1
config.json
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "10302244_iter8000_final/",
|
3 |
"architectures": [
|
4 |
"QWenLMHeadModel"
|
5 |
],
|
6 |
"attn_dropout_prob": 0.0,
|
7 |
"audio": {
|
8 |
"add_audio_bos_eos_token": true,
|
9 |
-
"audio_start_id":
|
10 |
"avg_pool": true,
|
11 |
"n_ctx": 1500,
|
12 |
"n_head": 20,
|
@@ -19,7 +18,7 @@
|
|
19 |
"AutoConfig": "configuration_qwen.QWenConfig",
|
20 |
"AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
|
21 |
},
|
22 |
-
"bf16":
|
23 |
"emb_dropout_prob": 0.0,
|
24 |
"fp16": false,
|
25 |
"fp32": false,
|
@@ -27,8 +26,8 @@
|
|
27 |
"initializer_range": 0.02,
|
28 |
"intermediate_size": 22016,
|
29 |
"kv_channels": 128,
|
30 |
-
"layer_norm_epsilon": 1e-
|
31 |
-
"max_position_embeddings":
|
32 |
"model_type": "qwen",
|
33 |
"no_bias": true,
|
34 |
"num_attention_heads": 32,
|
@@ -47,7 +46,7 @@
|
|
47 |
"use_cache_kernel": false,
|
48 |
"use_cache_quantization": false,
|
49 |
"use_dynamic_ntk": true,
|
50 |
-
"use_flash_attn":
|
51 |
"use_logn_attn": true,
|
52 |
"vocab_size": 155947
|
53 |
}
|
|
|
1 |
{
|
|
|
2 |
"architectures": [
|
3 |
"QWenLMHeadModel"
|
4 |
],
|
5 |
"attn_dropout_prob": 0.0,
|
6 |
"audio": {
|
7 |
"add_audio_bos_eos_token": true,
|
8 |
+
"audio_start_id": 155163,
|
9 |
"avg_pool": true,
|
10 |
"n_ctx": 1500,
|
11 |
"n_head": 20,
|
|
|
18 |
"AutoConfig": "configuration_qwen.QWenConfig",
|
19 |
"AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
|
20 |
},
|
21 |
+
"bf16": false,
|
22 |
"emb_dropout_prob": 0.0,
|
23 |
"fp16": false,
|
24 |
"fp32": false,
|
|
|
26 |
"initializer_range": 0.02,
|
27 |
"intermediate_size": 22016,
|
28 |
"kv_channels": 128,
|
29 |
+
"layer_norm_epsilon": 1e-06,
|
30 |
+
"max_position_embeddings": 2048,
|
31 |
"model_type": "qwen",
|
32 |
"no_bias": true,
|
33 |
"num_attention_heads": 32,
|
|
|
46 |
"use_cache_kernel": false,
|
47 |
"use_cache_quantization": false,
|
48 |
"use_dynamic_ntk": true,
|
49 |
+
"use_flash_attn": "auto",
|
50 |
"use_logn_attn": true,
|
51 |
"vocab_size": 155947
|
52 |
}
|
configuration.json
CHANGED
@@ -1 +1,3 @@
|
|
1 |
-
{"framework":"Pytorch",
|
|
|
|
|
|
1 |
+
{"framework":"Pytorch",
|
2 |
+
"task":"multimodal-dialogue",
|
3 |
+
"allow_remote": true}
|
modeling_qwen.py
CHANGED
@@ -1015,20 +1015,18 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
1015 |
self.lm_head.half()
|
1016 |
self.post_init()
|
1017 |
|
1018 |
-
|
1019 |
@classmethod
|
1020 |
def from_pretrained(
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
|
1025 |
-
|
1026 |
-
|
1027 |
):
|
1028 |
if os.path.isdir(pretrained_model_name_or_path):
|
1029 |
# Local Directory of Models
|
1030 |
mel_filters_path = os.path.join(pretrained_model_name_or_path, 'mel_filters.npz')
|
1031 |
-
print(mel_filters_path)
|
1032 |
tgt_cache_path = os.path.join(os.path.dirname(__file__), 'mel_filters.npz')
|
1033 |
shutil.copy(mel_filters_path, tgt_cache_path)
|
1034 |
else:
|
@@ -1036,7 +1034,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
1036 |
from huggingface_hub import hf_hub_download
|
1037 |
hf_hub_download(repo_id=pretrained_model_name_or_path, filename="mel_filters.npz",
|
1038 |
token=kwargs.get('token', None), local_dir=os.path.dirname(__file__))
|
1039 |
-
return super().from_pretrained(pretrained_model_name_or_path, *model_args, config=config, cache_dir=cache_dir,
|
|
|
1040 |
|
1041 |
def get_output_embeddings(self):
|
1042 |
return self.lm_head
|
|
|
1015 |
self.lm_head.half()
|
1016 |
self.post_init()
|
1017 |
|
|
|
1018 |
@classmethod
|
1019 |
def from_pretrained(
|
1020 |
+
cls,
|
1021 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
1022 |
+
*model_args,
|
1023 |
+
config=None,
|
1024 |
+
cache_dir=None,
|
1025 |
+
**kwargs,
|
1026 |
):
|
1027 |
if os.path.isdir(pretrained_model_name_or_path):
|
1028 |
# Local Directory of Models
|
1029 |
mel_filters_path = os.path.join(pretrained_model_name_or_path, 'mel_filters.npz')
|
|
|
1030 |
tgt_cache_path = os.path.join(os.path.dirname(__file__), 'mel_filters.npz')
|
1031 |
shutil.copy(mel_filters_path, tgt_cache_path)
|
1032 |
else:
|
|
|
1034 |
from huggingface_hub import hf_hub_download
|
1035 |
hf_hub_download(repo_id=pretrained_model_name_or_path, filename="mel_filters.npz",
|
1036 |
token=kwargs.get('token', None), local_dir=os.path.dirname(__file__))
|
1037 |
+
return super().from_pretrained(pretrained_model_name_or_path, *model_args, config=config, cache_dir=cache_dir,
|
1038 |
+
**kwargs)
|
1039 |
|
1040 |
def get_output_embeddings(self):
|
1041 |
return self.lm_head
|
qwen_generation_utils.py
CHANGED
@@ -186,7 +186,6 @@ def make_context(
|
|
186 |
+ nl_tokens
|
187 |
)
|
188 |
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
|
189 |
-
print(raw_text)
|
190 |
audio_info = tokenizer.process_audio(raw_text)
|
191 |
|
192 |
elif chat_format == "raw":
|
|
|
186 |
+ nl_tokens
|
187 |
)
|
188 |
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
|
|
|
189 |
audio_info = tokenizer.process_audio(raw_text)
|
190 |
|
191 |
elif chat_format == "raw":
|
tokenization_qwen.py
CHANGED
@@ -17,13 +17,11 @@ from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Opt
|
|
17 |
|
18 |
import tiktoken
|
19 |
import numpy as np
|
20 |
-
|
21 |
-
from PIL import ImageFont
|
22 |
-
from PIL import ImageDraw
|
23 |
from transformers import PreTrainedTokenizer, AddedToken
|
24 |
from transformers.utils import try_to_load_from_cache
|
25 |
-
from transformers.tokenization_utils_base import BatchEncoding,PaddingStrategy,TruncationStrategy
|
26 |
-
TextInput,TextInputPair,PreTokenizedInput,PreTokenizedInputPair,TensorType, EncodedInput, EncodedInputPair
|
27 |
|
28 |
import matplotlib.colors as mcolors
|
29 |
from matplotlib.font_manager import FontProperties
|
@@ -31,7 +29,6 @@ from .audio import *
|
|
31 |
|
32 |
logger = logging.getLogger(__name__)
|
33 |
|
34 |
-
|
35 |
VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken", "ttf": "SimSun.ttf"}
|
36 |
|
37 |
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
@@ -43,11 +40,11 @@ IMEND = "<|im_end|>"
|
|
43 |
# as different as possible to minimize the impact
|
44 |
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
|
45 |
SPECIAL_TOKENS = (
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
) + EXTRAS
|
50 |
-
|
51 |
LANGUAGES = {
|
52 |
"en": "english",
|
53 |
"zh": "chinese",
|
@@ -68,23 +65,25 @@ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
|
|
68 |
for token, rank in (line.split() for line in contents.splitlines() if line)
|
69 |
}
|
70 |
|
|
|
71 |
def _list_find(
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
):
|
76 |
for i in range(start, len(input_list)):
|
77 |
if input_list[i] in candidates:
|
78 |
return i
|
79 |
return -1
|
80 |
|
|
|
81 |
def _replace_closed_tag(
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
):
|
89 |
if isinstance(start_tags, (str, int)):
|
90 |
start_tags = (start_tags,)
|
@@ -99,107 +98,93 @@ def _replace_closed_tag(
|
|
99 |
start = _list_find(input_tokens, start_tags, end)
|
100 |
if start == -1:
|
101 |
break
|
102 |
-
output_tokens.extend(exclusive_replace_func(input_tokens[end
|
103 |
tag_idx = start_tags.index(input_tokens[start])
|
104 |
end = _list_find(input_tokens, (end_tags[tag_idx],), start)
|
105 |
if end == -1:
|
106 |
-
raise ValueError("Unclosed
|
107 |
-
output_tokens.extend(inclusive_replace_func(input_tokens[start
|
108 |
end += 1
|
109 |
audio_idx += 1
|
110 |
-
output_tokens.extend(exclusive_replace_func(input_tokens[end
|
111 |
return output_tokens
|
112 |
|
|
|
113 |
class QWenTokenizer(PreTrainedTokenizer):
|
114 |
"""QWen tokenizer."""
|
115 |
|
116 |
vocab_files_names = VOCAB_FILES_NAMES
|
117 |
|
118 |
def __init__(
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
):
|
126 |
super().__init__(**kwargs)
|
127 |
self.audio_start_tag = audio_start_tag
|
128 |
self.audio_end_tag = audio_end_tag
|
129 |
self.audio_pad_tag = "[[[AUDIO:modality]]]"
|
130 |
-
self.IMAGE_ST = ("<ref>", "</ref>", "<box>", "</box>", "<quad>", "</quad>")
|
131 |
|
132 |
self.AUDIO_ST = (
|
133 |
'[[[AUDIO:modality]]]',
|
134 |
-
|
135 |
-
"<|
|
136 |
-
#
|
|
|
137 |
"<|translate|>",
|
138 |
"<|transcribe|>",
|
139 |
"<|caption|>",
|
140 |
"<|keyword|>",
|
141 |
-
#
|
142 |
-
"<|unknown|>", #
|
143 |
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
144 |
-
"<|
|
145 |
-
#
|
146 |
"<|notimestamps|>",
|
147 |
"<|sil|>",
|
148 |
"<|timestamps|>",
|
149 |
-
*[f"<|{i * 0.01:.2f}|>" for i in range(3001)],
|
150 |
-
#
|
151 |
-
"<|caption_audiocaps|>", #
|
152 |
-
"<|caption_clotho|>", #
|
153 |
-
"<|audioset_ontology|>", #
|
154 |
-
"<|caption_plain|>", #
|
155 |
-
"<|itn|>", #
|
156 |
-
"<|wo_itn|>", #
|
157 |
-
# 特殊任务——实体识别
|
158 |
"<|startofentityvalue|>",
|
159 |
"<|endofentityvalue|>",
|
160 |
"<|startofentitytype|>",
|
161 |
"<|endofentitytype|>",
|
162 |
-
"<|named_entity_recognition|>",
|
163 |
-
|
164 |
-
"<|grounding|>",
|
165 |
"<|startofword|>",
|
166 |
"<|endofword|>",
|
167 |
-
"<|delim|>", #
|
168 |
-
#
|
169 |
-
"<|
|
170 |
-
#
|
171 |
-
"<|
|
172 |
-
#
|
173 |
-
"<|
|
174 |
-
"<|
|
175 |
-
|
176 |
-
"<|
|
177 |
-
|
178 |
-
"<|
|
179 |
-
"<|
|
180 |
-
#
|
181 |
-
"<|
|
182 |
-
#
|
183 |
-
"<|
|
184 |
-
#
|
185 |
-
"<|
|
186 |
-
"<|
|
187 |
-
"<|
|
188 |
-
#
|
189 |
-
"<|
|
190 |
-
# 子任务--event
|
191 |
-
"<|event|>",
|
192 |
-
# 子任务--vocal_classification
|
193 |
-
"<|vocal_classification|>",
|
194 |
-
# 特殊任务--SLU
|
195 |
-
"<|speech_understanding|>",
|
196 |
-
"<|scenario|>",
|
197 |
-
"<|action|>",
|
198 |
-
"<|entities|>",
|
199 |
-
# 子任务--语音编辑
|
200 |
-
"<|speech_edit|>",
|
201 |
-
# 子任务--命令
|
202 |
-
"<|speech_command|>",
|
203 |
audio_start_tag,
|
204 |
audio_end_tag
|
205 |
)
|
@@ -210,9 +195,8 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
210 |
self.special_tokens = {
|
211 |
token: index
|
212 |
for index, token in enumerate(
|
213 |
-
# SPECIAL_TOKENS + self.IMAGE_ST + self.AUDIO_ST, start=len(self.mergeable_ranks)
|
214 |
SPECIAL_TOKENS + self.AUDIO_ST, start=len(self.mergeable_ranks)
|
215 |
-
|
216 |
)
|
217 |
}
|
218 |
self.audio_start_id = self.special_tokens[self.audio_start_tag]
|
@@ -229,7 +213,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
229 |
special_tokens=self.special_tokens,
|
230 |
)
|
231 |
assert (
|
232 |
-
|
233 |
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
|
234 |
|
235 |
self.decoder = {
|
@@ -260,7 +244,6 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
260 |
)
|
261 |
self.tokenizer = enc
|
262 |
|
263 |
-
|
264 |
def __len__(self) -> int:
|
265 |
return self.tokenizer.n_vocab
|
266 |
|
@@ -268,7 +251,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
268 |
return self.mergeable_ranks
|
269 |
|
270 |
def convert_tokens_to_ids(
|
271 |
-
|
272 |
) -> List[int]:
|
273 |
ids = []
|
274 |
if isinstance(tokens, (str, bytes)):
|
@@ -288,7 +271,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
288 |
raise ValueError('Adding regular tokens is not supported')
|
289 |
for token in new_tokens:
|
290 |
surface_form = token.content if isinstance(token, AddedToken) else token
|
291 |
-
if surface_form not in SPECIAL_TOKENS
|
292 |
raise ValueError('Adding unknown special tokens is not supported')
|
293 |
return 0
|
294 |
|
@@ -307,12 +290,12 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
307 |
return (file_path,)
|
308 |
|
309 |
def tokenize(
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
) -> List[Union[bytes, str]]:
|
317 |
"""
|
318 |
Converts a string in a sequence of tokens.
|
@@ -338,44 +321,46 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
338 |
|
339 |
# this implementation takes a detour: text -> token id -> token surface forms
|
340 |
for t in self.tokenizer.encode(
|
341 |
-
|
342 |
):
|
343 |
tokens.append(self.decoder[t])
|
344 |
|
345 |
def _encode_audiourl(audio_tokens, audio_info, audio_idx):
|
346 |
assert audio_tokens[0] == self.audio_start_tag and audio_tokens[-1] == self.audio_end_tag
|
347 |
audio_token_span = audio_info['audio_span_tokens'][audio_idx]
|
348 |
-
out_audio_tokens = [self.audio_start_tag] + [self.audio_pad_tag]*(audio_token_span-2) + [
|
|
|
349 |
return out_audio_tokens
|
350 |
|
351 |
-
return _replace_closed_tag(tokens, self.audio_start_tag, self.audio_end_tag, _encode_audiourl,
|
|
|
352 |
|
353 |
def _batch_encode_plus(
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
) -> BatchEncoding:
|
380 |
|
381 |
def get_input_ids(text):
|
@@ -409,7 +394,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
409 |
for pair_id in range(len(batch_text_or_text_pairs)):
|
410 |
kwargs['audio_info'] = audio_info[pair_id]
|
411 |
ids_or_pair_ids = batch_text_or_text_pairs[pair_id]
|
412 |
-
|
413 |
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
414 |
ids, pair_ids = ids_or_pair_ids, None
|
415 |
elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):
|
@@ -488,23 +473,23 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
488 |
raise NotImplementedError
|
489 |
|
490 |
def _decode(
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
) -> str:
|
497 |
if isinstance(token_ids, int):
|
498 |
token_ids = [token_ids]
|
499 |
audio_info = kwargs.pop("audio_info", None)
|
500 |
|
501 |
-
|
502 |
def _decode_audiourl(audio_token_ids, audio_info, audio_idx):
|
503 |
assert audio_token_ids[0] == self.audio_start_id and audio_token_ids[-1] == self.audio_end_id
|
504 |
audio_url = audio_info["audio_urls"][audio_idx]
|
505 |
return [self.audio_start_id] + self.tokenizer.encode(audio_url) + [self.audio_end_id]
|
506 |
|
507 |
-
token_ids = _replace_closed_tag(token_ids, self.audio_start_id, self.audio_end_id, _decode_audiourl,
|
|
|
508 |
|
509 |
if skip_special_tokens:
|
510 |
token_ids = [i for i in token_ids if i < self.eod_id]
|
@@ -513,7 +498,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
513 |
def to_list_format(self, text: str):
|
514 |
text = unicodedata.normalize("NFC", text)
|
515 |
token_ids = self.tokenizer.encode(
|
516 |
-
text, allowed_special=set(self.
|
517 |
|
518 |
def _encode_audio_info(tokens):
|
519 |
if len(tokens) == 0:
|
@@ -561,10 +546,10 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
561 |
|
562 |
def process_audio(self, text):
|
563 |
audio_urls = self.extract_audio_urls(text)
|
564 |
-
if len(audio_urls)> 0:
|
565 |
audios, audio_lens, audio_span_tokens = [], [], []
|
566 |
for audio_path in audio_urls:
|
567 |
-
if audio_path.startswith("http://") or audio_path.startswith("https://"):
|
568 |
data = bytes(requests.get(audio_path, stream=True).content)
|
569 |
audio = load_bytesio_audio(data)
|
570 |
else:
|
@@ -578,7 +563,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
578 |
audio_len = [audio_len_after_cnn, audio_token_num]
|
579 |
audios.append(mel)
|
580 |
audio_lens.append(audio_len)
|
581 |
-
audio_span_tokens.append(audio_token_num+2)
|
582 |
input_audio_lengths = torch.IntTensor(audio_lens)
|
583 |
input_audios = torch.stack(audios, dim=0)
|
584 |
return {"input_audios": input_audios,
|
|
|
17 |
|
18 |
import tiktoken
|
19 |
import numpy as np
|
20 |
+
|
|
|
|
|
21 |
from transformers import PreTrainedTokenizer, AddedToken
|
22 |
from transformers.utils import try_to_load_from_cache
|
23 |
+
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TruncationStrategy, \
|
24 |
+
TextInput, TextInputPair, PreTokenizedInput, PreTokenizedInputPair, TensorType, EncodedInput, EncodedInputPair
|
25 |
|
26 |
import matplotlib.colors as mcolors
|
27 |
from matplotlib.font_manager import FontProperties
|
|
|
29 |
|
30 |
logger = logging.getLogger(__name__)
|
31 |
|
|
|
32 |
VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken", "ttf": "SimSun.ttf"}
|
33 |
|
34 |
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
|
|
40 |
# as different as possible to minimize the impact
|
41 |
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
|
42 |
SPECIAL_TOKENS = (
|
43 |
+
ENDOFTEXT,
|
44 |
+
IMSTART,
|
45 |
+
IMEND,
|
46 |
+
) + EXTRAS
|
47 |
+
|
48 |
LANGUAGES = {
|
49 |
"en": "english",
|
50 |
"zh": "chinese",
|
|
|
65 |
for token, rank in (line.split() for line in contents.splitlines() if line)
|
66 |
}
|
67 |
|
68 |
+
|
69 |
def _list_find(
|
70 |
+
input_list: List[Any],
|
71 |
+
candidates: Tuple[Any],
|
72 |
+
start: int = 0,
|
73 |
):
|
74 |
for i in range(start, len(input_list)):
|
75 |
if input_list[i] in candidates:
|
76 |
return i
|
77 |
return -1
|
78 |
|
79 |
+
|
80 |
def _replace_closed_tag(
|
81 |
+
input_tokens: List[Any],
|
82 |
+
start_tags: Union[Any, Tuple[Any]],
|
83 |
+
end_tags: Union[Any, Tuple[Any]],
|
84 |
+
inclusive_replace_func: Callable,
|
85 |
+
exclusive_replace_func: Callable = lambda x: x,
|
86 |
+
audio_info: Dict = None
|
87 |
):
|
88 |
if isinstance(start_tags, (str, int)):
|
89 |
start_tags = (start_tags,)
|
|
|
98 |
start = _list_find(input_tokens, start_tags, end)
|
99 |
if start == -1:
|
100 |
break
|
101 |
+
output_tokens.extend(exclusive_replace_func(input_tokens[end: start]))
|
102 |
tag_idx = start_tags.index(input_tokens[start])
|
103 |
end = _list_find(input_tokens, (end_tags[tag_idx],), start)
|
104 |
if end == -1:
|
105 |
+
raise ValueError("Unclosed audio token")
|
106 |
+
output_tokens.extend(inclusive_replace_func(input_tokens[start: end + 1], audio_info, audio_idx))
|
107 |
end += 1
|
108 |
audio_idx += 1
|
109 |
+
output_tokens.extend(exclusive_replace_func(input_tokens[end:]))
|
110 |
return output_tokens
|
111 |
|
112 |
+
|
113 |
class QWenTokenizer(PreTrainedTokenizer):
|
114 |
"""QWen tokenizer."""
|
115 |
|
116 |
vocab_files_names = VOCAB_FILES_NAMES
|
117 |
|
118 |
def __init__(
|
119 |
+
self,
|
120 |
+
vocab_file,
|
121 |
+
errors="replace",
|
122 |
+
audio_start_tag='<audio>',
|
123 |
+
audio_end_tag='</audio>',
|
124 |
+
**kwargs,
|
125 |
):
|
126 |
super().__init__(**kwargs)
|
127 |
self.audio_start_tag = audio_start_tag
|
128 |
self.audio_end_tag = audio_end_tag
|
129 |
self.audio_pad_tag = "[[[AUDIO:modality]]]"
|
|
|
130 |
|
131 |
self.AUDIO_ST = (
|
132 |
'[[[AUDIO:modality]]]',
|
133 |
+
# Transcription Tag
|
134 |
+
"<|startoftranscript|>", # Transcription
|
135 |
+
"<|startofanalysis|>", # Analysis
|
136 |
+
# Task Tag
|
137 |
"<|translate|>",
|
138 |
"<|transcribe|>",
|
139 |
"<|caption|>",
|
140 |
"<|keyword|>",
|
141 |
+
# Language Tag
|
142 |
+
"<|unknown|>", # unknown language
|
143 |
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
144 |
+
"<|zh_tr|>", # tranditional Chinese
|
145 |
+
# Timestamps Tag
|
146 |
"<|notimestamps|>",
|
147 |
"<|sil|>",
|
148 |
"<|timestamps|>",
|
149 |
+
*[f"<|{i * 0.01:.2f}|>" for i in range(3001)], # timestamps 0.00-30.00
|
150 |
+
# Output Instruction
|
151 |
+
"<|caption_audiocaps|>", # Audiocaps caption style
|
152 |
+
"<|caption_clotho|>", # Clotho caption style
|
153 |
+
"<|audioset_ontology|>", # Audioset ontology style
|
154 |
+
"<|caption_plain|>", # plain caption
|
155 |
+
"<|itn|>", # inversed text normalized
|
156 |
+
"<|wo_itn|>", # without inversed text normalized
|
|
|
157 |
"<|startofentityvalue|>",
|
158 |
"<|endofentityvalue|>",
|
159 |
"<|startofentitytype|>",
|
160 |
"<|endofentitytype|>",
|
161 |
+
"<|named_entity_recognition|>", # named entity recognition task
|
162 |
+
"<|audio_grounding|>",
|
|
|
163 |
"<|startofword|>",
|
164 |
"<|endofword|>",
|
165 |
+
"<|delim|>", # delimiter of timestamps pair in audio grounding
|
166 |
+
"<|emotion_recognition|>", # emotion recognition
|
167 |
+
"<|music_description|>", # music description
|
168 |
+
"<|note_analysis|>", # note analysis
|
169 |
+
"<|pitch|>", # note analysis: pitch
|
170 |
+
*[f"<|midi_pitch_{i}|>" for i in range(128)], # midi pitch 0-127
|
171 |
+
"<|velocity|>", # note analysis: velocity
|
172 |
+
*[f"<|midi_velocity_{i}|>" for i in range(128)], # midi velocity 0-127
|
173 |
+
"<|sonic|>", # note analysis: sonic
|
174 |
+
"<|instrument|>", # note analysis: instrument
|
175 |
+
"<|speaker_meta|>", # meta information of speaker
|
176 |
+
"<|song_meta|>", # meta information of song
|
177 |
+
"<|question|>", # AQA: question
|
178 |
+
"<|answer|>", # AQA: answer
|
179 |
+
"<|choice|>", # AQA: answer choice
|
180 |
+
"<|scene|>", # scene recognition
|
181 |
+
"<|event|>", # sound event
|
182 |
+
"<|vocal_classification|>", # vocal classification
|
183 |
+
"<|speech_understanding|>", # speech language understanding
|
184 |
+
"<|scenario|>", # speech language understanding: scenario
|
185 |
+
"<|action|>", # speech language understanding: action
|
186 |
+
"<|entities|>", # speech language understanding: entities
|
187 |
+
"<|speech_edit|>", # speech edit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
audio_start_tag,
|
189 |
audio_end_tag
|
190 |
)
|
|
|
195 |
self.special_tokens = {
|
196 |
token: index
|
197 |
for index, token in enumerate(
|
|
|
198 |
SPECIAL_TOKENS + self.AUDIO_ST, start=len(self.mergeable_ranks)
|
199 |
+
|
200 |
)
|
201 |
}
|
202 |
self.audio_start_id = self.special_tokens[self.audio_start_tag]
|
|
|
213 |
special_tokens=self.special_tokens,
|
214 |
)
|
215 |
assert (
|
216 |
+
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
|
217 |
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
|
218 |
|
219 |
self.decoder = {
|
|
|
244 |
)
|
245 |
self.tokenizer = enc
|
246 |
|
|
|
247 |
def __len__(self) -> int:
|
248 |
return self.tokenizer.n_vocab
|
249 |
|
|
|
251 |
return self.mergeable_ranks
|
252 |
|
253 |
def convert_tokens_to_ids(
|
254 |
+
self, tokens: Union[bytes, str, List[Union[bytes, str]]]
|
255 |
) -> List[int]:
|
256 |
ids = []
|
257 |
if isinstance(tokens, (str, bytes)):
|
|
|
271 |
raise ValueError('Adding regular tokens is not supported')
|
272 |
for token in new_tokens:
|
273 |
surface_form = token.content if isinstance(token, AddedToken) else token
|
274 |
+
if surface_form not in SPECIAL_TOKENS + self.AUDIO_ST:
|
275 |
raise ValueError('Adding unknown special tokens is not supported')
|
276 |
return 0
|
277 |
|
|
|
290 |
return (file_path,)
|
291 |
|
292 |
def tokenize(
|
293 |
+
self,
|
294 |
+
text: str,
|
295 |
+
allowed_special: Union[Set, str] = "all",
|
296 |
+
disallowed_special: Union[Collection, str] = (),
|
297 |
+
audio_info: Dict = None,
|
298 |
+
**kwargs,
|
299 |
) -> List[Union[bytes, str]]:
|
300 |
"""
|
301 |
Converts a string in a sequence of tokens.
|
|
|
321 |
|
322 |
# this implementation takes a detour: text -> token id -> token surface forms
|
323 |
for t in self.tokenizer.encode(
|
324 |
+
text, allowed_special=allowed_special, disallowed_special=disallowed_special
|
325 |
):
|
326 |
tokens.append(self.decoder[t])
|
327 |
|
328 |
def _encode_audiourl(audio_tokens, audio_info, audio_idx):
|
329 |
assert audio_tokens[0] == self.audio_start_tag and audio_tokens[-1] == self.audio_end_tag
|
330 |
audio_token_span = audio_info['audio_span_tokens'][audio_idx]
|
331 |
+
out_audio_tokens = [self.audio_start_tag] + [self.audio_pad_tag] * (audio_token_span - 2) + [
|
332 |
+
self.audio_end_tag]
|
333 |
return out_audio_tokens
|
334 |
|
335 |
+
return _replace_closed_tag(tokens, self.audio_start_tag, self.audio_end_tag, _encode_audiourl,
|
336 |
+
audio_info=audio_info)
|
337 |
|
338 |
def _batch_encode_plus(
|
339 |
+
self,
|
340 |
+
batch_text_or_text_pairs: Union[
|
341 |
+
List[TextInput],
|
342 |
+
List[TextInputPair],
|
343 |
+
List[PreTokenizedInput],
|
344 |
+
List[PreTokenizedInputPair],
|
345 |
+
List[EncodedInput],
|
346 |
+
List[EncodedInputPair],
|
347 |
+
],
|
348 |
+
add_special_tokens: bool = True,
|
349 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
350 |
+
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
351 |
+
max_length: Optional[int] = None,
|
352 |
+
stride: int = 0,
|
353 |
+
is_split_into_words: bool = False,
|
354 |
+
pad_to_multiple_of: Optional[int] = None,
|
355 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
356 |
+
return_token_type_ids: Optional[bool] = None,
|
357 |
+
return_attention_mask: Optional[bool] = None,
|
358 |
+
return_overflowing_tokens: bool = False,
|
359 |
+
return_special_tokens_mask: bool = False,
|
360 |
+
return_offsets_mapping: bool = False,
|
361 |
+
return_length: bool = False,
|
362 |
+
verbose: bool = True,
|
363 |
+
**kwargs,
|
364 |
) -> BatchEncoding:
|
365 |
|
366 |
def get_input_ids(text):
|
|
|
394 |
for pair_id in range(len(batch_text_or_text_pairs)):
|
395 |
kwargs['audio_info'] = audio_info[pair_id]
|
396 |
ids_or_pair_ids = batch_text_or_text_pairs[pair_id]
|
397 |
+
# for ids_or_pair_ids in batch_text_or_text_pairs:
|
398 |
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
399 |
ids, pair_ids = ids_or_pair_ids, None
|
400 |
elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):
|
|
|
473 |
raise NotImplementedError
|
474 |
|
475 |
def _decode(
|
476 |
+
self,
|
477 |
+
token_ids: Union[int, List[int]],
|
478 |
+
skip_special_tokens: bool = False,
|
479 |
+
errors: str = None,
|
480 |
+
**kwargs,
|
481 |
) -> str:
|
482 |
if isinstance(token_ids, int):
|
483 |
token_ids = [token_ids]
|
484 |
audio_info = kwargs.pop("audio_info", None)
|
485 |
|
|
|
486 |
def _decode_audiourl(audio_token_ids, audio_info, audio_idx):
|
487 |
assert audio_token_ids[0] == self.audio_start_id and audio_token_ids[-1] == self.audio_end_id
|
488 |
audio_url = audio_info["audio_urls"][audio_idx]
|
489 |
return [self.audio_start_id] + self.tokenizer.encode(audio_url) + [self.audio_end_id]
|
490 |
|
491 |
+
token_ids = _replace_closed_tag(token_ids, self.audio_start_id, self.audio_end_id, _decode_audiourl,
|
492 |
+
audio_info=audio_info)
|
493 |
|
494 |
if skip_special_tokens:
|
495 |
token_ids = [i for i in token_ids if i < self.eod_id]
|
|
|
498 |
def to_list_format(self, text: str):
|
499 |
text = unicodedata.normalize("NFC", text)
|
500 |
token_ids = self.tokenizer.encode(
|
501 |
+
text, allowed_special=set(self.AUDIO_ST + (ENDOFTEXT,)))
|
502 |
|
503 |
def _encode_audio_info(tokens):
|
504 |
if len(tokens) == 0:
|
|
|
546 |
|
547 |
def process_audio(self, text):
|
548 |
audio_urls = self.extract_audio_urls(text)
|
549 |
+
if len(audio_urls) > 0:
|
550 |
audios, audio_lens, audio_span_tokens = [], [], []
|
551 |
for audio_path in audio_urls:
|
552 |
+
if audio_path.startswith("http://") or audio_path.startswith("https://"): # http
|
553 |
data = bytes(requests.get(audio_path, stream=True).content)
|
554 |
audio = load_bytesio_audio(data)
|
555 |
else:
|
|
|
563 |
audio_len = [audio_len_after_cnn, audio_token_num]
|
564 |
audios.append(mel)
|
565 |
audio_lens.append(audio_len)
|
566 |
+
audio_span_tokens.append(audio_token_num + 2) # add audio bos eos
|
567 |
input_audio_lengths = torch.IntTensor(audio_lens)
|
568 |
input_audios = torch.stack(audios, dim=0)
|
569 |
return {"input_audios": input_audios,
|
tokenizer_config.json
CHANGED
@@ -6,6 +6,6 @@
|
|
6 |
]
|
7 |
},
|
8 |
"clean_up_tokenization_spaces": true,
|
9 |
-
"model_max_length":
|
10 |
"tokenizer_class": "QWenTokenizer"
|
11 |
}
|
|
|
6 |
]
|
7 |
},
|
8 |
"clean_up_tokenization_spaces": true,
|
9 |
+
"model_max_length": 2048,
|
10 |
"tokenizer_class": "QWenTokenizer"
|
11 |
}
|