Spaces:
Running
Running
from typing import List, Tuple | |
from openai.types.chat import ChatCompletionMessageParam | |
from transformers.generation.logits_process import ( | |
LogitsProcessorList, | |
RepetitionPenaltyLogitsProcessor, | |
TemperatureLogitsWarper, | |
TopKLogitsWarper, | |
TopPLogitsWarper, | |
) | |
from api.utils.protocol import Role | |
def parse_messages( | |
messages: List[ChatCompletionMessageParam], split_role=Role.USER | |
) -> Tuple[str, List[List[ChatCompletionMessageParam]]]: | |
""" | |
Parse a list of chat completion messages into system and rounds. | |
Args: | |
messages (List[ChatCompletionMessageParam]): The list of chat completion messages. | |
split_role: The role at which to split the rounds. Defaults to Role.USER. | |
Returns: | |
Tuple[str, List[List[ChatCompletionMessageParam]]]: A tuple containing the system message and a list of rounds. | |
""" | |
system, rounds = "", [] | |
r = [] | |
for i, message in enumerate(messages): | |
if message["role"] == Role.SYSTEM: | |
system = message["content"] | |
continue | |
if message["role"] == split_role and r: | |
rounds.append(r) | |
r = [] | |
r.append(message) | |
if r: | |
rounds.append(r) | |
return system, rounds | |
def prepare_logits_processor( | |
temperature: float, repetition_penalty: float, top_p: float, top_k: int | |
) -> LogitsProcessorList: | |
""" | |
Prepare a list of logits processors based on the provided parameters. | |
Args: | |
temperature (float): The temperature value for temperature warping. | |
repetition_penalty (float): The repetition penalty value. | |
top_p (float): The top-p value for top-p warping. | |
top_k (int): The top-k value for top-k warping. | |
Returns: | |
LogitsProcessorList: A list of logits processors. | |
""" | |
processor_list = LogitsProcessorList() | |
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases. | |
if temperature >= 1e-5 and temperature != 1.0: | |
processor_list.append(TemperatureLogitsWarper(temperature)) | |
if repetition_penalty > 1.0: | |
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) | |
if 1e-8 <= top_p < 1.0: | |
processor_list.append(TopPLogitsWarper(top_p)) | |
if top_k > 0: | |
processor_list.append(TopKLogitsWarper(top_k)) | |
return processor_list | |
def is_partial_stop(output: str, stop_str: str): | |
""" Check whether the output contains a partial stop str. """ | |
return any( | |
stop_str.startswith(output[-i:]) | |
for i in range(0, min(len(output), len(stop_str))) | |
) | |
# Models don't use the same configuration key for determining the maximum | |
# sequence length. Store them here so we can sanely check them. | |
# NOTE: The ordering here is important. Some models have two of these, and we | |
# have a preference for which value gets used. | |
SEQUENCE_LENGTH_KEYS = [ | |
"max_sequence_length", | |
"seq_length", | |
"max_position_embeddings", | |
"max_seq_len", | |
"model_max_length", | |
] | |
def get_context_length(config) -> int: | |
""" Get the context length of a model from a huggingface model config. """ | |
rope_scaling = getattr(config, "rope_scaling", None) | |
rope_scaling_factor = config.rope_scaling["factor"] if rope_scaling else 1 | |
for key in SEQUENCE_LENGTH_KEYS: | |
val = getattr(config, key, None) | |
if val is not None: | |
return int(rope_scaling_factor * val) | |
return 2048 | |
def apply_stopping_strings(reply: str, stop_strings: List[str]) -> Tuple[str, bool]: | |
""" | |
Apply stopping strings to the reply and check if a stop string is found. | |
Args: | |
reply (str): The reply to apply stopping strings to. | |
stop_strings (List[str]): The list of stopping strings to check for. | |
Returns: | |
Tuple[str, bool]: A tuple containing the modified reply and a boolean indicating if a stop string was found. | |
""" | |
stop_found = False | |
for string in stop_strings: | |
idx = reply.find(string) | |
if idx != -1: | |
reply = reply[:idx] | |
stop_found = True | |
break | |
if not stop_found: | |
# If something like "\nYo" is generated just before "\nYou: is completed, trim it | |
for string in stop_strings: | |
for j in range(len(string) - 1, 0, -1): | |
if reply[-j:] == string[:j]: | |
reply = reply[:-j] | |
break | |
else: | |
continue | |
break | |
return reply, stop_found | |