Spaces:
Running
Running
Gemini
fix: Resolve runtime errors and timeout issues\n\n- Create writable directories for logs and Hugging Face cache.\n- Set TRANSFORMERS_CACHE environment variable.\n- Fix SyntaxWarning in lpm_kernel/utils.py.\n- Disable ChromaDB telemetry to prevent posthog errors.
e09e6c3
| import logging | |
| from enum import Enum | |
| import tiktoken | |
| import re | |
| from typing import Any, Optional, Union, Collection, AbstractSet, Literal, List | |
| from langchain.text_splitter import TextSplitter | |
| import random | |
| import string | |
| from itertools import chain | |
| import json | |
| from lpm_kernel.configs.logging import get_train_process_logger | |
| logger = get_train_process_logger() | |
| class IntentType(Enum): | |
| Emotion = "Emotion" | |
| Knowledge = "Knowledge" | |
| def select_language_desc( | |
| preferred_language, | |
| default_desc="Identify the language of the provided Hint. Your response must be in the same language.", | |
| ): | |
| custom_desc = "You must respond in {}." | |
| if isinstance(preferred_language, str) and "/" in preferred_language: | |
| native, es = preferred_language.split("/") | |
| logging.info(f"Native: {native}, ES: {es}") | |
| return custom_desc.format(es) | |
| else: | |
| logging.info( | |
| "Error: preferred_language is not in the correct format. It should be 'native/es'." | |
| ) | |
| return default_desc | |
| def cal_upperbound( | |
| model_limit: int = 4096, | |
| generage_limit: int = 512, | |
| tolerance: int = 500, | |
| raw: str = "", | |
| model_name: str = "gpt-3.5-turbo", | |
| ) -> int: | |
| """ | |
| :param model_limit: Maximum token count for the underlying model call | |
| :param tolerance: Error tolerance buffer | |
| :param raw: system prompt and raw content | |
| :return: | |
| """ | |
| if model_name is not None: | |
| if model_name in tiktoken.model.MODEL_TO_ENCODING: | |
| enc = tiktoken.encoding_for_model(model_name) | |
| logging.info(f"Successfully initialized tokenizer for model: {model_name}") | |
| else: | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| logging.warning(f"Model '{model_name}' doesn't have a corresponding tokenizer, falling back to default: cl100k_base") | |
| else: | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| logging.info(f"No model specified, using default tokenizer: cl100k_base") | |
| raw_token = len(enc.encode(raw)) | |
| upper_bound = model_limit - raw_token - tolerance - generage_limit | |
| if upper_bound < 0: | |
| logging.info(f"raw content is too long: {raw_token}") | |
| return 0 | |
| return upper_bound | |
| def equidistant_filter(chunks, separator, filtered_chunks_n=6): | |
| # Select the first and last two chunks, sample the remaining chunks evenly from the middle | |
| gap = (len(chunks) - 2) / (filtered_chunks_n - 2) | |
| indexes = [ | |
| int(gap * i) | |
| for i in range(int(len(chunks) / gap) + 1) | |
| if (gap * i < len(chunks) - 2) | |
| ] | |
| filtered_chunks = [chunks[i] for i in indexes] | |
| filtered_chunks.append(separator.join(chunks[-2:])) | |
| return filtered_chunks | |
| def tab_or_space_replacement(match): | |
| # If there is a tab character in the matched string, replace it with a single tab, otherwise replace it with a single space | |
| return "\t" if "\t" in match.group() else " " | |
| def text_filter(text: str) -> str: | |
| pattern_tab_space = "[ \t]{3,}" | |
| pattern_wordwrap = "[\n\f\r\v]{3,}" | |
| # Replace when encountering three or more spaces or tabs | |
| replaced_text = re.sub(pattern_tab_space, tab_or_space_replacement, text) | |
| # When there are multiple consecutive \n (newline), \f (form feed), \r (carriage return), \v (vertical tab), replace them with 2 original newlines | |
| replaced_text = re.sub(pattern_wordwrap, "\n\n", replaced_text) | |
| return replaced_text | |
| ALLOW_SPECIAL_TOKEN = {"<|endofprompt|>", "<|endoftext|>"} | |
| def find_sublist_indices(main_list, sublist): | |
| indices = [] | |
| length = len(sublist) | |
| for i in range(len(main_list) - length + 1): | |
| if main_list[i : i + length] == sublist: | |
| indices.append((i, i + length)) | |
| return indices | |
| class TokenTextSplitter(TextSplitter): | |
| """Implementation of splitting text that looks at tokens.""" | |
| def __init__( | |
| self, | |
| encoding_name: str = "cl100k_base", | |
| model_name: Optional[str] = None, | |
| allowed_special: Union[Literal["all"], AbstractSet[str]] = ALLOW_SPECIAL_TOKEN, | |
| disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
| **kwargs: Any, | |
| ): | |
| """Create a new TextSplitter.""" | |
| super().__init__(**kwargs) | |
| try: | |
| import tiktoken | |
| except ImportError: | |
| raise ValueError( | |
| "Could not import tiktoken python package. " | |
| "This is needed in order to for TokenTextSplitter. " | |
| "Please it install it with `pip install tiktoken`." | |
| ) | |
| # create a GPT-3 encoder instance | |
| if model_name is not None: | |
| if model_name in tiktoken.model.MODEL_TO_ENCODING: | |
| enc = tiktoken.encoding_for_model(model_name) | |
| logging.info(f"Successfully initialized tokenizer for model: {model_name}") | |
| else: | |
| enc = tiktoken.get_encoding(encoding_name) | |
| logging.warning(f"Model '{model_name}' doesn't have a corresponding tokenizer, falling back to default: {encoding_name}") | |
| else: | |
| enc = tiktoken.get_encoding(encoding_name) | |
| logging.info(f"No model specified, using default tokenizer: {encoding_name}") | |
| self._tokenizer = enc | |
| self._allowed_special = allowed_special | |
| self._disallowed_special = disallowed_special | |
| def split_text(self, text: str) -> List[str]: | |
| """Split incoming text and return chunks.""" | |
| # Filter content with a large number of whitespace characters in the input text to increase the proportion of effective content within chunks | |
| text = text_filter(text) | |
| splits = [] | |
| input_ids = self._tokenizer.encode( | |
| text, | |
| allowed_special=self._allowed_special, | |
| disallowed_special=self._disallowed_special, | |
| ) | |
| start_idx = 0 | |
| while start_idx < len(input_ids): | |
| cur_idx = min(start_idx + self._chunk_size, len(input_ids)) | |
| chunk_ids = input_ids[start_idx:cur_idx] | |
| s = self._tokenizer.decode(chunk_ids).strip() | |
| if s: | |
| s = self._cut_meaningless_head_tail(s) | |
| if s: | |
| splits.append(s) | |
| start_idx += self._chunk_size - self._chunk_overlap | |
| logging.debug("finished split_text(): %s splits", len(splits)) | |
| return splits | |
| def _cut_meaningless_head_tail(self, text: str) -> str: | |
| # Only split when there are multiple newlines, as parsing of PDF/Word often contains false newlines | |
| sentences = re.split(r"\. |! |\? |。|!|?|\n+ *\n+", text) | |
| if len(sentences) < 2: | |
| return text | |
| head = sentences[0] | |
| body = ". ".join(sentences[1:-1]) | |
| tail = sentences[-1] | |
| head_len = len( | |
| self._tokenizer.encode( | |
| body, | |
| allowed_special=self._allowed_special, | |
| disallowed_special=self._disallowed_special, | |
| ) | |
| ) | |
| body_len = len( | |
| self._tokenizer.encode( | |
| body, | |
| allowed_special=self._allowed_special, | |
| disallowed_special=self._disallowed_special, | |
| ) | |
| ) | |
| tail_len = len( | |
| self._tokenizer.encode( | |
| tail, | |
| allowed_special=self._allowed_special, | |
| disallowed_special=self._disallowed_special, | |
| ) | |
| ) | |
| parts = [] | |
| # Use length to roughly estimate the impact of discarding the tail; if the impact is not significant, discard it | |
| # Rough estimate: Chinese 20 tokens, 8 characters; English 10 tokens, 30 characters | |
| if head_len >= 20 or len(head) >= 30: | |
| parts.append(head) | |
| if body_len > 0: | |
| parts.append(body) | |
| if tail_len >= 20 or len(tail) >= 30: | |
| parts.append(tail) | |
| res = "\n".join(parts) | |
| logger.info( | |
| "_cut_meaningless_tail() removes redundant sentence tails from chunks, before cut: %s characters, after cut: %s characters", | |
| len(text), | |
| len(res), | |
| ) | |
| return res | |
| def chunk_filter( | |
| chunks, filter, filtered_chunks_n=6, separator="\n", spacer="\n……\n……\n……\n" | |
| ): | |
| if len(chunks) <= filtered_chunks_n: | |
| return separator.join(chunks) | |
| return spacer.join(filter(chunks, separator, filtered_chunks_n)) | |
| def get_safe_content_turncate(content, model_name="gpt-3.5-turbo", max_tokens=3300): | |
| if model_name is not None: | |
| if model_name in tiktoken.model.MODEL_TO_ENCODING: | |
| enc = tiktoken.encoding_for_model(model_name) | |
| logging.info(f"Successfully initialized tokenizer for model: {model_name}") | |
| else: | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| logging.warning(f"Model '{model_name}' doesn't have a corresponding tokenizer, falling back to default: cl100k_base") | |
| else: | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| logging.info(f"No model specified, using default tokenizer: cl100k_base") | |
| logging.warning( | |
| "get_safe_content_turncate(): current model maximum input length is %s, current input length is %s", | |
| max_tokens, | |
| len(enc.encode(content)), | |
| ) | |
| if len(enc.encode(content)) > max_tokens: | |
| content = enc.decode(enc.encode(content)[:max_tokens]) | |
| return content | |
| class DataType(Enum): | |
| DOCUMENT = "DOCUMENT" | |
| WEBSITE = "WEBSITE" | |
| IMAGE = "IMAGE" | |
| TABLE = "TABLE" | |
| AUDIO = "AUDIO" | |
| TEXT = "TEXT" | |
| def extra_values_map(): | |
| return { | |
| "SHORT_AUDIO": "AUDIO", | |
| } | |
| def _missing_(cls, value): | |
| # Try to find the corresponding primary key value from the extra value mapping | |
| extra_map = cls.extra_values_map() | |
| if value in extra_map: | |
| value = extra_map[value] | |
| return cls.__members__.get(value) | |
| # If not found, return DOCUMENT by default | |
| logging.error("DataType._missing_(): Could not find corresponding DataType enum value: %s", value) | |
| return cls.DOCUMENT | |
| def get_urls(string): | |
| url_arr = [] | |
| if not string: | |
| return url_arr | |
| pattern = re.compile( | |
| r"(https?|ftp|file)://[-A-Za-z0-9+&@#/%?=~_|!:,.;\u4e00-\u9fa5]+[-A-Za-z0-9+&@#/%=~_|]" | |
| ) | |
| matcher = pattern.finditer(string) | |
| for match in matcher: | |
| url_arr.append(match.group()) | |
| sorted_url_arr = sorted(set(url_arr), key=len, reverse=True) | |
| return sorted_url_arr | |
| def get_random_string(s_length: int) -> str: | |
| # Generate a random string | |
| letters = string.ascii_letters + string.digits | |
| return "".join(random.choice(letters) for i in range(s_length)) | |
| def get_random_strings(n: int, s_length: int) -> List[str]: | |
| unique_strings = set() | |
| while len(unique_strings) < n: | |
| unique_strings.add(get_random_string(s_length)) | |
| return list(unique_strings) | |
| def encode_urls(text, random_string_len: int = 16): | |
| urls = get_urls(text) | |
| random_strings = get_random_strings(len(urls), random_string_len) | |
| url2string_dict = dict(zip(urls, random_strings)) | |
| string2url_dict = dict(zip(random_strings, urls)) | |
| for url, random_string in url2string_dict.items(): | |
| text = text.replace(url, random_string) | |
| return text, string2url_dict | |
| def decode_urls(text, string2url_dict): | |
| for random_string, url in string2url_dict.items(): | |
| text = text.replace(random_string, url) | |
| return text | |
| class TokenParagraphSplitter(TextSplitter): | |
| """For business data characteristics, perform some additional processing. This includes: | |
| 1. Complete fragments as independent chunks help improve information focus in each chunk. Complete fragments are mainly determined by period+newline. | |
| 2. When complete fragments are too long, split them into sentences and combine sentences into chunks that meet window size limits | |
| 3. If a sentence is too long, split it directly by token granularity | |
| """ | |
| line_break_characters = ["\n", "\f", "\r", "\v"] | |
| whitespace_characters = [" ", "\t"] | |
| sentence_terminators = [ | |
| ".", | |
| "!", | |
| "?", | |
| "。", | |
| "!", | |
| "?", | |
| "……", | |
| "...", | |
| ] + line_break_characters | |
| paired_punctuation = [ | |
| ("(", ")"), | |
| ("[", "]"), | |
| ("{", "}"), | |
| ("<", ">"), | |
| ("“", "”"), | |
| ("‘", "’"), | |
| ("《", "》"), | |
| ("【", "】"), | |
| ] | |
| intra_sentence_delimiters = [",", ",", ";", ";"] + whitespace_characters | |
| def __init__( | |
| self, | |
| encoding_name: str = "cl100k_base", | |
| allowed_special: Union[Literal["all"], AbstractSet[str]] = ALLOW_SPECIAL_TOKEN, | |
| disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
| **kwargs: Any, | |
| ): | |
| """Create a new TextSplitter.""" | |
| super().__init__(**kwargs) | |
| try: | |
| import tiktoken | |
| except ImportError: | |
| raise ValueError( | |
| "Could not import tiktoken python package. " | |
| "This is needed in order to for TokenTextSplitter. " | |
| "Please it install it with `pip install tiktoken`." | |
| ) | |
| # create a GPT-3 encoder instance | |
| self._tokenizer = tiktoken.get_encoding(encoding_name) | |
| self._allowed_special = allowed_special | |
| self._disallowed_special = disallowed_special | |
| def split_text(self, text: str) -> List[str]: | |
| chunks = [] | |
| # Clean up abnormal whitespace characters in the text, such as replacing 3 or more consecutive \n with \n\n | |
| text = text_filter(text) | |
| # Replace URLs in the text to avoid symbols like ./?/ in URLs interfering with sentence splitting | |
| text, string2url_dict = encode_urls(text) | |
| url_strings = list(string2url_dict.keys()) | |
| # Split by paragraphs according to rules | |
| paragraphs = self._split_to_paragraphs( | |
| text, min_paragraph_length=self._chunk_size // 2 | |
| ) | |
| for i, paragraph in enumerate(paragraphs): | |
| splits = self._split_to_chunks(paragraph, url_strings) | |
| logging.debug( | |
| "paragraph %s/%s %s characters: %s", | |
| i + 1, | |
| len(paragraphs), | |
| len(paragraph), | |
| paragraph, | |
| ) | |
| logging.debug( | |
| "paragraph %s/%s split into %s chunks: %s", | |
| i + 1, | |
| len(paragraphs), | |
| len(splits), | |
| splits, | |
| ) | |
| chunks.extend(splits) | |
| chunks = [decode_urls(chunk, string2url_dict) for chunk in chunks] | |
| return chunks | |
| def _split_to_chunks(self, text: str, url_strings: List[str] = []) -> List[str]: | |
| sentences = self._split_to_sentences(text, url_strings) | |
| chunks = self._merge_sentences_into_chunks( | |
| sentences, min_chunk_size=self._chunk_size // 2 | |
| ) | |
| return chunks | |
| def _split_to_paragraphs( | |
| self, text: str, min_paragraph_length: int = 0 | |
| ) -> List[str]: | |
| """Currently split the original document into paragraphs directly based on the \n[any space]\n rule.""" | |
| line_break_characters = "".join(self.line_break_characters) | |
| whitespace_characters = "".join(self.whitespace_characters) | |
| paragraphs = re.split( | |
| f"([{line_break_characters}]+[{whitespace_characters}]*[{line_break_characters}])+", | |
| text, | |
| ) | |
| if len(paragraphs) % 2 == 1: | |
| paragraphs = [""] + paragraphs | |
| paragraphs = [ | |
| (paragraphs[i], paragraphs[i + 1]) | |
| for i in range(0, len(paragraphs), 2) | |
| if (paragraphs[i] + paragraphs[i + 1]).strip() | |
| ] | |
| if not paragraphs: | |
| return [] | |
| new_paragraphs = [] | |
| cur_paragraph, cur_paragraph_len = "", 0 | |
| # merge short or broken paragraphs | |
| for sep, paragraph in paragraphs: | |
| if cur_paragraph_len >= min_paragraph_length and any( | |
| cur_paragraph.endswith(sym) for sym in self.sentence_terminators | |
| ): | |
| new_paragraphs.append(cur_paragraph.strip()) | |
| cur_paragraph, cur_paragraph_len = "", 0 | |
| cur_paragraph_len += len(self._tokenizer.encode(sep + paragraph)) | |
| cur_paragraph += sep + paragraph | |
| if cur_paragraph: | |
| new_paragraphs.append(cur_paragraph.strip()) | |
| return new_paragraphs | |
| def _split_to_sentences(self, text: str, url_strings: List[str] = []) -> List[str]: | |
| # Use capture groups to preserve sentence separators | |
| pattern = ( | |
| f"({'|'.join(re.escape(symbol) for symbol in self.sentence_terminators)})+" | |
| ) | |
| parts = re.split(pattern, text) | |
| sentences = [] | |
| # Merge by skipping steps to ensure punctuation is added to the end of the corresponding sentence | |
| if len(parts) % 2 == 1: | |
| parts.append("") | |
| sentences = ["".join(parts[i : i + 2]) for i in range(0, len(parts), 2)] | |
| sentences = [s for s in sentences if s.strip()] | |
| if not sentences: | |
| return [] | |
| # Fix fragmented sentences, mainly for special cases such as numeric indices, floating-point numbers, etc., which may be separated | |
| sentences = self.recombine_broken_sentences(sentences) | |
| # Split sentences that are too long; in the short term, split directly by character length; future optimizations could consider splitting by punctuation within sentences | |
| sentences_list = [ | |
| self._force_split_to_chunks(s, url_strings) for s in sentences | |
| ] | |
| sentences = list(chain.from_iterable(sentences_list)) | |
| return sentences | |
| def recombine_broken_sentences(self, sentences: List[str]) -> List[str]: | |
| """Fix fragmented sentences, mainly for special cases such as numeric indices, floating-point numbers, etc., which may be separated。""" | |
| if len(sentences) < 2: | |
| return sentences | |
| open_symbols_dict = { | |
| open_sym: close_sym for open_sym, close_sym in self.paired_punctuation | |
| } | |
| close_symbols_dict = { | |
| close_sym: open_sym for open_sym, close_sym in self.paired_punctuation | |
| } | |
| new_sentences = [] | |
| cur_sentences = "" | |
| unmatched_symbol = [] | |
| for sent in sentences: | |
| # If the current sentence is not empty, doesn't meet predefined merge conditions, and has no pending matching punctuation ([, (, {, etc.), then consider the sentence complete | |
| if cur_sentences.strip() and not ( | |
| self.check_merge(cur_sentences, sent) or unmatched_symbol | |
| ): | |
| new_sentences.append(cur_sentences) | |
| cur_sentences = "" | |
| for c in sent: | |
| if c in open_symbols_dict: | |
| unmatched_symbol.append(c) | |
| elif c in close_symbols_dict: | |
| if ( | |
| unmatched_symbol | |
| and unmatched_symbol[-1] == close_symbols_dict[c] | |
| ): | |
| unmatched_symbol.pop() | |
| # By default, the current sentence ends when a newline-like character appears | |
| if c in self.line_break_characters: | |
| unmatched_symbol = [] | |
| if cur_sentences.strip(): | |
| new_sentences.append(cur_sentences) | |
| cur_sentences = "" | |
| cur_sentences += c | |
| if cur_sentences: | |
| new_sentences.append(cur_sentences) | |
| return new_sentences | |
| def check_merge(self, pre_sen, cur_sen): | |
| if len(pre_sen) > 1 and len(cur_sen) > 0: | |
| # If it's a decimal point in the middle of a floating-point number | |
| if pre_sen[-1] == "." and pre_sen[-2].isdigit() and cur_sen[0].isdigit(): | |
| return True | |
| # If it's a numeric index at the beginning of a sentence, such as 1. *****\n2. ***** | |
| if ( | |
| pre_sen[-1] == "." | |
| and pre_sen[-2].isdigit() | |
| and cur_sen[0] not in self.line_break_characters | |
| ): | |
| return True | |
| # In markdown format, ! followed by [ may be an image link | |
| if ( | |
| pre_sen[-1] == "!" | |
| and pre_sen[-2] in self.line_break_characters | |
| and cur_sen[0] == "[" | |
| ): | |
| return True | |
| return False | |
| def _merge_sentences_into_chunks( | |
| self, sentences: List[str], min_chunk_size: int = 200 | |
| ) -> List[str]: | |
| """Assemble into chunks according to chunk_size and overlap. Note that external guarantees ensure that the length of a single sentence does not exceed chunk_size""" | |
| if not sentences: | |
| return [] | |
| n_tokens = [ | |
| len( | |
| self._tokenizer.encode( | |
| sentence, | |
| allowed_special=self._allowed_special, | |
| disallowed_special=self._disallowed_special, | |
| ) | |
| ) | |
| for sentence in sentences | |
| ] | |
| chunks = [] | |
| start_idx = 0 | |
| end_idx = start_idx + 1 | |
| cur_token_num = n_tokens[start_idx] | |
| while start_idx < len(n_tokens): | |
| # Tail reaches the end point, | |
| if end_idx >= len(n_tokens): | |
| chunk = "".join(sentences[start_idx:end_idx]) | |
| logging.debug( | |
| "sentences[%s:%s] merged into chunk, current num_tokens: %s(%s)", | |
| start_idx, | |
| end_idx, | |
| sum(n_tokens[start_idx:end_idx]), | |
| cur_token_num, | |
| ) | |
| chunks.append(chunk) | |
| break | |
| else: | |
| # +The next sentence will not exceed chunk_size, continue to include new sentences | |
| if cur_token_num + n_tokens[end_idx] <= self._chunk_size: | |
| cur_token_num += n_tokens[end_idx] | |
| end_idx += 1 | |
| # +The next sentence will exceed chunk_size, assemble the current chunk and move to the next chunk | |
| else: | |
| chunk = "".join(sentences[start_idx:end_idx]) | |
| logging.debug( | |
| "sentences[%s:%s] merged into chunk, current num_tokens: %s(%s)", | |
| start_idx, | |
| end_idx, | |
| sum(n_tokens[start_idx:end_idx]), | |
| cur_token_num, | |
| ) | |
| chunks.append(chunk) | |
| # Next chunk: idx moves at least one position forward, start_idx allows overlap | |
| end_idx = end_idx + 1 | |
| # Find a new starting point for start_idx that doesn't exceed the overlap | |
| new_start_idx = end_idx - 1 | |
| overlap = 0 | |
| new_cur_token_num = n_tokens[new_start_idx] | |
| while new_start_idx > start_idx + 1: | |
| if ( | |
| overlap + n_tokens[new_start_idx - 1] >= self._chunk_overlap | |
| or new_cur_token_num >= self._chunk_size | |
| ): | |
| break | |
| new_start_idx -= 1 | |
| overlap += n_tokens[new_start_idx] | |
| new_cur_token_num += n_tokens[new_start_idx] | |
| start_idx = new_start_idx | |
| cur_token_num = new_cur_token_num | |
| if len(chunks) > 1 and len(chunks[-1]) < min_chunk_size: | |
| logging.warning( | |
| "The last chunk length %s is less than %s, merge with the previous chunk", | |
| len(chunks[-1]), | |
| min_chunk_size, | |
| ) | |
| last_chunk = chunks.pop() | |
| chunks[-1] += last_chunk | |
| chunks = [chunk for chunk in chunks if chunk.strip()] | |
| return chunks | |
| def _force_split_to_chunks( | |
| self, text: str, url_strings: List[str] = [] | |
| ) -> List[str]: | |
| # TODO: In the future, consider adding forced splitting logic, such as: if a single sentence is too long, split by punctuation within the sentence, trying to preserve links and other data that require complete information | |
| """If a single sentence is too long, it can only be forcibly split, split by punctuation within the sentence, trying to preserve links and other data that require complete information""" | |
| splits = [] | |
| input_ids = self._tokenizer.encode( | |
| text, | |
| allowed_special=self._allowed_special, | |
| disallowed_special=self._disallowed_special, | |
| ) | |
| if len(input_ids) < self._chunk_size: | |
| return [text] | |
| if text[-1] not in self.sentence_terminators + self.intra_sentence_delimiters: | |
| text += self.sentence_terminators[0] | |
| cur_sentence, cur_sentence_len = "", 0 | |
| sub_sentence = "" | |
| for c in text: | |
| sub_sentence += c | |
| if c in self.intra_sentence_delimiters + self.sentence_terminators: | |
| sub_sentence_len = len(self._tokenizer.encode(sub_sentence)) | |
| if ( | |
| cur_sentence_len + sub_sentence_len | |
| > self._chunk_size - self._chunk_overlap | |
| ): | |
| if cur_sentence: | |
| splits.append(cur_sentence) | |
| cur_sentence, cur_sentence_len = sub_sentence, sub_sentence_len | |
| else: | |
| # This indicates that sub_sentence is too long, at this point directly follow the forced splitting logic based on tokens | |
| _splits = self.safe_split(sub_sentence, url_strings) | |
| splits.extend(_splits[:-1]) | |
| cur_sentence, cur_sentence_len = _splits[-1], len(_splits[-1]) | |
| else: | |
| cur_sentence += sub_sentence | |
| cur_sentence_len += sub_sentence_len | |
| sub_sentence = "" | |
| if cur_sentence: | |
| splits.append(cur_sentence) | |
| return splits | |
| def safe_split(self, sub_sentence: str, url_strings: List[str] = []) -> List[str]: | |
| sub_sentence_tokens = self._tokenizer.encode(sub_sentence) | |
| # Find the position intervals of all strings in url_strings | |
| url_string_intervals = [] | |
| for url_string in url_strings: | |
| encoded_url_string = self._tokenizer.encode(url_string) | |
| # Use find_sublist_indices to find all position intervals | |
| url_string_intervals.extend( | |
| find_sublist_indices(sub_sentence_tokens, encoded_url_string) | |
| ) | |
| _splits = [] | |
| i = 0 | |
| while i < len(sub_sentence_tokens): | |
| if i + self._chunk_size >= len(sub_sentence_tokens): | |
| slice_end = len(sub_sentence_tokens) | |
| else: | |
| slice_end = i + self._chunk_size - self._chunk_overlap | |
| # Determine if the split interval overlaps with any important string intervals | |
| for s_begin, s_end in url_string_intervals: | |
| if i < s_end <= slice_end or i < s_begin < slice_end: | |
| slice_end = max(slice_end, s_end) | |
| # Split and record the current chunk | |
| _splits.append(self._tokenizer.decode(sub_sentence_tokens[i:slice_end])) | |
| # Move to the starting point of the next chunk | |
| i = slice_end | |
| return _splits | |
| def get_summarize_title_keywords(responses): | |
| # Clean LLM generated content to obtain summarized text titles, abstracts, and keywords | |
| pattern = re.compile(r"\{.*(\}|\]|\,)", re.DOTALL) | |
| gen_texts = [each.choices[0].message.content for each in responses] | |
| logging.info("gen_texts: %s", gen_texts) | |
| results = [] | |
| for res in gen_texts: | |
| try: | |
| # Match against the pattern | |
| matches = list(pattern.finditer(res)) | |
| if not matches: | |
| results.append(("", "", [])) | |
| else: | |
| answer = matches[0].group(0) | |
| content = answer.strip().strip(",") | |
| content += "]" * (content.count("[") - content.count("]")) | |
| content += "}" * (content.count("{") - content.count("}")) | |
| d = json.loads(res) | |
| results.append( | |
| (d.get("title", ""), d.get("summary", ""), d.get("keywords", [])) | |
| ) | |
| except json.JSONDecodeError: | |
| logging.warning("JSON parsing failed, returning empty list") | |
| results.append(("", "", [])) | |
| return results | |