gordonchan's picture
Upload 41 files
ca56e6a verified
raw
history blame contribute delete
No virus
2.62 kB
from typing import List
from openai.types.chat import ChatCompletionMessageParam
from transformers import PreTrainedTokenizer
from api.generation.utils import parse_messages
from api.utils.protocol import Role
def build_xverse_chat_input(
tokenizer: PreTrainedTokenizer,
messages: List[ChatCompletionMessageParam],
context_len: int = 8192,
max_new_tokens: int = 256
) -> List[int]:
"""
Builds the input tokens for the Xverse chat model based on the given messages.
Refs:
https://huggingface.co/xverse/XVERSE-13B-Chat/blob/main/modeling_xverse.py
Args:
tokenizer: The PreTrainedTokenizer object.
messages: A list of ChatCompletionMessageParam objects representing the chat messages.
context_len: The maximum length of the context (default=8192).
max_new_tokens: The maximum number of new tokens to be added (default=256).
Returns:
List[int]: The input tokens for the Baichuan chat model.
"""
max_input_tokens = context_len - max_new_tokens
system, rounds = parse_messages(messages)
system = f"{system}\n\n" if system else system
def _tokenize_str(role, content):
return tokenizer.encode(f"{role}: {content}", return_token_type_ids=False)
system_tokens = tokenizer.encode(system, return_token_type_ids=False)
max_history_tokens = max_input_tokens - len(system_tokens)
history_tokens = []
for i, r in enumerate(rounds[::-1]):
round_tokens = []
for message in r:
if message["role"] == Role.USER:
content = f"{message['content']}\n\n"
if i == 0:
content += "Assistant: "
content_tokens = _tokenize_str("Human", content)
else:
content_tokens = _tokenize_str("Assistant", f"{message['content']}") + [3] # add eos token id
round_tokens.extend(content_tokens)
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
history_tokens = round_tokens + history_tokens # concat left
if len(history_tokens) < max_history_tokens:
continue
break
input_tokens = system_tokens + history_tokens
return input_tokens[-max_input_tokens:] # truncate left
def check_is_xverse(model) -> bool:
"""
Checks if the given model is a Xverse model.
Args:
model: The model to be checked.
Returns:
bool: True if the model is a Xverse model, False otherwise.
"""
return "XverseDecoderLayer" in getattr(model, "_no_split_modules", [])