AlexanderKazakov commited on
Commit
10ddae5
1 Parent(s): 34b78ab

configurable chunking and embedding

Browse files
gradio_app/app.py CHANGED
@@ -10,12 +10,13 @@ from time import perf_counter
10
 
11
  import gradio as gr
12
  import markdown
 
13
  from jinja2 import Environment, FileSystemLoader
14
 
15
  from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
16
  from gradio_app.backend.cross_encoder import rerank_with_cross_encoder
17
  from gradio_app.backend.query_llm import *
18
- from gradio_app.backend.semantic_search import table, embedder
19
 
20
  from settings import *
21
 
@@ -30,6 +31,8 @@ env = Environment(loader=FileSystemLoader('gradio_app/templates'))
30
  context_template = env.get_template('context_template.j2')
31
  context_html_template = env.get_template('context_html_template.j2')
32
 
 
 
33
  # Examples
34
  examples = [
35
  'What is BERT?',
@@ -46,7 +49,7 @@ def add_text(history, text):
46
  return history, gr.Textbox(value="", interactive=False)
47
 
48
 
49
- def bot(history, llm, cross_enc):
50
  history[-1][1] = ""
51
  query = history[-1][0]
52
 
@@ -55,27 +58,33 @@ def bot(history, llm, cross_enc):
55
 
56
  logger.info('Retrieving documents...')
57
  gr.Info('Start documents retrieval ...')
58
- time = perf_counter()
 
 
 
 
 
59
 
60
- query_vec = embedder.embed(query)[0]
61
  documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME)
62
- documents = documents.limit(TOP_K_RANK).to_list()
63
- thresh_dist = thresh_distances[EMBED_NAME]
 
64
  thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
65
  documents = [d for d in documents if d['_distance'] <= thresh_dist]
66
  documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
67
 
68
- time = perf_counter() - time
69
- logger.info(f'Finished Retrieving documents in {round(time, 2)} seconds...')
70
 
71
  logger.info('Reranking documents...')
72
  gr.Info('Start documents reranking ...')
73
- time = perf_counter()
74
 
75
  documents = rerank_with_cross_encoder(cross_enc, documents, query)
76
 
77
- time = perf_counter() - time
78
- logger.info(f'Finished Reranking documents in {round(time, 2)} seconds...')
79
 
80
  msg_constructor = get_message_constructor(llm)
81
  while len(documents) != 0:
@@ -91,11 +100,14 @@ def bot(history, llm, cross_enc):
91
  raise gr.Error('Model context length exceeded, reload the page')
92
 
93
  llm_gen = get_llm_generator(llm)
 
 
94
  for part in llm_gen(messages):
95
  history[-1][1] += part
96
  yield history, context_html
97
  else:
98
- print('Finished generation stream.')
 
99
 
100
 
101
  with gr.Blocks() as demo:
@@ -109,7 +121,7 @@ with gr.Blocks() as demo:
109
  bubble_full_width=False,
110
  show_copy_button=True,
111
  show_share_button=True,
112
- height=600,
113
  )
114
 
115
  with gr.Row():
@@ -121,14 +133,22 @@ with gr.Blocks() as demo:
121
  )
122
  txt_btn = gr.Button(value="Submit text", scale=1)
123
 
124
- llm_name = gr.Radio(
125
  choices=[
126
- "gpt-3.5-turbo",
127
- "mistralai/Mistral-7B-Instruct-v0.1",
128
- "GeneZC/MiniChat-3B",
129
  ],
130
- value="gpt-3.5-turbo",
131
- label='LLM'
 
 
 
 
 
 
 
 
 
132
  )
133
 
134
  cross_enc_name = gr.Radio(
@@ -141,6 +161,16 @@ with gr.Blocks() as demo:
141
  label='Cross-Encoder'
142
  )
143
 
 
 
 
 
 
 
 
 
 
 
144
  # Examples
145
  gr.Examples(examples, input_textbox)
146
 
@@ -151,7 +181,7 @@ with gr.Blocks() as demo:
151
  txt_msg = txt_btn.click(
152
  add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
153
  ).then(
154
- bot, [chatbot, llm_name, cross_enc_name], [chatbot, context_html]
155
  )
156
 
157
  # Turn it back on
@@ -159,7 +189,7 @@ with gr.Blocks() as demo:
159
 
160
  # Turn off interactivity while generating if you hit enter
161
  txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
162
- bot, [chatbot, llm_name, cross_enc_name], [chatbot, context_html])
163
 
164
  # Turn it back on
165
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
 
10
 
11
  import gradio as gr
12
  import markdown
13
+ import lancedb
14
  from jinja2 import Environment, FileSystemLoader
15
 
16
  from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
17
  from gradio_app.backend.cross_encoder import rerank_with_cross_encoder
18
  from gradio_app.backend.query_llm import *
19
+ from gradio_app.backend.embedders import EmbedderFactory
20
 
21
  from settings import *
22
 
 
31
  context_template = env.get_template('context_template.j2')
32
  context_html_template = env.get_template('context_html_template.j2')
33
 
34
+ db = lancedb.connect(LANCEDB_DIRECTORY)
35
+
36
  # Examples
37
  examples = [
38
  'What is BERT?',
 
49
  return history, gr.Textbox(value="", interactive=False)
50
 
51
 
52
+ def bot(history, llm, cross_enc, chunk, embed):
53
  history[-1][1] = ""
54
  query = history[-1][0]
55
 
 
58
 
59
  logger.info('Retrieving documents...')
60
  gr.Info('Start documents retrieval ...')
61
+ t = perf_counter()
62
+
63
+ table_name = f'{LANCEDB_TABLE_NAME}_{chunk}_{embed}'
64
+ table = db.open_table(table_name)
65
+
66
+ embedder = EmbedderFactory.get_embedder(embed)
67
 
68
+ query_vec = embedder.embed([query])[0]
69
  documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME)
70
+ top_k_rank = TOP_K_RANK if cross_enc is not None else TOP_K_RERANK
71
+ documents = documents.limit(top_k_rank).to_list()
72
+ thresh_dist = thresh_distances[embed]
73
  thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
74
  documents = [d for d in documents if d['_distance'] <= thresh_dist]
75
  documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
76
 
77
+ t = perf_counter() - t
78
+ logger.info(f'Finished Retrieving documents in {round(t, 2)} seconds...')
79
 
80
  logger.info('Reranking documents...')
81
  gr.Info('Start documents reranking ...')
82
+ t = perf_counter()
83
 
84
  documents = rerank_with_cross_encoder(cross_enc, documents, query)
85
 
86
+ t = perf_counter() - t
87
+ logger.info(f'Finished Reranking documents in {round(t, 2)} seconds...')
88
 
89
  msg_constructor = get_message_constructor(llm)
90
  while len(documents) != 0:
 
100
  raise gr.Error('Model context length exceeded, reload the page')
101
 
102
  llm_gen = get_llm_generator(llm)
103
+ logger.info('Generating answer...')
104
+ t = perf_counter()
105
  for part in llm_gen(messages):
106
  history[-1][1] += part
107
  yield history, context_html
108
  else:
109
+ t = perf_counter() - t
110
+ logger.info(f'Finished Generating answer in {round(t, 2)} seconds...')
111
 
112
 
113
  with gr.Blocks() as demo:
 
121
  bubble_full_width=False,
122
  show_copy_button=True,
123
  show_share_button=True,
124
+ height=500,
125
  )
126
 
127
  with gr.Row():
 
133
  )
134
  txt_btn = gr.Button(value="Submit text", scale=1)
135
 
136
+ chunk_name = gr.Radio(
137
  choices=[
138
+ "md",
139
+ "txt",
 
140
  ],
141
+ value="md",
142
+ label='Chunking policy'
143
+ )
144
+
145
+ embed_name = gr.Radio(
146
+ choices=[
147
+ "text-embedding-ada-002",
148
+ "sentence-transformers/all-MiniLM-L6-v2",
149
+ ],
150
+ value="text-embedding-ada-002",
151
+ label='Embedder'
152
  )
153
 
154
  cross_enc_name = gr.Radio(
 
161
  label='Cross-Encoder'
162
  )
163
 
164
+ llm_name = gr.Radio(
165
+ choices=[
166
+ "gpt-3.5-turbo",
167
+ "mistralai/Mistral-7B-Instruct-v0.1",
168
+ "GeneZC/MiniChat-3B",
169
+ ],
170
+ value="gpt-3.5-turbo",
171
+ label='LLM'
172
+ )
173
+
174
  # Examples
175
  gr.Examples(examples, input_textbox)
176
 
 
181
  txt_msg = txt_btn.click(
182
  add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
183
  ).then(
184
+ bot, [chatbot, llm_name, cross_enc_name, chunk_name, embed_name], [chatbot, context_html]
185
  )
186
 
187
  # Turn it back on
 
189
 
190
  # Turn off interactivity while generating if you hit enter
191
  txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
192
+ bot, [chatbot, llm_name, cross_enc_name, chunk_name, embed_name], [chatbot, context_html])
193
 
194
  # Turn it back on
195
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
gradio_app/backend/semantic_search.py DELETED
@@ -1,14 +0,0 @@
1
- import logging
2
- import lancedb
3
-
4
- from gradio_app.backend.embedders import EmbedderFactory
5
- from settings import *
6
-
7
-
8
- # Setting up the logging
9
- logging.basicConfig(level=logging.INFO)
10
- logger = logging.getLogger(__name__)
11
- embedder = EmbedderFactory.get_embedder(EMBED_NAME)
12
-
13
- db = lancedb.connect(LANCEDB_DIRECTORY)
14
- table = db.open_table(LANCEDB_TABLE_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prep_scripts/lancedb_setup.py CHANGED
@@ -1,4 +1,5 @@
1
  import shutil
 
2
 
3
  import lancedb
4
  import openai
@@ -18,7 +19,7 @@ with open('data/openaikey.txt') as f:
18
  openai.api_key = OPENAI_KEY
19
 
20
 
21
- shutil.rmtree(LANCEDB_DIRECTORY, ignore_errors=True)
22
  db = lancedb.connect(LANCEDB_DIRECTORY)
23
  batch_size = 32
24
 
@@ -27,7 +28,8 @@ schema = pa.schema([
27
  pa.field(TEXT_COLUMN_NAME, pa.string()),
28
  pa.field(DOCUMENT_PATH_COLUMN_NAME, pa.string()),
29
  ])
30
- tbl = db.create_table(LANCEDB_TABLE_NAME, schema=schema, mode="overwrite")
 
31
 
32
  input_dir = Path(MARKDOWN_SOURCE_DIR)
33
  files = list(input_dir.rglob("*"))
@@ -45,15 +47,21 @@ for file in files:
45
  with open(file, encoding='utf-8') as f:
46
  f = f.read()
47
  f = remove_comments(f)
48
- f = split_markdown(f)
 
 
 
 
49
  chunks.extend((chunk, os.path.abspath(file)) for chunk in f)
50
 
51
  from matplotlib import pyplot as plt
52
  plt.hist([len(c) for c, d in chunks], bins=100)
 
53
  plt.show()
54
 
55
  embedder = EmbedderFactory.get_embedder(EMBED_NAME)
56
 
 
57
  for i in tqdm.tqdm(range(0, int(np.ceil(len(chunks) / batch_size)))):
58
  texts, doc_paths = [], []
59
  for text, doc_path in chunks[i * batch_size:(i + 1) * batch_size]:
@@ -61,14 +69,24 @@ for i in tqdm.tqdm(range(0, int(np.ceil(len(chunks) / batch_size)))):
61
  texts.append(text)
62
  doc_paths.append(doc_path)
63
 
 
64
  encoded = embedder.embed(texts)
 
 
65
  df = pd.DataFrame({
66
  VECTOR_COLUMN_NAME: encoded,
67
  TEXT_COLUMN_NAME: texts,
68
  DOCUMENT_PATH_COLUMN_NAME: doc_paths,
69
  })
70
 
 
71
  tbl.add(df)
 
 
 
 
 
 
72
 
73
 
74
 
 
1
  import shutil
2
+ import time
3
 
4
  import lancedb
5
  import openai
 
19
  openai.api_key = OPENAI_KEY
20
 
21
 
22
+ # shutil.rmtree(LANCEDB_DIRECTORY, ignore_errors=True)
23
  db = lancedb.connect(LANCEDB_DIRECTORY)
24
  batch_size = 32
25
 
 
28
  pa.field(TEXT_COLUMN_NAME, pa.string()),
29
  pa.field(DOCUMENT_PATH_COLUMN_NAME, pa.string()),
30
  ])
31
+ table_name = f'{LANCEDB_TABLE_NAME}_{CHUNK_POLICY}_{EMBED_NAME}'
32
+ tbl = db.create_table(table_name, schema=schema, mode="overwrite")
33
 
34
  input_dir = Path(MARKDOWN_SOURCE_DIR)
35
  files = list(input_dir.rglob("*"))
 
47
  with open(file, encoding='utf-8') as f:
48
  f = f.read()
49
  f = remove_comments(f)
50
+ if CHUNK_POLICY == "txt":
51
+ f = md2txt_then_split(f)
52
+ else:
53
+ assert CHUNK_POLICY == "md"
54
+ f = split_markdown(f)
55
  chunks.extend((chunk, os.path.abspath(file)) for chunk in f)
56
 
57
  from matplotlib import pyplot as plt
58
  plt.hist([len(c) for c, d in chunks], bins=100)
59
+ plt.title(table_name)
60
  plt.show()
61
 
62
  embedder = EmbedderFactory.get_embedder(EMBED_NAME)
63
 
64
+ time_embed, time_ingest = [], []
65
  for i in tqdm.tqdm(range(0, int(np.ceil(len(chunks) / batch_size)))):
66
  texts, doc_paths = [], []
67
  for text, doc_path in chunks[i * batch_size:(i + 1) * batch_size]:
 
69
  texts.append(text)
70
  doc_paths.append(doc_path)
71
 
72
+ t = time.perf_counter()
73
  encoded = embedder.embed(texts)
74
+ time_embed.append(time.perf_counter() - t)
75
+
76
  df = pd.DataFrame({
77
  VECTOR_COLUMN_NAME: encoded,
78
  TEXT_COLUMN_NAME: texts,
79
  DOCUMENT_PATH_COLUMN_NAME: doc_paths,
80
  })
81
 
82
+ t = time.perf_counter()
83
  tbl.add(df)
84
+ time_ingest.append(time.perf_counter() - t)
85
+
86
+
87
+ time_embed = sum(time_embed)
88
+ time_ingest = sum(time_ingest)
89
+ print(f'Embedding: {time_embed}, Ingesting: {time_ingest}')
90
 
91
 
92
 
prep_scripts/markdown_to_text.py CHANGED
@@ -1,6 +1,9 @@
1
  import os
2
  import re
3
 
 
 
 
4
  from settings import *
5
 
6
 
@@ -95,3 +98,32 @@ def split_markdown(md):
95
 
96
  return res
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
 
4
+ from bs4 import BeautifulSoup
5
+ from markdown import markdown
6
+
7
  from settings import *
8
 
9
 
 
98
 
99
  return res
100
 
101
+
102
+ def markdown_to_text(markdown_string):
103
+ """ Converts a markdown string to plaintext """
104
+
105
+ # md -> html -> text since BeautifulSoup can extract text cleanly
106
+ html = markdown(markdown_string)
107
+
108
+ html = re.sub(r'<!--((.|\n)*)-->', '', html)
109
+ html = re.sub('<code>bash', '<code>', html)
110
+
111
+ # extract text
112
+ soup = BeautifulSoup(html, "html.parser")
113
+ text = ''.join(soup.findAll(string=True))
114
+
115
+ text = re.sub('```(py|diff|python)', '', text)
116
+ text = re.sub('```\n', '\n', text)
117
+ text = re.sub('- .*', '', text)
118
+ text = text.replace('...', '')
119
+ text = re.sub('\n(\n)+', '\n\n', text)
120
+
121
+ return text
122
+
123
+
124
+ def md2txt_then_split(md):
125
+ txt = markdown_to_text(md)
126
+ return split_content(txt)
127
+
128
+
129
+
settings.py CHANGED
@@ -5,8 +5,11 @@ VECTOR_COLUMN_NAME = "embedding"
5
  TEXT_COLUMN_NAME = "text"
6
  DOCUMENT_PATH_COLUMN_NAME = "document_path"
7
 
8
- # EMBED_NAME = "sentence-transformers/all-MiniLM-L6-v2"
9
- EMBED_NAME = "text-embedding-ada-002"
 
 
 
10
 
11
  TOP_K_RANK = 50
12
  TOP_K_RERANK = 5
@@ -28,5 +31,5 @@ context_lengths = {
28
  "gpt-3.5-turbo": 4096,
29
  "sentence-transformers/all-MiniLM-L6-v2": 128,
30
  "thenlper/gte-large": 512,
31
- "text-embedding-ada-002": 8191,
32
  }
 
5
  TEXT_COLUMN_NAME = "text"
6
  DOCUMENT_PATH_COLUMN_NAME = "document_path"
7
 
8
+ CHUNK_POLICY = "md"
9
+ # CHUNK_POLICY = "txt"
10
+
11
+ EMBED_NAME = "sentence-transformers/all-MiniLM-L6-v2"
12
+ # EMBED_NAME = "text-embedding-ada-002"
13
 
14
  TOP_K_RANK = 50
15
  TOP_K_RERANK = 5
 
31
  "gpt-3.5-turbo": 4096,
32
  "sentence-transformers/all-MiniLM-L6-v2": 128,
33
  "thenlper/gte-large": 512,
34
+ "text-embedding-ada-002": 1000, # actual context length is 8191, but it's too much
35
  }