from typing import List from openai.types.chat import ChatCompletionMessageParam from transformers import PreTrainedTokenizer from api.generation.utils import parse_messages from api.utils.protocol import Role def build_baichuan_chat_input( tokenizer: PreTrainedTokenizer, messages: List[ChatCompletionMessageParam], context_len: int = 4096, max_new_tokens: int = 256 ) -> List[int]: """ Builds the input tokens for the Baichuan chat model based on the given messages. Refs: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_utils.py Args: tokenizer: The PreTrainedTokenizer object. messages: A list of ChatCompletionMessageParam objects representing the chat messages. context_len: The maximum length of the context (default=4096). max_new_tokens: The maximum number of new tokens to be added (default=256). Returns: List[int]: The input tokens for the Baichuan chat model. """ max_input_tokens = context_len - max_new_tokens system, rounds = parse_messages(messages) system_tokens = tokenizer.encode(system) max_history_tokens = max_input_tokens - len(system_tokens) history_tokens = [] for r in rounds[::-1]: round_tokens = [] for message in r: if message["role"] == Role.USER: round_tokens.append(195) else: round_tokens.append(196) round_tokens.extend(tokenizer.encode(message["content"])) if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens: history_tokens = round_tokens + history_tokens # concat left if len(history_tokens) < max_history_tokens: continue break input_tokens = system_tokens + history_tokens if messages[-1]["role"] != Role.ASSISTANT: input_tokens.append(196) return input_tokens[-max_input_tokens:] # truncate left def check_is_baichuan(model) -> bool: """ Checks if the given model is a Baichuan model. Args: model: The model to be checked. Returns: bool: True if the model is a Baichuan model, False otherwise. """ return "BaichuanLayer" in getattr(model, "_no_split_modules", [])