Jeffrey Fong commited on
Commit
7fae8ba
1 Parent(s): 6f7ed43

add tokenizer remote code

Browse files
tokenization_functionary.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ import json
3
+ from typing import Any, Dict, List, Literal, Optional, Union
4
+
5
+ import jsonref
6
+ from pydantic import BaseModel, Field, model_validator
7
+ from typing_extensions import Self
8
+
9
+ from transformers.tokenization_utils_base import BatchEncoding
10
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
11
+ from transformers.utils import TensorType, logging
12
+
13
+
14
+ logger = logging.get_logger(__name__)
15
+ SYSTEM_PROMPT = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
16
+ CODE_INTERPRETER_SYSTEM_PROMPT = """When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files."""
17
+
18
+ class Function(BaseModel):
19
+ name: str
20
+ description: Optional[str] = Field(default="")
21
+ parameters: Optional[dict] = None
22
+
23
+
24
+ class Tool(BaseModel):
25
+ type: Literal["function", "code_interpreter"]
26
+ function: Optional[Function] = None
27
+
28
+ @model_validator(mode="after")
29
+ def check_type_function_matches(self) -> Self:
30
+ if self.type == "function":
31
+ assert self.function is not None, '"function" must contain function description when `"type": "function"`'
32
+ else:
33
+ assert self.function is None, '"function" must not be provided when `"type": "code_interpreter"`'
34
+ return self
35
+
36
+
37
+ def convert_data_type(param_type: str) -> str:
38
+ """convert data_type to typescript data type
39
+
40
+ Args:
41
+ param_type (str): param_type
42
+
43
+ Returns:
44
+ str: param type in typescript
45
+ """
46
+ if param_type == "integer" or param_type == "float":
47
+ return "number"
48
+ return param_type
49
+
50
+
51
+ def get_param_type(param: Dict) -> str:
52
+ """get param_type of parameter
53
+
54
+ Args:
55
+ param (Dict): param dict in properties
56
+
57
+ Returns:
58
+ str: _description_
59
+ """
60
+ param_type = "any"
61
+ if "type" in param:
62
+ raw_param_type = param["type"]
63
+ if type(raw_param_type) is list:
64
+ param_type = " | ".join(raw_param_type)
65
+ else:
66
+ param_type = raw_param_type
67
+
68
+ else: # in many cases, the json schema contains: oneOf instead of "type"
69
+ if "oneOf" in param:
70
+ one_of_types = []
71
+ for item in param["oneOf"]:
72
+ if "type" in item:
73
+ one_of_types.append(convert_data_type(item["type"]))
74
+ one_of_types = list(set(one_of_types))
75
+ param_type = " | ".join(one_of_types)
76
+ return convert_data_type(param_type)
77
+
78
+
79
+ def get_format_param(param: Dict) -> Optional[str]:
80
+ """Get "format" from param. There are cases where format is not directly in param but in oneOf
81
+
82
+ Args:
83
+ param (Dict): _description_
84
+
85
+ Returns:
86
+ Optional[str]: _description_
87
+ """
88
+ if "format" in param:
89
+ return param["format"]
90
+ if "oneOf" in param:
91
+ formats = []
92
+ for item in param["oneOf"]:
93
+ if "format" in item:
94
+ formats.append(item["format"])
95
+ if len(formats) > 0:
96
+ return " or ".join(formats)
97
+ return None
98
+
99
+
100
+ def get_param_info(param: Dict) -> Optional[str]:
101
+ """get additional information about parameter such as: format, default value, min, max, ...
102
+
103
+ Args:
104
+ param (Dict): _description_
105
+
106
+ Returns:
107
+ Optional[str]: _description_
108
+ """
109
+ param_type = param.get("type", "any")
110
+ info_list = []
111
+ if "description" in param:
112
+ desc = param["description"]
113
+ if not desc.endswith("."):
114
+ desc += "."
115
+ info_list.append(desc)
116
+
117
+ if "default" in param:
118
+ default_value = param["default"]
119
+ if param_type == "string":
120
+ default_value = f'"{default_value}"' # if string --> add ""
121
+ info_list.append(f"Default={default_value}.")
122
+
123
+ format_param = get_format_param(param)
124
+ if format_param is not None:
125
+ info_list.append("Format=" + format_param)
126
+
127
+ for field, field_name in [
128
+ ("maximum", "Maximum"),
129
+ ("minimum", "Minimum"),
130
+ ("maxLength", "Maximum length"),
131
+ ("minLength", "Minimum length"),
132
+ ]:
133
+ if field in param:
134
+ info_list.append(f"{field_name}=" + str(param[field]))
135
+
136
+ if len(info_list) > 0:
137
+ result = "// " + " ".join(info_list)
138
+ result = result.replace("\n", " ")
139
+ return result
140
+ return None
141
+
142
+
143
+ def append_new_param_info(
144
+ info_list: List[str],
145
+ param_declaration: str,
146
+ comment_info: Optional[str],
147
+ examples_info: List,
148
+ depth: int,
149
+ ):
150
+ """Append a new parameter with comment to the info_list
151
+
152
+ Args:
153
+ info_lines (List[str]): current info_list
154
+ param_declaration (str): param: type
155
+ comment_info (Optional[str]): information of comment
156
+ examples_info (List): information of examples given
157
+ depth (int): level of nested param
158
+ """
159
+ offset = ""
160
+ if depth >= 1:
161
+ offset = "".join([" " for _ in range(depth)])
162
+ if comment_info is not None:
163
+ # if depth == 0: # format: //comment\nparam: type
164
+ info_list.append(f"{offset}{comment_info}")
165
+ if len(examples_info) > 0:
166
+ for example in examples_info:
167
+ info_list.append(f"{offset}{example}")
168
+ info_list.append(f"{offset}{param_declaration}")
169
+ # else: # format: param: type // comment
170
+ # info_list.append(f"{offset}{param_declaration} {comment_info}")
171
+ else:
172
+ info_list.append(f"{offset}{param_declaration}")
173
+
174
+
175
+ def get_examples_info(param_name: str, examples: List) -> List:
176
+ """get information about examples provided
177
+
178
+ Args:
179
+ param_name (str): _description_
180
+ examples (List): _description_
181
+
182
+ Returns:
183
+ List: _description_
184
+ """
185
+ examples_list = [f"// Example {param_name}:"]
186
+ for example in examples:
187
+ if isinstance(example, dict) or isinstance(example, list):
188
+ example_str = json.dumps(example, ensure_ascii=False).replace('\n', '\\n')
189
+ else:
190
+ example_str = str(example).replace('\n', '\\n')
191
+ examples_list.append(f"// {example_str}")
192
+
193
+ return examples_list
194
+
195
+
196
+ def get_enum_option_str(enum_options: List) -> str:
197
+ """get enum option separated by: "|"
198
+
199
+ Args:
200
+ enum_options (List): list of options
201
+
202
+ Returns:
203
+ _type_: concatenation of options separated by "|"
204
+ """
205
+ # if each option is string --> add quote
206
+ return " | ".join([f'"{v}"' if type(v) is str else str(v) for v in enum_options])
207
+
208
+
209
+ def get_array_typescript(
210
+ param_name: Optional[str], param_dic: dict, depth: int = 0
211
+ ) -> str:
212
+ """recursive implementation for generating type script of array
213
+
214
+ Args:
215
+ param_name (Optional[str]): name of param, optional
216
+ param_dic (dict): param_dic
217
+ depth (int, optional): nested level. Defaults to 0.
218
+
219
+ Returns:
220
+ _type_: typescript of array
221
+ """
222
+ offset = ""
223
+ if depth >= 1:
224
+ offset = "".join([" " for _ in range(depth)])
225
+ items_info = param_dic.get("items", {})
226
+
227
+ if len(items_info) == 0:
228
+ if param_name is not None:
229
+ return f"{offset}{param_name}: []"
230
+ else:
231
+ return "[]"
232
+ array_type = get_param_type(items_info)
233
+ if array_type == "object":
234
+ info_lines = []
235
+ child_lines = get_parameter_typescript(
236
+ items_info.get("properties", {}), items_info.get("required", []), depth + 1
237
+ )
238
+ # if comment_info is not None:
239
+ # info_lines.append(f"{offset}{comment_info}")
240
+ if param_name is not None:
241
+ info_lines.append(f"{offset}{param_name}" + ": {")
242
+ else:
243
+ info_lines.append(f"{offset}" + "{")
244
+ info_lines.extend(child_lines)
245
+ info_lines.append(f"{offset}" + "}[]")
246
+ return "\n".join(info_lines)
247
+
248
+ elif array_type == "array":
249
+ item_info = get_array_typescript(None, items_info, depth + 1)
250
+ if param_name is None:
251
+ return f"{item_info}[]"
252
+ return f"{offset}{param_name}: {item_info.strip()}[]"
253
+
254
+ else:
255
+ if "enum" in items_info:
256
+ item_type = get_enum_option_str(items_info["enum"])
257
+ if param_name is None:
258
+ return f"({item_type})[]"
259
+ else:
260
+ return f"{offset}{param_name}: ({item_type})[]"
261
+ else:
262
+ if param_name is None:
263
+ return f"{array_type}[]"
264
+ else:
265
+ return f"{offset}{param_name}: {array_type}[],"
266
+
267
+
268
+ def get_parameter_typescript(properties, required_params, depth=0) -> List[str]:
269
+ """Recursion, returning the information about parameters including data type, description and other information
270
+ These kinds of information will be put into the prompt
271
+
272
+ Args:
273
+ properties (_type_): properties in parameters
274
+ required_params (_type_): List of required parameters
275
+ depth (int, optional): the depth of params (nested level). Defaults to 0.
276
+
277
+ Returns:
278
+ _type_: list of lines containing information about all parameters
279
+ """
280
+ tp_lines = []
281
+ for param_name, param in properties.items():
282
+ # Sometimes properties have "required" field as a list of string.
283
+ # Even though its supposed to be not under properties. So we skip it
284
+ if not isinstance(param, dict):
285
+ continue
286
+ # Param Description
287
+ comment_info = get_param_info(param)
288
+ # Param Examples
289
+ examples_info = []
290
+ if "examples" in param:
291
+ examples_info = get_examples_info(param_name, param["examples"])
292
+ # Param Name declaration
293
+ param_declaration = f"{param_name}"
294
+ if isinstance(required_params, list):
295
+ if param_name not in required_params:
296
+ param_declaration += "?"
297
+ param_type = get_param_type(param)
298
+
299
+ offset = ""
300
+ if depth >= 1:
301
+ offset = "".join([" " for _ in range(depth)])
302
+
303
+ if param_type == "object": # param_type is object
304
+ child_lines = get_parameter_typescript(
305
+ param.get("properties", {}), param.get("required", []), depth + 1
306
+ )
307
+ if comment_info is not None:
308
+ tp_lines.append(f"{offset}{comment_info}")
309
+ if len(examples_info) > 0:
310
+ for example in examples_info:
311
+ tp_lines.append(f"{offset}{example}")
312
+
313
+ param_declaration += ": {"
314
+ tp_lines.append(f"{offset}{param_declaration}")
315
+ tp_lines.extend(child_lines)
316
+ tp_lines.append(f"{offset}" + "},")
317
+
318
+ elif param_type == "array": # param_type is an array
319
+ item_info = param.get("items", {})
320
+ if "type" not in item_info: # don't know type of array
321
+ param_declaration += ": [],"
322
+ append_new_param_info(
323
+ tp_lines, param_declaration, comment_info, examples_info, depth
324
+ )
325
+ else:
326
+ array_declaration = get_array_typescript(
327
+ param_declaration, param, depth
328
+ )
329
+ if not array_declaration.endswith(","):
330
+ array_declaration += ","
331
+ if comment_info is not None:
332
+ tp_lines.append(f"{offset}{comment_info}")
333
+ if len(examples_info) > 0:
334
+ for example in examples_info:
335
+ tp_lines.append(f"{offset}{example}")
336
+ tp_lines.append(array_declaration)
337
+ else:
338
+ if "enum" in param:
339
+ param_type = get_enum_option_str(param["enum"])
340
+ # param_type = " | ".join([f'"{v}"' for v in param["enum"]])
341
+ if "nullable" in param and param["nullable"] is True:
342
+ param_type += " | null"
343
+ param_declaration += f": {param_type},"
344
+ append_new_param_info(
345
+ tp_lines, param_declaration, comment_info, examples_info, depth
346
+ )
347
+
348
+ return tp_lines
349
+
350
+ def generate_schema_from_functions(
351
+ functions: List[Function], namespace="functions"
352
+ ) -> str:
353
+ """
354
+ Convert functions schema to a schema that language models can understand.
355
+ """
356
+
357
+ schema = "// Supported function definitions that should be called when necessary.\n"
358
+ schema += f"namespace {namespace} {{\n\n"
359
+
360
+ for function in functions:
361
+ # Convert a Function object to dict, if necessary
362
+ if not isinstance(function, dict):
363
+ function = function.model_dump()
364
+ function_name = function.get("name", None)
365
+ if function_name is None:
366
+ continue
367
+
368
+ description = function.get("description", "")
369
+ schema += f"// {description}\n"
370
+ schema += f"type {function_name}"
371
+
372
+ parameters = function.get("parameters", None)
373
+ if parameters is not None and parameters.get("properties") is not None:
374
+ parameters = deepcopy(jsonref.JsonRef.replace_refs(parameters))
375
+ schema += " = (_: {\n"
376
+ required_params = parameters.get("required", [])
377
+ tp_lines = get_parameter_typescript(
378
+ parameters.get("properties"),
379
+ required_params,
380
+ 0,
381
+ )
382
+ schema += "\n".join(tp_lines)
383
+ schema += "\n}) => any;\n\n"
384
+ else:
385
+ # Doesn't have any parameters
386
+ schema += " = () => any;\n\n"
387
+
388
+ schema += f"}} // namespace {namespace}"
389
+
390
+ return schema
391
+
392
+ class FunctionaryTokenizer(PreTrainedTokenizerFast):
393
+ def apply_chat_template(
394
+ self,
395
+ conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], str],
396
+ tools: Optional[List[Dict[str, Any]]],
397
+ chat_template: Optional[str] = None,
398
+ add_generation_prompt: bool = False,
399
+ tokenize: bool = True,
400
+ padding: bool = False,
401
+ truncation: bool = False,
402
+ max_length: Optional[int] = None,
403
+ return_tensors: Optional[Union[str, TensorType]] = None,
404
+ return_dict: bool = False,
405
+ tokenizer_kwargs: Optional[Dict[str, Any]] = None,
406
+ **kwargs,
407
+ ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
408
+
409
+ if return_dict and not tokenize:
410
+ raise ValueError(
411
+ "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
412
+ "of tokenizer outputs to return."
413
+ )
414
+
415
+ if tokenizer_kwargs is None:
416
+ tokenizer_kwargs = {}
417
+
418
+ using_default_template = False
419
+
420
+ # First, handle the cases when the model has a dict of multiple templates
421
+ if isinstance(self.chat_template, dict) or (
422
+ self.chat_template is None and isinstance(self.default_chat_template, dict)
423
+ ):
424
+ if self.chat_template is not None:
425
+ template_dict = self.chat_template
426
+ using_default_dict = False
427
+ else:
428
+ template_dict = self.default_chat_template
429
+ using_default_dict = True
430
+ if chat_template is not None and chat_template in template_dict:
431
+ # The user can pass the name of a template to the chat template argument instead of an entire template
432
+ chat_template = template_dict[chat_template]
433
+ if using_default_dict:
434
+ using_default_template = True
435
+ elif chat_template is None and "default" in template_dict:
436
+ chat_template = template_dict["default"]
437
+ if using_default_dict:
438
+ using_default_template = True
439
+ elif chat_template is None:
440
+ raise ValueError(
441
+ "This model has multiple chat templates with no default specified! Please either pass a chat "
442
+ "template or the name of the template you wish to use to the `chat_template` argument. Available "
443
+ f"template names are {sorted(template_dict.keys())}."
444
+ )
445
+ elif chat_template is None:
446
+ # These are the cases when the model has a single template
447
+ # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
448
+ if self.chat_template is not None:
449
+ chat_template = self.chat_template
450
+ else:
451
+ chat_template = self.default_chat_template
452
+ using_default_template = True
453
+
454
+ if using_default_template:
455
+ logger.warning_once(
456
+ "No chat template is set for this tokenizer, falling back to a default class-level template. This is "
457
+ "very error-prone, because models are often trained with templates different from the class default! "
458
+ "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
459
+ "point any code depending on them will stop working. We recommend setting a valid chat template before "
460
+ "then to ensure that this model continues working without issues."
461
+ )
462
+
463
+ # Prepare tools/functions into schema
464
+ functions_pydantic_to_render = []
465
+ has_code_interpreter = False
466
+ for i in range(len(tools)):
467
+ tool_pydantic = Tool.model_validate(tools[i])
468
+ if tool_pydantic.type == "function":
469
+ functions_pydantic_to_render.append(tool_pydantic.function)
470
+ else:
471
+ has_code_interpreter = True
472
+ conversation.insert(0, {"role": "system", "content": generate_schema_from_functions(functions_pydantic_to_render)})
473
+ # Insert system prompt
474
+ system_prompt_to_use = SYSTEM_PROMPT if not has_code_interpreter else CODE_INTERPRETER_SYSTEM_PROMPT
475
+ conversation.insert(1, {"role": "system", "content": system_prompt_to_use})
476
+
477
+ # Compilation function uses a cache to avoid recompiling the same template
478
+ compiled_template = self._compile_jinja_template(chat_template)
479
+
480
+ if isinstance(conversation, (list, tuple)) and (
481
+ isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
482
+ ):
483
+ conversations = conversation
484
+ is_batched = True
485
+ else:
486
+ conversations = [conversation]
487
+ is_batched = False
488
+
489
+ rendered = []
490
+ template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
491
+ for chat in conversations:
492
+ if hasattr(chat, "messages"):
493
+ # Indicates it's a Conversation object
494
+ chat = chat.messages
495
+ rendered_chat = compiled_template.render(
496
+ messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
497
+ )
498
+ rendered.append(rendered_chat)
499
+
500
+ if not is_batched:
501
+ rendered = rendered[0]
502
+
503
+ if tokenize:
504
+ out = self(
505
+ rendered,
506
+ padding=padding,
507
+ truncation=truncation,
508
+ max_length=max_length,
509
+ add_special_tokens=False,
510
+ return_tensors=return_tensors,
511
+ **tokenizer_kwargs,
512
+ )
513
+ if return_dict:
514
+ return out
515
+ else:
516
+ return out["input_ids"]
517
+ else:
518
+ return rendered
tokenizer_config.json CHANGED
@@ -2061,5 +2061,8 @@
2061
  "model_max_length": 8192,
2062
  "pad_token": "<|end_of_text|>",
2063
  "padding_side": "right",
2064
- "tokenizer_class": "PreTrainedTokenizerFast"
 
 
 
2065
  }
 
2061
  "model_max_length": 8192,
2062
  "pad_token": "<|end_of_text|>",
2063
  "padding_side": "right",
2064
+ "tokenizer_class": "PreTrainedTokenizerFast",
2065
+ "auto_map": {
2066
+ "AutoTokenizer": ["tokenization_functionary.FunctionaryTokenizer", null]
2067
+ }
2068
  }