zR commited on
Commit
547dc25
·
1 Parent(s): 1b77b8d

fix convert_tokens_to_string

Browse files
Files changed (2) hide show
  1. README_en.md +0 -1
  2. tokenization_chatglm.py +4 -4
README_en.md CHANGED
@@ -135,7 +135,6 @@ sampling_params = SamplingParams(temperature=0.95, max_tokens=1024, stop_token_i
135
 
136
  inputs = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
137
  outputs = llm.generate(prompts=inputs, sampling_params=sampling_params)
138
-
139
  print(outputs[0].outputs[0].text)
140
  ```
141
 
 
135
 
136
  inputs = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
137
  outputs = llm.generate(prompts=inputs, sampling_params=sampling_params)
 
138
  print(outputs[0].outputs[0].text)
139
  ```
140
 
tokenization_chatglm.py CHANGED
@@ -63,22 +63,22 @@ class ChatGLM4Tokenizer(PreTrainedTokenizer):
63
  vocab.update(self.added_tokens_encoder)
64
  return vocab
65
 
66
- def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
67
  """
68
  Converts a sequence of tokens in a single string.
69
  """
70
  text = ""
71
  temp = b""
72
  for t in tokens:
 
 
73
  if isinstance(t, str):
74
  if temp:
75
  text += temp.decode("utf-8", errors="replace")
76
- temp = b""
77
- text += t
78
  elif isinstance(t, bytes):
79
  temp += t
80
  else:
81
- raise TypeError("token should only be of type types or str")
82
  if temp:
83
  text += temp.decode("utf-8", errors="replace")
84
  return text
 
63
  vocab.update(self.added_tokens_encoder)
64
  return vocab
65
 
66
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str, int]]) -> str:
67
  """
68
  Converts a sequence of tokens in a single string.
69
  """
70
  text = ""
71
  temp = b""
72
  for t in tokens:
73
+ if isinstance(t, int):
74
+ t = chr(t)
75
  if isinstance(t, str):
76
  if temp:
77
  text += temp.decode("utf-8", errors="replace")
 
 
78
  elif isinstance(t, bytes):
79
  temp += t
80
  else:
81
+ raise TypeError("token should only be of type int, bytes or str")
82
  if temp:
83
  text += temp.decode("utf-8", errors="replace")
84
  return text