gordonchan's picture
Upload 41 files
ca56e6a verified
raw
history blame contribute delete
No virus
2.3 kB
from typing import List
from openai.types.chat import ChatCompletionMessageParam
from transformers import PreTrainedTokenizer
from api.generation.utils import parse_messages
from api.utils.protocol import Role
def build_baichuan_chat_input(
tokenizer: PreTrainedTokenizer,
messages: List[ChatCompletionMessageParam],
context_len: int = 4096,
max_new_tokens: int = 256
) -> List[int]:
"""
Builds the input tokens for the Baichuan chat model based on the given messages.
Refs:
https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_utils.py
Args:
tokenizer: The PreTrainedTokenizer object.
messages: A list of ChatCompletionMessageParam objects representing the chat messages.
context_len: The maximum length of the context (default=4096).
max_new_tokens: The maximum number of new tokens to be added (default=256).
Returns:
List[int]: The input tokens for the Baichuan chat model.
"""
max_input_tokens = context_len - max_new_tokens
system, rounds = parse_messages(messages)
system_tokens = tokenizer.encode(system)
max_history_tokens = max_input_tokens - len(system_tokens)
history_tokens = []
for r in rounds[::-1]:
round_tokens = []
for message in r:
if message["role"] == Role.USER:
round_tokens.append(195)
else:
round_tokens.append(196)
round_tokens.extend(tokenizer.encode(message["content"]))
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
history_tokens = round_tokens + history_tokens # concat left
if len(history_tokens) < max_history_tokens:
continue
break
input_tokens = system_tokens + history_tokens
if messages[-1]["role"] != Role.ASSISTANT:
input_tokens.append(196)
return input_tokens[-max_input_tokens:] # truncate left
def check_is_baichuan(model) -> bool:
"""
Checks if the given model is a Baichuan model.
Args:
model: The model to be checked.
Returns:
bool: True if the model is a Baichuan model, False otherwise.
"""
return "BaichuanLayer" in getattr(model, "_no_split_modules", [])