File size: 7,236 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""Implementation of splitting text that looks at characters.

Recursively tries to split by different characters to find one that works.

The implementation below is forked from the LangChain project with the MIT license below.
See `RecursiveCharacterTextSplitter` in
https://github.com/hwchase17/langchain/blob/master/langchain/text_splitter.py
"""

# The MIT License

# Copyright (c) Harrison Chase

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

from typing import Any, Callable, Iterable, Optional

from pydantic import validator
from typing_extensions import override

from ...data.dataset_utils import lilac_span
from ...schema import Item, RichData
from ...utils import log
from ..signal import TextSplitterSignal

TextChunk = tuple[str, tuple[int, int]]

DEFAULT_SEPARATORS = ['\n\n', '\n', ' ', '']
CHUNK_SIZE = 400
CHUNK_OVERLAP = 50


class ChunkSplitter(TextSplitterSignal):
  """Recursively split documents by different characters to find one that works."""

  name = 'chunk'
  display_name = 'Chunk Splitter'

  chunk_size: int = CHUNK_SIZE
  chunk_overlap: int = CHUNK_OVERLAP
  separators: list[str] = DEFAULT_SEPARATORS

  _length_function: Callable[[str], int] = len

  @validator('chunk_overlap')
  def check_overlap_smaller_than_chunk(cls, chunk_overlap: int, values: dict[str, Any]) -> int:
    """Check that the chunk overlap is smaller than the chunk size."""
    chunk_size: int = values['chunk_size']
    if chunk_overlap > chunk_size:
      raise ValueError(f'Got a larger chunk overlap ({chunk_overlap}) than chunk size '
                       f'({chunk_size}), should be smaller.')
    return chunk_overlap

  @validator('separators')
  def check_separators_are_strings(cls, separators: list[str]) -> list[str]:
    """Check that the separators are strings."""
    separators = list(separators) or DEFAULT_SEPARATORS
    for sep in separators:
      if not isinstance(sep, str):
        raise ValueError(f'Got separator {sep} that is not a string.')
    return separators

  @override
  def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
    for text in data:
      if not isinstance(text, str):
        yield None
        continue

      chunks = split_text(text, self.chunk_size, self.chunk_overlap, self.separators,
                          self._length_function)
      if not chunks:
        yield None
        continue

      yield [lilac_span(start, end) for _, (start, end) in chunks]


def _sep_split(text: str, separator: str) -> list[TextChunk]:
  if separator == '':
    # We need to split by char.
    return [(letter, (i, i + 1)) for i, letter in enumerate(text)]

  offset = 0
  chunks: list[TextChunk] = []
  end_index = text.find(separator, offset)

  while end_index >= 0:
    chunks.append((text[offset:end_index], (offset, end_index)))
    offset = end_index + len(separator)
    end_index = text.find(separator, offset)

  # Append the last chunk.
  chunks.append((text[offset:], (offset, len(text))))

  return chunks


def split_text(text: str,
               chunk_size: int = CHUNK_SIZE,
               chunk_overlap: int = CHUNK_OVERLAP,
               separators: list[str] = DEFAULT_SEPARATORS,
               length_function: Callable[[str], int] = len) -> list[TextChunk]:
  """Split incoming text and return chunks."""

  def _merge_splits(splits: Iterable[TextChunk], separator: str) -> list[TextChunk]:
    # We now want to combine these smaller pieces into medium size
    # chunks to send to the LLM.
    separator_len = length_function(separator)

    docs: list[TextChunk] = []
    current_doc: list[TextChunk] = []
    total = 0
    for chunk in splits:
      text_chunk, _ = chunk
      _len = length_function(text_chunk)
      if (total + _len + (separator_len if len(current_doc) > 0 else 0) > chunk_size):
        if total > chunk_size:
          log(f'Created a chunk of size {total}, '
              f'which is longer than the specified {chunk_size}')
        if len(current_doc) > 0:
          doc = _join_chunks(current_doc, separator)
          if doc is not None:
            docs.append(doc)
          # Keep on popping if:
          # - we have a larger chunk than in the chunk overlap
          # - or if we still have any chunks and the length is long
          while total > chunk_overlap or (
              total + _len +
            (separator_len if len(current_doc) > 0 else 0) > chunk_size and total > 0):
            total -= length_function(current_doc[0][0]) + (
              separator_len if len(current_doc) > 1 else 0)
            current_doc = current_doc[1:]
      current_doc.append(chunk)
      total += _len + (separator_len if len(current_doc) > 1 else 0)
    doc = _join_chunks(current_doc, separator)
    if doc is not None:
      docs.append(doc)
    return docs

  final_chunks: list[TextChunk] = []
  # Get appropriate separator to use
  separator = separators[-1]
  for _s in separators:
    if _s == '':
      separator = _s
      break
    if _s in text:
      separator = _s
      break
  # Now that we have the separator, split the text.
  splits = _sep_split(text, separator)
  # Now go merging things, recursively splitting longer texts.
  good_splits: list[TextChunk] = []
  for chunk in splits:
    text_chunk, (start, _) = chunk
    if length_function(text_chunk) < chunk_size:
      good_splits.append(chunk)
    else:
      if good_splits:
        merged_text = _merge_splits(good_splits, separator)
        final_chunks.extend(merged_text)
        good_splits = []
      other_chunks = split_text(text_chunk, chunk_size, chunk_overlap, separators, length_function)
      # Adjust the offsets of the other chunks.
      other_chunks = [(t, (s + start, e + start)) for t, (s, e) in other_chunks]
      final_chunks.extend(other_chunks)
  if good_splits:
    merged_text = _merge_splits(good_splits, separator)
    final_chunks.extend(merged_text)
  return final_chunks


def _join_chunks(chunks: list[TextChunk], separator: str) -> Optional[TextChunk]:
  text = separator.join([text for text, _ in chunks])
  text = text.strip()
  if text == '':
    return None

  _, (first_span_start, _) = chunks[0]
  _, (_, last_span_end) = chunks[-1]
  return (text, (first_span_start, last_span_end))