xangma
commited on
Commit
•
0f7b25d
1
Parent(s):
80b4f00
latest
Browse files- .gitignore +1 -1
- app.py +115 -66
- chain.py +14 -3
- ingest.py +147 -78
- requirements.txt +2 -1
.gitignore
CHANGED
@@ -4,4 +4,4 @@ downloaded/*
|
|
4 |
__pycache__/*
|
5 |
launch.json
|
6 |
.DS_Store
|
7 |
-
devcode
|
|
|
4 |
__pycache__/*
|
5 |
launch.json
|
6 |
.DS_Store
|
7 |
+
*devcode*
|
app.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
# chat-pykg/app.py
|
2 |
import datetime
|
|
|
3 |
import os
|
4 |
import random
|
5 |
import shutil
|
6 |
import string
|
|
|
7 |
|
8 |
import chromadb
|
9 |
import gradio as gr
|
@@ -13,10 +15,26 @@ from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
|
|
13 |
from langchain.vectorstores import Chroma
|
14 |
|
15 |
from chain import get_new_chain1
|
16 |
-
from ingest import ingest_docs
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def randomword(length):
|
22 |
letters = string.ascii_lowercase
|
@@ -25,23 +43,27 @@ def randomword(length):
|
|
25 |
def change_tab():
|
26 |
return gr.Tabs.update(selected=0)
|
27 |
|
28 |
-
def merge_collections(collection_load_names, vs_state):
|
|
|
|
|
|
|
|
|
29 |
merged_documents = []
|
30 |
merged_embeddings = []
|
31 |
for collection_name in collection_load_names:
|
32 |
chroma_obj_get = chromadb.Client(Settings(
|
33 |
chroma_db_impl="duckdb+parquet",
|
34 |
-
persist_directory=
|
35 |
anonymized_telemetry = True
|
36 |
))
|
37 |
if collection_name == '':
|
38 |
continue
|
39 |
-
collection_obj = chroma_obj_get.get_collection(collection_name, embedding_function=
|
40 |
collection = collection_obj.get(include=["metadatas", "documents", "embeddings"])
|
41 |
for i in range(len(collection['documents'])):
|
42 |
merged_documents.append(Document(page_content=collection['documents'][i], metadata = collection['metadatas'][i]))
|
43 |
merged_embeddings.append(collection['embeddings'][i])
|
44 |
-
merged_vectorstore = Chroma(collection_name="temp", embedding_function=
|
45 |
merged_vectorstore.add_documents(documents=merged_documents, embeddings=merged_embeddings)
|
46 |
return merged_vectorstore
|
47 |
|
@@ -64,28 +86,38 @@ def set_chain_up(openai_api_key, model_selector, k_textbox, max_tokens_textbox,
|
|
64 |
else:
|
65 |
return agent
|
66 |
|
67 |
-
def delete_collection(all_collections_state, collections_viewer):
|
|
|
|
|
|
|
68 |
client = chromadb.Client(Settings(
|
69 |
chroma_db_impl="duckdb+parquet",
|
70 |
-
persist_directory=
|
71 |
))
|
72 |
for collection in collections_viewer:
|
73 |
try:
|
74 |
client.delete_collection(collection)
|
75 |
all_collections_state.remove(collection)
|
76 |
collections_viewer.remove(collection)
|
77 |
-
except:
|
78 |
-
|
|
|
79 |
return all_collections_state, collections_viewer
|
80 |
|
81 |
-
def delete_all_collections(all_collections_state):
|
82 |
-
|
|
|
|
|
|
|
83 |
return []
|
84 |
|
85 |
-
def list_collections(all_collections_state):
|
|
|
|
|
|
|
86 |
client = chromadb.Client(Settings(
|
87 |
chroma_db_impl="duckdb+parquet",
|
88 |
-
persist_directory=
|
89 |
))
|
90 |
collection_names = [[c.name][0] for c in client.list_collections()]
|
91 |
return collection_names
|
@@ -94,9 +126,12 @@ def update_checkboxgroup(all_collections_state):
|
|
94 |
new_options = [i for i in all_collections_state]
|
95 |
return gr.CheckboxGroup.update(choices=new_options)
|
96 |
|
97 |
-
def
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
100 |
|
101 |
def clear_chat(chatbot, history):
|
102 |
return [], []
|
@@ -110,12 +145,6 @@ def chat(inp, history, agent):
|
|
110 |
if agent == 'no_vectorstore':
|
111 |
history.append((inp, "Please ingest some package docs to use"))
|
112 |
return history, history
|
113 |
-
if agent == 'all_collections' and inp != []:
|
114 |
-
history.append(("", f"Current vectorstores: {inp}"))
|
115 |
-
return history, history
|
116 |
-
if agent == 'all_vs_deleted':
|
117 |
-
history.append((inp, "All vectorstores deleted"))
|
118 |
-
return history, history
|
119 |
else:
|
120 |
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
|
121 |
print("inp: " + inp)
|
@@ -126,10 +155,10 @@ def chat(inp, history, agent):
|
|
126 |
print(history)
|
127 |
return history, history
|
128 |
|
129 |
-
block = gr.Blocks(css=".gradio-container {background-color: system;}")
|
130 |
|
131 |
with block:
|
132 |
-
gr.Markdown("<
|
133 |
with gr.Tabs() as tabs:
|
134 |
with gr.TabItem("Chat", id=0):
|
135 |
with gr.Row():
|
@@ -139,22 +168,26 @@ with block:
|
|
139 |
lines=1,
|
140 |
type="password",
|
141 |
)
|
142 |
-
model_selector = gr.Dropdown(
|
143 |
-
|
|
|
|
|
|
|
|
|
144 |
k_textbox = gr.Textbox(
|
145 |
placeholder="k: Number of search results to consider",
|
146 |
label="Search Results k:",
|
147 |
show_label=True,
|
148 |
lines=1,
|
|
|
149 |
)
|
150 |
-
k_textbox.value = "20"
|
151 |
max_tokens_textbox = gr.Textbox(
|
152 |
placeholder="max_tokens: Maximum number of tokens to generate",
|
153 |
label="max_tokens",
|
154 |
show_label=True,
|
155 |
lines=1,
|
|
|
156 |
)
|
157 |
-
max_tokens_textbox.value="1000"
|
158 |
chatbot = gr.Chatbot()
|
159 |
with gr.Row():
|
160 |
clear_btn = gr.Button("Clear Chat", variant="secondary").style(full_width=False)
|
@@ -167,7 +200,7 @@ with block:
|
|
167 |
gr.Examples(
|
168 |
examples=[
|
169 |
"What does this code do?",
|
170 |
-
"
|
171 |
],
|
172 |
inputs=message,
|
173 |
)
|
@@ -178,35 +211,41 @@ with block:
|
|
178 |
The source code is split/broken down into many document objects using langchain's pythoncodetextsplitter, which apparently tries to keep whole functions etc. together. This means that each file in the source code is split into many smaller documents, and the k value is the number of documents to consider when searching for the most similar documents to the question. With gpt-3.5-turbo, k=10 seems to work well, but with gpt-4, k=20 seems to work better.
|
179 |
The model's memory is set to 5 messages, but I haven't tested with gpt-3.5-turbo yet to see if it works well. It seems to work well with gpt-4."""
|
180 |
)
|
181 |
-
with gr.TabItem("
|
182 |
with gr.Row():
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
)
|
199 |
-
|
200 |
-
|
201 |
-
with gr.Row():
|
202 |
-
gr.HTML('<center>See the <a href=https://python.langchain.com/en/latest/reference/modules/text_splitter.html>Langchain textsplitter docs</a></center>')
|
203 |
-
with gr.Column(scale=2):
|
204 |
-
collections_viewer = gr.CheckboxGroup(choices=[], label='Collections_viewer', show_label=True)
|
205 |
-
with gr.Column(scale=1):
|
206 |
-
load_collections_button = gr.Button(value="Load collection(s) to chat!", variant="secondary").style(full_width=False)
|
207 |
-
get_all_collection_names_button = gr.Button(value="List all saved collections", variant="secondary").style(full_width=False)
|
208 |
-
delete_collections_button = gr.Button(value="Delete selected saved collections", variant="secondary").style(full_width=False)
|
209 |
-
delete_all_collections_button = gr.Button(value="Delete all saved collections", variant="secondary").style(full_width=False)
|
210 |
gr.HTML(
|
211 |
"<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain 🦜️🔗</a></center>"
|
212 |
)
|
@@ -216,25 +255,35 @@ with block:
|
|
216 |
vs_state = gr.State()
|
217 |
all_collections_state = gr.State()
|
218 |
chat_state = gr.State()
|
|
|
|
|
219 |
|
220 |
submit.click(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
|
221 |
message.submit(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
|
222 |
|
223 |
-
load_collections_button.click(merge_collections, inputs=[collections_viewer, vs_state], outputs=[vs_state])#.then(change_tab, None, tabs) #.then(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state])
|
224 |
-
make_collections_button.click(ingest_docs, inputs=[all_collections_state, all_collections_to_get, chunk_size_textbox, chunk_overlap_textbox], outputs=[all_collections_state], show_progress=True).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
225 |
-
delete_collections_button.click(delete_collection, inputs=[all_collections_state, collections_viewer], outputs=[all_collections_state, collections_viewer]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
226 |
-
delete_all_collections_button.click(delete_all_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
227 |
-
get_all_collection_names_button.click(list_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
228 |
clear_btn.click(clear_chat, inputs = [chatbot, history_state], outputs = [chatbot, history_state])
|
229 |
# Whenever chain parameters change, destroy the agent.
|
230 |
-
input_list = [openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox]
|
231 |
output_list = [agent_state]
|
232 |
for input_item in input_list:
|
233 |
input_item.change(
|
234 |
-
|
235 |
inputs=output_list,
|
236 |
outputs=output_list,
|
237 |
)
|
238 |
-
all_collections_state.value = list_collections(all_collections_state)
|
239 |
block.load(update_checkboxgroup, inputs = all_collections_state, outputs = collections_viewer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
block.launch(debug=True)
|
|
|
1 |
# chat-pykg/app.py
|
2 |
import datetime
|
3 |
+
import logging
|
4 |
import os
|
5 |
import random
|
6 |
import shutil
|
7 |
import string
|
8 |
+
import sys
|
9 |
|
10 |
import chromadb
|
11 |
import gradio as gr
|
|
|
15 |
from langchain.vectorstores import Chroma
|
16 |
|
17 |
from chain import get_new_chain1
|
18 |
+
from ingest import embedding_chooser, ingest_docs
|
19 |
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
20 |
+
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
21 |
|
22 |
+
class LogTextboxHandler(logging.StreamHandler):
|
23 |
+
def __init__(self, textbox):
|
24 |
+
super().__init__()
|
25 |
+
self.textbox = textbox
|
26 |
+
|
27 |
+
def emit(self, record):
|
28 |
+
log_entry = self.format(record)
|
29 |
+
self.textbox.value += f"{log_entry}\n"
|
30 |
+
|
31 |
+
def toggle_log_textbox(log_textbox_state):
|
32 |
+
toggle_visibility = not log_textbox_state
|
33 |
+
log_textbox_state = not log_textbox_state
|
34 |
+
return log_textbox_state,gr.update(visible=toggle_visibility)
|
35 |
+
|
36 |
+
def update_textbox(full_log):
|
37 |
+
return gr.update(value=full_log)
|
38 |
|
39 |
def randomword(length):
|
40 |
letters = string.ascii_lowercase
|
|
|
43 |
def change_tab():
|
44 |
return gr.Tabs.update(selected=0)
|
45 |
|
46 |
+
def merge_collections(collection_load_names, vs_state, embedding_radio):
|
47 |
+
if type(embedding_radio) == gr.Radio:
|
48 |
+
embedding_radio = embedding_radio.value
|
49 |
+
persist_directory = os.path.join(".persisted_data", embedding_radio.replace(' ','_'))
|
50 |
+
embedding_function = embedding_chooser(embedding_radio)
|
51 |
merged_documents = []
|
52 |
merged_embeddings = []
|
53 |
for collection_name in collection_load_names:
|
54 |
chroma_obj_get = chromadb.Client(Settings(
|
55 |
chroma_db_impl="duckdb+parquet",
|
56 |
+
persist_directory=persist_directory,
|
57 |
anonymized_telemetry = True
|
58 |
))
|
59 |
if collection_name == '':
|
60 |
continue
|
61 |
+
collection_obj = chroma_obj_get.get_collection(collection_name, embedding_function=embedding_function)
|
62 |
collection = collection_obj.get(include=["metadatas", "documents", "embeddings"])
|
63 |
for i in range(len(collection['documents'])):
|
64 |
merged_documents.append(Document(page_content=collection['documents'][i], metadata = collection['metadatas'][i]))
|
65 |
merged_embeddings.append(collection['embeddings'][i])
|
66 |
+
merged_vectorstore = Chroma(collection_name="temp", embedding_function=embedding_function)
|
67 |
merged_vectorstore.add_documents(documents=merged_documents, embeddings=merged_embeddings)
|
68 |
return merged_vectorstore
|
69 |
|
|
|
86 |
else:
|
87 |
return agent
|
88 |
|
89 |
+
def delete_collection(all_collections_state, collections_viewer, embedding_radio):
|
90 |
+
if type(embedding_radio) == gr.Radio:
|
91 |
+
embedding_radio = embedding_radio.value
|
92 |
+
persist_directory = os.path.join(".persisted_data", embedding_radio.replace(' ','_'))
|
93 |
client = chromadb.Client(Settings(
|
94 |
chroma_db_impl="duckdb+parquet",
|
95 |
+
persist_directory=persist_directory # Optional, defaults to .chromadb/ in the current directory
|
96 |
))
|
97 |
for collection in collections_viewer:
|
98 |
try:
|
99 |
client.delete_collection(collection)
|
100 |
all_collections_state.remove(collection)
|
101 |
collections_viewer.remove(collection)
|
102 |
+
except Exception as e:
|
103 |
+
logging.error(e)
|
104 |
+
|
105 |
return all_collections_state, collections_viewer
|
106 |
|
107 |
+
def delete_all_collections(all_collections_state, embedding_radio):
|
108 |
+
if type(embedding_radio) == gr.Radio:
|
109 |
+
embedding_radio = embedding_radio.value
|
110 |
+
persist_directory = os.path.join(".persisted_data", embedding_radio.replace(' ','_'))
|
111 |
+
shutil.rmtree(persist_directory)
|
112 |
return []
|
113 |
|
114 |
+
def list_collections(all_collections_state, embedding_radio):
|
115 |
+
if type(embedding_radio) == gr.Radio:
|
116 |
+
embedding_radio = embedding_radio.value
|
117 |
+
persist_directory = os.path.join(".persisted_data", embedding_radio.replace(' ','_'))
|
118 |
client = chromadb.Client(Settings(
|
119 |
chroma_db_impl="duckdb+parquet",
|
120 |
+
persist_directory=persist_directory # Optional, defaults to .chromadb/ in the current directory
|
121 |
))
|
122 |
collection_names = [[c.name][0] for c in client.list_collections()]
|
123 |
return collection_names
|
|
|
126 |
new_options = [i for i in all_collections_state]
|
127 |
return gr.CheckboxGroup.update(choices=new_options)
|
128 |
|
129 |
+
def update_log_textbox(full_log):
|
130 |
+
return gr.Textbox.update(value=full_log)
|
131 |
+
|
132 |
+
def destroy_state(state):
|
133 |
+
state = None
|
134 |
+
return state
|
135 |
|
136 |
def clear_chat(chatbot, history):
|
137 |
return [], []
|
|
|
145 |
if agent == 'no_vectorstore':
|
146 |
history.append((inp, "Please ingest some package docs to use"))
|
147 |
return history, history
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
else:
|
149 |
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
|
150 |
print("inp: " + inp)
|
|
|
155 |
print(history)
|
156 |
return history, history
|
157 |
|
158 |
+
block = gr.Blocks(title = "chat-pykg", analytics_enabled = False, css=".gradio-container {background-color: system;}")
|
159 |
|
160 |
with block:
|
161 |
+
gr.Markdown("<h1><center>chat-pykg</center></h1>")
|
162 |
with gr.Tabs() as tabs:
|
163 |
with gr.TabItem("Chat", id=0):
|
164 |
with gr.Row():
|
|
|
168 |
lines=1,
|
169 |
type="password",
|
170 |
)
|
171 |
+
model_selector = gr.Dropdown(
|
172 |
+
choices=["gpt-3.5-turbo", "gpt-4", "other"],
|
173 |
+
label="Model",
|
174 |
+
show_label=True,
|
175 |
+
value = "gpt-3.5-turbo"
|
176 |
+
)
|
177 |
k_textbox = gr.Textbox(
|
178 |
placeholder="k: Number of search results to consider",
|
179 |
label="Search Results k:",
|
180 |
show_label=True,
|
181 |
lines=1,
|
182 |
+
value="20",
|
183 |
)
|
|
|
184 |
max_tokens_textbox = gr.Textbox(
|
185 |
placeholder="max_tokens: Maximum number of tokens to generate",
|
186 |
label="max_tokens",
|
187 |
show_label=True,
|
188 |
lines=1,
|
189 |
+
value="1000",
|
190 |
)
|
|
|
191 |
chatbot = gr.Chatbot()
|
192 |
with gr.Row():
|
193 |
clear_btn = gr.Button("Clear Chat", variant="secondary").style(full_width=False)
|
|
|
200 |
gr.Examples(
|
201 |
examples=[
|
202 |
"What does this code do?",
|
203 |
+
"I want to change the chat-pykg app to have a log viewer, where the user can see what python is doing in the background. How could I do that?",
|
204 |
],
|
205 |
inputs=message,
|
206 |
)
|
|
|
211 |
The source code is split/broken down into many document objects using langchain's pythoncodetextsplitter, which apparently tries to keep whole functions etc. together. This means that each file in the source code is split into many smaller documents, and the k value is the number of documents to consider when searching for the most similar documents to the question. With gpt-3.5-turbo, k=10 seems to work well, but with gpt-4, k=20 seems to work better.
|
212 |
The model's memory is set to 5 messages, but I haven't tested with gpt-3.5-turbo yet to see if it works well. It seems to work well with gpt-4."""
|
213 |
)
|
214 |
+
with gr.TabItem("Repository Selector/Manager", id=1):
|
215 |
with gr.Row():
|
216 |
+
collections_viewer = gr.CheckboxGroup(choices=[], label='Repository Viewer', show_label=True)
|
217 |
+
with gr.Row():
|
218 |
+
load_collections_button = gr.Button(value="Load respositories to chat!", variant="secondary")#.style(full_width=False)
|
219 |
+
get_all_collection_names_button = gr.Button(value="List all saved repositories", variant="secondary")#.style(full_width=False)
|
220 |
+
delete_collections_button = gr.Button(value="Delete selected saved repositories", variant="secondary")#.style(full_width=False)
|
221 |
+
delete_all_collections_button = gr.Button(value="Delete all saved repositories", variant="secondary")#.style(full_width=False)
|
222 |
+
with gr.TabItem("Get New Repositories", id=2):
|
223 |
+
with gr.Row():
|
224 |
+
all_collections_to_get = gr.List(headers=['Repository URL', 'Folders'], row_count=3, col_count=2, label='Repositories to get', show_label=True, interactive=True, max_cols=2, max_rows=3)
|
225 |
+
make_collections_button = gr.Button(value="Get new repositories", variant="secondary").style(full_width=False)
|
226 |
+
with gr.Row():
|
227 |
+
chunk_size_textbox = gr.Textbox(
|
228 |
+
placeholder="Chunk size",
|
229 |
+
label="Chunk size",
|
230 |
+
show_label=True,
|
231 |
+
lines=1,
|
232 |
+
value="1000"
|
233 |
+
)
|
234 |
+
chunk_overlap_textbox = gr.Textbox(
|
235 |
+
placeholder="Chunk overlap",
|
236 |
+
label="Chunk overlap",
|
237 |
+
show_label=True,
|
238 |
+
lines=1,
|
239 |
+
value="0"
|
240 |
+
)
|
241 |
+
embedding_radio = gr.Radio(
|
242 |
+
choices = ['Sentence Transformers', 'OpenAI'],
|
243 |
+
label="Embedding Options",
|
244 |
+
show_label=True,
|
245 |
+
value='Sentence Transformers'
|
246 |
)
|
247 |
+
with gr.Row():
|
248 |
+
gr.HTML('<center>See the <a href=https://python.langchain.com/en/latest/reference/modules/text_splitter.html>Langchain textsplitter docs</a></center>')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
gr.HTML(
|
250 |
"<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain 🦜️🔗</a></center>"
|
251 |
)
|
|
|
255 |
vs_state = gr.State()
|
256 |
all_collections_state = gr.State()
|
257 |
chat_state = gr.State()
|
258 |
+
debug_state = gr.State()
|
259 |
+
debug_state.value = False
|
260 |
|
261 |
submit.click(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
|
262 |
message.submit(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
|
263 |
|
264 |
+
load_collections_button.click(merge_collections, inputs=[collections_viewer, vs_state, embedding_radio], outputs=[vs_state])#.then(change_tab, None, tabs) #.then(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state])
|
265 |
+
make_collections_button.click(ingest_docs, inputs=[all_collections_state, all_collections_to_get, chunk_size_textbox, chunk_overlap_textbox, embedding_radio, debug_state], outputs=[all_collections_state, all_collections_to_get], show_progress=True).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
266 |
+
delete_collections_button.click(delete_collection, inputs=[all_collections_state, collections_viewer, embedding_radio], outputs=[all_collections_state, collections_viewer]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
267 |
+
delete_all_collections_button.click(delete_all_collections, inputs=[all_collections_state, embedding_radio], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
268 |
+
get_all_collection_names_button.click(list_collections, inputs=[all_collections_state, embedding_radio], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
269 |
clear_btn.click(clear_chat, inputs = [chatbot, history_state], outputs = [chatbot, history_state])
|
270 |
# Whenever chain parameters change, destroy the agent.
|
271 |
+
input_list = [openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, embedding_radio]
|
272 |
output_list = [agent_state]
|
273 |
for input_item in input_list:
|
274 |
input_item.change(
|
275 |
+
destroy_state,
|
276 |
inputs=output_list,
|
277 |
outputs=output_list,
|
278 |
)
|
279 |
+
all_collections_state.value = list_collections(all_collections_state, embedding_radio)
|
280 |
block.load(update_checkboxgroup, inputs = all_collections_state, outputs = collections_viewer)
|
281 |
+
log_textbox_handler = LogTextboxHandler(gr.TextArea(interactive=False, placeholder="Logs will appear here...", visible=False))
|
282 |
+
log_textbox = log_textbox_handler.textbox
|
283 |
+
logging.getLogger().addHandler(log_textbox_handler)
|
284 |
+
log_textbox_visibility_state = gr.State()
|
285 |
+
log_textbox_visibility_state.value = False
|
286 |
+
log_toggle_button = gr.Button("Toggle Log", variant="secondary")
|
287 |
+
log_toggle_button.click(toggle_log_textbox, inputs=[log_textbox_visibility_state], outputs=[log_textbox_visibility_state,log_textbox])
|
288 |
+
block.queue(concurrency_count=40)
|
289 |
block.launch(debug=True)
|
chain.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# chat-pykg/chain.py
|
2 |
-
|
|
|
3 |
from langchain.chains.base import Chain
|
4 |
from langchain import HuggingFaceHub
|
5 |
from langchain.chains.question_answering import load_qa_chain
|
@@ -10,12 +11,21 @@ from langchain.chains.llm import LLMChain
|
|
10 |
from langchain.callbacks.base import CallbackManager
|
11 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
12 |
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
15 |
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
16 |
|
17 |
def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -> Chain:
|
18 |
|
|
|
|
|
|
|
|
|
19 |
template = """You are called chat-pykg and are an AI assistant coded in python using langchain and gradio. You are very helpful for answering questions about various open source libraries.
|
20 |
You are given the following extracted parts of code and a question. Provide a conversational answer to the question.
|
21 |
Do NOT make up any hyperlinks that are not in the code.
|
@@ -34,8 +44,8 @@ def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -
|
|
34 |
llm = HuggingFaceHub(repo_id="chavinlo/gpt4-x-alpaca")#, model_kwargs={"temperature":0, "max_length":64})
|
35 |
doc_chain_llm = HuggingFaceHub(repo_id="chavinlo/gpt4-x-alpaca")
|
36 |
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
|
37 |
-
doc_chain = load_qa_chain(doc_chain_llm, chain_type="stuff", prompt=QA_PROMPT)
|
38 |
-
|
39 |
# memory = ConversationKGMemory(llm=llm, input_key="question", output_key="answer")
|
40 |
memory = ConversationBufferWindowMemory(input_key="question", output_key="answer", k=5)
|
41 |
retriever = vectorstore.as_retriever(search_type="similarity")
|
@@ -45,5 +55,6 @@ def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -
|
|
45 |
retriever.search_kwargs = {"k": 10}
|
46 |
qa = ConversationalRetrievalChain(
|
47 |
retriever=retriever, memory=memory, combine_docs_chain=doc_chain, question_generator=question_generator)
|
|
|
48 |
|
49 |
return qa
|
|
|
1 |
# chat-pykg/chain.py
|
2 |
+
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar
|
3 |
+
from pydantic import Extra, Field, root_validator
|
4 |
from langchain.chains.base import Chain
|
5 |
from langchain import HuggingFaceHub
|
6 |
from langchain.chains.question_answering import load_qa_chain
|
|
|
11 |
from langchain.callbacks.base import CallbackManager
|
12 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
13 |
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT
|
14 |
+
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
15 |
+
from langchain.chains.llm import LLMChain
|
16 |
+
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
|
17 |
+
from langchain.prompts.prompt import PromptTemplate
|
18 |
+
|
19 |
|
20 |
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
21 |
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
22 |
|
23 |
def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -> Chain:
|
24 |
|
25 |
+
# def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
26 |
+
# docs = self.retriever.vectorstore._collection.query(question, n_results=self.retriever.search_kwargs["k"], where = {"source":{"$contains":"search_string"}}, where_document = {"$contains":"search_string"})
|
27 |
+
# return self._reduce_tokens_below_limit(docs)
|
28 |
+
|
29 |
template = """You are called chat-pykg and are an AI assistant coded in python using langchain and gradio. You are very helpful for answering questions about various open source libraries.
|
30 |
You are given the following extracted parts of code and a question. Provide a conversational answer to the question.
|
31 |
Do NOT make up any hyperlinks that are not in the code.
|
|
|
44 |
llm = HuggingFaceHub(repo_id="chavinlo/gpt4-x-alpaca")#, model_kwargs={"temperature":0, "max_length":64})
|
45 |
doc_chain_llm = HuggingFaceHub(repo_id="chavinlo/gpt4-x-alpaca")
|
46 |
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
|
47 |
+
doc_chain = load_qa_chain(doc_chain_llm, chain_type="stuff", prompt=QA_PROMPT)#, document_prompt = PromptTemplate(input_variables=["source", "page_content"], template="{source}\n{page_content}"))
|
48 |
+
|
49 |
# memory = ConversationKGMemory(llm=llm, input_key="question", output_key="answer")
|
50 |
memory = ConversationBufferWindowMemory(input_key="question", output_key="answer", k=5)
|
51 |
retriever = vectorstore.as_retriever(search_type="similarity")
|
|
|
55 |
retriever.search_kwargs = {"k": 10}
|
56 |
qa = ConversationalRetrievalChain(
|
57 |
retriever=retriever, memory=memory, combine_docs_chain=doc_chain, question_generator=question_generator)
|
58 |
+
# qa._get_docs = _get_docs.__get__(qa, ConversationalRetrievalChain)
|
59 |
|
60 |
return qa
|
ingest.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
# chat-pykg/ingest.py
|
2 |
import tempfile
|
|
|
3 |
from langchain.document_loaders import SitemapLoader, ReadTheDocsLoader, TextLoader
|
4 |
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
|
5 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter, PythonCodeTextSplitter, MarkdownTextSplitter
|
6 |
from langchain.vectorstores.faiss import FAISS
|
7 |
import os
|
8 |
from langchain.vectorstores import Chroma
|
@@ -10,8 +11,12 @@ import shutil
|
|
10 |
from pathlib import Path
|
11 |
import subprocess
|
12 |
import chromadb
|
13 |
-
|
14 |
-
import
|
|
|
|
|
|
|
|
|
15 |
|
16 |
# class CachedChroma(Chroma, ABC):
|
17 |
# """
|
@@ -65,6 +70,62 @@ import chromadb.utils.embedding_functions as ef
|
|
65 |
# )
|
66 |
# raise ValueError("Either documents or collection_name must be specified.")
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
def get_text(content):
|
69 |
relevant_part = content.find("div", {"class": "markdown"})
|
70 |
if relevant_part is not None:
|
@@ -72,83 +133,96 @@ def get_text(content):
|
|
72 |
else:
|
73 |
return ""
|
74 |
|
75 |
-
def ingest_docs(all_collections_state, urls, chunk_size, chunk_overlap):
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
all_docs = []
|
78 |
-
folders=[]
|
79 |
-
documents = []
|
80 |
shutil.rmtree('downloaded/', ignore_errors=True)
|
81 |
known_exts = ["py", "md"]
|
|
|
82 |
py_splitter = PythonCodeTextSplitter(chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap))
|
83 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap))
|
84 |
md_splitter = MarkdownTextSplitter(chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap))
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
paths_by_ext = {}
|
87 |
docs_by_ext = {}
|
88 |
for ext in known_exts + ["other"]:
|
89 |
docs_by_ext[ext] = []
|
90 |
paths_by_ext[ext] = []
|
91 |
-
|
92 |
-
if
|
93 |
-
|
94 |
-
|
95 |
-
if len(url) > 1:
|
96 |
-
folder = url.split('.')[1]
|
97 |
-
else:
|
98 |
-
folder = '.'
|
99 |
else:
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
if url[0] == '/':
|
104 |
-
url = url[1:]
|
105 |
-
org = url.split('/')[0]
|
106 |
-
repo = url.split('/')[1]
|
107 |
repo_url = f"https://github.com/{org}/{repo}.git"
|
108 |
-
|
109 |
-
folder = '/'.join(url.split('/')[2:])
|
110 |
-
if folder[-1] == '/':
|
111 |
-
folder = folder[:-1]
|
112 |
-
if folder:
|
113 |
-
with tempfile.TemporaryDirectory() as temp_dir:
|
114 |
-
temp_path = Path(temp_dir)
|
115 |
-
|
116 |
-
# Initialize the Git repository
|
117 |
-
subprocess.run(["git", "init"], cwd=temp_path)
|
118 |
-
|
119 |
-
# Add the remote repository
|
120 |
-
subprocess.run(["git", "remote", "add", "-f", "origin", repo_url], cwd=temp_path)
|
121 |
-
|
122 |
-
# Enable sparse-checkout
|
123 |
-
subprocess.run(["git", "config", "core.sparseCheckout", "true"], cwd=temp_path)
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
for file in files:
|
140 |
-
|
141 |
-
rel_file_path = os.path.relpath(file_path, local_repo_path_1)
|
142 |
-
ext = rel_file_path.split('.')[-1]
|
143 |
-
if rel_file_path.startswith('.'):
|
144 |
continue
|
|
|
|
|
145 |
try:
|
146 |
-
if
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
149 |
else:
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
152 |
except Exception as e:
|
153 |
continue
|
154 |
for ext in docs_by_ext.keys():
|
@@ -157,25 +231,20 @@ def ingest_docs(all_collections_state, urls, chunk_size, chunk_overlap):
|
|
157 |
if ext == "md":
|
158 |
documents += md_splitter.split_documents(docs_by_ext[ext])
|
159 |
# else:
|
160 |
-
# documents += text_splitter.split_documents(docs_by_ext[ext]
|
161 |
all_docs += documents
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
167 |
collection.persist()
|
168 |
-
all_collections_state.append(
|
169 |
-
|
170 |
-
|
171 |
-
# merged_vectorstore = Chroma.from_documents(persist_directory=".persisted_data", documents=documents, embedding=embeddings, collection_name='merged_collections')
|
172 |
-
# #vectorstore = FAISS.from_documents(documents, embeddings)
|
173 |
-
# # # Save vectorstore
|
174 |
-
# # with open("vectorstore.pkl", "wb") as f:
|
175 |
-
# # pickle.dump(vectorstore. , f)
|
176 |
-
|
177 |
-
# return merged_vectorstore
|
178 |
-
|
179 |
|
180 |
if __name__ == "__main__":
|
181 |
ingest_docs()
|
|
|
1 |
# chat-pykg/ingest.py
|
2 |
import tempfile
|
3 |
+
import gradio as gr
|
4 |
from langchain.document_loaders import SitemapLoader, ReadTheDocsLoader, TextLoader
|
5 |
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
|
6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter, PythonCodeTextSplitter, MarkdownTextSplitter, TextSplitter
|
7 |
from langchain.vectorstores.faiss import FAISS
|
8 |
import os
|
9 |
from langchain.vectorstores import Chroma
|
|
|
11 |
from pathlib import Path
|
12 |
import subprocess
|
13 |
import chromadb
|
14 |
+
import magic
|
15 |
+
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar
|
16 |
+
from pydantic import Extra, Field, root_validator
|
17 |
+
import logging
|
18 |
+
logger = logging.getLogger()
|
19 |
+
from langchain.docstore.document import Document
|
20 |
|
21 |
# class CachedChroma(Chroma, ABC):
|
22 |
# """
|
|
|
70 |
# )
|
71 |
# raise ValueError("Either documents or collection_name must be specified.")
|
72 |
|
73 |
+
def embedding_chooser(embedding_radio):
|
74 |
+
if embedding_radio == "Sentence Transformers":
|
75 |
+
embedding_function = HuggingFaceEmbeddings()
|
76 |
+
elif embedding_radio == "OpenAI":
|
77 |
+
embedding_function = OpenAIEmbeddings()
|
78 |
+
else:
|
79 |
+
embedding_function = HuggingFaceEmbeddings()
|
80 |
+
return embedding_function
|
81 |
+
|
82 |
+
# Monkeypatch pending PR
|
83 |
+
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
|
84 |
+
# We now want to combine these smaller pieces into medium size
|
85 |
+
# chunks to send to the LLM.
|
86 |
+
separator_len = self._length_function(separator)
|
87 |
+
|
88 |
+
docs = []
|
89 |
+
current_doc: List[str] = []
|
90 |
+
total = 0
|
91 |
+
for index, d in enumerate(splits):
|
92 |
+
_len = self._length_function(d)
|
93 |
+
if (
|
94 |
+
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
95 |
+
> self._chunk_size
|
96 |
+
):
|
97 |
+
if total > self._chunk_size:
|
98 |
+
logger.warning(
|
99 |
+
f"Created a chunk of size {total}, "
|
100 |
+
f"which is longer than the specified {self._chunk_size}"
|
101 |
+
)
|
102 |
+
if len(current_doc) > 0:
|
103 |
+
doc = self._join_docs(current_doc, separator)
|
104 |
+
if doc is not None:
|
105 |
+
docs.append(doc)
|
106 |
+
# Keep on popping if:
|
107 |
+
# - we have a larger chunk than in the chunk overlap
|
108 |
+
# - or if we still have any chunks and the length is long
|
109 |
+
while total > self._chunk_overlap or (
|
110 |
+
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
111 |
+
> self._chunk_size
|
112 |
+
and total > 0
|
113 |
+
):
|
114 |
+
total -= self._length_function(current_doc[0]) + (
|
115 |
+
separator_len if len(current_doc) > 1 else 0
|
116 |
+
)
|
117 |
+
current_doc = current_doc[1:]
|
118 |
+
|
119 |
+
if index > 0:
|
120 |
+
current_doc.append(separator + d)
|
121 |
+
else:
|
122 |
+
current_doc.append(d)
|
123 |
+
total += _len + (separator_len if len(current_doc) > 1 else 0)
|
124 |
+
doc = self._join_docs(current_doc, separator)
|
125 |
+
if doc is not None:
|
126 |
+
docs.append(doc)
|
127 |
+
return docs
|
128 |
+
|
129 |
def get_text(content):
|
130 |
relevant_part = content.find("div", {"class": "markdown"})
|
131 |
if relevant_part is not None:
|
|
|
133 |
else:
|
134 |
return ""
|
135 |
|
136 |
+
def ingest_docs(all_collections_state, urls, chunk_size, chunk_overlap, embedding_radio, debug=False):
|
137 |
+
cleared_list = urls.copy()
|
138 |
+
def sanitize_folder_name(folder_name):
|
139 |
+
if folder_name != '':
|
140 |
+
folder_name = folder_name.strip().rstrip('/')
|
141 |
+
else:
|
142 |
+
folder_name = '.' # current directory
|
143 |
+
return folder_name
|
144 |
+
|
145 |
+
def is_hidden(path):
|
146 |
+
return os.path.basename(path).startswith('.')
|
147 |
+
|
148 |
+
embedding_function = embedding_chooser(embedding_radio)
|
149 |
all_docs = []
|
|
|
|
|
150 |
shutil.rmtree('downloaded/', ignore_errors=True)
|
151 |
known_exts = ["py", "md"]
|
152 |
+
# Initialize text splitters
|
153 |
py_splitter = PythonCodeTextSplitter(chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap))
|
154 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap))
|
155 |
md_splitter = MarkdownTextSplitter(chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap))
|
156 |
+
py_splitter._merge_splits = _merge_splits.__get__(py_splitter, TextSplitter)
|
157 |
+
# Process input URLs
|
158 |
+
urls = [[url.strip(), [sanitize_folder_name(folder) for folder in url_folders.split(',')]] for url, url_folders in urls]
|
159 |
+
for j in range(len(urls)):
|
160 |
+
orgrepo = urls[j][0]
|
161 |
+
repo_folders = urls[j][1]
|
162 |
+
if orgrepo == '':
|
163 |
+
continue
|
164 |
+
if orgrepo.replace('/','-') in all_collections_state:
|
165 |
+
logging.info(f"Skipping {orgrepo} as it is already in the database")
|
166 |
+
continue
|
167 |
+
documents = []
|
168 |
+
paths = []
|
169 |
paths_by_ext = {}
|
170 |
docs_by_ext = {}
|
171 |
for ext in known_exts + ["other"]:
|
172 |
docs_by_ext[ext] = []
|
173 |
paths_by_ext[ext] = []
|
174 |
+
|
175 |
+
if orgrepo[0] == '/' or orgrepo[0] == '.':
|
176 |
+
# Ingest local folder
|
177 |
+
local_repo_path = sanitize_folder_name(orgrepo[1:])
|
|
|
|
|
|
|
|
|
178 |
else:
|
179 |
+
# Ingest remote git repo
|
180 |
+
org = orgrepo.split('/')[0]
|
181 |
+
repo = orgrepo.split('/')[1]
|
|
|
|
|
|
|
|
|
182 |
repo_url = f"https://github.com/{org}/{repo}.git"
|
183 |
+
local_repo_path = os.path.join('.downloaded', orgrepo) if debug else tempfile.mkdtemp()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
+
# Initialize the Git repository
|
186 |
+
subprocess.run(["git", "init"], cwd=local_repo_path)
|
187 |
+
# Add the remote repository
|
188 |
+
subprocess.run(["git", "remote", "add", "-f", "origin", repo_url], cwd=local_repo_path)
|
189 |
+
# Enable sparse-checkout
|
190 |
+
subprocess.run(["git", "config", "core.sparseCheckout", "true"], cwd=local_repo_path)
|
191 |
+
# Specify the folder to checkout
|
192 |
+
cmd = ["git", "sparse-checkout", "set"] + [i for i in repo_folders]
|
193 |
+
subprocess.run(cmd, cwd=local_repo_path)
|
194 |
+
# Check if branch is called main or master
|
195 |
|
196 |
+
# Checkout the desired branch
|
197 |
+
res = subprocess.run(["git", "checkout", 'main'], cwd=local_repo_path)
|
198 |
+
if res.returncode == 1:
|
199 |
+
res = subprocess.run(["git", "checkout", "master"], cwd=local_repo_path)
|
200 |
+
#res = subprocess.run(["cp", "-r", (Path(local_repo_path) / repo_folders[i]).as_posix(), '/'.join(destination.split('/')[:-1])])#
|
201 |
+
# Iterate through files and process them
|
202 |
+
if local_repo_path == '.':
|
203 |
+
orgrepo='chat-pykg'
|
204 |
+
for root, dirs, files in os.walk(local_repo_path):
|
205 |
+
dirs[:] = [d for d in dirs if not is_hidden(d)] # Ignore hidden directories
|
206 |
for file in files:
|
207 |
+
if is_hidden(file):
|
|
|
|
|
|
|
208 |
continue
|
209 |
+
file_path = os.path.join(root, file)
|
210 |
+
rel_file_path = os.path.relpath(file_path, local_repo_path)
|
211 |
try:
|
212 |
+
if '.' not in rel_file_path:
|
213 |
+
inferred_filetype = magic.from_file(file_path, mime=True)
|
214 |
+
if "python" in inferred_filetype or "text/plain" in inferred_filetype:
|
215 |
+
ext = "py"
|
216 |
+
else:
|
217 |
+
ext = "other"
|
218 |
else:
|
219 |
+
ext = rel_file_path.split('.')[-1]
|
220 |
+
if docs_by_ext.get(ext) is None:
|
221 |
+
ext = "other"
|
222 |
+
doc = TextLoader(os.path.join(local_repo_path, rel_file_path)).load()[0]
|
223 |
+
doc.metadata["source"] = os.path.join(orgrepo, rel_file_path)
|
224 |
+
docs_by_ext[ext].append(doc)
|
225 |
+
paths_by_ext[ext].append(rel_file_path)
|
226 |
except Exception as e:
|
227 |
continue
|
228 |
for ext in docs_by_ext.keys():
|
|
|
231 |
if ext == "md":
|
232 |
documents += md_splitter.split_documents(docs_by_ext[ext])
|
233 |
# else:
|
234 |
+
# documents += text_splitter.split_documents(docs_by_ext[ext]
|
235 |
all_docs += documents
|
236 |
+
# For each document, add the metadata to the page_content
|
237 |
+
for doc in documents:
|
238 |
+
doc.page_content = f'# source:{doc.metadata["source"]}\n{doc.page_content}'
|
239 |
+
if type(embedding_radio) == gr.Radio:
|
240 |
+
embedding_radio = embedding_radio.value
|
241 |
+
persist_directory = os.path.join(".persisted_data", embedding_radio.replace(' ','_'))
|
242 |
+
collection_name = orgrepo.replace('/','-')
|
243 |
+
collection = Chroma.from_documents(documents=documents, collection_name=collection_name, embedding=embedding_function, persist_directory=persist_directory)
|
244 |
collection.persist()
|
245 |
+
all_collections_state.append(collection_name)
|
246 |
+
cleared_list[j][0], cleared_list[j][1] = '', ''
|
247 |
+
return all_collections_state, gr.update(value=cleared_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
249 |
if __name__ == "__main__":
|
250 |
ingest_docs()
|
requirements.txt
CHANGED
@@ -6,4 +6,5 @@ Flask
|
|
6 |
transformers
|
7 |
gradio
|
8 |
chromadb
|
9 |
-
sentence_transformers
|
|
|
|
6 |
transformers
|
7 |
gradio
|
8 |
chromadb
|
9 |
+
sentence_transformers
|
10 |
+
python-magic
|