from typing import List, Tuple from openai.types.chat import ChatCompletionMessageParam from transformers.generation.logits_process import ( LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, ) from api.utils.protocol import Role def parse_messages( messages: List[ChatCompletionMessageParam], split_role=Role.USER ) -> Tuple[str, List[List[ChatCompletionMessageParam]]]: """ Parse a list of chat completion messages into system and rounds. Args: messages (List[ChatCompletionMessageParam]): The list of chat completion messages. split_role: The role at which to split the rounds. Defaults to Role.USER. Returns: Tuple[str, List[List[ChatCompletionMessageParam]]]: A tuple containing the system message and a list of rounds. """ system, rounds = "", [] r = [] for i, message in enumerate(messages): if message["role"] == Role.SYSTEM: system = message["content"] continue if message["role"] == split_role and r: rounds.append(r) r = [] r.append(message) if r: rounds.append(r) return system, rounds def prepare_logits_processor( temperature: float, repetition_penalty: float, top_p: float, top_k: int ) -> LogitsProcessorList: """ Prepare a list of logits processors based on the provided parameters. Args: temperature (float): The temperature value for temperature warping. repetition_penalty (float): The repetition penalty value. top_p (float): The top-p value for top-p warping. top_k (int): The top-k value for top-k warping. Returns: LogitsProcessorList: A list of logits processors. """ processor_list = LogitsProcessorList() # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases. if temperature >= 1e-5 and temperature != 1.0: processor_list.append(TemperatureLogitsWarper(temperature)) if repetition_penalty > 1.0: processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) if 1e-8 <= top_p < 1.0: processor_list.append(TopPLogitsWarper(top_p)) if top_k > 0: processor_list.append(TopKLogitsWarper(top_k)) return processor_list def is_partial_stop(output: str, stop_str: str): """ Check whether the output contains a partial stop str. """ return any( stop_str.startswith(output[-i:]) for i in range(0, min(len(output), len(stop_str))) ) # Models don't use the same configuration key for determining the maximum # sequence length. Store them here so we can sanely check them. # NOTE: The ordering here is important. Some models have two of these, and we # have a preference for which value gets used. SEQUENCE_LENGTH_KEYS = [ "max_sequence_length", "seq_length", "max_position_embeddings", "max_seq_len", "model_max_length", ] def get_context_length(config) -> int: """ Get the context length of a model from a huggingface model config. """ rope_scaling = getattr(config, "rope_scaling", None) rope_scaling_factor = config.rope_scaling["factor"] if rope_scaling else 1 for key in SEQUENCE_LENGTH_KEYS: val = getattr(config, key, None) if val is not None: return int(rope_scaling_factor * val) return 2048 def apply_stopping_strings(reply: str, stop_strings: List[str]) -> Tuple[str, bool]: """ Apply stopping strings to the reply and check if a stop string is found. Args: reply (str): The reply to apply stopping strings to. stop_strings (List[str]): The list of stopping strings to check for. Returns: Tuple[str, bool]: A tuple containing the modified reply and a boolean indicating if a stop string was found. """ stop_found = False for string in stop_strings: idx = reply.find(string) if idx != -1: reply = reply[:idx] stop_found = True break if not stop_found: # If something like "\nYo" is generated just before "\nYou: is completed, trim it for string in stop_strings: for j in range(len(string) - 1, 0, -1): if reply[-j:] == string[:j]: reply = reply[:-j] break else: continue break return reply, stop_found