yangapku commited on
Commit
7309552
·
1 Parent(s): 83cdd36

update tokenization.py

Browse files
Files changed (1) hide show
  1. tokenization_qwen.py +44 -15
tokenization_qwen.py CHANGED
@@ -27,11 +27,21 @@ IMEND = "<|im_end|>"
27
  # regular texts, the surface forms of special tokens need to be
28
  # as different as possible to minimize the impact
29
  EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
30
- SPECIAL_TOKENS = (
31
- ENDOFTEXT,
32
- IMSTART,
33
- IMEND,
34
- ) + EXTRAS
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
@@ -42,6 +52,7 @@ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
42
  for token, rank in (line.split() for line in contents.splitlines() if line)
43
  }
44
 
 
45
  class QWenTokenizer(PreTrainedTokenizer):
46
  """QWen tokenizer."""
47
 
@@ -51,20 +62,35 @@ class QWenTokenizer(PreTrainedTokenizer):
51
  self,
52
  vocab_file,
53
  errors="replace",
 
54
  **kwargs,
55
  ):
56
  super().__init__(**kwargs)
57
 
58
- self.errors = errors # how to handle errors in decoding
 
 
59
 
60
- self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
61
  self.special_tokens = {
62
  token: index
63
- for index, token in enumerate(
64
- SPECIAL_TOKENS, start=len(self.mergeable_ranks)
65
- )
66
  }
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  enc = tiktoken.Encoding(
69
  "Qwen",
70
  pat_str=PAT_STR,
@@ -89,7 +115,7 @@ class QWenTokenizer(PreTrainedTokenizer):
89
  def __getstate__(self):
90
  # for pickle lovers
91
  state = self.__dict__.copy()
92
- del state['tokenizer']
93
  return state
94
 
95
  def __setstate__(self, state):
@@ -103,7 +129,6 @@ class QWenTokenizer(PreTrainedTokenizer):
103
  )
104
  self.tokenizer = enc
105
 
106
-
107
  def __len__(self) -> int:
108
  return self.tokenizer.n_vocab
109
 
@@ -126,13 +151,17 @@ class QWenTokenizer(PreTrainedTokenizer):
126
  ids.append(self.mergeable_ranks.get(token))
127
  return ids
128
 
129
- def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
 
 
 
 
130
  if not special_tokens and new_tokens:
131
- raise ValueError('Adding regular tokens is not supported')
132
  for token in new_tokens:
133
  surface_form = token.content if isinstance(token, AddedToken) else token
134
  if surface_form not in SPECIAL_TOKENS:
135
- raise ValueError('Adding unknown special tokens is not supported')
136
  return 0
137
 
138
  def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
 
27
  # regular texts, the surface forms of special tokens need to be
28
  # as different as possible to minimize the impact
29
  EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
30
+ # changed to use actual index to avoid misconfiguration with vocabulary expansion
31
+ SPECIAL_START_ID = 151643
32
+ SPECIAL_TOKENS = tuple(
33
+ enumerate(
34
+ (
35
+ (
36
+ ENDOFTEXT,
37
+ IMSTART,
38
+ IMEND,
39
+ )
40
+ + EXTRAS
41
+ ),
42
+ start=SPECIAL_START_ID,
43
+ )
44
+ )
45
 
46
 
47
  def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
 
52
  for token, rank in (line.split() for line in contents.splitlines() if line)
53
  }
54
 
55
+
56
  class QWenTokenizer(PreTrainedTokenizer):
57
  """QWen tokenizer."""
58
 
 
62
  self,
63
  vocab_file,
64
  errors="replace",
65
+ extra_vocab_file=None,
66
  **kwargs,
67
  ):
68
  super().__init__(**kwargs)
69
 
70
+ # how to handle errors in decoding UTF-8 byte sequences
71
+ # use ignore if you are in streaming inference
72
+ self.errors = errors
73
 
74
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int]
75
  self.special_tokens = {
76
  token: index
77
+ for index, token in SPECIAL_TOKENS
 
 
78
  }
79
 
80
+ # try load extra vocab from file
81
+ if extra_vocab_file is not None:
82
+ used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
83
+ extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
84
+ for token, index in extra_mergeable_ranks.items():
85
+ if token in self.mergeable_ranks:
86
+ logger.info(f"extra token {token} exists, skipping")
87
+ continue
88
+ if index in used_ids:
89
+ logger.info(f'the index {index} for extra token {token} exists, skipping')
90
+ continue
91
+ self.mergeable_ranks[token] = index
92
+ # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
93
+
94
  enc = tiktoken.Encoding(
95
  "Qwen",
96
  pat_str=PAT_STR,
 
115
  def __getstate__(self):
116
  # for pickle lovers
117
  state = self.__dict__.copy()
118
+ del state["tokenizer"]
119
  return state
120
 
121
  def __setstate__(self, state):
 
129
  )
130
  self.tokenizer = enc
131
 
 
132
  def __len__(self) -> int:
133
  return self.tokenizer.n_vocab
134
 
 
151
  ids.append(self.mergeable_ranks.get(token))
152
  return ids
153
 
154
+ def _add_tokens(
155
+ self,
156
+ new_tokens: Union[List[str], List[AddedToken]],
157
+ special_tokens: bool = False,
158
+ ) -> int:
159
  if not special_tokens and new_tokens:
160
+ raise ValueError("Adding regular tokens is not supported")
161
  for token in new_tokens:
162
  surface_form = token.content if isinstance(token, AddedToken) else token
163
  if surface_form not in SPECIAL_TOKENS:
164
+ raise ValueError("Adding unknown special tokens is not supported")
165
  return 0
166
 
167
  def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: