BBuf commited on
Commit
8343faa
1 Parent(s): 262527f

Upload 9 files

Browse files
config.json CHANGED
@@ -21,6 +21,5 @@
21
  "tie_word_embeddings": false,
22
  "transformers_version": "4.33.1",
23
  "use_cache": true,
24
- "use_cache_kernel": true,
25
  "vocab_size": 65536
26
  }
 
21
  "tie_word_embeddings": false,
22
  "transformers_version": "4.33.1",
23
  "use_cache": true,
 
24
  "vocab_size": 65536
25
  }
configuration_rwkv5.py CHANGED
@@ -101,7 +101,6 @@ class Rwkv5Config(PretrainedConfig):
101
  eos_token_id=0,
102
  rescale_every=6,
103
  tie_word_embeddings=False,
104
- use_cache_kernel=True,
105
  use_cache=True,
106
  model_version="5_2",
107
  **kwargs,
@@ -115,7 +114,6 @@ class Rwkv5Config(PretrainedConfig):
115
  self.intermediate_size = None
116
  self.layer_norm_epsilon = layer_norm_epsilon
117
  self.rescale_every = rescale_every
118
- self.use_cache_kernel = use_cache_kernel
119
  self.use_cache = use_cache
120
 
121
  self.bos_token_id = bos_token_id
 
101
  eos_token_id=0,
102
  rescale_every=6,
103
  tie_word_embeddings=False,
 
104
  use_cache=True,
105
  model_version="5_2",
106
  **kwargs,
 
114
  self.intermediate_size = None
115
  self.layer_norm_epsilon = layer_norm_epsilon
116
  self.rescale_every = rescale_every
 
117
  self.use_cache = use_cache
118
 
119
  self.bos_token_id = bos_token_id
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chat_format": "chatml",
3
+ "eos_token_id": 0,
4
+ "pad_token_id": 0,
5
+ "max_window_size": 4096,
6
+ "max_new_tokens": 4096,
7
+ "do_sample": true,
8
+ "top_k": 0,
9
+ "top_p": 0.1,
10
+ "repetition_penalty": 1.0,
11
+ "transformers_version": "4.31.1"
12
+ }
modeling_rwkv5.py CHANGED
@@ -14,6 +14,7 @@
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  """PyTorch RWKV5 World model."""
 
17
  import math
18
  from dataclasses import dataclass
19
  from pathlib import Path
@@ -36,7 +37,8 @@ from transformers.utils import (
36
  logging,
37
  )
38
  from .configuration_rwkv5 import Rwkv5Config
39
- from .cpp_kernels import cache_wkv5
 
40
  logger = logging.get_logger(__name__)
41
 
42
  _CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world"
@@ -46,30 +48,6 @@ RWKV_PRETRAINED_MODEL_ARCHIVE_LIST = [
46
 
47
  ]
48
 
49
- def rwkv_linear_attention_v5_2_cuda(B, T, C, H, state, r, k, v, w, u, cache_kernels):
50
- assert HEAD_SIZE == C // H
51
- ctx.B = B
52
- ctx.T = T
53
- ctx.C = C
54
- ctx.H = H
55
- assert state.dtype == torch.float32
56
- assert w.dtype == torch.float32
57
- assert r.is_contiguous()
58
- assert k.is_contiguous()
59
- assert v.is_contiguous()
60
- assert w.is_contiguous()
61
- assert u.is_contiguous()
62
- assert state.is_contiguous()
63
-
64
- y = torch.empty((B, T, C), device=w.device, dtype=r.dtype, memory_format=torch.contiguous_format)
65
- if r.dtype == torch.bfloat16:
66
- cache_kernels.forward_bf16(B, T, C, H, state, r, k, v, w, u, y)
67
- elif r.dtype == torch.float16:
68
- cache_kernels.forward_fp16(B, T, C, H, state, r, k, v, w, u, y)
69
- elif r.dtype == torch.float32:
70
- cache_kernels.forward_fp32(B, T, C, H, state, r, k, v, w, u, y)
71
- return y, state
72
-
73
  def rwkv_linear_attention_v5_0(H, S, T, hidden, time_decay, time_first, receptance, key, value, lxw, lxb, ow, state, return_state=False, seq_mode=True):
74
  time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1)
75
  time_first = torch.exp(time_first.float()).reshape(-1,1,1)
@@ -107,7 +85,7 @@ def rwkv_linear_attention_v5_0(H, S, T, hidden, time_decay, time_first, receptan
107
 
108
  return out, state
109
 
110
- def rwkv_linear_attention_v5_2_cpu(H, S, T, n_head, hidden, time_decay, time_first, receptance, key, value, gate, lxw, lxb, ow, state, return_state=False, seq_mode=True):
111
  time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1).reshape(n_head, -1, 1)
112
  time_first = time_first.float().reshape(-1,1,1).reshape(n_head, -1, 1)
113
  lxw = lxw.float()
@@ -136,55 +114,43 @@ def rwkv_linear_attention_v5_2_cpu(H, S, T, n_head, hidden, time_decay, time_fir
136
  out = out @ ow
137
 
138
  return out, state
 
 
139
  class RwkvSelfAttention(nn.Module):
140
  def __init__(self, config, layer_id=0):
141
  super().__init__()
142
  self.config = config
143
  self.layer_id = layer_id
144
- if config.use_cache_kernel:
145
- # pre check if the support files existing
146
- module_root = pathlib.Path(__file__).parent
147
- src_files = ("rwkv5_op.cpp", "rwkv5.cu")
148
- if any(not (module_root/src).is_file() for src in src_files):
149
- warnings.warn("State cache kernel source files (.cpp and .cu) not found.")
150
- self.cache_kernels = None
151
- else:
152
- try:
153
- from .cpp_kernels import cache_wkv5
154
- self.cache_kernels = cache_wkv5
155
- except ImportError:
156
- warnings.warn("Failed to import KV cache kernels.")
157
- self.cache_kernels = None
158
- self.hidden_size = config.hidden_size
159
  # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L146
160
- num_attention_heads = self.hidden_size // config.head_size
161
  self.num_attention_heads = num_attention_heads
162
  attention_hidden_size = (
163
- config.attention_hidden_size if config.attention_hidden_size is not None else self.hidden_size
164
  )
165
  self.attention_hidden_size = attention_hidden_size
166
 
167
  if self.config.model_version == "5_2":
168
  self.time_decay = nn.Parameter(torch.empty(num_attention_heads, config.head_size))
169
  self.time_faaaa = nn.Parameter(torch.empty(num_attention_heads, config.head_size))
170
- self.time_mix_gate = nn.Parameter(torch.empty(1, 1, self.hidden_size))
171
  else:
172
  self.time_decay = nn.Parameter(torch.empty(num_attention_heads))
173
  self.time_first = nn.Parameter(torch.empty(num_attention_heads))
174
 
175
- self.time_mix_key = nn.Parameter(torch.empty(1, 1, self.hidden_size))
176
- self.time_mix_value = nn.Parameter(torch.empty(1, 1, self.hidden_size))
177
- self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, self.hidden_size))
178
 
179
  self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
180
- self.key = nn.Linear(self.hidden_size, attention_hidden_size, bias=False)
181
- self.value = nn.Linear(self.hidden_size, attention_hidden_size, bias=False)
182
- self.receptance = nn.Linear(self.hidden_size, attention_hidden_size, bias=False)
183
  if self.config.model_version == "5_2":
184
- self.gate = nn.Linear(self.hidden_size, attention_hidden_size, bias=False)
185
- self.output = nn.Linear(attention_hidden_size, self.hidden_size, bias=False)
186
  # https://github.com/BlinkDL/RWKV-LM/blob/3db37a72356b736966ddd377268f02b80963af3f/RWKV-v4neo/src/model.py#L190C1-L190C1
187
- self.ln_x = nn.GroupNorm(self.hidden_size // config.head_size, self.hidden_size)
188
 
189
  # TODO: maybe jit, otherwise move inside forward
190
  def extract_key_value(self, H, S, T, hidden, state=None):
@@ -200,18 +166,19 @@ class RwkvSelfAttention(nn.Module):
200
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
201
  if self.config.model_version == "5_2":
202
  gate = hidden* self.time_mix_gate + shifted * (1 - self.time_mix_gate)
203
- gate = F.silu(self.gate(gate))
204
 
205
- if self.cache_kernels is None:
206
- if hidden.size(1) == 1 and state is not None:
207
- receptance = self.receptance(receptance).to(torch.float32).view(H, 1, S)
208
- key = self.key(key).to(torch.float32).view(H, S, 1)
209
- value = self.value(value).to(torch.float32).view(H, 1, S)
210
- else:
211
- # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
212
- key = self.key(key).to(torch.float32).view(T, H, S).transpose(0, 1).transpose(-2, -1)
213
- value = self.value(value).to(torch.float32).view(T, H, S).transpose(0, 1)
214
- receptance = self.receptance(receptance).to(torch.float32).view(T, H, S).transpose(0, 1)
 
 
215
 
216
  if state is not None:
217
  state[0][:, :, self.layer_id] = hidden[:, -1]
@@ -231,34 +198,25 @@ class RwkvSelfAttention(nn.Module):
231
  receptance, key, value, state = self.extract_key_value(H, S, T, hidden, state=state)
232
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
233
  if self.config.model_version == "5_2":
234
- if self.cache_kernels is not None and seq_mode:
235
- rwkv, layer_state = rwkv_linear_attention_v5_2_cuda(1, T, self.hidden_size, H, layer_state.transpose(-1, -2).contiguous(),
236
- receptance, key, value, self.time_decay, self.time_faaaa,)
237
- layer_state = layer_state.transpose(-1,-2)
238
- rwkv = rwkv.reshape(T, H*N)
239
- rwkv = F.group_norm(rwkv, num_groups=H, weight=self.ln_x.weight, bias=self.ln_x.bias)
240
- rwkv = rwkv.to(dtype=hidden.dtype) * gate
241
- rwkv = rwkv @ self.output.weight.t()
242
- else:
243
- rwkv, layer_state = rwkv_linear_attention_v5_2_cpu(
244
- H,
245
- S,
246
- T,
247
- self.num_attention_heads,
248
- hidden,
249
- self.time_decay,
250
- self.time_faaaa,
251
- receptance,
252
- key,
253
- value,
254
- gate,
255
- self.ln_x.weight,
256
- self.ln_x.bias,
257
- self.output.weight.t(),
258
- state=layer_state,
259
- return_state=use_cache,
260
- seq_mode=seq_mode,
261
- )
262
  else:
263
  rwkv, layer_state = rwkv_linear_attention_v5_0(
264
  H,
 
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  """PyTorch RWKV5 World model."""
17
+
18
  import math
19
  from dataclasses import dataclass
20
  from pathlib import Path
 
37
  logging,
38
  )
39
  from .configuration_rwkv5 import Rwkv5Config
40
+
41
+
42
  logger = logging.get_logger(__name__)
43
 
44
  _CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world"
 
48
 
49
  ]
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def rwkv_linear_attention_v5_0(H, S, T, hidden, time_decay, time_first, receptance, key, value, lxw, lxb, ow, state, return_state=False, seq_mode=True):
52
  time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1)
53
  time_first = torch.exp(time_first.float()).reshape(-1,1,1)
 
85
 
86
  return out, state
87
 
88
+ def rwkv_linear_attention_v5_2(H, S, T, n_head, hidden, time_decay, time_first, receptance, key, value, gate, lxw, lxb, ow, state, return_state=False, seq_mode=True):
89
  time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1).reshape(n_head, -1, 1)
90
  time_first = time_first.float().reshape(-1,1,1).reshape(n_head, -1, 1)
91
  lxw = lxw.float()
 
114
  out = out @ ow
115
 
116
  return out, state
117
+
118
+
119
  class RwkvSelfAttention(nn.Module):
120
  def __init__(self, config, layer_id=0):
121
  super().__init__()
122
  self.config = config
123
  self.layer_id = layer_id
124
+ hidden_size = config.hidden_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L146
126
+ num_attention_heads = hidden_size // config.head_size
127
  self.num_attention_heads = num_attention_heads
128
  attention_hidden_size = (
129
+ config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size
130
  )
131
  self.attention_hidden_size = attention_hidden_size
132
 
133
  if self.config.model_version == "5_2":
134
  self.time_decay = nn.Parameter(torch.empty(num_attention_heads, config.head_size))
135
  self.time_faaaa = nn.Parameter(torch.empty(num_attention_heads, config.head_size))
136
+ self.time_mix_gate = nn.Parameter(torch.empty(1, 1, hidden_size))
137
  else:
138
  self.time_decay = nn.Parameter(torch.empty(num_attention_heads))
139
  self.time_first = nn.Parameter(torch.empty(num_attention_heads))
140
 
141
+ self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
142
+ self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size))
143
+ self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
144
 
145
  self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
146
+ self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False)
147
+ self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False)
148
+ self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
149
  if self.config.model_version == "5_2":
150
+ self.gate = nn.Linear(hidden_size, attention_hidden_size, bias=False)
151
+ self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
152
  # https://github.com/BlinkDL/RWKV-LM/blob/3db37a72356b736966ddd377268f02b80963af3f/RWKV-v4neo/src/model.py#L190C1-L190C1
153
+ self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
154
 
155
  # TODO: maybe jit, otherwise move inside forward
156
  def extract_key_value(self, H, S, T, hidden, state=None):
 
166
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
167
  if self.config.model_version == "5_2":
168
  gate = hidden* self.time_mix_gate + shifted * (1 - self.time_mix_gate)
 
169
 
170
+ if hidden.size(1) == 1 and state is not None:
171
+ receptance = self.receptance(receptance).to(torch.float32).view(H, 1, S)
172
+ key = self.key(key).to(torch.float32).view(H, S, 1)
173
+ value = self.value(value).to(torch.float32).view(H, 1, S)
174
+ else:
175
+ # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
176
+ key = self.key(key).to(torch.float32).view(T, H, S).transpose(0, 1).transpose(-2, -1)
177
+ value = self.value(value).to(torch.float32).view(T, H, S).transpose(0, 1)
178
+ receptance = self.receptance(receptance).to(torch.float32).view(T, H, S).transpose(0, 1)
179
+
180
+ if self.config.model_version == "5_2":
181
+ gate = F.silu(self.gate(gate))
182
 
183
  if state is not None:
184
  state[0][:, :, self.layer_id] = hidden[:, -1]
 
198
  receptance, key, value, state = self.extract_key_value(H, S, T, hidden, state=state)
199
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
200
  if self.config.model_version == "5_2":
201
+ rwkv, layer_state = rwkv_linear_attention_v5_2(
202
+ H,
203
+ S,
204
+ T,
205
+ self.num_attention_heads,
206
+ hidden,
207
+ self.time_decay,
208
+ self.time_faaaa,
209
+ receptance,
210
+ key,
211
+ value,
212
+ gate,
213
+ self.ln_x.weight,
214
+ self.ln_x.bias,
215
+ self.output.weight.t(),
216
+ state=layer_state,
217
+ return_state=use_cache,
218
+ seq_mode=seq_mode,
219
+ )
 
 
 
 
 
 
 
 
 
220
  else:
221
  rwkv, layer_state = rwkv_linear_attention_v5_0(
222
  H,
rwkv_vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenization_rwkv_world.py CHANGED
@@ -52,186 +52,52 @@ if TYPE_CHECKING:
52
  logger = logging.get_logger(__name__)
53
 
54
  VOCAB_FILES_NAMES = {
55
- "vocab_file": "rwkv_vocab_v20230424.json",
56
  }
57
 
58
-
59
- class DATrie:
60
- class Node:
61
- def __init__(self, is_leaf=False, leaf_data=None, tail=""):
62
- self._is_leaf = is_leaf
63
- self._leaf_data = leaf_data
64
- self._tail = tail
65
- self._next_map = {}
66
-
67
- def is_leaf(self):
68
- return self._is_leaf
69
-
70
- def set_leaf(self):
71
- self._is_leaf = True
72
-
73
- def has_next(self, w):
74
- if w in self._next_map:
75
- return True
76
- return False
77
-
78
- def add_node(self, w, node):
79
- self._next_map[w] = node
80
-
81
- def get_node(self, w):
82
- if w in self._next_map:
83
- return self._next_map[w]
84
- return None
85
-
86
- def get_tail(self):
87
- return self._tail
88
-
89
- def get_data(self):
90
- return self._leaf_data
91
-
92
- def set_data(self, data):
93
- self._leaf_data = data
94
-
95
- def __init__(self, special_ids):
96
- self.root = self.Node()
97
- self.data = {}
98
- self.r_data = {}
99
- self.special_ids = special_ids
100
-
101
- def insert(self, word, data):
102
- self.data[word] = data
103
- self.r_data[data] = word
104
- idx = 0
105
- node = self.root
106
- while idx < len(word):
107
- w = word[idx]
108
- is_leaf = (idx == (len(word) - 1))
109
- leaf_data = (data if is_leaf else None)
110
- # 不存在则插入
111
- if not node.has_next(w):
112
- node.add_node(w, self.Node(is_leaf=is_leaf, leaf_data=leaf_data))
113
- # last word
114
- node = node.get_node(w)
115
- idx += 1
116
- if not node.is_leaf():
117
- node.set_leaf()
118
- node.set_data(data)
119
-
120
- def findStrict(self, word):
121
- idx = 0
122
- node = self.root
123
- while node is not None and idx < len(word):
124
- w = word[idx]
125
- if not node.has_next(w):
126
- return None
127
- # last word
128
- node = node.get_node(w)
129
- idx += 1
130
- if node.is_leaf():
131
- return node.get_data()
132
- return None
133
-
134
- def prefix(self, word):
135
- idx = 0
136
- node = self.root
137
- result = []
138
- while node is not None and idx < len(word):
139
- w = word[idx]
140
- if not node.has_next(w):
141
- return result
142
- # last word
143
- node = node.get_node(w)
144
- if node.is_leaf():
145
- result.append([word[:idx + 1], node.get_data()])
146
- idx += 1
147
- return result
148
-
149
- def max_prefix(self, content, start_idx):
150
- idx = start_idx
151
- node = self.root
152
- l = len(content)
153
- result = [["", ], ]
154
- while node is not None and idx < l:
155
- w = content[idx]
156
- if not node.has_next(w):
157
- return result[-1]
158
- # last word
159
- node = node.get_node(w)
160
- if node.is_leaf():
161
- result.append([content[start_idx:idx + 1], node.get_data()])
162
  idx += 1
163
- return result[-1]
164
-
165
- def max_score(self, content, start_idx):
166
- idx = start_idx
167
- node = self.root
168
- l = len(content)
169
- result = [["", (3, 0)], ]
170
- while node is not None and idx < l:
171
- w = content[idx]
172
- if not node.has_next(w):
173
- break
174
- # last word
175
- node = node.get_node(w)
176
- if node.is_leaf():
177
- result.append([content[start_idx:idx + 1], node.get_data()])
178
- idx += 1
179
- if len(result) > 1:
180
- result = sorted(result, key=lambda x: x[1][1])
181
- return result[-1]
182
-
183
- def match(self, content, add_unk=True, unk_id=-1, **kwargs):
184
- # length
185
- l = len(content)
186
- i = 0
187
- result_list = []
188
- while i < l:
189
- match_word = self.max_prefix(content=content, start_idx=i)
190
- # print(match_word)
191
- w = match_word[0]
192
- if len(w) > 0:
193
- result_list.append(match_word[1])
194
- i += len(w)
195
- else:
196
- if add_unk:
197
- result_list.append(unk_id)
198
- i += 1
199
- return result_list
200
-
201
- def id2str(self, ids, escape_special_ids=True, end_ids=[], **kwargs):
202
- res_str = ""
203
- for rid in ids:
204
- if rid in self.r_data:
205
- if rid in end_ids:
206
- break
207
- if escape_special_ids and rid in self.special_ids:
208
- continue
209
- rstr = self.r_data[rid]
210
- res_str += rstr
211
- elif rid == 0:
212
  break
213
- else:
214
- print("ERROR unknown id %d" % rid)
215
- res_str += "UNK"
216
- return res_str
217
-
218
- def id2str_v2(self, ids, escape_special_ids=True, end_ids=[], **kwargs):
219
- res_str = ""
220
- for rid in ids:
221
- if rid in self.r_data:
222
- if rid in end_ids:
223
- break
224
- rstr = self.r_data[rid]
225
- if escape_special_ids and rid in self.special_ids:
226
- continue
227
- res_str += rstr
228
- elif rid == 0:
229
- break
230
- else:
231
- print("ERROR unknown id %d" % rid)
232
- res_str += "UNK"
233
- return res_str
234
-
235
 
236
  class RWKVWorldTokenizer(PreTrainedTokenizer):
237
  vocab_files_names = VOCAB_FILES_NAMES
@@ -244,17 +110,30 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
244
  **kwargs
245
  ):
246
  self.add_bos_token = False
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- with open(vocab_file, encoding="utf-8") as vocab_handle:
249
- self.encoder = json.load(vocab_handle)
250
  super().__init__(
251
  errors=errors,
252
  **kwargs,
253
  )
254
- self.decoder = {v: k for k, v in self.encoder.items()}
255
- self.trie = DATrie(self.all_special_ids)
256
- for k, v in self.encoder.items():
257
- self.trie.insert(k, v)
 
 
 
258
  self.errors = errors # how to handle errors in decoding
259
  self.cache = {}
260
 
@@ -311,9 +190,23 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
311
  return [1] + ([0] * len(token_ids_0))
312
  return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  def _tokenize(self, text, **kwargs):
315
  """Tokenize a string."""
316
- return self.trie.match(text, unk_id=self.unk_token_id, **kwargs)
317
 
318
  def _decode(self,
319
  token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
@@ -326,13 +219,9 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
326
  if isinstance(token_ids, int):
327
  if token_ids in self.all_special_ids and skip_special_tokens:
328
  return ""
329
- return self.decoder.get(token_ids, self.unk_token)
330
  elif isinstance(token_ids, list):
331
- return self.trie.id2str(
332
- token_ids,
333
- escape_special_ids=skip_special_tokens,
334
- **kwargs
335
- )
336
  else:
337
  return token_ids
338
 
@@ -383,10 +272,10 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
383
  ) -> BatchEncoding:
384
  def get_input_ids(text):
385
  if isinstance(text, str):
386
- text_id = self.trie.match(text, unk_id=self.unk_token_id)
387
  return text_id
388
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
389
- return [self.trie.match(t, unk_id=self.unk_token_id) for t in text]
390
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
391
  return text
392
  else:
@@ -448,10 +337,10 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
448
  ) -> BatchEncoding:
449
  def get_input_ids(text):
450
  if isinstance(text, str):
451
- text_id = self.trie.match(text, unk_id=self.unk_token_id)
452
  return text_id
453
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
454
- return [self.trie.match(t, unk_id=self.unk_token_id) for t in text]
455
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
456
  return text
457
  else:
 
52
  logger = logging.get_logger(__name__)
53
 
54
  VOCAB_FILES_NAMES = {
55
+ "vocab_file": "rwkv_vocab_v20230424.txt",
56
  }
57
 
58
+ class TRIE:
59
+ __slots__ = tuple("ch,to,values,front".split(","))
60
+ to:list
61
+ values:set
62
+ def __init__(self, front=None, ch=None):
63
+ self.ch = ch
64
+ self.to = [None for ch in range(256)]
65
+ self.values = set()
66
+ self.front = front
67
+
68
+ def __repr__(self):
69
+ fr = self
70
+ ret = []
71
+ while(fr!=None):
72
+ if(fr.ch!=None):
73
+ ret.append(fr.ch)
74
+ fr = fr.front
75
+ return "<TRIE %s %s>"%(ret[::-1], self.values)
76
+
77
+ def add(self, key:bytes, idx:int=0, val=None):
78
+ if(idx == len(key)):
79
+ if(val is None):
80
+ val = key
81
+ self.values.add(val)
82
+ return self
83
+ ch = key[idx]
84
+ if(self.to[ch] is None):
85
+ self.to[ch] = TRIE(front=self, ch=ch)
86
+ return self.to[ch].add(key, idx=idx+1, val=val)
87
+
88
+ def find_longest(self, key:bytes, idx:int=0):
89
+ u:TRIE = self
90
+ ch:int = key[idx]
91
+
92
+ while(u.to[ch] is not None):
93
+ u = u.to[ch]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  idx += 1
95
+ if(u.values):
96
+ ret = idx, u, u.values
97
+ if(idx==len(key)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  break
99
+ ch = key[idx]
100
+ return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  class RWKVWorldTokenizer(PreTrainedTokenizer):
103
  vocab_files_names = VOCAB_FILES_NAMES
 
110
  **kwargs
111
  ):
112
  self.add_bos_token = False
113
+ self.encoder = {}
114
+ sorted = [] # must be already sorted
115
+ with open(vocab_file, "r", encoding="utf-8") as f:
116
+ lines = f.readlines()
117
+ for l in lines:
118
+ idx = int(l[:l.index(' ')])
119
+ x = eval(l[l.index(' '):l.rindex(' ')])
120
+ x = x.encode("utf-8") if isinstance(x, str) else x
121
+ assert isinstance(x, bytes)
122
+ assert len(x) == int(l[l.rindex(' '):])
123
+ sorted += [x]
124
+ self.encoder[idx] = x
125
 
 
 
126
  super().__init__(
127
  errors=errors,
128
  **kwargs,
129
  )
130
+ self.decoder = {}
131
+ for k,v in self.encoder.items():
132
+ self.decoder[v] = int(k)
133
+
134
+ self.trie = TRIE()
135
+ for t, i in self.decoder.items():
136
+ _ = self.trie.add(t, val=(t, i))
137
  self.errors = errors # how to handle errors in decoding
138
  self.cache = {}
139
 
 
190
  return [1] + ([0] * len(token_ids_0))
191
  return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
192
 
193
+ def encodeBytes(self, src:bytes):
194
+ idx:int = 0
195
+ tokens = []
196
+ while (idx < len(src)):
197
+ _idx:int = idx
198
+ idx, _, values = self.trie.find_longest(src, idx)
199
+ assert(idx != _idx)
200
+ _, token = next(iter(values))
201
+ tokens.append(token)
202
+ return tokens
203
+
204
+ def decodeBytes(self, tokens):
205
+ return b''.join(map(lambda i: self.encoder[i], tokens))
206
+
207
  def _tokenize(self, text, **kwargs):
208
  """Tokenize a string."""
209
+ return self.encodeBytes(text.encode("utf-8"))
210
 
211
  def _decode(self,
212
  token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
 
219
  if isinstance(token_ids, int):
220
  if token_ids in self.all_special_ids and skip_special_tokens:
221
  return ""
222
+ return self.encoder.get(token_ids, self.unk_token)
223
  elif isinstance(token_ids, list):
224
+ return self.decodeBytes(tokens).decode('utf-8')
 
 
 
 
225
  else:
226
  return token_ids
227
 
 
272
  ) -> BatchEncoding:
273
  def get_input_ids(text):
274
  if isinstance(text, str):
275
+ text_id = self._tokenize(text)
276
  return text_id
277
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
278
+ return [self._tokenize(t) for t in text]
279
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
280
  return text
281
  else:
 
337
  ) -> BatchEncoding:
338
  def get_input_ids(text):
339
  if isinstance(text, str):
340
+ text_id = self._tokenize(text)
341
  return text_id
342
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
343
+ return [self._tokenize(t) for t in text]
344
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
345
  return text
346
  else: