import json import re from copy import deepcopy from typing import List, Union, Optional, Dict, Any, Tuple from fastapi import HTTPException from loguru import logger from openai.types.chat import ( ChatCompletionMessageParam, ChatCompletionUserMessageParam, ChatCompletionAssistantMessageParam, ) from transformers import PreTrainedTokenizer from api.generation.utils import parse_messages from api.utils.protocol import Role TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}""" REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs: {tools_text} Use the following format: Question: the input question you must answer Thought: you should always think about what to do Action: the action to take, should be one of [{tools_name_text}] Action Input: the input to the action Observation: the result of the action ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) Thought: I now know the final answer Final Answer: the final answer to the original input question Begin!""" _TEXT_COMPLETION_CMD = object() def build_qwen_chat_input( tokenizer: PreTrainedTokenizer, messages: List[ChatCompletionMessageParam], context_len: int = 8192, max_new_tokens: int = 256, functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, tools: Optional[List[Dict[str, Any]]] = None, ) -> List[int]: """ Builds the input tokens for Qwen chat generation. Refs: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py Args: tokenizer: The tokenizer used to encode the input tokens. messages: The list of chat messages. context_len: The maximum length of the context. max_new_tokens: The maximum number of new tokens to add. functions: Optional dictionary or list of dictionaries representing the functions. tools: Optional list of dictionaries representing the tools. Returns: The list of input tokens. """ query, history = process_qwen_messages(messages, functions, tools) if query is _TEXT_COMPLETION_CMD: return build_last_message_input(tokenizer, history) messages = [] for q, r in history: messages.extend( [ ChatCompletionUserMessageParam(role="user", content=q), ChatCompletionAssistantMessageParam(role="assistant", content=r) ] ) messages.append(ChatCompletionUserMessageParam(role="user", content=query)) max_input_tokens = context_len - max_new_tokens system, rounds = parse_messages(messages) system = f"You are a helpful assistant.{system}" im_start_tokens, im_end_tokens = [tokenizer.im_start_id], [tokenizer.im_end_id] nl_tokens = tokenizer.encode("\n") def _tokenize_str(role, content): return tokenizer.encode( role, allowed_special=set() ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) system_tokens_part = _tokenize_str("system", system) system_tokens = im_start_tokens + system_tokens_part + im_end_tokens max_history_tokens = max_input_tokens - len(system_tokens) history_tokens = [] for r in rounds[::-1]: round_tokens = [] for message in r: if round_tokens: round_tokens += nl_tokens if message["role"] == Role.USER: content_tokens = im_start_tokens + _tokenize_str("user", message["content"]) + im_end_tokens else: content_tokens = im_start_tokens + _tokenize_str("assistant", message["content"]) + im_end_tokens round_tokens.extend(content_tokens) if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens: if history_tokens: history_tokens = nl_tokens + history_tokens history_tokens = round_tokens + history_tokens # concat left if len(history_tokens) < max_history_tokens: continue break input_tokens = system_tokens + nl_tokens + history_tokens if messages[-1]["role"] != Role.ASSISTANT: input_tokens += nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens return input_tokens[-max_input_tokens:] # truncate left def check_is_qwen(model) -> bool: """ Checks if the given model is a Qwen model. Args: model: The model to be checked. Returns: bool: True if the model is a Qwen model, False otherwise. """ return "QWenBlock" in getattr(model, "_no_split_modules", []) def process_qwen_messages( messages: List[ChatCompletionMessageParam], functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, tools: Optional[List[Dict[str, Any]]] = None, ) -> Tuple[str, List[List[str]]]: """ Process the Qwen messages and generate a query and history. Args: messages (List[ChatCompletionMessageParam]): The list of chat completion messages. functions (Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]): The functions to be used. tools (Optional[List[Dict[str, Any]]]): The tools to be used. Returns: Tuple[str, List[List[str]]]: The generated query and history. """ if all(m["role"] != Role.USER for m in messages): raise HTTPException( status_code=400, detail=f"Invalid request: Expecting at least one user message.", ) messages = deepcopy(messages) default_system = "You are a helpful assistant." system = "" if messages[0]["role"] == Role.SYSTEM: system = messages.pop(0)["content"].lstrip("\n").rstrip() if system == default_system: system = "" if tools: functions = [t["function"] for t in tools] if functions: tools_text = [] tools_name_text = [] for func_info in functions: name = func_info.get("name", "") name_m = func_info.get("name_for_model", name) name_h = func_info.get("name_for_human", name) desc = func_info.get("description", "") desc_m = func_info.get("description_for_model", desc) tool = TOOL_DESC.format( name_for_model=name_m, name_for_human=name_h, # Hint: You can add the following format requirements in description: # "Format the arguments as a JSON object." # "Enclose the code within triple backticks (`) at the beginning and end of the code." description_for_model=desc_m, parameters=json.dumps(func_info["parameters"], ensure_ascii=False), ) tools_text.append(tool) tools_name_text.append(name_m) tools_text = "\n\n".join(tools_text) tools_name_text = ", ".join(tools_name_text) system += "\n\n" + REACT_INSTRUCTION.format( tools_text=tools_text, tools_name_text=tools_name_text, ) system = system.lstrip("\n").rstrip() dummy_thought = { "en": "\nThought: I now know the final answer.\nFinal answer: ", "zh": "\nThought: 我会作答了。\nFinal answer: ", } _messages = messages messages = [] for m_idx, m in enumerate(_messages): role, content = m["role"], m["content"] func_call, tool_calls = m.get("function_call", None), m.get("tool_calls", None) if content: content = content.lstrip("\n").rstrip() if role in [Role.FUNCTION, Role.TOOL]: if (len(messages) == 0) or (messages[-1]["role"] != Role.ASSISTANT): raise HTTPException( status_code=400, detail=f"Invalid request: Expecting role assistant before role function.", ) messages[-1]["content"] += f"\nObservation: {content}" if m_idx == len(_messages) - 1: messages[-1]["content"] += "\nThought:" elif role == Role.ASSISTANT: if len(messages) == 0: raise HTTPException( status_code=400, detail=f"Invalid request: Expecting role user before role assistant.", ) last_msg = messages[-1]["content"] last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0 if func_call is None and tool_calls is None: if functions or tool_calls: content = dummy_thought["zh" if last_msg_has_zh else "en"] + content else: if func_call: f_name, f_args = func_call.get("name"), func_call.get("arguments") else: f_name, f_args = tool_calls[0]["function"]["name"], tool_calls[0]["function"]["arguments"] if not content: if last_msg_has_zh: content = f"Thought: 我可以使用 {f_name} API。" else: content = f"Thought: I can use {f_name}." if messages[-1]["role"] == Role.USER: messages.append( ChatCompletionAssistantMessageParam(role="assistant", content=content.lstrip("\n").rstrip()) ) else: messages[-1]["content"] += content elif role == Role.USER: messages.append( ChatCompletionUserMessageParam(role="user", content=content.lstrip("\n").rstrip()) ) else: raise HTTPException( status_code=400, detail=f"Invalid request: Incorrect role {role}." ) query = _TEXT_COMPLETION_CMD if messages[-1]["role"] == Role.USER: query = messages[-1]["content"] messages = messages[:-1] if len(messages) % 2 != 0: raise HTTPException(status_code=400, detail="Invalid request") history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)] for i in range(0, len(messages), 2): if messages[i]["role"] == Role.USER and messages[i + 1]["role"] == Role.ASSISTANT: usr_msg = messages[i]["content"].lstrip("\n").rstrip() bot_msg = messages[i + 1]["content"].lstrip("\n").rstrip() if system and (i == len(messages) - 2): usr_msg = f"{system}\n\nQuestion: {usr_msg}" system = "" for t in dummy_thought.values(): t = t.lstrip("\n") if bot_msg.startswith(t) and ("\nAction: " in bot_msg): bot_msg = bot_msg[len(t):] history.append([usr_msg, bot_msg]) else: raise HTTPException( status_code=400, detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.", ) if system: assert query is not _TEXT_COMPLETION_CMD query = f"{system}\n\nQuestion: {query}" return query, history def build_last_message_input(tokenizer: PreTrainedTokenizer, history: list): im_start = "<|im_start|>" im_end = "<|im_end|>" prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}" for i, (query, response) in enumerate(history): query = query.lstrip("\n").rstrip() response = response.lstrip("\n").rstrip() prompt += f"\n{im_start}user\n{query}{im_end}" prompt += f"\n{im_start}assistant\n{response}{im_end}" prompt = prompt[:-len(im_end)] logger.debug(f"==== Prompt with tools ====\n{prompt}") return tokenizer.encode(prompt)