m-ric HF staff commited on
Commit
40a40cb
1 Parent(s): 51c0840

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -26
app.py CHANGED
@@ -7,6 +7,8 @@ from langchain.text_splitter import (
7
  LABEL_TEXTSPLITTER = "LangChain's CharacterTextSplitter"
8
  LABEL_RECURSIVE = "Langchain's RecursiveCharacterTextSplitter"
9
 
 
 
10
  def extract_separators_from_string(separators_str):
11
  try:
12
  separators = separators_str[1:-1].split(", ")
@@ -18,36 +20,55 @@ def extract_separators_from_string(separators_str):
18
  Please type it in the correct format: "['separator_1', 'separator_2', etc]"
19
  """)
20
 
21
- def change_split_selection(text, slider_count, split_selection, separator_selection):
22
  print("Updating separator selection interactivity:")
23
  return (
24
  gr.Textbox.update(visible=(split_selection==LABEL_RECURSIVE)),
25
- chunk(text, slider_count, split_selection, separator_selection)
26
  )
27
 
28
- def chunk(text, length, splitter_selection, separators_str):
29
  separators = extract_separators_from_string(separators_str)
30
 
31
  if splitter_selection == LABEL_TEXTSPLITTER:
32
- text_splitter = CharacterTextSplitter(
33
- separator="",
34
- chunk_size=length,
35
- chunk_overlap=0,
36
- length_function=len,
37
- is_separator_regex=False,
38
- )
39
- splits = text_splitter.create_documents([text])
40
- text_splits = [split.page_content for split in splits]
 
 
 
 
 
 
 
 
41
  elif splitter_selection == LABEL_RECURSIVE:
42
- text_splitter = RecursiveCharacterTextSplitter(
43
- chunk_size=length,
44
- chunk_overlap=0,
45
- length_function=len,
46
- add_start_index=True,
47
- separators=separators,
48
- )
49
- splits = text_splitter.create_documents([text])
50
- text_splits = [split.page_content for split in splits]
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  output = [(split, str(i)) for i, split in enumerate(text_splits)]
53
  return output
@@ -105,7 +126,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css="#textbox_id {color: red; font-samily
105
  "Character count",
106
  "Token count",
107
  ],
108
- value="Token count",
109
  label="Length count",
110
  info="How should we count our chunk lengths?",
111
  )
@@ -119,22 +140,22 @@ with gr.Blocks(theme=gr.themes.Soft(), css="#textbox_id {color: red; font-samily
119
  )
120
  text.change(
121
  fn=chunk,
122
- inputs=[text, slider_count, split_selection, separator_selection],
123
  outputs=out,
124
  )
125
  length_unit_selection.change(
126
  fn=chunk,
127
- inputs=[text, slider_count, split_selection, separator_selection],
128
  outputs=out,
129
  )
130
  split_selection.change(
131
  fn=change_split_selection,
132
- inputs=[text, slider_count, split_selection, separator_selection],
133
  outputs=[separator_selection, out],
134
  )
135
  slider_count.change(
136
  fn=chunk,
137
- inputs=[text, slider_count, split_selection, separator_selection],
138
  outputs=out,
139
  )
140
  demo.launch()
 
7
  LABEL_TEXTSPLITTER = "LangChain's CharacterTextSplitter"
8
  LABEL_RECURSIVE = "Langchain's RecursiveCharacterTextSplitter"
9
 
10
+ bert_tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-uncased')
11
+
12
  def extract_separators_from_string(separators_str):
13
  try:
14
  separators = separators_str[1:-1].split(", ")
 
20
  Please type it in the correct format: "['separator_1', 'separator_2', etc]"
21
  """)
22
 
23
+ def change_split_selection(text, slider_count, split_selection, separator_selection, length_unit_selection):
24
  print("Updating separator selection interactivity:")
25
  return (
26
  gr.Textbox.update(visible=(split_selection==LABEL_RECURSIVE)),
27
+ chunk(text, slider_count, split_selection, separator_selection, length_unit_selection)
28
  )
29
 
30
+ def chunk(text, length, splitter_selection, separators_str, length_unit_selection):
31
  separators = extract_separators_from_string(separators_str)
32
 
33
  if splitter_selection == LABEL_TEXTSPLITTER:
34
+ if "token" in length_unit_selection.lower():
35
+ text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(
36
+ AutoTokenizer.from_pretrained(tokenizer_name),
37
+ separator="",
38
+ chunk_size=length,
39
+ chunk_overlap=0,
40
+ length_function=len,
41
+ is_separator_regex=False,
42
+ )
43
+ else:
44
+ text_splitter = CharacterTextSplitter(
45
+ separator="",
46
+ chunk_size=length,
47
+ chunk_overlap=0,
48
+ length_function=len,
49
+ is_separator_regex=False,
50
+ )
51
  elif splitter_selection == LABEL_RECURSIVE:
52
+ if "token" in length_unit_selection.lower():
53
+ text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
54
+ AutoTokenizer.from_pretrained(tokenizer_name),
55
+ chunk_size=chunk_size,
56
+ chunk_overlap=0,
57
+ add_start_index=True,
58
+ strip_whitespace=False,
59
+ separators=separators,
60
+ )
61
+ else:
62
+ text_splitter = RecursiveCharacterTextSplitter(
63
+ chunk_size=length,
64
+ chunk_overlap=0,
65
+ length_function=len,
66
+ add_start_index=True,
67
+ strip_whitespace=False,
68
+ separators=separators,
69
+ )
70
+ splits = text_splitter.create_documents([text])
71
+ text_splits = [split.page_content for split in splits]
72
 
73
  output = [(split, str(i)) for i, split in enumerate(text_splits)]
74
  return output
 
126
  "Character count",
127
  "Token count",
128
  ],
129
+ value=["Character count", "Token count (BERT tokens)"],
130
  label="Length count",
131
  info="How should we count our chunk lengths?",
132
  )
 
140
  )
141
  text.change(
142
  fn=chunk,
143
+ inputs=[text, slider_count, split_selection, separator_selection, length_unit_selection],
144
  outputs=out,
145
  )
146
  length_unit_selection.change(
147
  fn=chunk,
148
+ inputs=[text, slider_count, split_selection, separator_selection, length_unit_selection],
149
  outputs=out,
150
  )
151
  split_selection.change(
152
  fn=change_split_selection,
153
+ inputs=[text, slider_count, split_selection, separator_selection, length_unit_selection],
154
  outputs=[separator_selection, out],
155
  )
156
  slider_count.change(
157
  fn=chunk,
158
+ inputs=[text, slider_count, split_selection, separator_selection, length_unit_selection],
159
  outputs=out,
160
  )
161
  demo.launch()