Spaces:
Runtime error
Runtime error
"""Functionality for splitting text.""" | |
from __future__ import annotations | |
import copy | |
import logging | |
from abc import ABC, abstractmethod | |
from typing import ( | |
AbstractSet, | |
Any, | |
Callable, | |
Collection, | |
Iterable, | |
List, | |
Literal, | |
Optional, | |
Union, | |
) | |
from langchain.docstore.document import Document | |
logger = logging.getLogger() | |
class TextSplitter(ABC): | |
"""Interface for splitting text into chunks.""" | |
def __init__( | |
self, | |
chunk_size: int = 4000, | |
chunk_overlap: int = 200, | |
length_function: Callable[[str], int] = len, | |
): | |
"""Create a new TextSplitter.""" | |
if chunk_overlap > chunk_size: | |
raise ValueError( | |
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " | |
f"({chunk_size}), should be smaller." | |
) | |
self._chunk_size = chunk_size | |
self._chunk_overlap = chunk_overlap | |
self._length_function = length_function | |
def split_text(self, text: str) -> List[str]: | |
"""Split text into multiple components.""" | |
def create_documents( | |
self, texts: List[str], metadatas: Optional[List[dict]] = None | |
) -> List[Document]: | |
"""Create documents from a list of texts.""" | |
_metadatas = metadatas or [{}] * len(texts) | |
documents = [] | |
for i, text in enumerate(texts): | |
for chunk in self.split_text(text): | |
new_doc = Document( | |
page_content=chunk, metadata=copy.deepcopy(_metadatas[i]) | |
) | |
documents.append(new_doc) | |
return documents | |
def split_documents(self, documents: List[Document]) -> List[Document]: | |
"""Split documents.""" | |
texts = [doc.page_content for doc in documents] | |
metadatas = [doc.metadata for doc in documents] | |
return self.create_documents(texts, metadatas) | |
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: | |
text = separator.join(docs) | |
text = text.strip() | |
if text == "": | |
return None | |
else: | |
return text | |
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: | |
# We now want to combine these smaller pieces into medium size | |
# chunks to send to the LLM. | |
separator_len = self._length_function(separator) | |
docs = [] | |
current_doc: List[str] = [] | |
total = 0 | |
for d in splits: | |
_len = self._length_function(d) | |
if ( | |
total + _len + (separator_len if len(current_doc) > 0 else 0) | |
> self._chunk_size | |
): | |
if total > self._chunk_size: | |
logger.warning( | |
f"Created a chunk of size {total}, " | |
f"which is longer than the specified {self._chunk_size}" | |
) | |
if len(current_doc) > 0: | |
doc = self._join_docs(current_doc, separator) | |
if doc is not None: | |
docs.append(doc) | |
# Keep on popping if: | |
# - we have a larger chunk than in the chunk overlap | |
# - or if we still have any chunks and the length is long | |
while total > self._chunk_overlap or ( | |
total + _len + (separator_len if len(current_doc) > 0 else 0) | |
> self._chunk_size | |
and total > 0 | |
): | |
total -= self._length_function(current_doc[0]) + ( | |
separator_len if len(current_doc) > 1 else 0 | |
) | |
current_doc = current_doc[1:] | |
current_doc.append(d) | |
total += _len + (separator_len if len(current_doc) > 1 else 0) | |
doc = self._join_docs(current_doc, separator) | |
if doc is not None: | |
docs.append(doc) | |
return docs | |
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: | |
"""Text splitter that uses HuggingFace tokenizer to count length.""" | |
try: | |
from transformers import PreTrainedTokenizerBase | |
if not isinstance(tokenizer, PreTrainedTokenizerBase): | |
raise ValueError( | |
"Tokenizer received was not an instance of PreTrainedTokenizerBase" | |
) | |
def _huggingface_tokenizer_length(text: str) -> int: | |
return len(tokenizer.encode(text)) | |
except ImportError: | |
raise ValueError( | |
"Could not import transformers python package. " | |
"Please it install it with `pip install transformers`." | |
) | |
return cls(length_function=_huggingface_tokenizer_length, **kwargs) | |
def from_tiktoken_encoder( | |
cls, | |
encoding_name: str = "gpt2", | |
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), | |
disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
**kwargs: Any, | |
) -> TextSplitter: | |
"""Text splitter that uses tiktoken encoder to count length.""" | |
try: | |
import tiktoken | |
except ImportError: | |
raise ValueError( | |
"Could not import tiktoken python package. " | |
"This is needed in order to calculate max_tokens_for_prompt. " | |
"Please it install it with `pip install tiktoken`." | |
) | |
# create a GPT-3 encoder instance | |
enc = tiktoken.get_encoding(encoding_name) | |
def _tiktoken_encoder(text: str, **kwargs: Any) -> int: | |
return len( | |
enc.encode( | |
text, | |
allowed_special=allowed_special, | |
disallowed_special=disallowed_special, | |
**kwargs, | |
) | |
) | |
return cls(length_function=_tiktoken_encoder, **kwargs) | |
class CharacterTextSplitter(TextSplitter): | |
"""Implementation of splitting text that looks at characters.""" | |
def __init__(self, separator: str = "\n\n", **kwargs: Any): | |
"""Create a new TextSplitter.""" | |
super().__init__(**kwargs) | |
self._separator = separator | |
def split_text(self, text: str) -> List[str]: | |
"""Split incoming text and return chunks.""" | |
# First we naively split the large input into a bunch of smaller ones. | |
if self._separator: | |
splits = text.split(self._separator) | |
else: | |
splits = list(text) | |
return self._merge_splits(splits, self._separator) | |
class TokenTextSplitter(TextSplitter): | |
"""Implementation of splitting text that looks at tokens.""" | |
def __init__( | |
self, | |
encoding_name: str = "gpt2", | |
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), | |
disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
**kwargs: Any, | |
): | |
"""Create a new TextSplitter.""" | |
super().__init__(**kwargs) | |
try: | |
import tiktoken | |
except ImportError: | |
raise ValueError( | |
"Could not import tiktoken python package. " | |
"This is needed in order to for TokenTextSplitter. " | |
"Please it install it with `pip install tiktoken`." | |
) | |
# create a GPT-3 encoder instance | |
self._tokenizer = tiktoken.get_encoding(encoding_name) | |
self._allowed_special = allowed_special | |
self._disallowed_special = disallowed_special | |
def split_text(self, text: str) -> List[str]: | |
"""Split incoming text and return chunks.""" | |
splits = [] | |
input_ids = self._tokenizer.encode( | |
text, | |
allowed_special=self._allowed_special, | |
disallowed_special=self._disallowed_special, | |
) | |
start_idx = 0 | |
cur_idx = min(start_idx + self._chunk_size, len(input_ids)) | |
chunk_ids = input_ids[start_idx:cur_idx] | |
while start_idx < len(input_ids): | |
splits.append(self._tokenizer.decode(chunk_ids)) | |
start_idx += self._chunk_size - self._chunk_overlap | |
cur_idx = min(start_idx + self._chunk_size, len(input_ids)) | |
chunk_ids = input_ids[start_idx:cur_idx] | |
return splits | |
class RecursiveCharacterTextSplitter(TextSplitter): | |
"""Implementation of splitting text that looks at characters. | |
Recursively tries to split by different characters to find one | |
that works. | |
""" | |
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any): | |
"""Create a new TextSplitter.""" | |
super().__init__(**kwargs) | |
self._separators = separators or ["\n\n", "\n", " ", ""] | |
def split_text(self, text: str) -> List[str]: | |
"""Split incoming text and return chunks.""" | |
final_chunks = [] | |
# Get appropriate separator to use | |
separator = self._separators[-1] | |
for _s in self._separators: | |
if _s == "": | |
separator = _s | |
break | |
if _s in text: | |
separator = _s | |
break | |
# Now that we have the separator, split the text | |
if separator: | |
splits = text.split(separator) | |
else: | |
splits = list(text) | |
# Now go merging things, recursively splitting longer texts. | |
_good_splits = [] | |
for s in splits: | |
if self._length_function(s) < self._chunk_size: | |
_good_splits.append(s) | |
else: | |
if _good_splits: | |
merged_text = self._merge_splits(_good_splits, separator) | |
final_chunks.extend(merged_text) | |
_good_splits = [] | |
other_info = self.split_text(s) | |
final_chunks.extend(other_info) | |
if _good_splits: | |
merged_text = self._merge_splits(_good_splits, separator) | |
final_chunks.extend(merged_text) | |
return final_chunks | |
class NLTKTextSplitter(TextSplitter): | |
"""Implementation of splitting text that looks at sentences using NLTK.""" | |
def __init__(self, separator: str = "\n\n", **kwargs: Any): | |
"""Initialize the NLTK splitter.""" | |
super().__init__(**kwargs) | |
try: | |
from nltk.tokenize import sent_tokenize | |
self._tokenizer = sent_tokenize | |
except ImportError: | |
raise ImportError( | |
"NLTK is not installed, please install it with `pip install nltk`." | |
) | |
self._separator = separator | |
def split_text(self, text: str) -> List[str]: | |
"""Split incoming text and return chunks.""" | |
# First we naively split the large input into a bunch of smaller ones. | |
splits = self._tokenizer(text) | |
return self._merge_splits(splits, self._separator) | |
class SpacyTextSplitter(TextSplitter): | |
"""Implementation of splitting text that looks at sentences using Spacy.""" | |
def __init__( | |
self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any | |
): | |
"""Initialize the spacy text splitter.""" | |
super().__init__(**kwargs) | |
try: | |
import spacy | |
except ImportError: | |
raise ImportError( | |
"Spacy is not installed, please install it with `pip install spacy`." | |
) | |
self._tokenizer = spacy.load(pipeline) | |
self._separator = separator | |
def split_text(self, text: str) -> List[str]: | |
"""Split incoming text and return chunks.""" | |
splits = (str(s) for s in self._tokenizer(text).sents) | |
return self._merge_splits(splits, self._separator) | |
class MarkdownTextSplitter(RecursiveCharacterTextSplitter): | |
"""Attempts to split the text along Markdown-formatted headings.""" | |
def __init__(self, **kwargs: Any): | |
"""Initialize a MarkdownTextSplitter.""" | |
separators = [ | |
# First, try to split along Markdown headings (starting with level 2) | |
"\n## ", | |
"\n### ", | |
"\n#### ", | |
"\n##### ", | |
"\n###### ", | |
# Note the alternative syntax for headings (below) is not handled here | |
# Heading level 2 | |
# --------------- | |
# End of code block | |
"```\n\n", | |
# Horizontal lines | |
"\n\n***\n\n", | |
"\n\n---\n\n", | |
"\n\n___\n\n", | |
# Note that this splitter doesn't handle horizontal lines defined | |
# by *three or more* of ***, ---, or ___, but this is not handled | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
super().__init__(separators=separators, **kwargs) | |
class LatexTextSplitter(RecursiveCharacterTextSplitter): | |
"""Attempts to split the text along Latex-formatted layout elements.""" | |
def __init__(self, **kwargs: Any): | |
"""Initialize a LatexTextSplitter.""" | |
separators = [ | |
# First, try to split along Latex sections | |
"\n\\chapter{", | |
"\n\\section{", | |
"\n\\subsection{", | |
"\n\\subsubsection{", | |
# Now split by environments | |
"\n\\begin{enumerate}", | |
"\n\\begin{itemize}", | |
"\n\\begin{description}", | |
"\n\\begin{list}", | |
"\n\\begin{quote}", | |
"\n\\begin{quotation}", | |
"\n\\begin{verse}", | |
"\n\\begin{verbatim}", | |
## Now split by math environments | |
"\n\\begin{align}", | |
"$$", | |
"$", | |
# Now split by the normal type of lines | |
" ", | |
"", | |
] | |
super().__init__(separators=separators, **kwargs) | |
class PythonCodeTextSplitter(RecursiveCharacterTextSplitter): | |
"""Attempts to split the text along Python syntax.""" | |
def __init__(self, **kwargs: Any): | |
"""Initialize a MarkdownTextSplitter.""" | |
separators = [ | |
# First, try to split along class definitions | |
"\nclass ", | |
"\ndef ", | |
"\n\tdef ", | |
# Now split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
super().__init__(separators=separators, **kwargs) | |