diff --git "a/llmlingua/prompt_compressor.py" "b/llmlingua/prompt_compressor.py" new file mode 100644--- /dev/null +++ "b/llmlingua/prompt_compressor.py" @@ -0,0 +1,2412 @@ +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import bisect +import re +from collections import defaultdict +from typing import List + +import numpy as np +import torch + +import nltk +import tiktoken +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForTokenClassification, + AutoTokenizer, +) +import torch.nn.functional as F +import string +import copy +from torch.utils.data import DataLoader + +from .utils import TokenClfDataset, seed_everything, is_begin_of_new_word, replace_added_token, get_pure_token + + +class PromptCompressor: + """ + PromptCompressor is designed for compressing prompts based on a given language model. + + This class initializes with the language model and its configuration, preparing it for prompt compression tasks. + The PromptCompressor class is versatile and can be adapted for various models and specific requirements in prompt processing. + Users can specify different model names and configurations as needed for their particular use case.The architecture is + based on the paper "LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models". Jiang, Huiqiang, Qianhui Wu, + Chin-Yew Lin, Yuqing Yang, and Lili Qiu. "Llmlingua: Compressing prompts for accelerated inference of large language models." + arXiv preprint arXiv:2310.05736 (2023). + + Args: + model_name (str, optional): The name of the language model to be loaded. Default is "NousResearch/Llama-2-7b-hf". + device_map (str, optional): The device to load the model onto, e.g., "cuda" for GPU. Default is "cuda". + model_config (dict, optional): A dictionary containing the configuration parameters for the model. Default is an empty dictionary. + open_api_config (dict, optional): A dictionary containing configuration for openai APIs that may be used in conjunction with the model. Default is an empty dictionary. + use_llmlingua2 (bool, optional): Whether to use llmlingua-2 compressor based on the paper + "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression". + Zhuoshi Pan, Qianhui Wu, Huiqiang Jiang, Menglin Xia, Xufang Luo, Jue Zhang, Qingwei Lin, Victor Ruhle, Yuqing Yang, Chin-Yew Lin, H. Vicky Zhao, Lili Qiu, Dongmei Zhang. + "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression". arXiv preprint arXiv:, + Default is True. + llmlingua2_config (dict, optional): A dictionary containing the configuration parameters for llmlingua-2. Default is + { + "max_batch_size": 50, + "max_force_token": 100, # max number of the tokens which will be forcely preserved + } + Example: + >>> compress_method = PromptCompressor(model_name="xxx/llmlingua-2-xlm-roberta-large-meetingbank", use_llmlingua2=True, ) + >>> context = ["This is the first context sentence.", "Here is another context sentence."] + >>> result = compress_method.compress_prompt(context, use_context_level_filter=True, target_token=5) + >>> print(result["compressed_prompt"]) + # This will print the compressed version of the context. + + Note: + The `PromptCompressor` class requires the Hugging Face Transformers library and an appropriate environment to load and run the models. + """ + + def __init__( + self, + model_name: str = "NousResearch/Llama-2-7b-hf", + device_map: str = "cuda", + model_config: dict = {}, + open_api_config: dict = {}, + use_llmlingua2: bool = True, + llmlingua2_config: dict = {}, + ): + self.model_name = model_name + self.use_llmlingua2 = use_llmlingua2 + self.retrieval_model = None + self.retrieval_model_name = None + self.open_api_config = open_api_config + self.cache_bos_num = 10 + self.prefix_bos_num = 100 + self.oai_tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") + + self.load_model(model_name, device_map, model_config) + if use_llmlingua2: + self.init_llmlingua2(**llmlingua2_config) + + def init_llmlingua2( + self, + max_batch_size: int = 50, + max_force_token: int = 100, + ): + + seed_everything(42) + self.max_batch_size = max_batch_size + self.max_seq_len = 512 + self.max_force_token = max_force_token + self.special_tokens = set(self.tokenizer.special_tokens_map.values()) + + self.added_tokens = [f"[NEW{i}]" for i in range(max_force_token)] + self.tokenizer.add_special_tokens( + {"additional_special_tokens": self.added_tokens} + ) + self.model.resize_token_embeddings(len(self.tokenizer)) + + def load_model( + self, model_name: str, device_map: str = "cuda", model_config: dict = {} + ): + trust_remote_code = model_config.get("trust_remote_code", True) + if "trust_remote_code" not in model_config: + model_config["trust_remote_code"] = trust_remote_code + config = AutoConfig.from_pretrained(model_name, **model_config) + tokenizer = AutoTokenizer.from_pretrained(model_name, **model_config) + if model_config.get("pad_to_left", True): + tokenizer.padding_side = "left" + tokenizer.pad_token_id = ( + config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id + ) + MODEL_CLASS = ( + AutoModelForTokenClassification + if any("ForTokenClassification" in ar for ar in config.architectures) + else AutoModelForCausalLM + ) + self.device = ( + device_map + if any(key in device_map for key in ["cuda", "cpu", "mps"]) + else "cuda" + ) + if "cuda" in device_map or "cpu" in device_map: + model = MODEL_CLASS.from_pretrained( + model_name, + torch_dtype=model_config.get( + "torch_dtype", "auto" if device_map == "cuda" else torch.float32 + ), + device_map=device_map, + config=config, + ignore_mismatched_sizes=True, + **model_config, + ) + else: + model = MODEL_CLASS.from_pretrained( + model_name, + device_map=device_map, + torch_dtype=model_config.get("torch_dtype", "auto"), + pad_token_id=tokenizer.pad_token_id, + **model_config, + ) + self.tokenizer = tokenizer + self.model = model + self.context_idxs = [] + self.max_position_embeddings = config.max_position_embeddings + + def get_ppl( + self, + text: str, + granularity: str = "sentence", + input_ids=None, + attention_mask=None, + past_key_values=None, + return_kv=False, + end=None, + condition_mode: str = "none", + condition_pos_id: int = 0, + ): + if input_ids is None: + tokenized_text = self.tokenizer(text, return_tensors="pt") + input_ids = tokenized_text["input_ids"].to(self.device) + attention_mask = tokenized_text["attention_mask"].to(self.device) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + else: + past_length = 0 + if end is None: + end = input_ids.shape[1] + end = min(end, past_length + self.max_position_embeddings) + with torch.no_grad(): + response = self.model( + input_ids[:, past_length:end], + attention_mask=attention_mask[:, :end], + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = response.past_key_values + + shift_logits = response.logits[..., :-1, :].contiguous() + shift_labels = input_ids[..., past_length + 1 : end].contiguous() + # Flatten the tokens + active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1) + active_logits = shift_logits.view(-1, shift_logits.size(-1))[active] + active_labels = shift_labels.view(-1)[active] + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct(active_logits, active_labels) + if condition_mode == "before": + loss = loss[:condition_pos_id] + elif condition_mode == "after": + loss = loss[condition_pos_id:] + res = loss.mean() if granularity == "sentence" else loss + return (res, past_key_values) if return_kv else res + + def __call__(self, *args, **kwargs): + return self.compress_prompt(*args, **kwargs) + + def structured_compress_prompt( + self, + context: List[str], + instruction: str = "", + question: str = "", + rate: float = 0.5, + target_token: float = -1, + iterative_size: int = 200, + force_context_ids: List[int] = None, + force_context_number: int = None, + use_sentence_level_filter: bool = False, + use_context_level_filter: bool = True, + use_token_level_filter: bool = True, + keep_split: bool = False, + keep_first_sentence: int = 0, + keep_last_sentence: int = 0, + keep_sentence_number: int = 0, + high_priority_bonus: int = 100, + context_budget: str = "+100", + token_budget_ratio: float = 1.4, + condition_in_question: str = "none", + reorder_context: str = "original", + dynamic_context_compression_ratio: float = 0.0, + condition_compare: bool = False, + add_instruction: bool = False, + rank_method: str = "llmlingua", + concate_question: bool = True, + ): + """ + Compresses the given prompt context based on a specified structure. + + Each element of context should be segmented using one or more non-nested '' tags. + Each '' tag can include optional parameters 'rate' and 'compress' (e.g., ''), + indicating the compression rate for that segment. Default values are 'rate=rate' and 'compress=True'. + When 'compress' is set to False, it overrides the 'rate' parameter, resulting in no compression for that segment. + + Args: + context (List[str]): List of context strings divided by '' tags with optional compression settings. + instruction (str, optional): Additional instruction text to be included in the prompt. Default is an empty string. + question (str, optional): A specific question that the prompt is addressing. Default is an empty string. + rate (float, optional): The compression rate is defined the same as in paper "Language Modeling Is Compression". + Delétang, Grégoire, Anian Ruoss, Paul-Ambroise Duquenne, Elliot Catt, Tim Genewein, Christopher Mattern, + Jordi Grau-Moya et al. "Language modeling is compression." arXiv preprint arXiv:2309.10668 (2023): + .. math::\text{Compression Rate} = \frac{\text{Compressed Size}}{\text{Raw Size}} + Default is 0.5. The actual compression rate is generally lower than the specified target, but there can be + fluctuations due to differences in tokenizers. If specified, it should be a float less than or equal + to 1.0, representing the target compression rate. ``rate``, is applicable only within the context-level filter + and the sentence-level filter. In the token-level filter, the rate for each segment overrides the global rate. + However, for segments where no specific rate is defined, the global rate serves as the default value. The final + compression rate of the entire text is a composite result of multiple compression rates applied across different sections. + target_token (float, optional): The global maximum number of tokens to be achieved. Default is -1, indicating no + specific target. The actual number of tokens after compression should generally be less than the specified target_token, + but there can be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as + the sole criterion, overriding the ``rate``. ``target_token``, is applicable only within the context-level + filter and the sentence-level filter. In the token-level filter, the rate for each segment overrides the global target token. + However, for segments where no specific rate is defined, the global rate calculated from global target token serves + as the default value. The final target token of the entire text is a composite result of multiple compression rates + applied across different sections. + iterative_size (int, optional): The number of tokens to consider in each iteration of compression. Default is 200. + force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None. + force_context_number (int, optional): The number of context sections to forcibly include. Default is None. + use_sentence_level_filter (bool, optional): Whether to apply sentence-level filtering in compression. Default is False. + use_context_level_filter (bool, optional): Whether to apply context-level filtering in compression. Default is True. + use_token_level_filter (bool, optional): Whether to apply token-level filtering in compression. Default is True. + keep_split (bool, optional): Whether to preserve the original separators without compression. Default is False. + keep_first_sentence (int, optional): Number of sentences to forcibly preserve from the start of the context. Default is 0. + keep_last_sentence (int, optional): Number of sentences to forcibly preserve from the end of the context. Default is 0. + keep_sentence_number (int, optional): Total number of sentences to forcibly preserve in the compression. Default is 0. + high_priority_bonus (int, optional): Bonus score for high-priority sentences to influence their likelihood of being retained. Default is 100. + context_budget (str, optional): Token budget for the context-level filtering, expressed as a string to indicate flexibility. Default is "+100". + token_budget_ratio (float, optional): Ratio to adjust token budget during sentence-level filtering. Default is 1.4. + condition_in_question (str, optional): Specific condition to apply to question in the context. Default is "none". + reorder_context (str, optional): Strategy for reordering context in the compressed result. Default is "original". + dynamic_context_compression_ratio (float, optional): Ratio for dynamically adjusting context compression. Default is 0.0. + condition_compare (bool, optional): Whether to enable condition comparison during token-level compression. Default is False. + add_instruction (bool, optional): Whether to add the instruction to the prompt prefix. Default is False. + rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua". + concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True. + + Returns: + dict: A dictionary containing: + - "compressed_prompt" (str): The resulting compressed prompt. + - "origin_tokens" (int): The original number of tokens in the input. + - "compressed_tokens" (int): The number of tokens in the compressed output. + - "ratio" (str): The compression ratio achieved, calculated as the original token number divided by the token number after compression. + - "rate" (str): The compression rate achieved, in a human-readable format. + - "saving" (str): Estimated savings in GPT-4 token usage. + """ + if not context: + context = [" "] + if isinstance(context, str): + context = [context] + context = [ + self.tokenizer.decode(self.tokenizer(c, add_special_tokens=False).input_ids) + for c in context + ] + context_tokens_length = [self.get_token_length(c) for c in context] + instruction_tokens_length, question_tokens_length = self.get_token_length( + instruction + ), self.get_token_length(question) + if target_token == -1: + target_token = ( + ( + instruction_tokens_length + + question_tokens_length + + sum(context_tokens_length) + ) + * rate + - instruction_tokens_length + - (question_tokens_length if concate_question else 0) + ) + else: + rate = target_token / sum(context_tokens_length) + ( + context, + context_segs, + context_segs_rate, + context_segs_compress, + ) = self.segment_structured_context(context, rate) + return self.compress_prompt( + context, + instruction, + question, + rate, + target_token, + iterative_size, + force_context_ids, + force_context_number, + use_sentence_level_filter, + use_context_level_filter, + use_token_level_filter, + keep_split, + keep_first_sentence, + keep_last_sentence, + keep_sentence_number, + high_priority_bonus, + context_budget, + token_budget_ratio, + condition_in_question, + reorder_context, + dynamic_context_compression_ratio, + condition_compare, + add_instruction, + rank_method, + concate_question, + context_segs=context_segs, + context_segs_rate=context_segs_rate, + context_segs_compress=context_segs_compress, + ) + + def compress_prompt( + self, + context: List[str], + instruction: str = "", + question: str = "", + rate: float = 0.5, + target_token: float = -1, + iterative_size: int = 200, + force_context_ids: List[int] = None, + force_context_number: int = None, + use_sentence_level_filter: bool = False, + use_context_level_filter: bool = True, + use_token_level_filter: bool = True, + keep_split: bool = False, + keep_first_sentence: int = 0, + keep_last_sentence: int = 0, + keep_sentence_number: int = 0, + high_priority_bonus: int = 100, + context_budget: str = "+100", + token_budget_ratio: float = 1.4, + condition_in_question: str = "none", + reorder_context: str = "original", + dynamic_context_compression_ratio: float = 0.0, + condition_compare: bool = False, + add_instruction: bool = False, + rank_method: str = "llmlingua", + concate_question: bool = True, + context_segs: List[str] = None, + context_segs_rate: List[float] = None, + context_segs_compress: List[bool] = None, + target_context: int = -1, + context_level_rate: float = 1.0, + context_level_target_token: int = -1, + return_word_label: bool = False, + word_sep: str = "\t\t|\t\t", + label_sep: str = " ", + token_to_word: str = "mean", + force_tokens: List[str] = [], + force_reserve_digit: bool = False, + drop_consecutive: bool = False, + chunk_end_tokens: List[str] = [".", "\n"], + ): + """ + Compresses the given context. + + Args: + context (List[str]): List of context strings that form the basis of the prompt. + instruction (str, optional): Additional instruction text to be included in the prompt. Default is an empty string. + question (str, optional): A specific question that the prompt is addressing. Default is an empty string. + rate (float, optional): The maximum compression rate target to be achieved. The compression rate is defined + the same as in paper "Language Modeling Is Compression". Delétang, Grégoire, Anian Ruoss, Paul-Ambroise Duquenne, + Elliot Catt, Tim Genewein, Christopher Mattern, Jordi Grau-Moya et al. "Language modeling is compression." + arXiv preprint arXiv:2309.10668 (2023): + .. math::\text{Compression Rate} = \frac{\text{Compressed Size}}{\text{Raw Size}} + Default is 0.5. The actual compression rate is generally lower than the specified target, but there can be + fluctuations due to differences in tokenizers. If specified, it should be a float less than or equal + to 1.0, representing the target compression rate. + target_token (float, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target. + The actual number of tokens after compression should generally be less than the specified target_token, but there can + be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as + the sole criterion, overriding the ``rate``. + iterative_size (int, optional): The number of tokens to consider in each iteration of compression. Default is 200. + force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None. + force_context_number (int, optional): The number of context sections to forcibly include. Default is None. + use_sentence_level_filter (bool, optional): Whether to apply sentence-level filtering in compression. Default is False. + use_context_level_filter (bool, optional): Whether to apply context-level filtering in compression. Default is True. + use_token_level_filter (bool, optional): Whether to apply token-level filtering in compression. Default is True. + keep_split (bool, optional): Whether to preserve the original separators without compression. Default is False. + keep_first_sentence (int, optional): Number of sentences to forcibly preserve from the start of the context. Default is 0. + keep_last_sentence (int, optional): Number of sentences to forcibly preserve from the end of the context. Default is 0. + keep_sentence_number (int, optional): Total number of sentences to forcibly preserve in the compression. Default is 0. + high_priority_bonus (int, optional): Bonus score for high-priority sentences to influence their likelihood of being retained. Default is 100. + context_budget (str, optional): Token budget for the context-level filtering, expressed as a string to indicate flexibility. Default is "+100". + token_budget_ratio (float, optional): Ratio to adjust token budget during sentence-level filtering. Default is 1.4. + condition_in_question (str, optional): Specific condition to apply to question in the context. Default is "none". + reorder_context (str, optional): Strategy for reordering context in the compressed result. Default is "original". + dynamic_context_compression_ratio (float, optional): Ratio for dynamically adjusting context compression. Default is 0.0. + condition_compare (bool, optional): Whether to enable condition comparison during token-level compression. Default is False. + add_instruction (bool, optional): Whether to add the instruction to the prompt prefix. Default is False. + rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua". + concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True. + + target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target. + context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0. + context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression. + Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario. + force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None. + return_word_label (bool, optional): Whether to return word with corresponding label. Default is False. + word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t". + label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label. Default is " ". + token_to_word (str, optional): How to convert token probability to word probability. Default is "mean". + force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is []. + force_reserve_digit (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False. + drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt. + Default is False. + chunk_end_tokens (List[str], optinal): The early stop tokens for segmenting chunk. Default is [".", "\n"], + Returns: + dict: A dictionary containing: + - "compressed_prompt" (str): The resulting compressed prompt. + - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt. Only used in llmlingua2. + - "fn_labeled_original_prompt" (str): original words along with their labels + indicating whether to reserve in compressed prompt, in the format (word label_sep label) + Only used in llmlingua2 when return_word_label = True. + - "origin_tokens" (int): The original number of tokens in the input. + - "compressed_tokens" (int): The number of tokens in the compressed output. + - "ratio" (str): The compression ratio achieved, calculated as the original token number divided by the token number after compression. + - "rate" (str): The compression rate achieved, in a human-readable format. + - "saving" (str): Estimated savings in GPT-4 token usage. + """ + if self.use_llmlingua2: + return self.compress_prompt_llmlingua2( + context, + rate=rate, + target_token=target_token, + use_context_level_filter=use_context_level_filter, + use_token_level_filter=use_token_level_filter, + target_context=target_context, + context_level_rate=context_level_rate, + context_level_target_token=context_level_target_token, + force_context_ids=force_context_ids, + return_word_label=return_word_label, + word_sep=word_sep, + label_sep=label_sep, + token_to_word=token_to_word, + force_tokens=force_tokens, + force_reserve_digit=force_reserve_digit, + drop_consecutive=drop_consecutive, + chunk_end_tokens=chunk_end_tokens, + ) + assert ( + rate <= 1.0 + ), "Error: 'rate' must not exceed 1.0. The value of 'rate' indicates compression rate and must be within the range [0, 1]." + + if not context: + context = [" "] + if isinstance(context, str): + context = [context] + assert not ( + rank_method == "longllmlingua" and not question + ), "In the LongLLMLingua, it is necessary to set a question." + if condition_compare and "_condition" not in condition_in_question: + condition_in_question += "_condition" + if rank_method == "longllmlingua": + if condition_in_question == "none": + condition_in_question = "after" + elif rank_method == "llmlingua": + condition_in_question = ( + "none" + if "_condition" not in condition_in_question + else "none_condition" + ) + origin_tokens = len( + self.oai_tokenizer.encode( + "\n\n".join([instruction] + context + [question]).strip() + ) + ) + context_tokens_length = [self.get_token_length(c) for c in context] + instruction_tokens_length, question_tokens_length = self.get_token_length( + instruction + ), self.get_token_length(question) + if target_token == -1: + target_token = ( + ( + instruction_tokens_length + + question_tokens_length + + sum(context_tokens_length) + ) + * rate + - instruction_tokens_length + - (question_tokens_length if concate_question else 0) + ) + condition_flag = "_condition" in condition_in_question + condition_in_question = condition_in_question.replace("_condition", "") + + if len(context) > 1 and use_context_level_filter: + context, dynamic_ratio, context_used = self.control_context_budget( + context, + context_tokens_length, + target_token, + force_context_ids, + force_context_number, + question, + condition_in_question, + reorder_context=reorder_context, + dynamic_context_compression_ratio=dynamic_context_compression_ratio, + rank_method=rank_method, + context_budget=context_budget, + context_segs=context_segs, + context_segs_rate=context_segs_rate, + context_segs_compress=context_segs_compress, + ) + if context_segs is not None: + context_segs = [context_segs[idx] for idx in context_used] + context_segs_rate = [context_segs_rate[idx] for idx in context_used] + context_segs_compress = [ + context_segs_compress[idx] for idx in context_used + ] + else: + dynamic_ratio = [0.0] * len(context) + + segments_info = [] + if use_sentence_level_filter: + context, segments_info = self.control_sentence_budget( + context, + target_token, + keep_first_sentence=keep_first_sentence, + keep_last_sentence=keep_last_sentence, + keep_sentence_number=keep_sentence_number, + high_priority_bonus=high_priority_bonus, + token_budget_ratio=token_budget_ratio, + question=question, + condition_in_question=condition_in_question, + rank_method=rank_method, + context_segs=context_segs, + context_segs_rate=context_segs_rate, + context_segs_compress=context_segs_compress, + ) + elif context_segs is not None: + for context_idx in range(len(context)): + segments_info.append( + [ + (len(seg_text), seg_rate, seg_compress) + for seg_text, seg_rate, seg_compress in zip( + context_segs[context_idx], + context_segs_rate[context_idx], + context_segs_compress[context_idx], + ) + ] + ) + segments_info = [ + self.concate_segment_info(segment_info) for segment_info in segments_info + ] + + if condition_flag: + prefix = question + "\n\n" + instruction if add_instruction else question + if ( + self.get_token_length(prefix + "\n\n") + iterative_size * 2 + > self.max_position_embeddings + ): + tokens = self.tokenizer(prefix, add_special_tokens=False).input_ids + prefix = self.tokenizer.decode( + tokens[: self.prefix_bos_num] + + tokens[ + len(tokens) + - self.max_position_embeddings + + 2 + + self.prefix_bos_num + + 2 * iterative_size : + ] + ) + start = self.get_prefix_length(prefix + "\n\n", context[0]) + context = [prefix] + context + else: + start = 0 + + if use_token_level_filter: + context = self.iterative_compress_prompt( + context, + target_token, + iterative_size=iterative_size, + keep_split=keep_split, + start=start, + dynamic_ratio=dynamic_ratio, + condition_compare=condition_compare, + segments_info=segments_info, + ) + compressed_prompt = ( + self.tokenizer.batch_decode(context[0])[0] + .replace(" ", "") + .replace("", "") + ) + else: + if condition_flag: + context = context[1:] + compressed_prompt = "\n\n".join(context) + + res = [] + if instruction: + res.append(instruction) + if compressed_prompt.strip(): + res.append(compressed_prompt) + if question and concate_question: + res.append(question) + + compressed_prompt = "\n\n".join(res) + + compressed_tokens = len(self.oai_tokenizer.encode(compressed_prompt)) + saving = (origin_tokens - compressed_tokens) * 0.06 / 1000 + ratio = 1 if compressed_tokens == 0 else origin_tokens / compressed_tokens + rate = 1 / ratio + return { + "compressed_prompt": compressed_prompt, + "origin_tokens": origin_tokens, + "compressed_tokens": compressed_tokens, + "ratio": f"{ratio:.1f}x", + "rate": f"{rate * 100:.1f}%", + "saving": f", Saving ${saving:.1f} in GPT-4.", + } + + def compress_prompt_llmlingua2( + self, + context: List[str], + rate: float = 0.5, + target_token: int = -1, + use_context_level_filter: bool = False, + use_token_level_filter: bool = True, + target_context: int = -1, + context_level_rate: float = 1.0, + context_level_target_token: int = -1, + force_context_ids: List[int] = [], + return_word_label: bool = False, + word_sep: str = "\t\t|\t\t", + label_sep: str = " ", + token_to_word: str = "mean", + force_tokens: List[str] = [], + force_reserve_digit: bool = False, + drop_consecutive: bool = False, + chunk_end_tokens: List[str] = [".", "\n"], + ): + """ + Compresses the given context, instruction and question. + + Args: + context (List[str]): List of context strings that form the basis of the prompt. + rate (float, optional): The minimum compression rate target to be achieved. Default is 0.5. The actual compression rate + generally exceeds the specified target, but there can be fluctuations due to differences in tokenizers. If specified, + it should be a float greater than or equal to 1.0, representing the target compression rate. + target_token (int, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target. + The actual number of tokens after compression should generally be less than the specified target_token, but there can + be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as + the sole criterion, overriding the rate. + target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target. + Only used in the coarse-to-fine compression. + context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0. + Only used in the coarse-to-fine compression. + context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression. + Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario. + force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None. + return_word_label (bool, optional): Whether to return word with corresponding label. Default is False. + word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t". + label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label. Default is " ". + token_to_word (str, optional): How to convert token probability to word probability. Default is "mean". + force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is []. + force_reserve_digit (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False. + drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt. + Default is False. + chunk_end_tokens (List[str], optional): The early stop tokens for segmenting chunk. Default is [".", "\n"]. + Returns: + dict: A dictionary containing: + - "compressed_prompt" (str): The resulting compressed prompt. + - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt. + - "fn_labeled_original_prompt" (str): original words along with their labels + indicating whether to reserve in compressed prompt, in the format (word label_sep label) + - "origin_tokens" (int): The original number of tokens in the input. + - "compressed_tokens" (int): The number of tokens in the compressed output. + - "ratio" (str): The compression ratio achieved, in a human-readable format. + - "rate" (str): The compression rate achieved, in a human-readable format. + - "saving" (str): Estimated savings in GPT-4 token usage. + + """ + assert len(force_tokens) <= self.max_force_token + token_map = {} + for i, t in enumerate(force_tokens): + if len(self.tokenizer.tokenize(t)) != 1: + token_map[t] = self.added_tokens[i] + chunk_end_tokens = copy.deepcopy(chunk_end_tokens) + for c in chunk_end_tokens: + if c in token_map: + chunk_end_tokens.append(token_map[c]) + chunk_end_tokens = set(chunk_end_tokens) + + if type(context) == str: + context = [context] + context = copy.deepcopy(context) + + if len(context) == 1 and use_context_level_filter: + use_context_level_filter = False + + n_original_token = 0 + context_chunked = [] + for i in range(len(context)): + n_original_token += self.get_token_length(context[i], use_oai_tokenizer=True) + for ori_token, new_token in token_map.items(): + context[i] = context[i].replace(ori_token, new_token) + context_chunked.append(self.__chunk_context(context[i], chunk_end_tokens=chunk_end_tokens)) + + if use_context_level_filter: + # want use_context_level_filter but do not specify any parameters in context level? + # we will set context_level_rate = (rate + 1.0) / 2 if specify rate or target_token * 2 if specify target_token + if ( + target_context <= 0 + and context_level_rate >= 1.0 + and context_level_target_token <= 0 + ): + if target_token < 0 and rate < 1.0: + context_level_rate = ( + (rate + 1.0) / 2 if use_token_level_filter else rate + ) + print( + f"set context level compression rate to {context_level_rate}." + ) + if target_token >= 0: + context_level_target_token = ( + target_token * 2 if use_token_level_filter else target_token + ) + print( + f"set context level target token to {context_level_target_token}." + ) + + if target_context >= 0: + context_level_rate = min(target_context / len(context), 1.0) + # print(f'override context level compression rate to {context_level_rate} because you specified target_context = {target_context}.') + if context_level_target_token >= 0: + context_level_rate = min( + context_level_target_token / n_original_token, 1.0 + ) + # print(f'override context level compression rate to {context_level_rate} because you specified context_level_target_token = {context_level_target_token}.') + + context_probs, context_words = self.__get_context_prob( + context_chunked, + token_to_word=token_to_word, + force_tokens=force_tokens, + token_map=token_map, + force_reserve_digit=force_reserve_digit, + ) + + threshold = np.percentile( + context_probs, int(100 * (1 - context_level_rate)) + ) + + reserved_context = [] + context_label = [False] * len(context_probs) + for i, p in enumerate(context_probs): + if p >= threshold or ( + force_context_ids is not None and i in force_context_ids + ): + reserved_context.append(context_chunked[i]) + context_label[i] = True + n_reserved_token = 0 + for chunks in reserved_context: + for c in chunks: + n_reserved_token += self.get_token_length(c, use_oai_tokenizer=True) + if target_token >= 0: + rate = min(target_token / n_reserved_token, 1.0) + print( + f"override compression rate to {rate} because you specified target_token = {target_token}." + ) + + if use_token_level_filter: + compressed_context, word_list, word_label_list = self.__compress( + reserved_context, + reduce_rate=max(0, 1 - rate), + token_to_word=token_to_word, + force_tokens=force_tokens, + token_map=token_map, + force_reserve_digit=force_reserve_digit, + drop_consecutive=drop_consecutive, + ) + else: + compressed_context, word_list, word_label_list = self.__compress( + reserved_context, + reduce_rate=0, + token_to_word=token_to_word, + force_tokens=force_tokens, + token_map=token_map, + force_reserve_digit=force_reserve_digit, + drop_consecutive=drop_consecutive, + ) + print( + "return the original text because you specify use_token_level_filter=False" + ) + + n_compressed_token = 0 + for c in compressed_context: + n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True) + saving = (n_original_token - n_compressed_token) * 0.06 / 1000 + ratio = ( + 1 if n_compressed_token == 0 else n_original_token / n_compressed_token + ) + res = { + "compressed_prompt": "\n\n".join(compressed_context), + "compressed_prompt_list": compressed_context, + "origin_tokens": n_original_token, + "compressed_tokens": n_compressed_token, + "ratio": f"{ratio:.1f}x", + "rate": f"{1 / ratio * 100:.1f}%", + "saving": f", Saving ${saving:.1f} in GPT-4.", + } + if return_word_label: + words = [] + labels = [] + j = 0 + for i in range(len(context)): + if context_label[i]: + words.extend(word_list[j]) + labels.extend(word_label_list[j]) + j += 1 + else: + words.extend(context_words[i]) + labels.extend([0] * len(context_words[i])) + word_label_lines = word_sep.join( + [f"{word}{label_sep}{label}" for word, label in zip(words, labels)] + ) + res["fn_labeled_original_prompt"] = word_label_lines + return res + + if target_token > 0: + rate = min(target_token / n_original_token, 1.0) + print( + f"override compression rate to {rate} \ + because you specified target_token = {target_token}." + ) + + if use_token_level_filter: + compressed_context, word_list, word_label_list = self.__compress( + context_chunked, + reduce_rate=max(0, 1 - rate), + token_to_word=token_to_word, + force_tokens=force_tokens, + token_map=token_map, + force_reserve_digit=force_reserve_digit, + drop_consecutive=drop_consecutive, + ) + else: + compressed_context, word_list, word_label_list = self.__compress( + context_chunked, + reduce_rate=0, + token_to_word=token_to_word, + force_tokens=force_tokens, + token_map=token_map, + force_reserve_digit=force_reserve_digit, + drop_consecutive=drop_consecutive, + ) + print( + "return the original text because you specify use_token_level_filter=False" + ) + + n_compressed_token = 0 + for c in compressed_context: + n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True) + saving = (n_original_token - n_compressed_token) * 0.06 / 1000 + ratio = 1 if n_compressed_token == 0 else n_original_token / n_compressed_token + res = { + "compressed_prompt": "\n\n".join(compressed_context), + "compressed_prompt_list": compressed_context, + "origin_tokens": n_original_token, + "compressed_tokens": n_compressed_token, + "ratio": f"{ratio:.1f}x", + "rate": f"{1 / ratio * 100:.1f}%", + "saving": f", Saving ${saving:.1f} in GPT-4.", + } + if return_word_label: + words = [] + labels = [] + for w_list, l_list in zip(word_list, word_label_list): + words.extend(w_list) + labels.extend(l_list) + + # new_words = [] + # new_labels = [] + # for i in range(len(words)): + # word, label = words[i], labels[i] + # if word in string.punctuation: + # if labels[i-1] == 1 and label == 1 and i > 0: + # new_words[-1] += word + # else: + # new_words.append(word) + # new_labels.append(label) + # word_label_lines = word_sep.join([f'{word}{label_sep}{label}' for word, label in zip(new_words, new_labels)]) + + word_label_lines = word_sep.join( + [f"{word}{label_sep}{label}" for word, label in zip(words, labels)] + ) + res["fn_labeled_original_prompt"] = word_label_lines + return res + + def get_token_length(self, text: str, add_special_tokens: bool = True, use_oai_tokenizer: bool = False): + if use_oai_tokenizer: + return len(self.oai_tokenizer.encode(text)) + else: + return len( + self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids + ) + + def get_prefix_length(self, prefix: str, text: str): + possible_prefix_token = max(self.get_token_length(prefix, False) - 3, 1) + full_input_ids = self.tokenizer( + prefix + text[:100], add_special_tokens=False + ).input_ids + for i in range(possible_prefix_token, len(full_input_ids)): + cur_prefix = self.tokenizer.decode(full_input_ids[:i]) + if cur_prefix == prefix: + break + assert self.tokenizer.decode(full_input_ids[i:]) == text[:100] + return i + + def get_condition_ppl( + self, + text: str, + question: str, + condition_in_question: str = "none", + granularity: str = "sentence", + ): + if condition_in_question == "none": + return self.get_ppl(text, granularity=granularity) + elif condition_in_question == "before": + return self.get_ppl( + question + text, + granularity=granularity, + condition_mode="after", + condition_pos_id=self.get_token_length(question) - 1, + ) + elif condition_in_question == "after": + return self.get_ppl( + text + question, + granularity=granularity, + condition_mode="after", + condition_pos_id=self.get_token_length(text) - 1, + ) + + def get_dynamic_compression_ratio( + self, + context: list, + target_token: float, + iterative_size: int, + dynamic_ratio: list, + start: int, + seg_info: List[List[tuple]] = None, + ): + def get_ratio(base: float, delta: float): + return max(min(1, base + delta), 0) + + context_length = [self.get_token_length(ii, False) + 2 for ii in context] + if start: + context_length = context_length[1:] + tau = target_token / (sum(context_length) + 1) + res, idx, last, last_target = [], 0, 1, [] + while idx < len(context_length): + if last + context_length[idx] >= iterative_size: + last_target.append( + (iterative_size - last, get_ratio(tau, dynamic_ratio[idx])) + ) + res.append(last_target) + last = last + context_length[idx] - iterative_size + if last > iterative_size: + k = last // iterative_size + res.extend( + [[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k + ) + last -= k * iterative_size + + last_target = ( + [(last, get_ratio(tau, dynamic_ratio[idx]))] if last else [] + ) + else: + last += context_length[idx] + last_target.append( + (context_length[idx], get_ratio(tau, dynamic_ratio[idx])) + ) + idx += 1 + if last_target: + res.append(last_target) + return res + + def get_structured_dynamic_compression_ratio( + self, + context: list, + iterative_size: int, + dynamic_ratio: list, + start: int, + seg_info: List[List[tuple]] = None, + ): + if start: + pure_context = context[1:] + else: + pure_context = context + global_dynamic_rate, global_dynamic_compress, segments = [], [], [] + for context_idx, text in enumerate(pure_context): + text_seen = 0 + for seg_idx, (seg_len, seg_rate, seg_compress) in enumerate( + seg_info[context_idx] + ): + seg_text = text[text_seen : text_seen + seg_len] + if ( + seg_idx == len(seg_info[context_idx]) - 1 + and context_idx != len(pure_context) - 1 + ): + seg_text += "\n\n" + segments.append(seg_text) + if seg_compress: + global_dynamic_rate.append(seg_rate) + else: + global_dynamic_rate.append(1.0) + global_dynamic_compress.append(seg_compress) + text_seen += seg_len + origin_text = "\n\n".join(pure_context) + assert len("".join(segments)) == len(origin_text) + assert len(segments) == len(global_dynamic_rate) == len(global_dynamic_compress) + + text_input_ids = self.tokenizer( + "\n\n".join(context), add_special_tokens=False + ).input_ids[start:] + assert self.tokenizer.decode(text_input_ids) == origin_text + dynamic_compression_ratio = self.token_segment( + text_input_ids, + iterative_size, + segments, + global_dynamic_rate, + global_dynamic_compress, + ) + return dynamic_compression_ratio + + def token_segment( + self, + text_input_ids: List[int], + iterative_size: int, + segments: List[str], + global_dynamic_rate: List[float], + global_dynamic_compress: List[bool], + ): + decode_window = 3 + seg_idx, seg_seen, token_seen_num, last_rate = 0, 0, 0, -1 + dynamic_compression_rate, local_compresssion_rate = [], [] + for i in range(len(text_input_ids)): + if i < decode_window: + id_pre, id_cur = text_input_ids[:i], text_input_ids[: i + 1] + else: + id_pre, id_cur = ( + text_input_ids[i - decode_window + 1 : i], + text_input_ids[i - decode_window + 1 : i + 1], + ) + cur_word = self.tokenizer.decode(id_cur)[ + len(self.tokenizer.decode(id_pre)) : + ] + cur_word_len = len(cur_word) + if cur_word_len and cur_word_len >= len(segments[seg_idx]) - seg_seen: + possible_rate, possible_compress = [], [] + while ( + cur_word_len and cur_word_len >= len(segments[seg_idx]) - seg_seen + ): + possible_rate.append(global_dynamic_rate[seg_idx]) + possible_compress.append(global_dynamic_compress[seg_idx]) + cur_word_len -= len(segments[seg_idx]) - seg_seen + seg_idx += 1 + seg_seen = 0 + if cur_word_len: + possible_rate.append(global_dynamic_rate[seg_idx]) + possible_compress.append(global_dynamic_compress[seg_idx]) + new_rate = 1.0 if False in possible_compress else min(possible_rate) + else: + new_rate = global_dynamic_rate[seg_idx] + if new_rate != last_rate and i - token_seen_num: + local_compresssion_rate.append((i - token_seen_num, last_rate)) + token_seen_num = i + last_rate = new_rate + seg_seen += cur_word_len + if (i + 1) % iterative_size == 0: + if token_seen_num != i + 1: + local_compresssion_rate.append((i + 1 - token_seen_num, last_rate)) + token_seen_num = i + 1 + dynamic_compression_rate.append(local_compresssion_rate[:]) + local_compresssion_rate = [] + if token_seen_num != len(text_input_ids): + local_compresssion_rate.append( + (len(text_input_ids) - token_seen_num, last_rate) + ) + if local_compresssion_rate != []: + dynamic_compression_rate.append(local_compresssion_rate[:]) + return dynamic_compression_rate + + def control_context_budget( + self, + context: List[str], + context_tokens_length: List[int], + target_token: float, + force_context_ids: List[int] = None, + force_context_number: int = None, + question: str = "", + condition_in_question: str = "none", + reorder_context: str = "original", + dynamic_context_compression_ratio: float = 0.0, + rank_method: str = "longllmlingua", + context_budget: str = "+100", + context_segs: List[List[str]] = None, + context_segs_rate: List[List[float]] = None, + context_segs_compress: List[List[bool]] = None, + ): + demostrations_sort = self.get_rank_results( + context, + question, + rank_method, + condition_in_question, + context_tokens_length, + ) + + if target_token < 0: + target_token = 100 + target_token = eval("target_token" + context_budget) + res = [] + used = force_context_ids if force_context_ids is not None else [] + if context_segs is not None: + for idx, _ in enumerate(context): + if False in context_segs_compress[idx]: + used.append(idx) + + self.context_idxs.append([x for idx, (x, _) in enumerate(demostrations_sort)]) + for idx, _ in demostrations_sort: + if idx >= len(context_tokens_length): + continue + target_token -= context_tokens_length[idx] + if idx not in used: + used.append(idx) + if target_token < 0 or ( + force_context_number is not None and len(res) >= force_context_number + ): + break + original_used = used + if reorder_context == "original": + used = sorted(used) + elif reorder_context == "two_stage": + l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [ + _ for idx, _ in enumerate(used) if idx % 2 == 1 + ] + used = l + r[::-1] + + if dynamic_context_compression_ratio > 0: + N = len(used) + dynamic_ratio = [ + i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0 + for i in range(-(N - 1), N, 2) + ][::-1] + dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)} + dynamic_ratio = [dynamic_ratio_map[i] for i in used] + else: + dynamic_ratio = [0.0] * len(used) + + res = [context[idx] for idx in used if idx < len(context)] + return res, dynamic_ratio, used + + def control_sentence_budget( + self, + context: List[str], + target_token: float, + keep_first_sentence: int = 0, + keep_last_sentence: int = 0, + keep_sentence_number: int = 0, + high_priority_bonus: int = 100, + token_budget_ratio: float = 1.4, + question: str = "", + condition_in_question: str = "none", + rank_method: str = "longllmlingua", + context_segs: List[List[str]] = None, + context_segs_rate: List[List[float]] = None, + context_segs_compress: List[List[bool]] = None, + ): + def keep_sentence(dem_idx: int, sent_keep: int): + idxs = sorted(dem_g[dem_idx], key=lambda x: sentence_ppl[x])[:sent_keep] + for idx in idxs: + sentence_ppl[idx] += high_priority_bonus + + def sync_sentence(segments, text): + seg_num = len(segments) + new_segments = [] + text_seen = 0 + seg_idx, cur_seg_seen = 0, 0 + for i, s in enumerate(text): + while seg_idx < seg_num and s != segments[seg_idx][cur_seg_seen]: + if cur_seg_seen < len(segments[seg_idx]) - 1: + cur_seg_seen += 1 + continue + new_segments.append(text[text_seen:i]) + text_seen = i + seg_idx += 1 + cur_seg_seen = 0 + cur_seg_seen += 1 + if seg_idx == seg_num: + break + if cur_seg_seen == len(segments[seg_idx]): + new_segments.append(text[text_seen : i + 1]) + text_seen = i + 1 + seg_idx += 1 + cur_seg_seen = 0 + if text_seen < len(text): + new_segments.append(text[text_seen:]) + assert len("".join(new_segments)) == len(text) + return new_segments + + sentences = [nltk.sent_tokenize(c) for c in context] + dem_g, s2de, idx = defaultdict(set), defaultdict(int), 0 + for idx_d, s in enumerate(sentences): + for _ in s: + dem_g[idx_d].add(idx) + s2de[idx] = idx_d + idx += 1 + + if context_segs is not None: + context_segs = [ + sync_sentence(s, "".join(c)) for s, c in zip(context_segs, sentences) + ] + sen2seg_ratio = {} + idx = 0 + for idx_d, sentences_each_context in enumerate(sentences): + segments_length = [len(s) for s in context_segs[idx_d]] + seg_idx, cur_seg_seen = 0, 0 + for sentence in sentences_each_context: + sentence_seg_ratio = [] + remain = len(sentence) + while remain: + if segments_length[seg_idx] - cur_seg_seen <= remain: + new_seg_len = segments_length[seg_idx] - cur_seg_seen + sentence_seg_ratio.append( + ( + new_seg_len, + context_segs_rate[idx_d][seg_idx], + context_segs_compress[idx_d][seg_idx], + ) + ) + seg_idx += 1 + cur_seg_seen = 0 + remain -= new_seg_len + else: + sentence_seg_ratio.append( + ( + remain, + context_segs_rate[idx_d][seg_idx], + context_segs_compress[idx_d][seg_idx], + ) + ) + cur_seg_seen += remain + remain = 0 + sen2seg_ratio[idx] = sentence_seg_ratio + idx += 1 + + context_sentences = [s for ii in sentences for s in ii] + sentence_tokens_length = [ + self.get_token_length(sentence) for sentence in context_sentences + ] + N = len(context_sentences) + flags = list(range(len(context_sentences))) + if len(sentence_tokens_length) == 1: + return context + if rank_method == "longllmlingua": + sentence_ppl = [ + self.get_condition_ppl(sentence, question, condition_in_question) + .cpu() + .numpy() + .item() + for sentence in context_sentences + ] + if keep_first_sentence: + sentence_ppl[:keep_first_sentence] = [ + ii + high_priority_bonus + for ii in sentence_ppl[:keep_first_sentence] + ] + if keep_last_sentence: + sentence_ppl[-keep_last_sentence:] = [ + ii + high_priority_bonus + for ii in sentence_ppl[-keep_last_sentence:] + ] + if keep_sentence_number: + for dem_idx in range(len(sentences)): + keep_sentence(dem_idx, keep_sentence_number) + sort_direct = -1 if condition_in_question == "none" else 1 + sent_sort = sorted( + enumerate(sentence_ppl), key=lambda x: sort_direct * x[1] + ) + else: + sent_sort = self.get_rank_results( + context_sentences, + question, + rank_method, + condition_in_question, + [0] * len(context_sentences), + ) + + sentence_flags = [False] * N + if target_token < 0: + target_token = 100 + target_token *= token_budget_ratio + res = [] + for idx, _ in sent_sort: + idx = flags[idx] + target_token -= sentence_tokens_length[idx] + sentence_flags[idx] = True + if target_token < 0: + break + + if context_segs is not None: + for idx in range(N): + preserved = [sen_seg_info[2] for sen_seg_info in sen2seg_ratio[idx]] + if False in preserved: + sentence_flags[idx] = True + + idx = 0 + res = [] + new_segments_info = [] + for s in sentences: + tmp = [jj for ii, jj in enumerate(s) if sentence_flags[idx + ii]] + res.append("".join(tmp)) + if context_segs is not None: + segment_ratio = [] + for ii in range(len(s)): + if sentence_flags[idx + ii]: + segment_ratio.extend(sen2seg_ratio[idx + ii]) + new_segments_info.append(segment_ratio) + idx += len(s) + if context_segs is not None: + new_segments_info = [ + self.concate_segment_info(segment_info) + for segment_info in new_segments_info + ] + return res, new_segments_info + + def get_compressed_input( + self, + loss, + input_ids, + attention_mask, + end=200, + iterative_size=200, + threshold=0.5, + keep_flag=None, + split_token_id: int = 13, + start: int = 0, + self_loss=None, + self_input_ids=None, + self_attention_mask=None, + ): + if self_loss is not None: + need_idx = torch.concat( + [ + loss[:start] > 0, + self_loss[: loss[start:].shape[0]] - loss[start:] > threshold, + loss[:1] > 0, + ] + ) + else: + need_idx = torch.concat([loss > threshold, loss[:1] > 0]) + need_idx[end:] = 1 + need_idx[: end - iterative_size] = 1 + loss = loss[need_idx[:-1]] + if self_loss is not None: + if need_idx.shape[0] < self_loss.shape[0] + start + 1: + need_idx = torch.cat( + [ + need_idx, + torch.ones( + self_loss.shape[0] - need_idx.shape[0] + start + 1, + dtype=torch.bool, + ).to(need_idx.device), + ] + ) + self_loss = self_loss[need_idx[start:-1]] + + if need_idx.shape[0] < input_ids.shape[1]: + need_idx = torch.cat( + [ + need_idx, + torch.ones( + input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool + ).to(need_idx.device), + ] + ) + elif need_idx.shape[0] > input_ids.shape[1]: + need_idx = need_idx[: input_ids.shape[1]] + + if keep_flag is not None: + need_idx[keep_flag == 1] = 1 + last = -1 + if keep_flag is not None: + for ii in range(max(0, end - iterative_size), end): + if need_idx[ii] != 1: + continue + now = input_ids[0][ii].detach().cpu().item() + if ( + now == split_token_id + and last == split_token_id + and keep_flag[ii].detach().cpu().item() == 0 + ): + need_idx[ii] = 0 + else: + last = now + compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0) + compressed_attention_mask = attention_mask[attention_mask == 1][ + need_idx + ].unsqueeze(0) + + if self_loss is not None: + self_compressed_input_ids = self_input_ids[self_attention_mask == 1][ + need_idx[start:] + ].unsqueeze(0) + self_compressed_attention_mask = self_attention_mask[ + self_attention_mask == 1 + ][need_idx[start:]].unsqueeze(0) + else: + self_compressed_input_ids, self_compressed_attention_mask = None, None + if keep_flag is not None: + if len(keep_flag) > len(need_idx): + keep_flag = torch.cat( + [ + keep_flag[:start], + keep_flag[start : len(need_idx) + start][need_idx], + keep_flag[start + len(need_idx) :], + ] + ) + else: + keep_flag = keep_flag[need_idx] + end -= (need_idx[:end] == 0).sum() + return ( + compressed_input_ids, + compressed_attention_mask, + keep_flag, + end, + loss, + self_loss, + self_compressed_input_ids, + self_compressed_attention_mask, + ) + + def get_estimate_threshold_base_distribution( + self, ppl, ratio: float, condition_flag: bool = False + ): + if ratio == 1.0: + return float("-inf") + ppl = ppl[ppl != 10000] + target_token = max(0, min(len(ppl) - 1, int(len(ppl) * ratio) - 1)) + return ( + ppl.sort(descending=not condition_flag) + .values[target_token] + .detach() + .cpu() + .item() + ) + + def iterative_compress_prompt( + self, + context: List[str], + target_token: float, + iterative_size: int = 200, + keep_split: bool = False, + split_token_id: int = 13, + start: int = 0, + dynamic_ratio: list = None, + condition_compare: bool = False, + segments_info: List[List[tuple]] = None, + ): + if segments_info is None or segments_info == []: + iterative_ratios = self.get_dynamic_compression_ratio( + context, target_token, iterative_size, dynamic_ratio, start + ) + else: + iterative_ratios = self.get_structured_dynamic_compression_ratio( + context, iterative_size, dynamic_ratio, start, segments_info + ) + context = "\n\n".join(context) + tokenized_text = self.tokenizer( + context, return_tensors="pt", add_special_tokens=False + ) + input_ids = tokenized_text["input_ids"].to(self.device) + attention_mask = tokenized_text["attention_mask"].to(self.device) + + N = (attention_mask == 1).sum() + compressed_input_ids, compressed_attention_mask = input_ids, attention_mask + if condition_compare: + self_input_ids, self_attention_mask = ( + input_ids[:, start:], + attention_mask[:, start:], + ) + self_compressed_input_ids, self_compressed_attention_mask = ( + self_input_ids, + self_attention_mask, + ) + + end = min(iterative_size + start, compressed_input_ids.shape[1]) + threshold, keep_flag = None, None + if keep_split: + input_ids_numpy = input_ids.cpu().detach().numpy()[0] + N = len(input_ids_numpy) + keep_flag = [ + int( + ( + ii > 0 + and input_ids_numpy[ii] == split_token_id + and input_ids_numpy[ii - 1] == split_token_id + ) + or ( + ii < N - 1 + and input_ids_numpy[ii] == split_token_id + and input_ids_numpy[ii + 1] == split_token_id + ) + ) + for ii in range(N) + ] + keep_flag = torch.tensor(keep_flag).to(self.device) + past_key_values, past_loss, ready_end = None, None, 0 + self_past_key_values, self_past_loss, self_ready_end = None, None, 0 + pop_compressed_input_ids, pop_self_compressed_input_ids = None, None + idx = 0 + while end <= compressed_input_ids.shape[1]: + if end > self.max_position_embeddings and past_key_values is not None: + # KV-Cache Compression + e, s = end - self.max_position_embeddings, min( + self.cache_bos_num + start, self.max_position_embeddings + ) + if pop_compressed_input_ids is None: + pop_compressed_input_ids = compressed_input_ids[:, :e] + else: + pop_compressed_input_ids = torch.cat( + [pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1 + ) + compressed_input_ids = compressed_input_ids[:, e:] + compressed_attention_mask = compressed_attention_mask[:, e:] + past_key_values = [ + [ + torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2), + torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2), + ] + for k, v in past_key_values + ] + if keep_flag is not None: + keep_flag = keep_flag[e:] + end, ready_end = end - e, ready_end - e + if condition_compare: + s = min(s, self_past_key_values[0][0].shape[2] - e) + self_ready_end -= e + if pop_self_compressed_input_ids is None: + pop_self_compressed_input_ids = self_compressed_input_ids[:, :e] + else: + pop_self_compressed_input_ids = torch.cat( + [ + pop_self_compressed_input_ids, + self_compressed_input_ids[:, :e], + ], + dim=-1, + ) + self_compressed_input_ids = self_compressed_input_ids[:, e:] + self_compressed_attention_mask = self_compressed_attention_mask[ + :, e: + ] + self_past_key_values = [ + [ + torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2), + torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2), + ] + for k, v in self_past_key_values + ] + + loss, past_key_values = self.get_ppl( + "", + "token", + compressed_input_ids, + compressed_attention_mask, + past_key_values=past_key_values, + return_kv=True, + end=end if idx else None, + ) + if loss.shape[0] == 0: + break + if past_loss is not None: + if end - 1 > len(past_loss): + past_loss = torch.cat( + [past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]] + ) + past_loss[ready_end : end - 1] = loss + loss = past_loss + else: + past_loss = loss + if idx: + past_key_values = [ + [k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]] + for k, v in past_key_values + ] + else: + past_key_values = None + + if condition_compare: + self_loss, self_past_key_values = self.get_ppl( + "", + "token", + self_compressed_input_ids, + self_compressed_attention_mask, + past_key_values=self_past_key_values, + return_kv=True, + end=end - start if idx else None, + ) + if self_past_loss is not None: + if end - start - 1 > len(self_past_loss): + self_past_loss = torch.cat( + [ + self_past_loss, + torch.zeros_like(self_loss)[ + : end - 1 - start - len(self_past_loss) + ], + ] + ) + self_past_loss[self_ready_end : end - start - 1] = self_loss + self_loss = self_past_loss + else: + self_past_loss = self_loss + if idx: + self_past_key_values = [ + [ + k[:, :, : end - iterative_size - start], + v[:, :, : end - iterative_size - start], + ] + for k, v in self_past_key_values + ] + else: + self_past_key_values = None + + self_ready_end = ( + end - start - iterative_size if not (start and idx == 0) else 0 + ) + ready_end = end - iterative_size if not (start and idx == 0) else 0 + + for delta_end, ratio in iterative_ratios[idx]: + loss = past_loss + if condition_compare: + self_loss = self_past_loss + threshold = self.get_estimate_threshold_base_distribution( + self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False + ) + else: + threshold = self.get_estimate_threshold_base_distribution( + loss, ratio, False + ) + + ( + compressed_input_ids, + compressed_attention_mask, + keep_flag, + end, + past_loss, + self_past_loss, + self_compressed_input_ids, + self_compressed_attention_mask, + ) = self.get_compressed_input( + loss, + compressed_input_ids, + compressed_attention_mask, + end - iterative_size + delta_end, + iterative_size=delta_end, + threshold=threshold, + keep_flag=keep_flag, + split_token_id=split_token_id, + start=start, + self_loss=self_loss if condition_compare else None, + self_input_ids=( + self_compressed_input_ids if condition_compare else None + ), + self_attention_mask=( + self_compressed_attention_mask if condition_compare else None + ), + ) + end += iterative_size + idx += 1 + if pop_compressed_input_ids is not None: + compressed_input_ids = torch.cat( + [pop_compressed_input_ids, compressed_input_ids], dim=-1 + ) + return compressed_input_ids[:, start:], compressed_attention_mask[:, start:] + + def recover( + self, + original_prompt: str, + compressed_prompt: str, + response: str, + ): + def match_from_compressed(response_word): + response_input_ids = self.tokenizer( + response_word, add_special_tokens=False + )["input_ids"] + response_set, response_c = set(response_input_ids), defaultdict(list) + for idx in range(M): + if original_input_ids[idx] in response_set: + response_c[original_input_ids[idx]].append(idx) + res, res_min, res_c = None, float("inf"), 1 + n = len(response_input_ids) + for l in response_c[response_input_ids[0]]: + x, y, c = 0, l, 1 + for x in range(1, n): + idx = bisect.bisect_right(response_c[response_input_ids[x]], y) + if ( + idx >= len(response_c[response_input_ids[x]]) + or response_c[response_input_ids[x]][idx] - y > 10 + ): + continue + c += 1 + y = response_c[response_input_ids[x]][idx] + if c > res_c: + res_c = c + res_min = y - l + 1 + res = (l, y + 1) + elif c == res_c and y - l + 1 < res_min: + res_min = y - l + 1 + res = (l, y + 1) + + if res is None: + return response_word + # while l > 0 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"): + # l -= 1 + # while r < M - 1 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"): + # l -= 1 + return self.tokenizer.decode(original_input_ids[res[0] : res[1]]) + + response_words = response.split(" ") + + original_input_ids = self.tokenizer(original_prompt, add_special_tokens=False)[ + "input_ids" + ] + N, M = len(response_words), len(original_input_ids) + recovered_response_words = [] + l = 0 + while l < N: + if response_words[l] not in compressed_prompt: + recovered_response_words.append(response_words[l]) + l += 1 + continue + r = l + while ( + r + 1 < N and " ".join(response_words[l : r + 2]) in compressed_prompt + ): + r += 1 + + match_words = match_from_compressed(" ".join(response_words[l : r + 1])) + recovered_response_words.append(match_words) + l = r + 1 + return " ".join(recovered_response_words) + + def get_rank_results( + self, + context: list, + question: str, + rank_method: str, + condition_in_question: str, + context_tokens_length: list, + ): + def get_distance_bm25(corpus, query): + from rank_bm25 import BM25Okapi + + tokenized_corpus = [doc.split(" ") for doc in corpus] + bm25 = BM25Okapi(tokenized_corpus) + tokenized_query = query.split(" ") + doc_scores = bm25.get_scores(tokenized_query) + idx = [(ii, 0) for ii in (-doc_scores).argsort()] + return idx + + def get_distance_gzip(corpus, query): + def get_score(x, y): + cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode())) + cxy = len(gzip.compress(f"{x} {y}".encode())) + return (cxy - min(cx, cy)) / max(cx, cy) + + import gzip + + doc_scores = [get_score(doc, query) for doc in corpus] + idx = [(ii, 0) for ii in np.argsort(doc_scores)] + return idx + + def get_distance_sentbert(corpus, query): + from sentence_transformers import SentenceTransformer, util + + if self.retrieval_model is None or self.retrieval_model_name != rank_method: + self.retrieval_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1") + self.retrieval_model_name = rank_method + doc_embeds = self.retrieval_model.encode(corpus) + query = self.retrieval_model.encode(query) + doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) + idx = [(ii, 0) for ii in np.argsort(doc_scores)] + return idx + + def get_distance_openai(corpus, query): + import openai + from sentence_transformers import util + + openai.api_key = self.open_api_config.get("api_key", "") + openai.api_base = self.open_api_config.get( + "api_base", "https://api.openai.com/v1" + ) + openai.api_type = self.open_api_config.get("api_type", "open_ai") + openai.api_version = self.open_api_config.get("api_version", "2023-05-15") + engine = self.open_api_config.get("engine", "text-embedding-ada-002") + + def get_embed(text): + return openai.Embedding.create( + input=[text.replace("\n", " ")], engine=engine + )["data"][0]["embedding"] + + doc_embeds = [get_embed(i) for i in corpus] + query = get_embed(query) + doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) + idx = [(ii, 0) for ii in np.argsort(doc_scores)] + return idx + + def get_distance_sentbert_bge(corpus, query): + from sentence_transformers import SentenceTransformer, util + + if self.retrieval_model is None or self.retrieval_model_name != rank_method: + self.retrieval_model = SentenceTransformer("BAAI/bge-large-en-v1.5") + self.retrieval_model_name = rank_method + doc_embeds = self.retrieval_model.encode( + [i for i in corpus], normalize_embeddings=True + ) + query = self.retrieval_model.encode(query, normalize_embeddings=True) + doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) + idx = [(ii, 0) for ii in np.argsort(doc_scores)] + return idx + + def get_distance_bge_ranker(corpus, query): + from transformers import AutoModelForSequenceClassification, AutoTokenizer + + pairs = [[i, query] for i in corpus] + if self.retrieval_model is None or self.retrieval_model_name != rank_method: + tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large") + model = ( + AutoModelForSequenceClassification.from_pretrained( + "BAAI/bge-reranker-large" + ) + .eval() + .to(self.device) + ) + self.retrieval_model = [tokenizer, model] + self.retrieval_model_name = rank_method + with torch.no_grad(): + inputs = self.retrieval_model[0]( + pairs, + padding=True, + truncation=True, + return_tensors="pt", + max_length=512, + ).to(self.device) + scores = ( + self.retrieval_model[1](**inputs, return_dict=True) + .logits.view( + -1, + ) + .float() + ) + idx = [(ii, 0) for ii in np.argsort(-scores.cpu())] + return idx + + def get_distance_bge_llmembedder(corpus, query): + from transformers import AutoModel, AutoTokenizer + + if self.retrieval_model is None or self.retrieval_model_name != rank_method: + tokenizer = AutoTokenizer.from_pretrained("BAAI/llm-embedder") + model = ( + AutoModel.from_pretrained("BAAI/llm-embedder") + .eval() + .to(self.device) + ) + self.retrieval_model = [tokenizer, model] + self.retrieval_model_name = rank_method + + instruction_qa_query = ( + "Represent this query for retrieving relevant documents: " + ) + instruction_qa_key = "Represent this document for retrieval: " + queries = [instruction_qa_query + query for _ in corpus] + keys = [instruction_qa_key + key for key in corpus] + with torch.no_grad(): + query_inputs = self.retrieval_model[0]( + queries, + padding=True, + truncation=True, + return_tensors="pt", + max_length=512, + ).to(self.device) + key_inputs = self.retrieval_model[0]( + keys, + padding=True, + truncation=True, + return_tensors="pt", + max_length=512, + ).to(self.device) + query_outputs = self.retrieval_model[1](**query_inputs) + key_outputs = self.retrieval_model[1](**key_inputs) + # CLS pooling + query_embeddings = query_outputs.last_hidden_state[:, 0] + key_embeddings = key_outputs.last_hidden_state[:, 0] + # Normalize + query_embeddings = torch.nn.functional.normalize( + query_embeddings, p=2, dim=1 + ) + key_embeddings = torch.nn.functional.normalize( + key_embeddings, p=2, dim=1 + ) + similarity = query_embeddings @ key_embeddings.T + idx = [(ii, 0) for ii in np.argsort(-similarity[0].cpu())] + return idx + + def get_distance_jinza(corpus, query): + from numpy.linalg import norm + + from transformers import AutoModel + + def cos_sim(a, b): + return (a @ b.T) / (norm(a) * norm(b)) + + if self.retrieval_model is None or self.retrieval_model_name != rank_method: + model = ( + AutoModel.from_pretrained( + "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True + ) + .eval() + .to(self.device) + ) + self.retrieval_model = model + self.retrieval_model_name = rank_method + + doc_embeds = self.retrieval_model.encode(corpus) + query = self.retrieval_model.encode(query) + doc_scores = cos_sim(doc_embeds, query) + idx = [(ii, 0) for ii in np.argsort(-doc_scores)] + return idx + + def get_distance_voyageai(corpus, query): + import voyageai + from sentence_transformers import util + + voyageai.api_key = self.open_api_config.get("voyageai_api_key", "") + + def get_embed(text): + return voyageai.get_embedding(text, model="voyage-01") + + doc_embeds = [get_embed(i) for i in corpus] + query = get_embed(query) + doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) + idx = [(ii, 0) for ii in np.argsort(doc_scores)] + return idx + + def get_distance_cohere(corpus, query): + import cohere + + api_key = self.open_api_config.get("cohere_api_key", "") + co = cohere.Client(api_key) + results = co.rerank( + model="rerank-english-v2.0", query=query, documents=corpus, top_n=20 + ) + c_map = {jj: ii for ii, jj in enumerate(corpus)} + doc_rank = [c_map[ii.document["text"]] for ii in results] + idx = [(ii, 0) for ii in doc_rank] + return idx + + def get_distance_longllmlingua(corpus, query): + context_ppl = [ + self.get_condition_ppl( + d, + query + + " We can get the answer to this question in the given documents.", + condition_in_question, + ) + - dl * 2 / 250 * 0 + for d, dl in zip(corpus, context_tokens_length) + ] + sort_direct = -1 if condition_in_question == "none" else 1 + ys = sorted(enumerate(context_ppl), key=lambda x: sort_direct * x[1]) + return ys + + method = None + if rank_method == "bm25": + method = get_distance_bm25 + elif rank_method == "gzip": + method = get_distance_gzip + elif rank_method == "sentbert": + method = get_distance_sentbert + elif rank_method == "openai": + method = get_distance_openai + elif rank_method in ["longllmlingua", "llmlingua"]: + method = get_distance_longllmlingua + elif rank_method == "bge": + method = get_distance_sentbert_bge + elif rank_method == "bge_reranker": + method = get_distance_bge_ranker + elif rank_method == "bge_llmembedder": + method = get_distance_bge_llmembedder + elif rank_method == "jinza": + method = get_distance_jinza + elif rank_method == "voyageai": + method = get_distance_voyageai + elif rank_method == "cohere": + method = get_distance_cohere + return method(context, question) + + def segment_structured_context( + self, + context: List[str], + global_rate: float, + ): + new_context, context_segs, context_segs_rate, context_segs_compress = ( + [], + [], + [], + [], + ) + for text in context: + if not text.startswith(""): + text = text + "" + + # Regular expression to match content, allowing rate and compress in any order + pattern = r"([^<]+)" + matches = re.findall(pattern, text) + + # Extracting segment contents + segments = [match[4] for match in matches] + + # Extracting rate and compress, considering their possible positions + segs_rate = [ + float(match[0]) if match[0] else (float(match[2]) if match[2] else None) + for match in matches + ] + segs_compress = [ + ( + match[1] == "True" + if match[1] + else (match[3] == "True" if match[3] else None) + ) + for match in matches + ] + + segs_compress = [ + compress if compress is not None else True for compress in segs_compress + ] + segs_rate = [ + rate if rate else (global_rate if compress else 1.0) + for rate, compress in zip(segs_rate, segs_compress) + ] + assert ( + len(segments) == len(segs_rate) == len(segs_compress) + ), "The number of segments, rates, and compress flags should be the same." + assert all( + seg_rate <= 1.0 for seg_rate in segs_rate + ), "Error: 'rate' must not exceed 1.0. The value of 'rate' indicates compression rate and must be within the range [0, 1]." + + new_context.append("".join(segments)) + context_segs.append(segments) + context_segs_rate.append(segs_rate) + context_segs_compress.append(segs_compress) + + return new_context, context_segs, context_segs_rate, context_segs_compress + + def concate_segment_info( + self, + segment_info: List[List[tuple]], + ): + new_segment_info = [] + for i, (seg_len, seg_ratio, seg_compress) in enumerate(segment_info): + if ( + new_segment_info + and new_segment_info[-1][1] == seg_ratio + and new_segment_info[-1][2] == seg_compress + ): + new_segment_info[-1] = ( + new_segment_info[-1][0] + seg_len, + seg_ratio, + seg_compress, + ) + else: + new_segment_info.append((seg_len, seg_ratio, seg_compress)) + return new_segment_info + + def __get_context_prob( + self, + context_list: list, + token_to_word="mean", + force_tokens: List[str]=[], + token_map: dict={}, + force_reserve_digit: bool=False, + ): + chunk_list = [] + for chunks in context_list: + for c in chunks: + chunk_list.append(c) + + dataset = TokenClfDataset( + chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len + ) + dataloader = DataLoader( + dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False + ) + + chunk_probs = [] + chunk_words = [] + with torch.no_grad(): + for batch in dataloader: + ids = batch["ids"].to(self.device, dtype=torch.long) + mask = batch["mask"].to(self.device, dtype=torch.long) == 1 + + outputs = self.model(input_ids=ids, attention_mask=mask) + loss, logits = outputs.loss, outputs.logits + probs = F.softmax(logits, dim=-1) + + for j in range(ids.shape[0]): + _probs = probs[j, :, 1] + _ids = ids[j] + _mask = mask[j] + + active_probs = torch.masked_select(_probs, _mask) + active_ids = torch.masked_select(_ids, _mask) + + tokens = self.tokenizer.convert_ids_to_tokens( + active_ids.squeeze().tolist() + ) + token_probs = [prob for prob in active_probs.cpu().numpy()] + + ( + words, + valid_token_probs, + valid_token_probs_no_force, + ) = self.__merge_token_to_word( + tokens, + token_probs, + force_tokens=force_tokens, + token_map=token_map, + force_reserve_digit=force_reserve_digit, + ) + word_probs_no_force = self.__token_prob_to_word_prob( + valid_token_probs_no_force, convert_mode=token_to_word + ) + + if "xlm-roberta-large" in self.model_name: + for i in range(len(words)): + words[i] = words[i].lstrip("▁") + chunk_words.append(words) + chunk_probs.append(word_probs_no_force) + + prev_idx = 0 + context_probs = [] + context_words = [] + for chunk_list in context_list: + n_chunk = len(chunk_list) + context_probs.append([]) + context_words.append([]) + for i in range(n_chunk): + context_probs[-1].extend(chunk_probs[prev_idx + i]) + context_words[-1].extend(chunk_words[prev_idx + i]) + prev_idx = prev_idx + n_chunk + context_probs = [sum(probs) / len(probs) for probs in context_probs] + return context_probs, context_words + + def __chunk_context(self, origin_text, chunk_end_tokens): + origin_list = [] + origin_tokens = self.tokenizer.tokenize(origin_text) + n = len(origin_tokens) + st = 0 + while st < n: + if st + self.max_seq_len > n - 1: + chunk = self.tokenizer.convert_tokens_to_string(origin_tokens[st:n]) + origin_list.append(chunk) + break + else: + ed = st + self.max_seq_len + for j in range(0, ed - st): + if origin_tokens[ed - j] in chunk_end_tokens: + ed = ed - j + break + chunk = self.tokenizer.convert_tokens_to_string( + origin_tokens[st : ed + 1] + ) + origin_list.append(chunk) + st = ed + 1 + return origin_list + + def __merge_token_to_word(self, tokens, token_probs, force_tokens, token_map, force_reserve_digit): + words = [] + word_probs = [] + word_probs_no_force = [] + + for token, prob in zip(tokens, token_probs): + if token in self.special_tokens: + continue + # add a new word + elif is_begin_of_new_word(token, self.model_name, force_tokens, token_map): + pure_token = get_pure_token(token, self.model_name) + prob_no_force = prob + if pure_token in force_tokens or pure_token in set(token_map.values()): + prob=1.0 + token = replace_added_token(token, token_map) + words.append(token) + word_probs.append( + [ + 1.0 + if force_reserve_digit + and bool(re.search(r"\d", token)) + else prob + ] + ) + word_probs_no_force.append([prob_no_force]) + # concatenate with previous token + else: + pure_token = get_pure_token(token, self.model_name) + words[-1] += pure_token + word_probs[-1].append( + 1.0 + if force_reserve_digit + and bool(re.search(r"\d", token)) + else prob + ) + word_probs_no_force[-1].append(prob_no_force) + + return words, word_probs, word_probs_no_force + + def __token_prob_to_word_prob(self, token_probs, convert_mode="mean"): + if convert_mode == "mean": + word_probs = [sum(p) / len(p) for p in token_probs] + elif convert_mode == "first": + word_probs = [p[0] for p in token_probs] + else: + raise NotImplementedError() + + return word_probs + + def __compress( + self, + context_list: list, + reduce_rate: float=0.5, + token_to_word: str="mean", + force_tokens: List[str]=[], + token_map: dict={}, + force_reserve_digit: bool=False, + drop_consecutive: bool=False, + ): + def split_string_to_words(input_string): + pattern = r'\b\w+\b|[<>=/!@#$%^&*()?":{}|\\`~;_+-]' + result = re.findall(pattern, input_string) + return result + # print(force_tokens, token_map, force_reserve_digit, drop_consecutive) + if reduce_rate <= 0: + words, word_labels = [], [] + for i in range(len(context_list)): + chunk_list = context_list[i] + chunk_words = [] + chunk_word_labels = [] + for j in range(len(chunk_list)): + # replace to original token + for ori_token, new_token in token_map.items(): + chunk_list[j] = chunk_list[j].replace(new_token, ori_token) + ws = split_string_to_words(chunk_list[j]) + chunk_words.extend(ws) + chunk_word_labels.extend([1 for _ in range(len(ws))]) + context_list[i] = "".join(chunk_list) + words.append(chunk_words) + word_labels.append(chunk_word_labels) + return context_list, words, word_labels + + chunk_list = [] + for chunks in context_list: + for c in chunks: + chunk_list.append(c) + + dataset = TokenClfDataset( + chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len + ) + dataloader = DataLoader( + dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False + ) + + compressed_chunk_list = [] + word_list = [] + word_label_list = [] + with torch.no_grad(): + for batch in dataloader: + ids = batch["ids"].to(self.device, dtype=torch.long) + mask = batch["mask"].to(self.device, dtype=torch.long) == 1 + + outputs = self.model(input_ids=ids, attention_mask=mask) + loss, logits = outputs.loss, outputs.logits + probs = F.softmax(logits, dim=-1) + + for j in range(ids.shape[0]): + chunk_probs = probs[j, :, 1] + chunk_ids = ids[j] + chunk_mask = mask[j] + + active_probs = torch.masked_select(chunk_probs, chunk_mask) + active_ids = torch.masked_select(chunk_ids, chunk_mask) + + tokens = self.tokenizer.convert_ids_to_tokens( + active_ids.squeeze().tolist() + ) + token_probs = [prob for prob in active_probs.cpu().numpy()] + + words, valid_token_probs, _ = self.__merge_token_to_word( + tokens=tokens, + token_probs=token_probs, + force_tokens=force_tokens, + token_map=token_map, + force_reserve_digit=force_reserve_digit, + ) + word_probs = self.__token_prob_to_word_prob( + valid_token_probs, convert_mode=token_to_word + ) + + if drop_consecutive: + threshold = np.percentile(word_probs, int(100 * reduce_rate)) + is_token_between = False + prev = None + for i, (word, word_prob) in enumerate(zip(words, word_probs)): + if word in force_tokens: + if is_token_between: + is_token_between = False + elif not is_token_between and word == prev: + word_probs[i] = 0.0 + prev = word + else: + is_token_between |= word_prob > threshold + + # calculate compression ratio w.r.t. gpt-4 tokenizer + new_token_probs = [] + for word, word_prob in zip(words, word_probs): + num_token = len(self.oai_tokenizer.encode(word)) + new_token_probs.extend([word_prob for _ in range(num_token)]) + threshold = np.percentile( + new_token_probs, int(100 * reduce_rate + 1) + ) + + keep_words = [] + word_labels = [] + assert len(words) == len(word_probs) + for word, word_porb in zip(words, word_probs): + if word_porb > threshold: + if ( + drop_consecutive + and word in force_tokens + and len(keep_words) > 0 + and keep_words[-1] == word + ): + word_labels.append(0) + else: + keep_words.append(word) + word_labels.append(1) + else: + word_labels.append(0) + keep_str = self.tokenizer.convert_tokens_to_string(keep_words) + if "xlm-roberta-large" in self.model_name: + for i in range(len(words)): + words[i] = words[i].lstrip("▁") + + compressed_chunk_list.append(keep_str) + word_list.append(words[:]) + word_label_list.append(word_labels[:]) + + compressed_context_list = [] + original_word_list = [] + original_word_label_list = [] + prev_idx = 0 + for chunk_list in context_list: + n_chunk = len(chunk_list) + compressed_context_list.append( + "".join(compressed_chunk_list[prev_idx : prev_idx + n_chunk]) + ) + original_word_list.append([]) + original_word_label_list.append([]) + for i in range(n_chunk): + original_word_list[-1].extend(word_list[prev_idx + i]) + original_word_label_list[-1].extend(word_label_list[prev_idx + i]) + prev_idx = prev_idx + n_chunk + + return compressed_context_list, original_word_list, original_word_label_list