webui / langchain /text_splitter.py
zhangyi617's picture
Upload folder using huggingface_hub
129cd69
"""**Text Splitters** are classes for splitting text.
**Class hierarchy:**
.. code-block::
BaseDocumentTransformer --> TextSplitter --> <name>TextSplitter # Example: CharacterTextSplitter
RecursiveCharacterTextSplitter --> <name>TextSplitter
Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter.
**Main helpers:**
.. code-block::
Document, Tokenizer, Language, LineType, HeaderType
""" # noqa: E501
from __future__ import annotations
import asyncio
import copy
import logging
import pathlib
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from functools import partial
from io import BytesIO, StringIO
from typing import (
AbstractSet,
Any,
Callable,
Collection,
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
cast,
)
import requests
from langchain_core.documents import BaseDocumentTransformer
from langchain.docstore.document import Document
logger = logging.getLogger(__name__)
TS = TypeVar("TS", bound="TextSplitter")
def _make_spacy_pipeline_for_splitting(pipeline: str) -> Any: # avoid importing spacy
try:
import spacy
except ImportError:
raise ImportError(
"Spacy is not installed, please install it with `pip install spacy`."
)
if pipeline == "sentencizer":
from spacy.lang.en import English
sentencizer = English()
sentencizer.add_pipe("sentencizer")
else:
sentencizer = spacy.load(pipeline, exclude=["ner", "tagger"])
return sentencizer
def _split_text_with_regex(
text: str, separator: str, keep_separator: bool
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
class TextSplitter(BaseDocumentTransformer, ABC):
"""Interface for splitting text into chunks."""
def __init__(
self,
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
add_start_index: bool = False,
strip_whitespace: bool = True,
) -> None:
"""Create a new TextSplitter.
Args:
chunk_size: Maximum size of chunks to return
chunk_overlap: Overlap in characters between chunks
length_function: Function that measures the length of given chunks
keep_separator: Whether to keep the separator in the chunks
add_start_index: If `True`, includes chunk's start index in metadata
strip_whitespace: If `True`, strips whitespace from the start and end of
every document
"""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
)
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
self._keep_separator = keep_separator
self._add_start_index = add_start_index
self._strip_whitespace = strip_whitespace
@abstractmethod
def split_text(self, text: str) -> List[str]:
"""Split text into multiple components."""
def create_documents(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
index = -1
for chunk in self.split_text(text):
metadata = copy.deepcopy(_metadatas[i])
if self._add_start_index:
index = text.find(chunk, index + 1)
metadata["start_index"] = index
new_doc = Document(page_content=chunk, metadata=metadata)
documents.append(new_doc)
return documents
def split_documents(self, documents: Iterable[Document]) -> List[Document]:
"""Split documents."""
texts, metadatas = [], []
for doc in documents:
texts.append(doc.page_content)
metadatas.append(doc.metadata)
return self.create_documents(texts, metadatas=metadatas)
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
text = separator.join(docs)
if self._strip_whitespace:
text = text.strip()
if text == "":
return None
else:
return text
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)
docs = []
current_doc: List[str] = []
total = 0
for d in splits:
_len = self._length_function(d)
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
):
if total > self._chunk_size:
logger.warning(
f"Created a chunk of size {total}, "
f"which is longer than the specified {self._chunk_size}"
)
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self._chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
and total > 0
):
total -= self._length_function(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
return docs
@classmethod
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
"""Text splitter that uses HuggingFace tokenizer to count length."""
try:
from transformers import PreTrainedTokenizerBase
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError(
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
)
def _huggingface_tokenizer_length(text: str) -> int:
return len(tokenizer.encode(text))
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"Please install it with `pip install transformers`."
)
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
@classmethod
def from_tiktoken_encoder(
cls: Type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> TS:
"""Text splitter that uses tiktoken encoder to count length."""
try:
import tiktoken
except ImportError:
raise ImportError(
"Could not import tiktoken python package. "
"This is needed in order to calculate max_tokens_for_prompt. "
"Please install it with `pip install tiktoken`."
)
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
else:
enc = tiktoken.get_encoding(encoding_name)
def _tiktoken_encoder(text: str) -> int:
return len(
enc.encode(
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_tiktoken_encoder, **kwargs)
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform sequence of documents by splitting them."""
return self.split_documents(list(documents))
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a sequence of documents by splitting them."""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents
)
class CharacterTextSplitter(TextSplitter):
"""Splitting text that looks at characters."""
def __init__(
self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any
) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separator = separator
self._is_separator_regex = is_separator_regex
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
separator = (
self._separator if self._is_separator_regex else re.escape(self._separator)
)
splits = _split_text_with_regex(text, separator, self._keep_separator)
_separator = "" if self._keep_separator else self._separator
return self._merge_splits(splits, _separator)
class LineType(TypedDict):
"""Line type as typed dict."""
metadata: Dict[str, str]
content: str
class HeaderType(TypedDict):
"""Header type as typed dict."""
level: int
name: str
data: str
class MarkdownHeaderTextSplitter:
"""Splitting markdown files based on specified headers."""
def __init__(
self, headers_to_split_on: List[Tuple[str, str]], return_each_line: bool = False
):
"""Create a new MarkdownHeaderTextSplitter.
Args:
headers_to_split_on: Headers we want to track
return_each_line: Return each line w/ associated headers
"""
# Output line-by-line or aggregated into chunks w/ common headers
self.return_each_line = return_each_line
# Given the headers we want to split on,
# (e.g., "#, ##, etc") order by length
self.headers_to_split_on = sorted(
headers_to_split_on, key=lambda split: len(split[0]), reverse=True
)
def aggregate_lines_to_chunks(self, lines: List[LineType]) -> List[Document]:
"""Combine lines with common metadata into chunks
Args:
lines: Line of text / associated header metadata
"""
aggregated_chunks: List[LineType] = []
for line in lines:
if (
aggregated_chunks
and aggregated_chunks[-1]["metadata"] == line["metadata"]
):
# If the last line in the aggregated list
# has the same metadata as the current line,
# append the current content to the last lines's content
aggregated_chunks[-1]["content"] += " \n" + line["content"]
else:
# Otherwise, append the current line to the aggregated list
aggregated_chunks.append(line)
return [
Document(page_content=chunk["content"], metadata=chunk["metadata"])
for chunk in aggregated_chunks
]
def split_text(self, text: str) -> List[Document]:
"""Split markdown file
Args:
text: Markdown file"""
# Split the input text by newline character ("\n").
lines = text.split("\n")
# Final output
lines_with_metadata: List[LineType] = []
# Content and metadata of the chunk currently being processed
current_content: List[str] = []
current_metadata: Dict[str, str] = {}
# Keep track of the nested header structure
# header_stack: List[Dict[str, Union[int, str]]] = []
header_stack: List[HeaderType] = []
initial_metadata: Dict[str, str] = {}
in_code_block = False
for line in lines:
stripped_line = line.strip()
if stripped_line.startswith("```"):
# code block in one row
if stripped_line.count("```") >= 2:
in_code_block = False
else:
in_code_block = not in_code_block
if in_code_block:
current_content.append(stripped_line)
continue
# Check each line against each of the header types (e.g., #, ##)
for sep, name in self.headers_to_split_on:
# Check if line starts with a header that we intend to split on
if stripped_line.startswith(sep) and (
# Header with no text OR header is followed by space
# Both are valid conditions that sep is being used a header
len(stripped_line) == len(sep) or stripped_line[len(sep)] == " "
):
# Ensure we are tracking the header as metadata
if name is not None:
# Get the current header level
current_header_level = sep.count("#")
# Pop out headers of lower or same level from the stack
while (
header_stack
and header_stack[-1]["level"] >= current_header_level
):
# We have encountered a new header
# at the same or higher level
popped_header = header_stack.pop()
# Clear the metadata for the
# popped header in initial_metadata
if popped_header["name"] in initial_metadata:
initial_metadata.pop(popped_header["name"])
# Push the current header to the stack
header: HeaderType = {
"level": current_header_level,
"name": name,
"data": stripped_line[len(sep) :].strip(),
}
header_stack.append(header)
# Update initial_metadata with the current header
initial_metadata[name] = header["data"]
# Add the previous line to the lines_with_metadata
# only if current_content is not empty
if current_content:
lines_with_metadata.append(
{
"content": "\n".join(current_content),
"metadata": current_metadata.copy(),
}
)
current_content.clear()
break
else:
if stripped_line:
current_content.append(stripped_line)
elif current_content:
lines_with_metadata.append(
{
"content": "\n".join(current_content),
"metadata": current_metadata.copy(),
}
)
current_content.clear()
current_metadata = initial_metadata.copy()
if current_content:
lines_with_metadata.append(
{"content": "\n".join(current_content), "metadata": current_metadata}
)
# lines_with_metadata has each line with associated header metadata
# aggregate these into chunks based on common metadata
if not self.return_each_line:
return self.aggregate_lines_to_chunks(lines_with_metadata)
else:
return [
Document(page_content=chunk["content"], metadata=chunk["metadata"])
for chunk in lines_with_metadata
]
class ElementType(TypedDict):
"""Element type as typed dict."""
url: str
xpath: str
content: str
metadata: Dict[str, str]
class HTMLHeaderTextSplitter:
"""
Splitting HTML files based on specified headers.
Requires lxml package.
"""
def __init__(
self,
headers_to_split_on: List[Tuple[str, str]],
return_each_element: bool = False,
):
"""Create a new HTMLHeaderTextSplitter.
Args:
headers_to_split_on: list of tuples of headers we want to track mapped to
(arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4,
h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2)].
return_each_element: Return each element w/ associated headers.
"""
# Output element-by-element or aggregated into chunks w/ common headers
self.return_each_element = return_each_element
self.headers_to_split_on = sorted(headers_to_split_on)
def aggregate_elements_to_chunks(
self, elements: List[ElementType]
) -> List[Document]:
"""Combine elements with common metadata into chunks
Args:
elements: HTML element content with associated identifying info and metadata
"""
aggregated_chunks: List[ElementType] = []
for element in elements:
if (
aggregated_chunks
and aggregated_chunks[-1]["metadata"] == element["metadata"]
):
# If the last element in the aggregated list
# has the same metadata as the current element,
# append the current content to the last element's content
aggregated_chunks[-1]["content"] += " \n" + element["content"]
else:
# Otherwise, append the current element to the aggregated list
aggregated_chunks.append(element)
return [
Document(page_content=chunk["content"], metadata=chunk["metadata"])
for chunk in aggregated_chunks
]
def split_text_from_url(self, url: str) -> List[Document]:
"""Split HTML from web URL
Args:
url: web URL
"""
r = requests.get(url)
return self.split_text_from_file(BytesIO(r.content))
def split_text(self, text: str) -> List[Document]:
"""Split HTML text string
Args:
text: HTML text
"""
return self.split_text_from_file(StringIO(text))
def split_text_from_file(self, file: Any) -> List[Document]:
"""Split HTML file
Args:
file: HTML file
"""
try:
from lxml import etree
except ImportError as e:
raise ImportError(
"Unable to import lxml, please install with `pip install lxml`."
) from e
# use lxml library to parse html document and return xml ElementTree
parser = etree.HTMLParser()
tree = etree.parse(file, parser)
# document transformation for "structure-aware" chunking is handled with xsl.
# see comments in html_chunks_with_headers.xslt for more detailed information.
xslt_path = (
pathlib.Path(__file__).parent
/ "document_transformers/xsl/html_chunks_with_headers.xslt"
)
xslt_tree = etree.parse(xslt_path)
transform = etree.XSLT(xslt_tree)
result = transform(tree)
result_dom = etree.fromstring(str(result))
# create filter and mapping for header metadata
header_filter = [header[0] for header in self.headers_to_split_on]
header_mapping = dict(self.headers_to_split_on)
# map xhtml namespace prefix
ns_map = {"h": "http://www.w3.org/1999/xhtml"}
# build list of elements from DOM
elements = []
for element in result_dom.findall("*//*", ns_map):
if element.findall("*[@class='headers']") or element.findall(
"*[@class='chunk']"
):
elements.append(
ElementType(
url=file,
xpath="".join(
[
node.text
for node in element.findall("*[@class='xpath']", ns_map)
]
),
content="".join(
[
node.text
for node in element.findall("*[@class='chunk']", ns_map)
]
),
metadata={
# Add text of specified headers to metadata using header
# mapping.
header_mapping[node.tag]: node.text
for node in filter(
lambda x: x.tag in header_filter,
element.findall("*[@class='headers']/*", ns_map),
)
},
)
)
if not self.return_each_element:
return self.aggregate_elements_to_chunks(elements)
else:
return [
Document(page_content=chunk["content"], metadata=chunk["metadata"])
for chunk in elements
]
# should be in newer Python versions (3.10+)
# @dataclass(frozen=True, kw_only=True, slots=True)
@dataclass(frozen=True)
class Tokenizer:
"""Tokenizer data class."""
chunk_overlap: int
"""Overlap in tokens between chunks"""
tokens_per_chunk: int
"""Maximum number of tokens per chunk"""
decode: Callable[[List[int]], str]
""" Function to decode a list of token ids to a string"""
encode: Callable[[str], List[int]]
""" Function to encode a string to a list of token ids"""
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
"""Split incoming text and return chunks using tokenizer."""
splits: List[str] = []
input_ids = tokenizer.encode(text)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
splits.append(tokenizer.decode(chunk_ids))
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return splits
class TokenTextSplitter(TextSplitter):
"""Splitting text to tokens using model tokenizer."""
def __init__(
self,
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs)
try:
import tiktoken
except ImportError:
raise ImportError(
"Could not import tiktoken python package. "
"This is needed in order to for TokenTextSplitter. "
"Please install it with `pip install tiktoken`."
)
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
def split_text(self, text: str) -> List[str]:
def _encode(_text: str) -> List[int]:
return self._tokenizer.encode(
_text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
tokenizer = Tokenizer(
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self._chunk_size,
decode=self._tokenizer.decode,
encode=_encode,
)
return split_text_on_tokens(text=text, tokenizer=tokenizer)
class SentenceTransformersTokenTextSplitter(TextSplitter):
"""Splitting text to tokens using sentence model tokenizer."""
def __init__(
self,
chunk_overlap: int = 50,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
tokens_per_chunk: Optional[int] = None,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs, chunk_overlap=chunk_overlap)
try:
from sentence_transformers import SentenceTransformer
except ImportError:
raise ImportError(
"Could not import sentence_transformer python package. "
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
"Please install it with `pip install sentence-transformers`."
)
self.model_name = model_name
self._model = SentenceTransformer(self.model_name)
self.tokenizer = self._model.tokenizer
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
def _initialize_chunk_configuration(
self, *, tokens_per_chunk: Optional[int]
) -> None:
self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
if tokens_per_chunk is None:
self.tokens_per_chunk = self.maximum_tokens_per_chunk
else:
self.tokens_per_chunk = tokens_per_chunk
if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
raise ValueError(
f"The token limit of the models '{self.model_name}'"
f" is: {self.maximum_tokens_per_chunk}."
f" Argument tokens_per_chunk={self.tokens_per_chunk}"
f" > maximum token limit."
)
def split_text(self, text: str) -> List[str]:
def encode_strip_start_and_stop_token_ids(text: str) -> List[int]:
return self._encode(text)[1:-1]
tokenizer = Tokenizer(
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self.tokens_per_chunk,
decode=self.tokenizer.decode,
encode=encode_strip_start_and_stop_token_ids,
)
return split_text_on_tokens(text=text, tokenizer=tokenizer)
def count_tokens(self, *, text: str) -> int:
return len(self._encode(text))
_max_length_equal_32_bit_integer: int = 2**32
def _encode(self, text: str) -> List[int]:
token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
text,
max_length=self._max_length_equal_32_bit_integer,
truncation="do_not_truncate",
)
return token_ids_with_start_and_end_token_ids
class Language(str, Enum):
"""Enum of the programming languages."""
CPP = "cpp"
GO = "go"
JAVA = "java"
KOTLIN = "kotlin"
JS = "js"
TS = "ts"
PHP = "php"
PROTO = "proto"
PYTHON = "python"
RST = "rst"
RUBY = "ruby"
RUST = "rust"
SCALA = "scala"
SWIFT = "swift"
MARKDOWN = "markdown"
LATEX = "latex"
HTML = "html"
SOL = "sol"
CSHARP = "csharp"
COBOL = "cobol"
class RecursiveCharacterTextSplitter(TextSplitter):
"""Splitting text by recursively look at characters.
Recursively tries to split by different characters to find one
that works.
"""
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = False,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]
self._is_separator_regex = is_separator_regex
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1 :]
break
_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex(text, _separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return final_chunks
def split_text(self, text: str) -> List[str]:
return self._split_text(text, self._separators)
@classmethod
def from_language(
cls, language: Language, **kwargs: Any
) -> RecursiveCharacterTextSplitter:
separators = cls.get_separators_for_language(language)
return cls(separators=separators, is_separator_regex=True, **kwargs)
@staticmethod
def get_separators_for_language(language: Language) -> List[str]:
if language == Language.CPP:
return [
# Split along class definitions
"\nclass ",
# Split along function definitions
"\nvoid ",
"\nint ",
"\nfloat ",
"\ndouble ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.GO:
return [
# Split along function definitions
"\nfunc ",
"\nvar ",
"\nconst ",
"\ntype ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.JAVA:
return [
# Split along class definitions
"\nclass ",
# Split along method definitions
"\npublic ",
"\nprotected ",
"\nprivate ",
"\nstatic ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.KOTLIN:
return [
# Split along class definitions
"\nclass ",
# Split along method definitions
"\npublic ",
"\nprotected ",
"\nprivate ",
"\ninternal ",
"\ncompanion ",
"\nfun ",
"\nval ",
"\nvar ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nwhen ",
"\ncase ",
"\nelse ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.JS:
return [
# Split along function definitions
"\nfunction ",
"\nconst ",
"\nlet ",
"\nvar ",
"\nclass ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\ndefault ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.TS:
return [
"\nenum ",
"\ninterface ",
"\nnamespace ",
"\ntype ",
# Split along class definitions
"\nclass ",
# Split along function definitions
"\nfunction ",
"\nconst ",
"\nlet ",
"\nvar ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\ndefault ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.PHP:
return [
# Split along function definitions
"\nfunction ",
# Split along class definitions
"\nclass ",
# Split along control flow statements
"\nif ",
"\nforeach ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.PROTO:
return [
# Split along message definitions
"\nmessage ",
# Split along service definitions
"\nservice ",
# Split along enum definitions
"\nenum ",
# Split along option definitions
"\noption ",
# Split along import statements
"\nimport ",
# Split along syntax declarations
"\nsyntax ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.PYTHON:
return [
# First, try to split along class definitions
"\nclass ",
"\ndef ",
"\n\tdef ",
# Now split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.RST:
return [
# Split along section titles
"\n=+\n",
"\n-+\n",
"\n\\*+\n",
# Split along directive markers
"\n\n.. *\n\n",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.RUBY:
return [
# Split along method definitions
"\ndef ",
"\nclass ",
# Split along control flow statements
"\nif ",
"\nunless ",
"\nwhile ",
"\nfor ",
"\ndo ",
"\nbegin ",
"\nrescue ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.RUST:
return [
# Split along function definitions
"\nfn ",
"\nconst ",
"\nlet ",
# Split along control flow statements
"\nif ",
"\nwhile ",
"\nfor ",
"\nloop ",
"\nmatch ",
"\nconst ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.SCALA:
return [
# Split along class definitions
"\nclass ",
"\nobject ",
# Split along method definitions
"\ndef ",
"\nval ",
"\nvar ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nmatch ",
"\ncase ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.SWIFT:
return [
# Split along function definitions
"\nfunc ",
# Split along class definitions
"\nclass ",
"\nstruct ",
"\nenum ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.MARKDOWN:
return [
# First, try to split along Markdown headings (starting with level 2)
"\n#{1,6} ",
# Note the alternative syntax for headings (below) is not handled here
# Heading level 2
# ---------------
# End of code block
"```\n",
# Horizontal lines
"\n\\*\\*\\*+\n",
"\n---+\n",
"\n___+\n",
# Note that this splitter doesn't handle horizontal lines defined
# by *three or more* of ***, ---, or ___, but this is not handled
"\n\n",
"\n",
" ",
"",
]
elif language == Language.LATEX:
return [
# First, try to split along Latex sections
"\n\\\\chapter{",
"\n\\\\section{",
"\n\\\\subsection{",
"\n\\\\subsubsection{",
# Now split by environments
"\n\\\\begin{enumerate}",
"\n\\\\begin{itemize}",
"\n\\\\begin{description}",
"\n\\\\begin{list}",
"\n\\\\begin{quote}",
"\n\\\\begin{quotation}",
"\n\\\\begin{verse}",
"\n\\\\begin{verbatim}",
# Now split by math environments
"\n\\\begin{align}",
"$$",
"$",
# Now split by the normal type of lines
" ",
"",
]
elif language == Language.HTML:
return [
# First, try to split along HTML tags
"<body",
"<div",
"<p",
"<br",
"<li",
"<h1",
"<h2",
"<h3",
"<h4",
"<h5",
"<h6",
"<span",
"<table",
"<tr",
"<td",
"<th",
"<ul",
"<ol",
"<header",
"<footer",
"<nav",
# Head
"<head",
"<style",
"<script",
"<meta",
"<title",
"",
]
elif language == Language.CSHARP:
return [
"\ninterface ",
"\nenum ",
"\nimplements ",
"\ndelegate ",
"\nevent ",
# Split along class definitions
"\nclass ",
"\nabstract ",
# Split along method definitions
"\npublic ",
"\nprotected ",
"\nprivate ",
"\nstatic ",
"\nreturn ",
# Split along control flow statements
"\nif ",
"\ncontinue ",
"\nfor ",
"\nforeach ",
"\nwhile ",
"\nswitch ",
"\nbreak ",
"\ncase ",
"\nelse ",
# Split by exceptions
"\ntry ",
"\nthrow ",
"\nfinally ",
"\ncatch ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.SOL:
return [
# Split along compiler information definitions
"\npragma ",
"\nusing ",
# Split along contract definitions
"\ncontract ",
"\ninterface ",
"\nlibrary ",
# Split along method definitions
"\nconstructor ",
"\ntype ",
"\nfunction ",
"\nevent ",
"\nmodifier ",
"\nerror ",
"\nstruct ",
"\nenum ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\ndo while ",
"\nassembly ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.COBOL:
return [
# Split along divisions
"\nIDENTIFICATION DIVISION.",
"\nENVIRONMENT DIVISION.",
"\nDATA DIVISION.",
"\nPROCEDURE DIVISION.",
# Split along sections within DATA DIVISION
"\nWORKING-STORAGE SECTION.",
"\nLINKAGE SECTION.",
"\nFILE SECTION.",
# Split along sections within PROCEDURE DIVISION
"\nINPUT-OUTPUT SECTION.",
# Split along paragraphs and common statements
"\nOPEN ",
"\nCLOSE ",
"\nREAD ",
"\nWRITE ",
"\nIF ",
"\nELSE ",
"\nMOVE ",
"\nPERFORM ",
"\nUNTIL ",
"\nVARYING ",
"\nACCEPT ",
"\nDISPLAY ",
"\nSTOP RUN.",
# Split by the normal type of lines
"\n",
" ",
"",
]
else:
raise ValueError(
f"Language {language} is not supported! "
f"Please choose from {list(Language)}"
)
class NLTKTextSplitter(TextSplitter):
"""Splitting text using NLTK package."""
def __init__(
self, separator: str = "\n\n", language: str = "english", **kwargs: Any
) -> None:
"""Initialize the NLTK splitter."""
super().__init__(**kwargs)
try:
from nltk.tokenize import sent_tokenize
self._tokenizer = sent_tokenize
except ImportError:
raise ImportError(
"NLTK is not installed, please install it with `pip install nltk`."
)
self._separator = separator
self._language = language
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
splits = self._tokenizer(text, language=self._language)
return self._merge_splits(splits, self._separator)
class SpacyTextSplitter(TextSplitter):
"""Splitting text using Spacy package.
Per default, Spacy's `en_core_web_sm` model is used. For a faster, but
potentially less accurate splitting, you can use `pipeline='sentencizer'`.
"""
def __init__(
self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any
) -> None:
"""Initialize the spacy text splitter."""
super().__init__(**kwargs)
self._tokenizer = _make_spacy_pipeline_for_splitting(pipeline)
self._separator = separator
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
splits = (s.text for s in self._tokenizer(text).sents)
return self._merge_splits(splits, self._separator)
# For backwards compatibility
class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
"""Attempts to split the text along Python syntax."""
def __init__(self, **kwargs: Any) -> None:
"""Initialize a PythonCodeTextSplitter."""
separators = self.get_separators_for_language(Language.PYTHON)
super().__init__(separators=separators, **kwargs)
class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
"""Attempts to split the text along Markdown-formatted headings."""
def __init__(self, **kwargs: Any) -> None:
"""Initialize a MarkdownTextSplitter."""
separators = self.get_separators_for_language(Language.MARKDOWN)
super().__init__(separators=separators, **kwargs)
class LatexTextSplitter(RecursiveCharacterTextSplitter):
"""Attempts to split the text along Latex-formatted layout elements."""
def __init__(self, **kwargs: Any) -> None:
"""Initialize a LatexTextSplitter."""
separators = self.get_separators_for_language(Language.LATEX)
super().__init__(separators=separators, **kwargs)