Spaces:
Running
Running
File size: 4,489 Bytes
ca56e6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from typing import List, Tuple
from openai.types.chat import ChatCompletionMessageParam
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from api.utils.protocol import Role
def parse_messages(
messages: List[ChatCompletionMessageParam], split_role=Role.USER
) -> Tuple[str, List[List[ChatCompletionMessageParam]]]:
"""
Parse a list of chat completion messages into system and rounds.
Args:
messages (List[ChatCompletionMessageParam]): The list of chat completion messages.
split_role: The role at which to split the rounds. Defaults to Role.USER.
Returns:
Tuple[str, List[List[ChatCompletionMessageParam]]]: A tuple containing the system message and a list of rounds.
"""
system, rounds = "", []
r = []
for i, message in enumerate(messages):
if message["role"] == Role.SYSTEM:
system = message["content"]
continue
if message["role"] == split_role and r:
rounds.append(r)
r = []
r.append(message)
if r:
rounds.append(r)
return system, rounds
def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
"""
Prepare a list of logits processors based on the provided parameters.
Args:
temperature (float): The temperature value for temperature warping.
repetition_penalty (float): The repetition penalty value.
top_p (float): The top-p value for top-p warping.
top_k (int): The top-k value for top-k warping.
Returns:
LogitsProcessorList: A list of logits processors.
"""
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if repetition_penalty > 1.0:
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
if top_k > 0:
processor_list.append(TopKLogitsWarper(top_k))
return processor_list
def is_partial_stop(output: str, stop_str: str):
""" Check whether the output contains a partial stop str. """
return any(
stop_str.startswith(output[-i:])
for i in range(0, min(len(output), len(stop_str)))
)
# Models don't use the same configuration key for determining the maximum
# sequence length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these, and we
# have a preference for which value gets used.
SEQUENCE_LENGTH_KEYS = [
"max_sequence_length",
"seq_length",
"max_position_embeddings",
"max_seq_len",
"model_max_length",
]
def get_context_length(config) -> int:
""" Get the context length of a model from a huggingface model config. """
rope_scaling = getattr(config, "rope_scaling", None)
rope_scaling_factor = config.rope_scaling["factor"] if rope_scaling else 1
for key in SEQUENCE_LENGTH_KEYS:
val = getattr(config, key, None)
if val is not None:
return int(rope_scaling_factor * val)
return 2048
def apply_stopping_strings(reply: str, stop_strings: List[str]) -> Tuple[str, bool]:
"""
Apply stopping strings to the reply and check if a stop string is found.
Args:
reply (str): The reply to apply stopping strings to.
stop_strings (List[str]): The list of stopping strings to check for.
Returns:
Tuple[str, bool]: A tuple containing the modified reply and a boolean indicating if a stop string was found.
"""
stop_found = False
for string in stop_strings:
idx = reply.find(string)
if idx != -1:
reply = reply[:idx]
stop_found = True
break
if not stop_found:
# If something like "\nYo" is generated just before "\nYou: is completed, trim it
for string in stop_strings:
for j in range(len(string) - 1, 0, -1):
if reply[-j:] == string[:j]:
reply = reply[:-j]
break
else:
continue
break
return reply, stop_found
|