xangma commited on
Commit
0f7b25d
1 Parent(s): 80b4f00
Files changed (5) hide show
  1. .gitignore +1 -1
  2. app.py +115 -66
  3. chain.py +14 -3
  4. ingest.py +147 -78
  5. requirements.txt +2 -1
.gitignore CHANGED
@@ -4,4 +4,4 @@ downloaded/*
4
  __pycache__/*
5
  launch.json
6
  .DS_Store
7
- devcode.py
 
4
  __pycache__/*
5
  launch.json
6
  .DS_Store
7
+ *devcode*
app.py CHANGED
@@ -1,9 +1,11 @@
1
  # chat-pykg/app.py
2
  import datetime
 
3
  import os
4
  import random
5
  import shutil
6
  import string
 
7
 
8
  import chromadb
9
  import gradio as gr
@@ -13,10 +15,26 @@ from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
13
  from langchain.vectorstores import Chroma
14
 
15
  from chain import get_new_chain1
16
- from ingest import ingest_docs
 
 
17
 
18
- # logging.basicConfig(stream=sys.stdout, level=logging.INFO)
19
- # logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def randomword(length):
22
  letters = string.ascii_lowercase
@@ -25,23 +43,27 @@ def randomword(length):
25
  def change_tab():
26
  return gr.Tabs.update(selected=0)
27
 
28
- def merge_collections(collection_load_names, vs_state):
 
 
 
 
29
  merged_documents = []
30
  merged_embeddings = []
31
  for collection_name in collection_load_names:
32
  chroma_obj_get = chromadb.Client(Settings(
33
  chroma_db_impl="duckdb+parquet",
34
- persist_directory=".persisted_data",
35
  anonymized_telemetry = True
36
  ))
37
  if collection_name == '':
38
  continue
39
- collection_obj = chroma_obj_get.get_collection(collection_name, embedding_function=HuggingFaceEmbeddings())
40
  collection = collection_obj.get(include=["metadatas", "documents", "embeddings"])
41
  for i in range(len(collection['documents'])):
42
  merged_documents.append(Document(page_content=collection['documents'][i], metadata = collection['metadatas'][i]))
43
  merged_embeddings.append(collection['embeddings'][i])
44
- merged_vectorstore = Chroma(collection_name="temp", embedding_function=HuggingFaceEmbeddings())
45
  merged_vectorstore.add_documents(documents=merged_documents, embeddings=merged_embeddings)
46
  return merged_vectorstore
47
 
@@ -64,28 +86,38 @@ def set_chain_up(openai_api_key, model_selector, k_textbox, max_tokens_textbox,
64
  else:
65
  return agent
66
 
67
- def delete_collection(all_collections_state, collections_viewer):
 
 
 
68
  client = chromadb.Client(Settings(
69
  chroma_db_impl="duckdb+parquet",
70
- persist_directory=".persisted_data" # Optional, defaults to .chromadb/ in the current directory
71
  ))
72
  for collection in collections_viewer:
73
  try:
74
  client.delete_collection(collection)
75
  all_collections_state.remove(collection)
76
  collections_viewer.remove(collection)
77
- except:
78
- continue
 
79
  return all_collections_state, collections_viewer
80
 
81
- def delete_all_collections(all_collections_state):
82
- shutil.rmtree(".persisted_data")
 
 
 
83
  return []
84
 
85
- def list_collections(all_collections_state):
 
 
 
86
  client = chromadb.Client(Settings(
87
  chroma_db_impl="duckdb+parquet",
88
- persist_directory=".persisted_data" # Optional, defaults to .chromadb/ in the current directory
89
  ))
90
  collection_names = [[c.name][0] for c in client.list_collections()]
91
  return collection_names
@@ -94,9 +126,12 @@ def update_checkboxgroup(all_collections_state):
94
  new_options = [i for i in all_collections_state]
95
  return gr.CheckboxGroup.update(choices=new_options)
96
 
97
- def destroy_agent(agent):
98
- agent = None
99
- return agent
 
 
 
100
 
101
  def clear_chat(chatbot, history):
102
  return [], []
@@ -110,12 +145,6 @@ def chat(inp, history, agent):
110
  if agent == 'no_vectorstore':
111
  history.append((inp, "Please ingest some package docs to use"))
112
  return history, history
113
- if agent == 'all_collections' and inp != []:
114
- history.append(("", f"Current vectorstores: {inp}"))
115
- return history, history
116
- if agent == 'all_vs_deleted':
117
- history.append((inp, "All vectorstores deleted"))
118
- return history, history
119
  else:
120
  print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
121
  print("inp: " + inp)
@@ -126,10 +155,10 @@ def chat(inp, history, agent):
126
  print(history)
127
  return history, history
128
 
129
- block = gr.Blocks(css=".gradio-container {background-color: system;}")
130
 
131
  with block:
132
- gr.Markdown("<h3><center>chat-pykg</center></h3>")
133
  with gr.Tabs() as tabs:
134
  with gr.TabItem("Chat", id=0):
135
  with gr.Row():
@@ -139,22 +168,26 @@ with block:
139
  lines=1,
140
  type="password",
141
  )
142
- model_selector = gr.Dropdown(["gpt-3.5-turbo", "gpt-4", "other"], label="Model", show_label=True)
143
- model_selector.value = "gpt-3.5-turbo"
 
 
 
 
144
  k_textbox = gr.Textbox(
145
  placeholder="k: Number of search results to consider",
146
  label="Search Results k:",
147
  show_label=True,
148
  lines=1,
 
149
  )
150
- k_textbox.value = "20"
151
  max_tokens_textbox = gr.Textbox(
152
  placeholder="max_tokens: Maximum number of tokens to generate",
153
  label="max_tokens",
154
  show_label=True,
155
  lines=1,
 
156
  )
157
- max_tokens_textbox.value="1000"
158
  chatbot = gr.Chatbot()
159
  with gr.Row():
160
  clear_btn = gr.Button("Clear Chat", variant="secondary").style(full_width=False)
@@ -167,7 +200,7 @@ with block:
167
  gr.Examples(
168
  examples=[
169
  "What does this code do?",
170
- "Where is this specific method in the source code and why is it broken?"
171
  ],
172
  inputs=message,
173
  )
@@ -178,35 +211,41 @@ with block:
178
  The source code is split/broken down into many document objects using langchain's pythoncodetextsplitter, which apparently tries to keep whole functions etc. together. This means that each file in the source code is split into many smaller documents, and the k value is the number of documents to consider when searching for the most similar documents to the question. With gpt-3.5-turbo, k=10 seems to work well, but with gpt-4, k=20 seems to work better.
179
  The model's memory is set to 5 messages, but I haven't tested with gpt-3.5-turbo yet to see if it works well. It seems to work well with gpt-4."""
180
  )
181
- with gr.TabItem("Collections manager", id=1):
182
  with gr.Row():
183
- with gr.Column(scale=2):
184
- all_collections_to_get = gr.List(headers=['New Collections to make'],row_count=3, label='Collections_to_get', show_label=True, interactive=True, max_cols=1, max_rows=3)
185
- make_collections_button = gr.Button(value="Make new collection(s)", variant="secondary").style(full_width=False)
186
- with gr.Row():
187
- chunk_size_textbox = gr.Textbox(
188
- placeholder="Chunk size",
189
- label="Chunk size",
190
- show_label=True,
191
- lines=1,
192
- )
193
- chunk_overlap_textbox = gr.Textbox(
194
- placeholder="Chunk overlap",
195
- label="Chunk overlap",
196
- show_label=True,
197
- lines=1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
199
- chunk_size_textbox.value = "1000"
200
- chunk_overlap_textbox.value = "0"
201
- with gr.Row():
202
- gr.HTML('<center>See the <a href=https://python.langchain.com/en/latest/reference/modules/text_splitter.html>Langchain textsplitter docs</a></center>')
203
- with gr.Column(scale=2):
204
- collections_viewer = gr.CheckboxGroup(choices=[], label='Collections_viewer', show_label=True)
205
- with gr.Column(scale=1):
206
- load_collections_button = gr.Button(value="Load collection(s) to chat!", variant="secondary").style(full_width=False)
207
- get_all_collection_names_button = gr.Button(value="List all saved collections", variant="secondary").style(full_width=False)
208
- delete_collections_button = gr.Button(value="Delete selected saved collections", variant="secondary").style(full_width=False)
209
- delete_all_collections_button = gr.Button(value="Delete all saved collections", variant="secondary").style(full_width=False)
210
  gr.HTML(
211
  "<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain 🦜️🔗</a></center>"
212
  )
@@ -216,25 +255,35 @@ with block:
216
  vs_state = gr.State()
217
  all_collections_state = gr.State()
218
  chat_state = gr.State()
 
 
219
 
220
  submit.click(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
221
  message.submit(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
222
 
223
- load_collections_button.click(merge_collections, inputs=[collections_viewer, vs_state], outputs=[vs_state])#.then(change_tab, None, tabs) #.then(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state])
224
- make_collections_button.click(ingest_docs, inputs=[all_collections_state, all_collections_to_get, chunk_size_textbox, chunk_overlap_textbox], outputs=[all_collections_state], show_progress=True).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
225
- delete_collections_button.click(delete_collection, inputs=[all_collections_state, collections_viewer], outputs=[all_collections_state, collections_viewer]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
226
- delete_all_collections_button.click(delete_all_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
227
- get_all_collection_names_button.click(list_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
228
  clear_btn.click(clear_chat, inputs = [chatbot, history_state], outputs = [chatbot, history_state])
229
  # Whenever chain parameters change, destroy the agent.
230
- input_list = [openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox]
231
  output_list = [agent_state]
232
  for input_item in input_list:
233
  input_item.change(
234
- destroy_agent,
235
  inputs=output_list,
236
  outputs=output_list,
237
  )
238
- all_collections_state.value = list_collections(all_collections_state)
239
  block.load(update_checkboxgroup, inputs = all_collections_state, outputs = collections_viewer)
 
 
 
 
 
 
 
 
240
  block.launch(debug=True)
 
1
  # chat-pykg/app.py
2
  import datetime
3
+ import logging
4
  import os
5
  import random
6
  import shutil
7
  import string
8
+ import sys
9
 
10
  import chromadb
11
  import gradio as gr
 
15
  from langchain.vectorstores import Chroma
16
 
17
  from chain import get_new_chain1
18
+ from ingest import embedding_chooser, ingest_docs
19
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
20
+ logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
21
 
22
+ class LogTextboxHandler(logging.StreamHandler):
23
+ def __init__(self, textbox):
24
+ super().__init__()
25
+ self.textbox = textbox
26
+
27
+ def emit(self, record):
28
+ log_entry = self.format(record)
29
+ self.textbox.value += f"{log_entry}\n"
30
+
31
+ def toggle_log_textbox(log_textbox_state):
32
+ toggle_visibility = not log_textbox_state
33
+ log_textbox_state = not log_textbox_state
34
+ return log_textbox_state,gr.update(visible=toggle_visibility)
35
+
36
+ def update_textbox(full_log):
37
+ return gr.update(value=full_log)
38
 
39
  def randomword(length):
40
  letters = string.ascii_lowercase
 
43
  def change_tab():
44
  return gr.Tabs.update(selected=0)
45
 
46
+ def merge_collections(collection_load_names, vs_state, embedding_radio):
47
+ if type(embedding_radio) == gr.Radio:
48
+ embedding_radio = embedding_radio.value
49
+ persist_directory = os.path.join(".persisted_data", embedding_radio.replace(' ','_'))
50
+ embedding_function = embedding_chooser(embedding_radio)
51
  merged_documents = []
52
  merged_embeddings = []
53
  for collection_name in collection_load_names:
54
  chroma_obj_get = chromadb.Client(Settings(
55
  chroma_db_impl="duckdb+parquet",
56
+ persist_directory=persist_directory,
57
  anonymized_telemetry = True
58
  ))
59
  if collection_name == '':
60
  continue
61
+ collection_obj = chroma_obj_get.get_collection(collection_name, embedding_function=embedding_function)
62
  collection = collection_obj.get(include=["metadatas", "documents", "embeddings"])
63
  for i in range(len(collection['documents'])):
64
  merged_documents.append(Document(page_content=collection['documents'][i], metadata = collection['metadatas'][i]))
65
  merged_embeddings.append(collection['embeddings'][i])
66
+ merged_vectorstore = Chroma(collection_name="temp", embedding_function=embedding_function)
67
  merged_vectorstore.add_documents(documents=merged_documents, embeddings=merged_embeddings)
68
  return merged_vectorstore
69
 
 
86
  else:
87
  return agent
88
 
89
+ def delete_collection(all_collections_state, collections_viewer, embedding_radio):
90
+ if type(embedding_radio) == gr.Radio:
91
+ embedding_radio = embedding_radio.value
92
+ persist_directory = os.path.join(".persisted_data", embedding_radio.replace(' ','_'))
93
  client = chromadb.Client(Settings(
94
  chroma_db_impl="duckdb+parquet",
95
+ persist_directory=persist_directory # Optional, defaults to .chromadb/ in the current directory
96
  ))
97
  for collection in collections_viewer:
98
  try:
99
  client.delete_collection(collection)
100
  all_collections_state.remove(collection)
101
  collections_viewer.remove(collection)
102
+ except Exception as e:
103
+ logging.error(e)
104
+
105
  return all_collections_state, collections_viewer
106
 
107
+ def delete_all_collections(all_collections_state, embedding_radio):
108
+ if type(embedding_radio) == gr.Radio:
109
+ embedding_radio = embedding_radio.value
110
+ persist_directory = os.path.join(".persisted_data", embedding_radio.replace(' ','_'))
111
+ shutil.rmtree(persist_directory)
112
  return []
113
 
114
+ def list_collections(all_collections_state, embedding_radio):
115
+ if type(embedding_radio) == gr.Radio:
116
+ embedding_radio = embedding_radio.value
117
+ persist_directory = os.path.join(".persisted_data", embedding_radio.replace(' ','_'))
118
  client = chromadb.Client(Settings(
119
  chroma_db_impl="duckdb+parquet",
120
+ persist_directory=persist_directory # Optional, defaults to .chromadb/ in the current directory
121
  ))
122
  collection_names = [[c.name][0] for c in client.list_collections()]
123
  return collection_names
 
126
  new_options = [i for i in all_collections_state]
127
  return gr.CheckboxGroup.update(choices=new_options)
128
 
129
+ def update_log_textbox(full_log):
130
+ return gr.Textbox.update(value=full_log)
131
+
132
+ def destroy_state(state):
133
+ state = None
134
+ return state
135
 
136
  def clear_chat(chatbot, history):
137
  return [], []
 
145
  if agent == 'no_vectorstore':
146
  history.append((inp, "Please ingest some package docs to use"))
147
  return history, history
 
 
 
 
 
 
148
  else:
149
  print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
150
  print("inp: " + inp)
 
155
  print(history)
156
  return history, history
157
 
158
+ block = gr.Blocks(title = "chat-pykg", analytics_enabled = False, css=".gradio-container {background-color: system;}")
159
 
160
  with block:
161
+ gr.Markdown("<h1><center>chat-pykg</center></h1>")
162
  with gr.Tabs() as tabs:
163
  with gr.TabItem("Chat", id=0):
164
  with gr.Row():
 
168
  lines=1,
169
  type="password",
170
  )
171
+ model_selector = gr.Dropdown(
172
+ choices=["gpt-3.5-turbo", "gpt-4", "other"],
173
+ label="Model",
174
+ show_label=True,
175
+ value = "gpt-3.5-turbo"
176
+ )
177
  k_textbox = gr.Textbox(
178
  placeholder="k: Number of search results to consider",
179
  label="Search Results k:",
180
  show_label=True,
181
  lines=1,
182
+ value="20",
183
  )
 
184
  max_tokens_textbox = gr.Textbox(
185
  placeholder="max_tokens: Maximum number of tokens to generate",
186
  label="max_tokens",
187
  show_label=True,
188
  lines=1,
189
+ value="1000",
190
  )
 
191
  chatbot = gr.Chatbot()
192
  with gr.Row():
193
  clear_btn = gr.Button("Clear Chat", variant="secondary").style(full_width=False)
 
200
  gr.Examples(
201
  examples=[
202
  "What does this code do?",
203
+ "I want to change the chat-pykg app to have a log viewer, where the user can see what python is doing in the background. How could I do that?",
204
  ],
205
  inputs=message,
206
  )
 
211
  The source code is split/broken down into many document objects using langchain's pythoncodetextsplitter, which apparently tries to keep whole functions etc. together. This means that each file in the source code is split into many smaller documents, and the k value is the number of documents to consider when searching for the most similar documents to the question. With gpt-3.5-turbo, k=10 seems to work well, but with gpt-4, k=20 seems to work better.
212
  The model's memory is set to 5 messages, but I haven't tested with gpt-3.5-turbo yet to see if it works well. It seems to work well with gpt-4."""
213
  )
214
+ with gr.TabItem("Repository Selector/Manager", id=1):
215
  with gr.Row():
216
+ collections_viewer = gr.CheckboxGroup(choices=[], label='Repository Viewer', show_label=True)
217
+ with gr.Row():
218
+ load_collections_button = gr.Button(value="Load respositories to chat!", variant="secondary")#.style(full_width=False)
219
+ get_all_collection_names_button = gr.Button(value="List all saved repositories", variant="secondary")#.style(full_width=False)
220
+ delete_collections_button = gr.Button(value="Delete selected saved repositories", variant="secondary")#.style(full_width=False)
221
+ delete_all_collections_button = gr.Button(value="Delete all saved repositories", variant="secondary")#.style(full_width=False)
222
+ with gr.TabItem("Get New Repositories", id=2):
223
+ with gr.Row():
224
+ all_collections_to_get = gr.List(headers=['Repository URL', 'Folders'], row_count=3, col_count=2, label='Repositories to get', show_label=True, interactive=True, max_cols=2, max_rows=3)
225
+ make_collections_button = gr.Button(value="Get new repositories", variant="secondary").style(full_width=False)
226
+ with gr.Row():
227
+ chunk_size_textbox = gr.Textbox(
228
+ placeholder="Chunk size",
229
+ label="Chunk size",
230
+ show_label=True,
231
+ lines=1,
232
+ value="1000"
233
+ )
234
+ chunk_overlap_textbox = gr.Textbox(
235
+ placeholder="Chunk overlap",
236
+ label="Chunk overlap",
237
+ show_label=True,
238
+ lines=1,
239
+ value="0"
240
+ )
241
+ embedding_radio = gr.Radio(
242
+ choices = ['Sentence Transformers', 'OpenAI'],
243
+ label="Embedding Options",
244
+ show_label=True,
245
+ value='Sentence Transformers'
246
  )
247
+ with gr.Row():
248
+ gr.HTML('<center>See the <a href=https://python.langchain.com/en/latest/reference/modules/text_splitter.html>Langchain textsplitter docs</a></center>')
 
 
 
 
 
 
 
 
 
249
  gr.HTML(
250
  "<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain 🦜️🔗</a></center>"
251
  )
 
255
  vs_state = gr.State()
256
  all_collections_state = gr.State()
257
  chat_state = gr.State()
258
+ debug_state = gr.State()
259
+ debug_state.value = False
260
 
261
  submit.click(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
262
  message.submit(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
263
 
264
+ load_collections_button.click(merge_collections, inputs=[collections_viewer, vs_state, embedding_radio], outputs=[vs_state])#.then(change_tab, None, tabs) #.then(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state])
265
+ make_collections_button.click(ingest_docs, inputs=[all_collections_state, all_collections_to_get, chunk_size_textbox, chunk_overlap_textbox, embedding_radio, debug_state], outputs=[all_collections_state, all_collections_to_get], show_progress=True).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
266
+ delete_collections_button.click(delete_collection, inputs=[all_collections_state, collections_viewer, embedding_radio], outputs=[all_collections_state, collections_viewer]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
267
+ delete_all_collections_button.click(delete_all_collections, inputs=[all_collections_state, embedding_radio], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
268
+ get_all_collection_names_button.click(list_collections, inputs=[all_collections_state, embedding_radio], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
269
  clear_btn.click(clear_chat, inputs = [chatbot, history_state], outputs = [chatbot, history_state])
270
  # Whenever chain parameters change, destroy the agent.
271
+ input_list = [openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, embedding_radio]
272
  output_list = [agent_state]
273
  for input_item in input_list:
274
  input_item.change(
275
+ destroy_state,
276
  inputs=output_list,
277
  outputs=output_list,
278
  )
279
+ all_collections_state.value = list_collections(all_collections_state, embedding_radio)
280
  block.load(update_checkboxgroup, inputs = all_collections_state, outputs = collections_viewer)
281
+ log_textbox_handler = LogTextboxHandler(gr.TextArea(interactive=False, placeholder="Logs will appear here...", visible=False))
282
+ log_textbox = log_textbox_handler.textbox
283
+ logging.getLogger().addHandler(log_textbox_handler)
284
+ log_textbox_visibility_state = gr.State()
285
+ log_textbox_visibility_state.value = False
286
+ log_toggle_button = gr.Button("Toggle Log", variant="secondary")
287
+ log_toggle_button.click(toggle_log_textbox, inputs=[log_textbox_visibility_state], outputs=[log_textbox_visibility_state,log_textbox])
288
+ block.queue(concurrency_count=40)
289
  block.launch(debug=True)
chain.py CHANGED
@@ -1,5 +1,6 @@
1
  # chat-pykg/chain.py
2
-
 
3
  from langchain.chains.base import Chain
4
  from langchain import HuggingFaceHub
5
  from langchain.chains.question_answering import load_qa_chain
@@ -10,12 +11,21 @@ from langchain.chains.llm import LLMChain
10
  from langchain.callbacks.base import CallbackManager
11
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
12
  from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT
 
 
 
 
 
13
 
14
  # logging.basicConfig(stream=sys.stdout, level=logging.INFO)
15
  # logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
16
 
17
  def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -> Chain:
18
 
 
 
 
 
19
  template = """You are called chat-pykg and are an AI assistant coded in python using langchain and gradio. You are very helpful for answering questions about various open source libraries.
20
  You are given the following extracted parts of code and a question. Provide a conversational answer to the question.
21
  Do NOT make up any hyperlinks that are not in the code.
@@ -34,8 +44,8 @@ def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -
34
  llm = HuggingFaceHub(repo_id="chavinlo/gpt4-x-alpaca")#, model_kwargs={"temperature":0, "max_length":64})
35
  doc_chain_llm = HuggingFaceHub(repo_id="chavinlo/gpt4-x-alpaca")
36
  question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
37
- doc_chain = load_qa_chain(doc_chain_llm, chain_type="stuff", prompt=QA_PROMPT)
38
-
39
  # memory = ConversationKGMemory(llm=llm, input_key="question", output_key="answer")
40
  memory = ConversationBufferWindowMemory(input_key="question", output_key="answer", k=5)
41
  retriever = vectorstore.as_retriever(search_type="similarity")
@@ -45,5 +55,6 @@ def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -
45
  retriever.search_kwargs = {"k": 10}
46
  qa = ConversationalRetrievalChain(
47
  retriever=retriever, memory=memory, combine_docs_chain=doc_chain, question_generator=question_generator)
 
48
 
49
  return qa
 
1
  # chat-pykg/chain.py
2
+ from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar
3
+ from pydantic import Extra, Field, root_validator
4
  from langchain.chains.base import Chain
5
  from langchain import HuggingFaceHub
6
  from langchain.chains.question_answering import load_qa_chain
 
11
  from langchain.callbacks.base import CallbackManager
12
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
13
  from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT
14
+ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
15
+ from langchain.chains.llm import LLMChain
16
+ from langchain.schema import BaseLanguageModel, BaseRetriever, Document
17
+ from langchain.prompts.prompt import PromptTemplate
18
+
19
 
20
  # logging.basicConfig(stream=sys.stdout, level=logging.INFO)
21
  # logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
22
 
23
  def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -> Chain:
24
 
25
+ # def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
26
+ # docs = self.retriever.vectorstore._collection.query(question, n_results=self.retriever.search_kwargs["k"], where = {"source":{"$contains":"search_string"}}, where_document = {"$contains":"search_string"})
27
+ # return self._reduce_tokens_below_limit(docs)
28
+
29
  template = """You are called chat-pykg and are an AI assistant coded in python using langchain and gradio. You are very helpful for answering questions about various open source libraries.
30
  You are given the following extracted parts of code and a question. Provide a conversational answer to the question.
31
  Do NOT make up any hyperlinks that are not in the code.
 
44
  llm = HuggingFaceHub(repo_id="chavinlo/gpt4-x-alpaca")#, model_kwargs={"temperature":0, "max_length":64})
45
  doc_chain_llm = HuggingFaceHub(repo_id="chavinlo/gpt4-x-alpaca")
46
  question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
47
+ doc_chain = load_qa_chain(doc_chain_llm, chain_type="stuff", prompt=QA_PROMPT)#, document_prompt = PromptTemplate(input_variables=["source", "page_content"], template="{source}\n{page_content}"))
48
+
49
  # memory = ConversationKGMemory(llm=llm, input_key="question", output_key="answer")
50
  memory = ConversationBufferWindowMemory(input_key="question", output_key="answer", k=5)
51
  retriever = vectorstore.as_retriever(search_type="similarity")
 
55
  retriever.search_kwargs = {"k": 10}
56
  qa = ConversationalRetrievalChain(
57
  retriever=retriever, memory=memory, combine_docs_chain=doc_chain, question_generator=question_generator)
58
+ # qa._get_docs = _get_docs.__get__(qa, ConversationalRetrievalChain)
59
 
60
  return qa
ingest.py CHANGED
@@ -1,8 +1,9 @@
1
  # chat-pykg/ingest.py
2
  import tempfile
 
3
  from langchain.document_loaders import SitemapLoader, ReadTheDocsLoader, TextLoader
4
  from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter, PythonCodeTextSplitter, MarkdownTextSplitter
6
  from langchain.vectorstores.faiss import FAISS
7
  import os
8
  from langchain.vectorstores import Chroma
@@ -10,8 +11,12 @@ import shutil
10
  from pathlib import Path
11
  import subprocess
12
  import chromadb
13
- from chromadb.config import Settings
14
- import chromadb.utils.embedding_functions as ef
 
 
 
 
15
 
16
  # class CachedChroma(Chroma, ABC):
17
  # """
@@ -65,6 +70,62 @@ import chromadb.utils.embedding_functions as ef
65
  # )
66
  # raise ValueError("Either documents or collection_name must be specified.")
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def get_text(content):
69
  relevant_part = content.find("div", {"class": "markdown"})
70
  if relevant_part is not None:
@@ -72,83 +133,96 @@ def get_text(content):
72
  else:
73
  return ""
74
 
75
- def ingest_docs(all_collections_state, urls, chunk_size, chunk_overlap):
76
- """Get documents from web pages."""
 
 
 
 
 
 
 
 
 
 
 
77
  all_docs = []
78
- folders=[]
79
- documents = []
80
  shutil.rmtree('downloaded/', ignore_errors=True)
81
  known_exts = ["py", "md"]
 
82
  py_splitter = PythonCodeTextSplitter(chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap))
83
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap))
84
  md_splitter = MarkdownTextSplitter(chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap))
85
- for url in urls:
 
 
 
 
 
 
 
 
 
 
 
 
86
  paths_by_ext = {}
87
  docs_by_ext = {}
88
  for ext in known_exts + ["other"]:
89
  docs_by_ext[ext] = []
90
  paths_by_ext[ext] = []
91
- url = url[0]
92
- if url == '':
93
- continue
94
- if "." in url:
95
- if len(url) > 1:
96
- folder = url.split('.')[1]
97
- else:
98
- folder = '.'
99
  else:
100
- destination = Path(os.path.join('downloaded',url))
101
- destination.mkdir(exist_ok=True, parents=True)
102
- destination = destination.as_posix()
103
- if url[0] == '/':
104
- url = url[1:]
105
- org = url.split('/')[0]
106
- repo = url.split('/')[1]
107
  repo_url = f"https://github.com/{org}/{repo}.git"
108
- # join all strings after 2nd slash
109
- folder = '/'.join(url.split('/')[2:])
110
- if folder[-1] == '/':
111
- folder = folder[:-1]
112
- if folder:
113
- with tempfile.TemporaryDirectory() as temp_dir:
114
- temp_path = Path(temp_dir)
115
-
116
- # Initialize the Git repository
117
- subprocess.run(["git", "init"], cwd=temp_path)
118
-
119
- # Add the remote repository
120
- subprocess.run(["git", "remote", "add", "-f", "origin", repo_url], cwd=temp_path)
121
-
122
- # Enable sparse-checkout
123
- subprocess.run(["git", "config", "core.sparseCheckout", "true"], cwd=temp_path)
124
 
125
- # Specify the folder to checkout
126
- with open(temp_path / ".git" / "info" / "sparse-checkout", "w") as f:
127
- f.write(f"{folder}/\n")
 
 
 
 
 
 
 
128
 
129
- # Checkout the desired branch
130
- res = subprocess.run(["git", "checkout", 'main'], cwd=temp_path)
131
- if res.returncode == 1:
132
- res = subprocess.run(["git", "checkout", "master"], cwd=temp_path)
133
- res = subprocess.run(["cp", "-r", (temp_path / folder).as_posix(), '/'.join(destination.split('/')[:-1])])
134
- folder = destination
135
- local_repo_path_1 = folder
136
- if local_repo_path_1 == '.':
137
- local_repo_path_1 = os.getcwd()
138
- for root, dirs, files in os.walk(local_repo_path_1):
139
  for file in files:
140
- file_path = os.path.join(root, file)
141
- rel_file_path = os.path.relpath(file_path, local_repo_path_1)
142
- ext = rel_file_path.split('.')[-1]
143
- if rel_file_path.startswith('.'):
144
  continue
 
 
145
  try:
146
- if paths_by_ext.get(rel_file_path.split('.')[-1]) is None:
147
- paths_by_ext["other"].append(rel_file_path)
148
- docs_by_ext["other"].append(TextLoader(os.path.join(local_repo_path_1, rel_file_path)).load()[0])
 
 
 
149
  else:
150
- paths_by_ext[ext].append(rel_file_path)
151
- docs_by_ext[ext].append(TextLoader(os.path.join(local_repo_path_1, rel_file_path)).load()[0])
 
 
 
 
 
152
  except Exception as e:
153
  continue
154
  for ext in docs_by_ext.keys():
@@ -157,25 +231,20 @@ def ingest_docs(all_collections_state, urls, chunk_size, chunk_overlap):
157
  if ext == "md":
158
  documents += md_splitter.split_documents(docs_by_ext[ext])
159
  # else:
160
- # documents += text_splitter.split_documents(docs_by_ext[ext]
161
  all_docs += documents
162
- if 'downloaded/' in folder:
163
- folder = '-'.join(folder.split('/')[1:])
164
- if folder == '.':
165
- folder = 'chat-pykg'
166
- collection = Chroma.from_documents(documents=documents, collection_name=folder, embedding=HuggingFaceEmbeddings(), persist_directory=".persisted_data")
 
 
 
167
  collection.persist()
168
- all_collections_state.append(folder)
169
- return all_collections_state
170
- # embeddings = HuggingFaceEmbeddings()
171
- # merged_vectorstore = Chroma.from_documents(persist_directory=".persisted_data", documents=documents, embedding=embeddings, collection_name='merged_collections')
172
- # #vectorstore = FAISS.from_documents(documents, embeddings)
173
- # # # Save vectorstore
174
- # # with open("vectorstore.pkl", "wb") as f:
175
- # # pickle.dump(vectorstore. , f)
176
-
177
- # return merged_vectorstore
178
-
179
 
180
  if __name__ == "__main__":
181
  ingest_docs()
 
1
  # chat-pykg/ingest.py
2
  import tempfile
3
+ import gradio as gr
4
  from langchain.document_loaders import SitemapLoader, ReadTheDocsLoader, TextLoader
5
  from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, PythonCodeTextSplitter, MarkdownTextSplitter, TextSplitter
7
  from langchain.vectorstores.faiss import FAISS
8
  import os
9
  from langchain.vectorstores import Chroma
 
11
  from pathlib import Path
12
  import subprocess
13
  import chromadb
14
+ import magic
15
+ from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar
16
+ from pydantic import Extra, Field, root_validator
17
+ import logging
18
+ logger = logging.getLogger()
19
+ from langchain.docstore.document import Document
20
 
21
  # class CachedChroma(Chroma, ABC):
22
  # """
 
70
  # )
71
  # raise ValueError("Either documents or collection_name must be specified.")
72
 
73
+ def embedding_chooser(embedding_radio):
74
+ if embedding_radio == "Sentence Transformers":
75
+ embedding_function = HuggingFaceEmbeddings()
76
+ elif embedding_radio == "OpenAI":
77
+ embedding_function = OpenAIEmbeddings()
78
+ else:
79
+ embedding_function = HuggingFaceEmbeddings()
80
+ return embedding_function
81
+
82
+ # Monkeypatch pending PR
83
+ def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
84
+ # We now want to combine these smaller pieces into medium size
85
+ # chunks to send to the LLM.
86
+ separator_len = self._length_function(separator)
87
+
88
+ docs = []
89
+ current_doc: List[str] = []
90
+ total = 0
91
+ for index, d in enumerate(splits):
92
+ _len = self._length_function(d)
93
+ if (
94
+ total + _len + (separator_len if len(current_doc) > 0 else 0)
95
+ > self._chunk_size
96
+ ):
97
+ if total > self._chunk_size:
98
+ logger.warning(
99
+ f"Created a chunk of size {total}, "
100
+ f"which is longer than the specified {self._chunk_size}"
101
+ )
102
+ if len(current_doc) > 0:
103
+ doc = self._join_docs(current_doc, separator)
104
+ if doc is not None:
105
+ docs.append(doc)
106
+ # Keep on popping if:
107
+ # - we have a larger chunk than in the chunk overlap
108
+ # - or if we still have any chunks and the length is long
109
+ while total > self._chunk_overlap or (
110
+ total + _len + (separator_len if len(current_doc) > 0 else 0)
111
+ > self._chunk_size
112
+ and total > 0
113
+ ):
114
+ total -= self._length_function(current_doc[0]) + (
115
+ separator_len if len(current_doc) > 1 else 0
116
+ )
117
+ current_doc = current_doc[1:]
118
+
119
+ if index > 0:
120
+ current_doc.append(separator + d)
121
+ else:
122
+ current_doc.append(d)
123
+ total += _len + (separator_len if len(current_doc) > 1 else 0)
124
+ doc = self._join_docs(current_doc, separator)
125
+ if doc is not None:
126
+ docs.append(doc)
127
+ return docs
128
+
129
  def get_text(content):
130
  relevant_part = content.find("div", {"class": "markdown"})
131
  if relevant_part is not None:
 
133
  else:
134
  return ""
135
 
136
+ def ingest_docs(all_collections_state, urls, chunk_size, chunk_overlap, embedding_radio, debug=False):
137
+ cleared_list = urls.copy()
138
+ def sanitize_folder_name(folder_name):
139
+ if folder_name != '':
140
+ folder_name = folder_name.strip().rstrip('/')
141
+ else:
142
+ folder_name = '.' # current directory
143
+ return folder_name
144
+
145
+ def is_hidden(path):
146
+ return os.path.basename(path).startswith('.')
147
+
148
+ embedding_function = embedding_chooser(embedding_radio)
149
  all_docs = []
 
 
150
  shutil.rmtree('downloaded/', ignore_errors=True)
151
  known_exts = ["py", "md"]
152
+ # Initialize text splitters
153
  py_splitter = PythonCodeTextSplitter(chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap))
154
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap))
155
  md_splitter = MarkdownTextSplitter(chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap))
156
+ py_splitter._merge_splits = _merge_splits.__get__(py_splitter, TextSplitter)
157
+ # Process input URLs
158
+ urls = [[url.strip(), [sanitize_folder_name(folder) for folder in url_folders.split(',')]] for url, url_folders in urls]
159
+ for j in range(len(urls)):
160
+ orgrepo = urls[j][0]
161
+ repo_folders = urls[j][1]
162
+ if orgrepo == '':
163
+ continue
164
+ if orgrepo.replace('/','-') in all_collections_state:
165
+ logging.info(f"Skipping {orgrepo} as it is already in the database")
166
+ continue
167
+ documents = []
168
+ paths = []
169
  paths_by_ext = {}
170
  docs_by_ext = {}
171
  for ext in known_exts + ["other"]:
172
  docs_by_ext[ext] = []
173
  paths_by_ext[ext] = []
174
+
175
+ if orgrepo[0] == '/' or orgrepo[0] == '.':
176
+ # Ingest local folder
177
+ local_repo_path = sanitize_folder_name(orgrepo[1:])
 
 
 
 
178
  else:
179
+ # Ingest remote git repo
180
+ org = orgrepo.split('/')[0]
181
+ repo = orgrepo.split('/')[1]
 
 
 
 
182
  repo_url = f"https://github.com/{org}/{repo}.git"
183
+ local_repo_path = os.path.join('.downloaded', orgrepo) if debug else tempfile.mkdtemp()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ # Initialize the Git repository
186
+ subprocess.run(["git", "init"], cwd=local_repo_path)
187
+ # Add the remote repository
188
+ subprocess.run(["git", "remote", "add", "-f", "origin", repo_url], cwd=local_repo_path)
189
+ # Enable sparse-checkout
190
+ subprocess.run(["git", "config", "core.sparseCheckout", "true"], cwd=local_repo_path)
191
+ # Specify the folder to checkout
192
+ cmd = ["git", "sparse-checkout", "set"] + [i for i in repo_folders]
193
+ subprocess.run(cmd, cwd=local_repo_path)
194
+ # Check if branch is called main or master
195
 
196
+ # Checkout the desired branch
197
+ res = subprocess.run(["git", "checkout", 'main'], cwd=local_repo_path)
198
+ if res.returncode == 1:
199
+ res = subprocess.run(["git", "checkout", "master"], cwd=local_repo_path)
200
+ #res = subprocess.run(["cp", "-r", (Path(local_repo_path) / repo_folders[i]).as_posix(), '/'.join(destination.split('/')[:-1])])#
201
+ # Iterate through files and process them
202
+ if local_repo_path == '.':
203
+ orgrepo='chat-pykg'
204
+ for root, dirs, files in os.walk(local_repo_path):
205
+ dirs[:] = [d for d in dirs if not is_hidden(d)] # Ignore hidden directories
206
  for file in files:
207
+ if is_hidden(file):
 
 
 
208
  continue
209
+ file_path = os.path.join(root, file)
210
+ rel_file_path = os.path.relpath(file_path, local_repo_path)
211
  try:
212
+ if '.' not in rel_file_path:
213
+ inferred_filetype = magic.from_file(file_path, mime=True)
214
+ if "python" in inferred_filetype or "text/plain" in inferred_filetype:
215
+ ext = "py"
216
+ else:
217
+ ext = "other"
218
  else:
219
+ ext = rel_file_path.split('.')[-1]
220
+ if docs_by_ext.get(ext) is None:
221
+ ext = "other"
222
+ doc = TextLoader(os.path.join(local_repo_path, rel_file_path)).load()[0]
223
+ doc.metadata["source"] = os.path.join(orgrepo, rel_file_path)
224
+ docs_by_ext[ext].append(doc)
225
+ paths_by_ext[ext].append(rel_file_path)
226
  except Exception as e:
227
  continue
228
  for ext in docs_by_ext.keys():
 
231
  if ext == "md":
232
  documents += md_splitter.split_documents(docs_by_ext[ext])
233
  # else:
234
+ # documents += text_splitter.split_documents(docs_by_ext[ext]
235
  all_docs += documents
236
+ # For each document, add the metadata to the page_content
237
+ for doc in documents:
238
+ doc.page_content = f'# source:{doc.metadata["source"]}\n{doc.page_content}'
239
+ if type(embedding_radio) == gr.Radio:
240
+ embedding_radio = embedding_radio.value
241
+ persist_directory = os.path.join(".persisted_data", embedding_radio.replace(' ','_'))
242
+ collection_name = orgrepo.replace('/','-')
243
+ collection = Chroma.from_documents(documents=documents, collection_name=collection_name, embedding=embedding_function, persist_directory=persist_directory)
244
  collection.persist()
245
+ all_collections_state.append(collection_name)
246
+ cleared_list[j][0], cleared_list[j][1] = '', ''
247
+ return all_collections_state, gr.update(value=cleared_list)
 
 
 
 
 
 
 
 
248
 
249
  if __name__ == "__main__":
250
  ingest_docs()
requirements.txt CHANGED
@@ -6,4 +6,5 @@ Flask
6
  transformers
7
  gradio
8
  chromadb
9
- sentence_transformers
 
 
6
  transformers
7
  gradio
8
  chromadb
9
+ sentence_transformers
10
+ python-magic