from copy import deepcopy import json from typing import Any, Dict, List, Literal, Optional, Union import jsonref from pydantic import BaseModel, Field, model_validator from typing_extensions import Self from transformers.tokenization_utils_base import BatchEncoding from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from transformers.utils import TensorType, logging logger = logging.get_logger(__name__) SYSTEM_PROMPT = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" CODE_INTERPRETER_SYSTEM_PROMPT = """When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.""" class Function(BaseModel): name: str description: Optional[str] = Field(default="") parameters: Optional[dict] = None class Tool(BaseModel): type: Literal["function", "code_interpreter"] function: Optional[Function] = None @model_validator(mode="after") def check_type_function_matches(self) -> Self: if self.type == "function": assert self.function is not None, '"function" must contain function description when `"type": "function"`' else: assert self.function is None, '"function" must not be provided when `"type": "code_interpreter"`' return self def convert_data_type(param_type: str) -> str: """convert data_type to typescript data type Args: param_type (str): param_type Returns: str: param type in typescript """ if param_type == "integer" or param_type == "float": return "number" return param_type def get_param_type(param: Dict) -> str: """get param_type of parameter Args: param (Dict): param dict in properties Returns: str: _description_ """ param_type = "any" if "type" in param: raw_param_type = param["type"] if type(raw_param_type) is list: param_type = " | ".join(raw_param_type) else: param_type = raw_param_type else: # in many cases, the json schema contains: oneOf instead of "type" if "oneOf" in param: one_of_types = [] for item in param["oneOf"]: if "type" in item: one_of_types.append(convert_data_type(item["type"])) one_of_types = list(set(one_of_types)) param_type = " | ".join(one_of_types) return convert_data_type(param_type) def get_format_param(param: Dict) -> Optional[str]: """Get "format" from param. There are cases where format is not directly in param but in oneOf Args: param (Dict): _description_ Returns: Optional[str]: _description_ """ if "format" in param: return param["format"] if "oneOf" in param: formats = [] for item in param["oneOf"]: if "format" in item: formats.append(item["format"]) if len(formats) > 0: return " or ".join(formats) return None def get_param_info(param: Dict) -> Optional[str]: """get additional information about parameter such as: format, default value, min, max, ... Args: param (Dict): _description_ Returns: Optional[str]: _description_ """ param_type = param.get("type", "any") info_list = [] if "description" in param: desc = param["description"] if not desc.endswith("."): desc += "." info_list.append(desc) if "default" in param: default_value = param["default"] if param_type == "string": default_value = f'"{default_value}"' # if string --> add "" info_list.append(f"Default={default_value}.") format_param = get_format_param(param) if format_param is not None: info_list.append("Format=" + format_param) for field, field_name in [ ("maximum", "Maximum"), ("minimum", "Minimum"), ("maxLength", "Maximum length"), ("minLength", "Minimum length"), ]: if field in param: info_list.append(f"{field_name}=" + str(param[field])) if len(info_list) > 0: result = "// " + " ".join(info_list) result = result.replace("\n", " ") return result return None def append_new_param_info( info_list: List[str], param_declaration: str, comment_info: Optional[str], examples_info: List, depth: int, ): """Append a new parameter with comment to the info_list Args: info_lines (List[str]): current info_list param_declaration (str): param: type comment_info (Optional[str]): information of comment examples_info (List): information of examples given depth (int): level of nested param """ offset = "" if depth >= 1: offset = "".join([" " for _ in range(depth)]) if comment_info is not None: # if depth == 0: # format: //comment\nparam: type info_list.append(f"{offset}{comment_info}") if len(examples_info) > 0: for example in examples_info: info_list.append(f"{offset}{example}") info_list.append(f"{offset}{param_declaration}") # else: # format: param: type // comment # info_list.append(f"{offset}{param_declaration} {comment_info}") else: info_list.append(f"{offset}{param_declaration}") def get_examples_info(param_name: str, examples: List) -> List: """get information about examples provided Args: param_name (str): _description_ examples (List): _description_ Returns: List: _description_ """ examples_list = [f"// Example {param_name}:"] for example in examples: if isinstance(example, dict) or isinstance(example, list): example_str = json.dumps(example, ensure_ascii=False).replace('\n', '\\n') else: example_str = str(example).replace('\n', '\\n') examples_list.append(f"// {example_str}") return examples_list def get_enum_option_str(enum_options: List) -> str: """get enum option separated by: "|" Args: enum_options (List): list of options Returns: _type_: concatenation of options separated by "|" """ # if each option is string --> add quote return " | ".join([f'"{v}"' if type(v) is str else str(v) for v in enum_options]) def get_array_typescript( param_name: Optional[str], param_dic: dict, depth: int = 0 ) -> str: """recursive implementation for generating type script of array Args: param_name (Optional[str]): name of param, optional param_dic (dict): param_dic depth (int, optional): nested level. Defaults to 0. Returns: _type_: typescript of array """ offset = "" if depth >= 1: offset = "".join([" " for _ in range(depth)]) items_info = param_dic.get("items", {}) if len(items_info) == 0: if param_name is not None: return f"{offset}{param_name}: []" else: return "[]" array_type = get_param_type(items_info) if array_type == "object": info_lines = [] child_lines = get_parameter_typescript( items_info.get("properties", {}), items_info.get("required", []), depth + 1 ) # if comment_info is not None: # info_lines.append(f"{offset}{comment_info}") if param_name is not None: info_lines.append(f"{offset}{param_name}" + ": {") else: info_lines.append(f"{offset}" + "{") info_lines.extend(child_lines) info_lines.append(f"{offset}" + "}[]") return "\n".join(info_lines) elif array_type == "array": item_info = get_array_typescript(None, items_info, depth + 1) if param_name is None: return f"{item_info}[]" return f"{offset}{param_name}: {item_info.strip()}[]" else: if "enum" in items_info: item_type = get_enum_option_str(items_info["enum"]) if param_name is None: return f"({item_type})[]" else: return f"{offset}{param_name}: ({item_type})[]" else: if param_name is None: return f"{array_type}[]" else: return f"{offset}{param_name}: {array_type}[]," def get_parameter_typescript(properties, required_params, depth=0) -> List[str]: """Recursion, returning the information about parameters including data type, description and other information These kinds of information will be put into the prompt Args: properties (_type_): properties in parameters required_params (_type_): List of required parameters depth (int, optional): the depth of params (nested level). Defaults to 0. Returns: _type_: list of lines containing information about all parameters """ tp_lines = [] for param_name, param in properties.items(): # Sometimes properties have "required" field as a list of string. # Even though its supposed to be not under properties. So we skip it if not isinstance(param, dict): continue # Param Description comment_info = get_param_info(param) # Param Examples examples_info = [] if "examples" in param: examples_info = get_examples_info(param_name, param["examples"]) # Param Name declaration param_declaration = f"{param_name}" if isinstance(required_params, list): if param_name not in required_params: param_declaration += "?" param_type = get_param_type(param) offset = "" if depth >= 1: offset = "".join([" " for _ in range(depth)]) if param_type == "object": # param_type is object child_lines = get_parameter_typescript( param.get("properties", {}), param.get("required", []), depth + 1 ) if comment_info is not None: tp_lines.append(f"{offset}{comment_info}") if len(examples_info) > 0: for example in examples_info: tp_lines.append(f"{offset}{example}") param_declaration += ": {" tp_lines.append(f"{offset}{param_declaration}") tp_lines.extend(child_lines) tp_lines.append(f"{offset}" + "},") elif param_type == "array": # param_type is an array item_info = param.get("items", {}) if "type" not in item_info: # don't know type of array param_declaration += ": []," append_new_param_info( tp_lines, param_declaration, comment_info, examples_info, depth ) else: array_declaration = get_array_typescript( param_declaration, param, depth ) if not array_declaration.endswith(","): array_declaration += "," if comment_info is not None: tp_lines.append(f"{offset}{comment_info}") if len(examples_info) > 0: for example in examples_info: tp_lines.append(f"{offset}{example}") tp_lines.append(array_declaration) else: if "enum" in param: param_type = get_enum_option_str(param["enum"]) # param_type = " | ".join([f'"{v}"' for v in param["enum"]]) if "nullable" in param and param["nullable"] is True: param_type += " | null" param_declaration += f": {param_type}," append_new_param_info( tp_lines, param_declaration, comment_info, examples_info, depth ) return tp_lines def generate_schema_from_functions( functions: List[Function], namespace="functions" ) -> str: """ Convert functions schema to a schema that language models can understand. """ schema = "// Supported function definitions that should be called when necessary.\n" schema += f"namespace {namespace} {{\n\n" for function in functions: # Convert a Function object to dict, if necessary if not isinstance(function, dict): function = function.model_dump() function_name = function.get("name", None) if function_name is None: continue description = function.get("description", "") schema += f"// {description}\n" schema += f"type {function_name}" parameters = function.get("parameters", None) if parameters is not None and parameters.get("properties") is not None: parameters = deepcopy(jsonref.JsonRef.replace_refs(parameters)) schema += " = (_: {\n" required_params = parameters.get("required", []) tp_lines = get_parameter_typescript( parameters.get("properties"), required_params, 0, ) schema += "\n".join(tp_lines) schema += "\n}) => any;\n\n" else: # Doesn't have any parameters schema += " = () => any;\n\n" schema += f"}} // namespace {namespace}" return schema class FunctionaryTokenizer(PreTrainedTokenizerFast): def apply_chat_template( self, conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], str], tools: Optional[List[Dict[str, Any]]], chat_template: Optional[str] = None, add_generation_prompt: bool = False, tokenize: bool = True, padding: bool = False, truncation: bool = False, max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_dict: bool = False, tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]: if return_dict and not tokenize: raise ValueError( "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict " "of tokenizer outputs to return." ) if tokenizer_kwargs is None: tokenizer_kwargs = {} using_default_template = False # First, handle the cases when the model has a dict of multiple templates if isinstance(self.chat_template, dict) or ( self.chat_template is None and isinstance(self.default_chat_template, dict) ): if self.chat_template is not None: template_dict = self.chat_template using_default_dict = False else: template_dict = self.default_chat_template using_default_dict = True if chat_template is not None and chat_template in template_dict: # The user can pass the name of a template to the chat template argument instead of an entire template chat_template = template_dict[chat_template] if using_default_dict: using_default_template = True elif chat_template is None and "default" in template_dict: chat_template = template_dict["default"] if using_default_dict: using_default_template = True elif chat_template is None: raise ValueError( "This model has multiple chat templates with no default specified! Please either pass a chat " "template or the name of the template you wish to use to the `chat_template` argument. Available " f"template names are {sorted(template_dict.keys())}." ) elif chat_template is None: # These are the cases when the model has a single template # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template if self.chat_template is not None: chat_template = self.chat_template else: chat_template = self.default_chat_template using_default_template = True if using_default_template: logger.warning_once( "No chat template is set for this tokenizer, falling back to a default class-level template. This is " "very error-prone, because models are often trained with templates different from the class default! " "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " "point any code depending on them will stop working. We recommend setting a valid chat template before " "then to ensure that this model continues working without issues." ) # Prepare tools/functions into schema functions_pydantic_to_render = [] has_code_interpreter = False for i in range(len(tools)): tool_pydantic = Tool.model_validate(tools[i]) if tool_pydantic.type == "function": functions_pydantic_to_render.append(tool_pydantic.function) else: has_code_interpreter = True conversation.insert(0, {"role": "system", "content": generate_schema_from_functions(functions_pydantic_to_render)}) # Insert system prompt system_prompt_to_use = SYSTEM_PROMPT if not has_code_interpreter else CODE_INTERPRETER_SYSTEM_PROMPT conversation.insert(1, {"role": "system", "content": system_prompt_to_use}) # Compilation function uses a cache to avoid recompiling the same template compiled_template = self._compile_jinja_template(chat_template) if isinstance(conversation, (list, tuple)) and ( isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages") ): conversations = conversation is_batched = True else: conversations = [conversation] is_batched = False rendered = [] template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present for chat in conversations: if hasattr(chat, "messages"): # Indicates it's a Conversation object chat = chat.messages rendered_chat = compiled_template.render( messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs ) rendered.append(rendered_chat) if not is_batched: rendered = rendered[0] if tokenize: out = self( rendered, padding=padding, truncation=truncation, max_length=max_length, add_special_tokens=False, return_tensors=return_tensors, **tokenizer_kwargs, ) if return_dict: return out else: return out["input_ids"] else: return rendered