gordonchan's picture
Upload 41 files
ca56e6a verified
from typing import List
from openai.types.chat import ChatCompletionMessageParam
from transformers import PreTrainedTokenizer
from api.generation.utils import parse_messages
from api.utils.protocol import Role
def build_xverse_chat_input(
tokenizer: PreTrainedTokenizer,
messages: List[ChatCompletionMessageParam],
context_len: int = 8192,
max_new_tokens: int = 256
) -> List[int]:
"""
Builds the input tokens for the Xverse chat model based on the given messages.
Refs:
https://huggingface.co/xverse/XVERSE-13B-Chat/blob/main/modeling_xverse.py
Args:
tokenizer: The PreTrainedTokenizer object.
messages: A list of ChatCompletionMessageParam objects representing the chat messages.
context_len: The maximum length of the context (default=8192).
max_new_tokens: The maximum number of new tokens to be added (default=256).
Returns:
List[int]: The input tokens for the Baichuan chat model.
"""
max_input_tokens = context_len - max_new_tokens
system, rounds = parse_messages(messages)
system = f"{system}\n\n" if system else system
def _tokenize_str(role, content):
return tokenizer.encode(f"{role}: {content}", return_token_type_ids=False)
system_tokens = tokenizer.encode(system, return_token_type_ids=False)
max_history_tokens = max_input_tokens - len(system_tokens)
history_tokens = []
for i, r in enumerate(rounds[::-1]):
round_tokens = []
for message in r:
if message["role"] == Role.USER:
content = f"{message['content']}\n\n"
if i == 0:
content += "Assistant: "
content_tokens = _tokenize_str("Human", content)
else:
content_tokens = _tokenize_str("Assistant", f"{message['content']}") + [3] # add eos token id
round_tokens.extend(content_tokens)
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
history_tokens = round_tokens + history_tokens # concat left
if len(history_tokens) < max_history_tokens:
continue
break
input_tokens = system_tokens + history_tokens
return input_tokens[-max_input_tokens:] # truncate left
def check_is_xverse(model) -> bool:
"""
Checks if the given model is a Xverse model.
Args:
model: The model to be checked.
Returns:
bool: True if the model is a Xverse model, False otherwise.
"""
return "XverseDecoderLayer" in getattr(model, "_no_split_modules", [])