Spaces:
Running
Running
File size: 2,623 Bytes
ca56e6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from typing import List
from openai.types.chat import ChatCompletionMessageParam
from transformers import PreTrainedTokenizer
from api.generation.utils import parse_messages
from api.utils.protocol import Role
def build_xverse_chat_input(
tokenizer: PreTrainedTokenizer,
messages: List[ChatCompletionMessageParam],
context_len: int = 8192,
max_new_tokens: int = 256
) -> List[int]:
"""
Builds the input tokens for the Xverse chat model based on the given messages.
Refs:
https://huggingface.co/xverse/XVERSE-13B-Chat/blob/main/modeling_xverse.py
Args:
tokenizer: The PreTrainedTokenizer object.
messages: A list of ChatCompletionMessageParam objects representing the chat messages.
context_len: The maximum length of the context (default=8192).
max_new_tokens: The maximum number of new tokens to be added (default=256).
Returns:
List[int]: The input tokens for the Baichuan chat model.
"""
max_input_tokens = context_len - max_new_tokens
system, rounds = parse_messages(messages)
system = f"{system}\n\n" if system else system
def _tokenize_str(role, content):
return tokenizer.encode(f"{role}: {content}", return_token_type_ids=False)
system_tokens = tokenizer.encode(system, return_token_type_ids=False)
max_history_tokens = max_input_tokens - len(system_tokens)
history_tokens = []
for i, r in enumerate(rounds[::-1]):
round_tokens = []
for message in r:
if message["role"] == Role.USER:
content = f"{message['content']}\n\n"
if i == 0:
content += "Assistant: "
content_tokens = _tokenize_str("Human", content)
else:
content_tokens = _tokenize_str("Assistant", f"{message['content']}") + [3] # add eos token id
round_tokens.extend(content_tokens)
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
history_tokens = round_tokens + history_tokens # concat left
if len(history_tokens) < max_history_tokens:
continue
break
input_tokens = system_tokens + history_tokens
return input_tokens[-max_input_tokens:] # truncate left
def check_is_xverse(model) -> bool:
"""
Checks if the given model is a Xverse model.
Args:
model: The model to be checked.
Returns:
bool: True if the model is a Xverse model, False otherwise.
"""
return "XverseDecoderLayer" in getattr(model, "_no_split_modules", [])
|