AlexanderKazakov
commited on
Commit
•
10ddae5
1
Parent(s):
34b78ab
configurable chunking and embedding
Browse files- gradio_app/app.py +51 -21
- gradio_app/backend/semantic_search.py +0 -14
- prep_scripts/lancedb_setup.py +21 -3
- prep_scripts/markdown_to_text.py +32 -0
- settings.py +6 -3
gradio_app/app.py
CHANGED
@@ -10,12 +10,13 @@ from time import perf_counter
|
|
10 |
|
11 |
import gradio as gr
|
12 |
import markdown
|
|
|
13 |
from jinja2 import Environment, FileSystemLoader
|
14 |
|
15 |
from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
|
16 |
from gradio_app.backend.cross_encoder import rerank_with_cross_encoder
|
17 |
from gradio_app.backend.query_llm import *
|
18 |
-
from gradio_app.backend.
|
19 |
|
20 |
from settings import *
|
21 |
|
@@ -30,6 +31,8 @@ env = Environment(loader=FileSystemLoader('gradio_app/templates'))
|
|
30 |
context_template = env.get_template('context_template.j2')
|
31 |
context_html_template = env.get_template('context_html_template.j2')
|
32 |
|
|
|
|
|
33 |
# Examples
|
34 |
examples = [
|
35 |
'What is BERT?',
|
@@ -46,7 +49,7 @@ def add_text(history, text):
|
|
46 |
return history, gr.Textbox(value="", interactive=False)
|
47 |
|
48 |
|
49 |
-
def bot(history, llm, cross_enc):
|
50 |
history[-1][1] = ""
|
51 |
query = history[-1][0]
|
52 |
|
@@ -55,27 +58,33 @@ def bot(history, llm, cross_enc):
|
|
55 |
|
56 |
logger.info('Retrieving documents...')
|
57 |
gr.Info('Start documents retrieval ...')
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
query_vec = embedder.embed(query)[0]
|
61 |
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME)
|
62 |
-
|
63 |
-
|
|
|
64 |
thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
|
65 |
documents = [d for d in documents if d['_distance'] <= thresh_dist]
|
66 |
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
|
67 |
|
68 |
-
|
69 |
-
logger.info(f'Finished Retrieving documents in {round(
|
70 |
|
71 |
logger.info('Reranking documents...')
|
72 |
gr.Info('Start documents reranking ...')
|
73 |
-
|
74 |
|
75 |
documents = rerank_with_cross_encoder(cross_enc, documents, query)
|
76 |
|
77 |
-
|
78 |
-
logger.info(f'Finished Reranking documents in {round(
|
79 |
|
80 |
msg_constructor = get_message_constructor(llm)
|
81 |
while len(documents) != 0:
|
@@ -91,11 +100,14 @@ def bot(history, llm, cross_enc):
|
|
91 |
raise gr.Error('Model context length exceeded, reload the page')
|
92 |
|
93 |
llm_gen = get_llm_generator(llm)
|
|
|
|
|
94 |
for part in llm_gen(messages):
|
95 |
history[-1][1] += part
|
96 |
yield history, context_html
|
97 |
else:
|
98 |
-
|
|
|
99 |
|
100 |
|
101 |
with gr.Blocks() as demo:
|
@@ -109,7 +121,7 @@ with gr.Blocks() as demo:
|
|
109 |
bubble_full_width=False,
|
110 |
show_copy_button=True,
|
111 |
show_share_button=True,
|
112 |
-
height=
|
113 |
)
|
114 |
|
115 |
with gr.Row():
|
@@ -121,14 +133,22 @@ with gr.Blocks() as demo:
|
|
121 |
)
|
122 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
123 |
|
124 |
-
|
125 |
choices=[
|
126 |
-
"
|
127 |
-
"
|
128 |
-
"GeneZC/MiniChat-3B",
|
129 |
],
|
130 |
-
value="
|
131 |
-
label='
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
)
|
133 |
|
134 |
cross_enc_name = gr.Radio(
|
@@ -141,6 +161,16 @@ with gr.Blocks() as demo:
|
|
141 |
label='Cross-Encoder'
|
142 |
)
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
# Examples
|
145 |
gr.Examples(examples, input_textbox)
|
146 |
|
@@ -151,7 +181,7 @@ with gr.Blocks() as demo:
|
|
151 |
txt_msg = txt_btn.click(
|
152 |
add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
|
153 |
).then(
|
154 |
-
bot, [chatbot, llm_name, cross_enc_name], [chatbot, context_html]
|
155 |
)
|
156 |
|
157 |
# Turn it back on
|
@@ -159,7 +189,7 @@ with gr.Blocks() as demo:
|
|
159 |
|
160 |
# Turn off interactivity while generating if you hit enter
|
161 |
txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
|
162 |
-
bot, [chatbot, llm_name, cross_enc_name], [chatbot, context_html])
|
163 |
|
164 |
# Turn it back on
|
165 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
|
|
|
10 |
|
11 |
import gradio as gr
|
12 |
import markdown
|
13 |
+
import lancedb
|
14 |
from jinja2 import Environment, FileSystemLoader
|
15 |
|
16 |
from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
|
17 |
from gradio_app.backend.cross_encoder import rerank_with_cross_encoder
|
18 |
from gradio_app.backend.query_llm import *
|
19 |
+
from gradio_app.backend.embedders import EmbedderFactory
|
20 |
|
21 |
from settings import *
|
22 |
|
|
|
31 |
context_template = env.get_template('context_template.j2')
|
32 |
context_html_template = env.get_template('context_html_template.j2')
|
33 |
|
34 |
+
db = lancedb.connect(LANCEDB_DIRECTORY)
|
35 |
+
|
36 |
# Examples
|
37 |
examples = [
|
38 |
'What is BERT?',
|
|
|
49 |
return history, gr.Textbox(value="", interactive=False)
|
50 |
|
51 |
|
52 |
+
def bot(history, llm, cross_enc, chunk, embed):
|
53 |
history[-1][1] = ""
|
54 |
query = history[-1][0]
|
55 |
|
|
|
58 |
|
59 |
logger.info('Retrieving documents...')
|
60 |
gr.Info('Start documents retrieval ...')
|
61 |
+
t = perf_counter()
|
62 |
+
|
63 |
+
table_name = f'{LANCEDB_TABLE_NAME}_{chunk}_{embed}'
|
64 |
+
table = db.open_table(table_name)
|
65 |
+
|
66 |
+
embedder = EmbedderFactory.get_embedder(embed)
|
67 |
|
68 |
+
query_vec = embedder.embed([query])[0]
|
69 |
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME)
|
70 |
+
top_k_rank = TOP_K_RANK if cross_enc is not None else TOP_K_RERANK
|
71 |
+
documents = documents.limit(top_k_rank).to_list()
|
72 |
+
thresh_dist = thresh_distances[embed]
|
73 |
thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
|
74 |
documents = [d for d in documents if d['_distance'] <= thresh_dist]
|
75 |
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
|
76 |
|
77 |
+
t = perf_counter() - t
|
78 |
+
logger.info(f'Finished Retrieving documents in {round(t, 2)} seconds...')
|
79 |
|
80 |
logger.info('Reranking documents...')
|
81 |
gr.Info('Start documents reranking ...')
|
82 |
+
t = perf_counter()
|
83 |
|
84 |
documents = rerank_with_cross_encoder(cross_enc, documents, query)
|
85 |
|
86 |
+
t = perf_counter() - t
|
87 |
+
logger.info(f'Finished Reranking documents in {round(t, 2)} seconds...')
|
88 |
|
89 |
msg_constructor = get_message_constructor(llm)
|
90 |
while len(documents) != 0:
|
|
|
100 |
raise gr.Error('Model context length exceeded, reload the page')
|
101 |
|
102 |
llm_gen = get_llm_generator(llm)
|
103 |
+
logger.info('Generating answer...')
|
104 |
+
t = perf_counter()
|
105 |
for part in llm_gen(messages):
|
106 |
history[-1][1] += part
|
107 |
yield history, context_html
|
108 |
else:
|
109 |
+
t = perf_counter() - t
|
110 |
+
logger.info(f'Finished Generating answer in {round(t, 2)} seconds...')
|
111 |
|
112 |
|
113 |
with gr.Blocks() as demo:
|
|
|
121 |
bubble_full_width=False,
|
122 |
show_copy_button=True,
|
123 |
show_share_button=True,
|
124 |
+
height=500,
|
125 |
)
|
126 |
|
127 |
with gr.Row():
|
|
|
133 |
)
|
134 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
135 |
|
136 |
+
chunk_name = gr.Radio(
|
137 |
choices=[
|
138 |
+
"md",
|
139 |
+
"txt",
|
|
|
140 |
],
|
141 |
+
value="md",
|
142 |
+
label='Chunking policy'
|
143 |
+
)
|
144 |
+
|
145 |
+
embed_name = gr.Radio(
|
146 |
+
choices=[
|
147 |
+
"text-embedding-ada-002",
|
148 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
149 |
+
],
|
150 |
+
value="text-embedding-ada-002",
|
151 |
+
label='Embedder'
|
152 |
)
|
153 |
|
154 |
cross_enc_name = gr.Radio(
|
|
|
161 |
label='Cross-Encoder'
|
162 |
)
|
163 |
|
164 |
+
llm_name = gr.Radio(
|
165 |
+
choices=[
|
166 |
+
"gpt-3.5-turbo",
|
167 |
+
"mistralai/Mistral-7B-Instruct-v0.1",
|
168 |
+
"GeneZC/MiniChat-3B",
|
169 |
+
],
|
170 |
+
value="gpt-3.5-turbo",
|
171 |
+
label='LLM'
|
172 |
+
)
|
173 |
+
|
174 |
# Examples
|
175 |
gr.Examples(examples, input_textbox)
|
176 |
|
|
|
181 |
txt_msg = txt_btn.click(
|
182 |
add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
|
183 |
).then(
|
184 |
+
bot, [chatbot, llm_name, cross_enc_name, chunk_name, embed_name], [chatbot, context_html]
|
185 |
)
|
186 |
|
187 |
# Turn it back on
|
|
|
189 |
|
190 |
# Turn off interactivity while generating if you hit enter
|
191 |
txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
|
192 |
+
bot, [chatbot, llm_name, cross_enc_name, chunk_name, embed_name], [chatbot, context_html])
|
193 |
|
194 |
# Turn it back on
|
195 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
|
gradio_app/backend/semantic_search.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import lancedb
|
3 |
-
|
4 |
-
from gradio_app.backend.embedders import EmbedderFactory
|
5 |
-
from settings import *
|
6 |
-
|
7 |
-
|
8 |
-
# Setting up the logging
|
9 |
-
logging.basicConfig(level=logging.INFO)
|
10 |
-
logger = logging.getLogger(__name__)
|
11 |
-
embedder = EmbedderFactory.get_embedder(EMBED_NAME)
|
12 |
-
|
13 |
-
db = lancedb.connect(LANCEDB_DIRECTORY)
|
14 |
-
table = db.open_table(LANCEDB_TABLE_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prep_scripts/lancedb_setup.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import shutil
|
|
|
2 |
|
3 |
import lancedb
|
4 |
import openai
|
@@ -18,7 +19,7 @@ with open('data/openaikey.txt') as f:
|
|
18 |
openai.api_key = OPENAI_KEY
|
19 |
|
20 |
|
21 |
-
shutil.rmtree(LANCEDB_DIRECTORY, ignore_errors=True)
|
22 |
db = lancedb.connect(LANCEDB_DIRECTORY)
|
23 |
batch_size = 32
|
24 |
|
@@ -27,7 +28,8 @@ schema = pa.schema([
|
|
27 |
pa.field(TEXT_COLUMN_NAME, pa.string()),
|
28 |
pa.field(DOCUMENT_PATH_COLUMN_NAME, pa.string()),
|
29 |
])
|
30 |
-
|
|
|
31 |
|
32 |
input_dir = Path(MARKDOWN_SOURCE_DIR)
|
33 |
files = list(input_dir.rglob("*"))
|
@@ -45,15 +47,21 @@ for file in files:
|
|
45 |
with open(file, encoding='utf-8') as f:
|
46 |
f = f.read()
|
47 |
f = remove_comments(f)
|
48 |
-
|
|
|
|
|
|
|
|
|
49 |
chunks.extend((chunk, os.path.abspath(file)) for chunk in f)
|
50 |
|
51 |
from matplotlib import pyplot as plt
|
52 |
plt.hist([len(c) for c, d in chunks], bins=100)
|
|
|
53 |
plt.show()
|
54 |
|
55 |
embedder = EmbedderFactory.get_embedder(EMBED_NAME)
|
56 |
|
|
|
57 |
for i in tqdm.tqdm(range(0, int(np.ceil(len(chunks) / batch_size)))):
|
58 |
texts, doc_paths = [], []
|
59 |
for text, doc_path in chunks[i * batch_size:(i + 1) * batch_size]:
|
@@ -61,14 +69,24 @@ for i in tqdm.tqdm(range(0, int(np.ceil(len(chunks) / batch_size)))):
|
|
61 |
texts.append(text)
|
62 |
doc_paths.append(doc_path)
|
63 |
|
|
|
64 |
encoded = embedder.embed(texts)
|
|
|
|
|
65 |
df = pd.DataFrame({
|
66 |
VECTOR_COLUMN_NAME: encoded,
|
67 |
TEXT_COLUMN_NAME: texts,
|
68 |
DOCUMENT_PATH_COLUMN_NAME: doc_paths,
|
69 |
})
|
70 |
|
|
|
71 |
tbl.add(df)
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
|
74 |
|
|
|
1 |
import shutil
|
2 |
+
import time
|
3 |
|
4 |
import lancedb
|
5 |
import openai
|
|
|
19 |
openai.api_key = OPENAI_KEY
|
20 |
|
21 |
|
22 |
+
# shutil.rmtree(LANCEDB_DIRECTORY, ignore_errors=True)
|
23 |
db = lancedb.connect(LANCEDB_DIRECTORY)
|
24 |
batch_size = 32
|
25 |
|
|
|
28 |
pa.field(TEXT_COLUMN_NAME, pa.string()),
|
29 |
pa.field(DOCUMENT_PATH_COLUMN_NAME, pa.string()),
|
30 |
])
|
31 |
+
table_name = f'{LANCEDB_TABLE_NAME}_{CHUNK_POLICY}_{EMBED_NAME}'
|
32 |
+
tbl = db.create_table(table_name, schema=schema, mode="overwrite")
|
33 |
|
34 |
input_dir = Path(MARKDOWN_SOURCE_DIR)
|
35 |
files = list(input_dir.rglob("*"))
|
|
|
47 |
with open(file, encoding='utf-8') as f:
|
48 |
f = f.read()
|
49 |
f = remove_comments(f)
|
50 |
+
if CHUNK_POLICY == "txt":
|
51 |
+
f = md2txt_then_split(f)
|
52 |
+
else:
|
53 |
+
assert CHUNK_POLICY == "md"
|
54 |
+
f = split_markdown(f)
|
55 |
chunks.extend((chunk, os.path.abspath(file)) for chunk in f)
|
56 |
|
57 |
from matplotlib import pyplot as plt
|
58 |
plt.hist([len(c) for c, d in chunks], bins=100)
|
59 |
+
plt.title(table_name)
|
60 |
plt.show()
|
61 |
|
62 |
embedder = EmbedderFactory.get_embedder(EMBED_NAME)
|
63 |
|
64 |
+
time_embed, time_ingest = [], []
|
65 |
for i in tqdm.tqdm(range(0, int(np.ceil(len(chunks) / batch_size)))):
|
66 |
texts, doc_paths = [], []
|
67 |
for text, doc_path in chunks[i * batch_size:(i + 1) * batch_size]:
|
|
|
69 |
texts.append(text)
|
70 |
doc_paths.append(doc_path)
|
71 |
|
72 |
+
t = time.perf_counter()
|
73 |
encoded = embedder.embed(texts)
|
74 |
+
time_embed.append(time.perf_counter() - t)
|
75 |
+
|
76 |
df = pd.DataFrame({
|
77 |
VECTOR_COLUMN_NAME: encoded,
|
78 |
TEXT_COLUMN_NAME: texts,
|
79 |
DOCUMENT_PATH_COLUMN_NAME: doc_paths,
|
80 |
})
|
81 |
|
82 |
+
t = time.perf_counter()
|
83 |
tbl.add(df)
|
84 |
+
time_ingest.append(time.perf_counter() - t)
|
85 |
+
|
86 |
+
|
87 |
+
time_embed = sum(time_embed)
|
88 |
+
time_ingest = sum(time_ingest)
|
89 |
+
print(f'Embedding: {time_embed}, Ingesting: {time_ingest}')
|
90 |
|
91 |
|
92 |
|
prep_scripts/markdown_to_text.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
import os
|
2 |
import re
|
3 |
|
|
|
|
|
|
|
4 |
from settings import *
|
5 |
|
6 |
|
@@ -95,3 +98,32 @@ def split_markdown(md):
|
|
95 |
|
96 |
return res
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import re
|
3 |
|
4 |
+
from bs4 import BeautifulSoup
|
5 |
+
from markdown import markdown
|
6 |
+
|
7 |
from settings import *
|
8 |
|
9 |
|
|
|
98 |
|
99 |
return res
|
100 |
|
101 |
+
|
102 |
+
def markdown_to_text(markdown_string):
|
103 |
+
""" Converts a markdown string to plaintext """
|
104 |
+
|
105 |
+
# md -> html -> text since BeautifulSoup can extract text cleanly
|
106 |
+
html = markdown(markdown_string)
|
107 |
+
|
108 |
+
html = re.sub(r'<!--((.|\n)*)-->', '', html)
|
109 |
+
html = re.sub('<code>bash', '<code>', html)
|
110 |
+
|
111 |
+
# extract text
|
112 |
+
soup = BeautifulSoup(html, "html.parser")
|
113 |
+
text = ''.join(soup.findAll(string=True))
|
114 |
+
|
115 |
+
text = re.sub('```(py|diff|python)', '', text)
|
116 |
+
text = re.sub('```\n', '\n', text)
|
117 |
+
text = re.sub('- .*', '', text)
|
118 |
+
text = text.replace('...', '')
|
119 |
+
text = re.sub('\n(\n)+', '\n\n', text)
|
120 |
+
|
121 |
+
return text
|
122 |
+
|
123 |
+
|
124 |
+
def md2txt_then_split(md):
|
125 |
+
txt = markdown_to_text(md)
|
126 |
+
return split_content(txt)
|
127 |
+
|
128 |
+
|
129 |
+
|
settings.py
CHANGED
@@ -5,8 +5,11 @@ VECTOR_COLUMN_NAME = "embedding"
|
|
5 |
TEXT_COLUMN_NAME = "text"
|
6 |
DOCUMENT_PATH_COLUMN_NAME = "document_path"
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
10 |
|
11 |
TOP_K_RANK = 50
|
12 |
TOP_K_RERANK = 5
|
@@ -28,5 +31,5 @@ context_lengths = {
|
|
28 |
"gpt-3.5-turbo": 4096,
|
29 |
"sentence-transformers/all-MiniLM-L6-v2": 128,
|
30 |
"thenlper/gte-large": 512,
|
31 |
-
"text-embedding-ada-002": 8191,
|
32 |
}
|
|
|
5 |
TEXT_COLUMN_NAME = "text"
|
6 |
DOCUMENT_PATH_COLUMN_NAME = "document_path"
|
7 |
|
8 |
+
CHUNK_POLICY = "md"
|
9 |
+
# CHUNK_POLICY = "txt"
|
10 |
+
|
11 |
+
EMBED_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
12 |
+
# EMBED_NAME = "text-embedding-ada-002"
|
13 |
|
14 |
TOP_K_RANK = 50
|
15 |
TOP_K_RERANK = 5
|
|
|
31 |
"gpt-3.5-turbo": 4096,
|
32 |
"sentence-transformers/all-MiniLM-L6-v2": 128,
|
33 |
"thenlper/gte-large": 512,
|
34 |
+
"text-embedding-ada-002": 1000, # actual context length is 8191, but it's too much
|
35 |
}
|