xangma commited on
Commit
80b4f00
1 Parent(s): a3ab5d3

latest fixes

Browse files
Files changed (3) hide show
  1. app.py +34 -23
  2. chain.py +15 -3
  3. ingest.py +5 -4
app.py CHANGED
@@ -1,18 +1,23 @@
1
  # chat-pykg/app.py
2
  import datetime
3
  import os
4
- import gradio as gr
 
 
 
5
  import chromadb
 
6
  from chromadb.config import Settings
7
- # logging.basicConfig(stream=sys.stdout, level=logging.INFO)
8
- # logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
9
- from langchain.vectorstores import Chroma
10
  from langchain.docstore.document import Document
11
- import shutil
12
- import random, string
 
13
  from chain import get_new_chain1
14
  from ingest import ingest_docs
15
 
 
 
 
16
  def randomword(length):
17
  letters = string.ascii_lowercase
18
  return ''.join(random.choice(letters) for i in range(length))
@@ -23,22 +28,21 @@ def change_tab():
23
  def merge_collections(collection_load_names, vs_state):
24
  merged_documents = []
25
  merged_embeddings = []
26
- client = chromadb.Client(Settings(
27
- chroma_db_impl="duckdb+parquet",
28
- persist_directory=".persisted_data" # Optional, defaults to .chromadb/ in the current directory
29
- ))
30
-
31
  for collection_name in collection_load_names:
32
- collection_name = collection_name
 
 
 
 
33
  if collection_name == '':
34
  continue
35
- collection = client.get_collection(collection_name)
36
- collection = collection.get(include=["metadatas", "documents", "embeddings"])
37
  for i in range(len(collection['documents'])):
38
  merged_documents.append(Document(page_content=collection['documents'][i], metadata = collection['metadatas'][i]))
39
  merged_embeddings.append(collection['embeddings'][i])
40
- merged_collection_name = "merged_collection"
41
- merged_vectorstore = Chroma.from_documents(documents=merged_documents, embeddings=merged_embeddings, collection_name=merged_collection_name)
42
  return merged_vectorstore
43
 
44
  def set_chain_up(openai_api_key, model_selector, k_textbox, max_tokens_textbox, vectorstore, agent):
@@ -66,9 +70,12 @@ def delete_collection(all_collections_state, collections_viewer):
66
  persist_directory=".persisted_data" # Optional, defaults to .chromadb/ in the current directory
67
  ))
68
  for collection in collections_viewer:
69
- client.delete_collection(collection)
70
- all_collections_state.remove(collection)
71
- collections_viewer.remove(collection)
 
 
 
72
  return all_collections_state, collections_viewer
73
 
74
  def delete_all_collections(all_collections_state):
@@ -91,6 +98,9 @@ def destroy_agent(agent):
91
  agent = None
92
  return agent
93
 
 
 
 
94
  def chat(inp, history, agent):
95
  history = history or []
96
  if type(agent) == str:
@@ -144,15 +154,16 @@ with block:
144
  show_label=True,
145
  lines=1,
146
  )
147
- max_tokens_textbox.value="2000"
148
  chatbot = gr.Chatbot()
149
  with gr.Row():
 
150
  message = gr.Textbox(
151
  label="What's your question?",
152
  placeholder="What is this code?",
153
  lines=1,
154
  )
155
- submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
156
  gr.Examples(
157
  examples=[
158
  "What does this code do?",
@@ -207,14 +218,14 @@ with block:
207
  chat_state = gr.State()
208
 
209
  submit.click(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
210
- message.submit(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
211
 
212
  load_collections_button.click(merge_collections, inputs=[collections_viewer, vs_state], outputs=[vs_state])#.then(change_tab, None, tabs) #.then(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state])
213
  make_collections_button.click(ingest_docs, inputs=[all_collections_state, all_collections_to_get, chunk_size_textbox, chunk_overlap_textbox], outputs=[all_collections_state], show_progress=True).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
214
  delete_collections_button.click(delete_collection, inputs=[all_collections_state, collections_viewer], outputs=[all_collections_state, collections_viewer]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
215
  delete_all_collections_button.click(delete_all_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
216
  get_all_collection_names_button.click(list_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
217
-
218
  # Whenever chain parameters change, destroy the agent.
219
  input_list = [openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox]
220
  output_list = [agent_state]
 
1
  # chat-pykg/app.py
2
  import datetime
3
  import os
4
+ import random
5
+ import shutil
6
+ import string
7
+
8
  import chromadb
9
+ import gradio as gr
10
  from chromadb.config import Settings
 
 
 
11
  from langchain.docstore.document import Document
12
+ from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
13
+ from langchain.vectorstores import Chroma
14
+
15
  from chain import get_new_chain1
16
  from ingest import ingest_docs
17
 
18
+ # logging.basicConfig(stream=sys.stdout, level=logging.INFO)
19
+ # logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
20
+
21
  def randomword(length):
22
  letters = string.ascii_lowercase
23
  return ''.join(random.choice(letters) for i in range(length))
 
28
  def merge_collections(collection_load_names, vs_state):
29
  merged_documents = []
30
  merged_embeddings = []
 
 
 
 
 
31
  for collection_name in collection_load_names:
32
+ chroma_obj_get = chromadb.Client(Settings(
33
+ chroma_db_impl="duckdb+parquet",
34
+ persist_directory=".persisted_data",
35
+ anonymized_telemetry = True
36
+ ))
37
  if collection_name == '':
38
  continue
39
+ collection_obj = chroma_obj_get.get_collection(collection_name, embedding_function=HuggingFaceEmbeddings())
40
+ collection = collection_obj.get(include=["metadatas", "documents", "embeddings"])
41
  for i in range(len(collection['documents'])):
42
  merged_documents.append(Document(page_content=collection['documents'][i], metadata = collection['metadatas'][i]))
43
  merged_embeddings.append(collection['embeddings'][i])
44
+ merged_vectorstore = Chroma(collection_name="temp", embedding_function=HuggingFaceEmbeddings())
45
+ merged_vectorstore.add_documents(documents=merged_documents, embeddings=merged_embeddings)
46
  return merged_vectorstore
47
 
48
  def set_chain_up(openai_api_key, model_selector, k_textbox, max_tokens_textbox, vectorstore, agent):
 
70
  persist_directory=".persisted_data" # Optional, defaults to .chromadb/ in the current directory
71
  ))
72
  for collection in collections_viewer:
73
+ try:
74
+ client.delete_collection(collection)
75
+ all_collections_state.remove(collection)
76
+ collections_viewer.remove(collection)
77
+ except:
78
+ continue
79
  return all_collections_state, collections_viewer
80
 
81
  def delete_all_collections(all_collections_state):
 
98
  agent = None
99
  return agent
100
 
101
+ def clear_chat(chatbot, history):
102
+ return [], []
103
+
104
  def chat(inp, history, agent):
105
  history = history or []
106
  if type(agent) == str:
 
154
  show_label=True,
155
  lines=1,
156
  )
157
+ max_tokens_textbox.value="1000"
158
  chatbot = gr.Chatbot()
159
  with gr.Row():
160
+ clear_btn = gr.Button("Clear Chat", variant="secondary").style(full_width=False)
161
  message = gr.Textbox(
162
  label="What's your question?",
163
  placeholder="What is this code?",
164
  lines=1,
165
  )
166
+ submit = gr.Button(value="Send").style(full_width=False)
167
  gr.Examples(
168
  examples=[
169
  "What does this code do?",
 
218
  chat_state = gr.State()
219
 
220
  submit.click(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
221
+ message.submit(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
222
 
223
  load_collections_button.click(merge_collections, inputs=[collections_viewer, vs_state], outputs=[vs_state])#.then(change_tab, None, tabs) #.then(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state])
224
  make_collections_button.click(ingest_docs, inputs=[all_collections_state, all_collections_to_get, chunk_size_textbox, chunk_overlap_textbox], outputs=[all_collections_state], show_progress=True).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
225
  delete_collections_button.click(delete_collection, inputs=[all_collections_state, collections_viewer], outputs=[all_collections_state, collections_viewer]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
226
  delete_all_collections_button.click(delete_all_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
227
  get_all_collection_names_button.click(list_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
228
+ clear_btn.click(clear_chat, inputs = [chatbot, history_state], outputs = [chatbot, history_state])
229
  # Whenever chain parameters change, destroy the agent.
230
  input_list = [openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox]
231
  output_list = [agent_state]
chain.py CHANGED
@@ -1,8 +1,6 @@
1
  # chat-pykg/chain.py
2
 
3
  from langchain.chains.base import Chain
4
- # logging.basicConfig(stream=sys.stdout, level=logging.INFO)
5
- # logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
6
  from langchain import HuggingFaceHub
7
  from langchain.chains.question_answering import load_qa_chain
8
  from langchain.chat_models import ChatOpenAI
@@ -13,8 +11,22 @@ from langchain.callbacks.base import CallbackManager
13
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
14
  from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT
15
 
 
 
 
16
  def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -> Chain:
17
 
 
 
 
 
 
 
 
 
 
 
 
18
  if model_selector in ['gpt-4', 'gpt-3.5-turbo']:
19
  llm = ChatOpenAI(client = None, temperature=0.7, model_name=model_selector)
20
  doc_chain_llm = ChatOpenAI(client = None, streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0.7, model_name=model_selector, max_tokens=int(max_tokens_textbox))
@@ -26,7 +38,7 @@ def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -
26
 
27
  # memory = ConversationKGMemory(llm=llm, input_key="question", output_key="answer")
28
  memory = ConversationBufferWindowMemory(input_key="question", output_key="answer", k=5)
29
- retriever = vectorstore.as_retriever()
30
  if len(k_textbox) != 0:
31
  retriever.search_kwargs = {"k": int(k_textbox)}
32
  else:
 
1
  # chat-pykg/chain.py
2
 
3
  from langchain.chains.base import Chain
 
 
4
  from langchain import HuggingFaceHub
5
  from langchain.chains.question_answering import load_qa_chain
6
  from langchain.chat_models import ChatOpenAI
 
11
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
12
  from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT
13
 
14
+ # logging.basicConfig(stream=sys.stdout, level=logging.INFO)
15
+ # logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
16
+
17
  def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -> Chain:
18
 
19
+ template = """You are called chat-pykg and are an AI assistant coded in python using langchain and gradio. You are very helpful for answering questions about various open source libraries.
20
+ You are given the following extracted parts of code and a question. Provide a conversational answer to the question.
21
+ Do NOT make up any hyperlinks that are not in the code.
22
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
23
+ If the question is not about the package documentation, politely inform them that you are tuned to only answer questions about the package documentations.
24
+ Question: {question}
25
+ =========
26
+ {context}
27
+ =========
28
+ Answer in Markdown:"""
29
+ QA_PROMPT.template = template
30
  if model_selector in ['gpt-4', 'gpt-3.5-turbo']:
31
  llm = ChatOpenAI(client = None, temperature=0.7, model_name=model_selector)
32
  doc_chain_llm = ChatOpenAI(client = None, streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0.7, model_name=model_selector, max_tokens=int(max_tokens_textbox))
 
38
 
39
  # memory = ConversationKGMemory(llm=llm, input_key="question", output_key="answer")
40
  memory = ConversationBufferWindowMemory(input_key="question", output_key="answer", k=5)
41
+ retriever = vectorstore.as_retriever(search_type="similarity")
42
  if len(k_textbox) != 0:
43
  retriever.search_kwargs = {"k": int(k_textbox)}
44
  else:
ingest.py CHANGED
@@ -9,6 +9,9 @@ from langchain.vectorstores import Chroma
9
  import shutil
10
  from pathlib import Path
11
  import subprocess
 
 
 
12
 
13
  # class CachedChroma(Chroma, ABC):
14
  # """
@@ -72,7 +75,6 @@ def get_text(content):
72
  def ingest_docs(all_collections_state, urls, chunk_size, chunk_overlap):
73
  """Get documents from web pages."""
74
  all_docs = []
75
-
76
  folders=[]
77
  documents = []
78
  shutil.rmtree('downloaded/', ignore_errors=True)
@@ -157,13 +159,12 @@ def ingest_docs(all_collections_state, urls, chunk_size, chunk_overlap):
157
  # else:
158
  # documents += text_splitter.split_documents(docs_by_ext[ext]
159
  all_docs += documents
160
- embeddings = HuggingFaceEmbeddings()
161
  if 'downloaded/' in folder:
162
  folder = '-'.join(folder.split('/')[1:])
163
  if folder == '.':
164
  folder = 'chat-pykg'
165
- vectorstore = Chroma.from_documents(persist_directory=".persisted_data", documents=documents, embedding=embeddings, collection_name=folder)
166
- vectorstore.persist()
167
  all_collections_state.append(folder)
168
  return all_collections_state
169
  # embeddings = HuggingFaceEmbeddings()
 
9
  import shutil
10
  from pathlib import Path
11
  import subprocess
12
+ import chromadb
13
+ from chromadb.config import Settings
14
+ import chromadb.utils.embedding_functions as ef
15
 
16
  # class CachedChroma(Chroma, ABC):
17
  # """
 
75
  def ingest_docs(all_collections_state, urls, chunk_size, chunk_overlap):
76
  """Get documents from web pages."""
77
  all_docs = []
 
78
  folders=[]
79
  documents = []
80
  shutil.rmtree('downloaded/', ignore_errors=True)
 
159
  # else:
160
  # documents += text_splitter.split_documents(docs_by_ext[ext]
161
  all_docs += documents
 
162
  if 'downloaded/' in folder:
163
  folder = '-'.join(folder.split('/')[1:])
164
  if folder == '.':
165
  folder = 'chat-pykg'
166
+ collection = Chroma.from_documents(documents=documents, collection_name=folder, embedding=HuggingFaceEmbeddings(), persist_directory=".persisted_data")
167
+ collection.persist()
168
  all_collections_state.append(folder)
169
  return all_collections_state
170
  # embeddings = HuggingFaceEmbeddings()