gordonchan's picture
Upload 41 files
ca56e6a verified
from typing import List, Tuple
from openai.types.chat import ChatCompletionMessageParam
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from api.utils.protocol import Role
def parse_messages(
messages: List[ChatCompletionMessageParam], split_role=Role.USER
) -> Tuple[str, List[List[ChatCompletionMessageParam]]]:
"""
Parse a list of chat completion messages into system and rounds.
Args:
messages (List[ChatCompletionMessageParam]): The list of chat completion messages.
split_role: The role at which to split the rounds. Defaults to Role.USER.
Returns:
Tuple[str, List[List[ChatCompletionMessageParam]]]: A tuple containing the system message and a list of rounds.
"""
system, rounds = "", []
r = []
for i, message in enumerate(messages):
if message["role"] == Role.SYSTEM:
system = message["content"]
continue
if message["role"] == split_role and r:
rounds.append(r)
r = []
r.append(message)
if r:
rounds.append(r)
return system, rounds
def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
"""
Prepare a list of logits processors based on the provided parameters.
Args:
temperature (float): The temperature value for temperature warping.
repetition_penalty (float): The repetition penalty value.
top_p (float): The top-p value for top-p warping.
top_k (int): The top-k value for top-k warping.
Returns:
LogitsProcessorList: A list of logits processors.
"""
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if repetition_penalty > 1.0:
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
if top_k > 0:
processor_list.append(TopKLogitsWarper(top_k))
return processor_list
def is_partial_stop(output: str, stop_str: str):
""" Check whether the output contains a partial stop str. """
return any(
stop_str.startswith(output[-i:])
for i in range(0, min(len(output), len(stop_str)))
)
# Models don't use the same configuration key for determining the maximum
# sequence length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these, and we
# have a preference for which value gets used.
SEQUENCE_LENGTH_KEYS = [
"max_sequence_length",
"seq_length",
"max_position_embeddings",
"max_seq_len",
"model_max_length",
]
def get_context_length(config) -> int:
""" Get the context length of a model from a huggingface model config. """
rope_scaling = getattr(config, "rope_scaling", None)
rope_scaling_factor = config.rope_scaling["factor"] if rope_scaling else 1
for key in SEQUENCE_LENGTH_KEYS:
val = getattr(config, key, None)
if val is not None:
return int(rope_scaling_factor * val)
return 2048
def apply_stopping_strings(reply: str, stop_strings: List[str]) -> Tuple[str, bool]:
"""
Apply stopping strings to the reply and check if a stop string is found.
Args:
reply (str): The reply to apply stopping strings to.
stop_strings (List[str]): The list of stopping strings to check for.
Returns:
Tuple[str, bool]: A tuple containing the modified reply and a boolean indicating if a stop string was found.
"""
stop_found = False
for string in stop_strings:
idx = reply.find(string)
if idx != -1:
reply = reply[:idx]
stop_found = True
break
if not stop_found:
# If something like "\nYo" is generated just before "\nYou: is completed, trim it
for string in stop_strings:
for j in range(len(string) - 1, 0, -1):
if reply[-j:] == string[:j]:
reply = reply[:-j]
break
else:
continue
break
return reply, stop_found