xangma
commited on
Commit
•
80b4f00
1
Parent(s):
a3ab5d3
latest fixes
Browse files
app.py
CHANGED
@@ -1,18 +1,23 @@
|
|
1 |
# chat-pykg/app.py
|
2 |
import datetime
|
3 |
import os
|
4 |
-
import
|
|
|
|
|
|
|
5 |
import chromadb
|
|
|
6 |
from chromadb.config import Settings
|
7 |
-
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
8 |
-
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
9 |
-
from langchain.vectorstores import Chroma
|
10 |
from langchain.docstore.document import Document
|
11 |
-
import
|
12 |
-
import
|
|
|
13 |
from chain import get_new_chain1
|
14 |
from ingest import ingest_docs
|
15 |
|
|
|
|
|
|
|
16 |
def randomword(length):
|
17 |
letters = string.ascii_lowercase
|
18 |
return ''.join(random.choice(letters) for i in range(length))
|
@@ -23,22 +28,21 @@ def change_tab():
|
|
23 |
def merge_collections(collection_load_names, vs_state):
|
24 |
merged_documents = []
|
25 |
merged_embeddings = []
|
26 |
-
client = chromadb.Client(Settings(
|
27 |
-
chroma_db_impl="duckdb+parquet",
|
28 |
-
persist_directory=".persisted_data" # Optional, defaults to .chromadb/ in the current directory
|
29 |
-
))
|
30 |
-
|
31 |
for collection_name in collection_load_names:
|
32 |
-
|
|
|
|
|
|
|
|
|
33 |
if collection_name == '':
|
34 |
continue
|
35 |
-
|
36 |
-
collection =
|
37 |
for i in range(len(collection['documents'])):
|
38 |
merged_documents.append(Document(page_content=collection['documents'][i], metadata = collection['metadatas'][i]))
|
39 |
merged_embeddings.append(collection['embeddings'][i])
|
40 |
-
|
41 |
-
merged_vectorstore
|
42 |
return merged_vectorstore
|
43 |
|
44 |
def set_chain_up(openai_api_key, model_selector, k_textbox, max_tokens_textbox, vectorstore, agent):
|
@@ -66,9 +70,12 @@ def delete_collection(all_collections_state, collections_viewer):
|
|
66 |
persist_directory=".persisted_data" # Optional, defaults to .chromadb/ in the current directory
|
67 |
))
|
68 |
for collection in collections_viewer:
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
72 |
return all_collections_state, collections_viewer
|
73 |
|
74 |
def delete_all_collections(all_collections_state):
|
@@ -91,6 +98,9 @@ def destroy_agent(agent):
|
|
91 |
agent = None
|
92 |
return agent
|
93 |
|
|
|
|
|
|
|
94 |
def chat(inp, history, agent):
|
95 |
history = history or []
|
96 |
if type(agent) == str:
|
@@ -144,15 +154,16 @@ with block:
|
|
144 |
show_label=True,
|
145 |
lines=1,
|
146 |
)
|
147 |
-
max_tokens_textbox.value="
|
148 |
chatbot = gr.Chatbot()
|
149 |
with gr.Row():
|
|
|
150 |
message = gr.Textbox(
|
151 |
label="What's your question?",
|
152 |
placeholder="What is this code?",
|
153 |
lines=1,
|
154 |
)
|
155 |
-
submit = gr.Button(value="Send"
|
156 |
gr.Examples(
|
157 |
examples=[
|
158 |
"What does this code do?",
|
@@ -207,14 +218,14 @@ with block:
|
|
207 |
chat_state = gr.State()
|
208 |
|
209 |
submit.click(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
|
210 |
-
message.submit(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
|
211 |
|
212 |
load_collections_button.click(merge_collections, inputs=[collections_viewer, vs_state], outputs=[vs_state])#.then(change_tab, None, tabs) #.then(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state])
|
213 |
make_collections_button.click(ingest_docs, inputs=[all_collections_state, all_collections_to_get, chunk_size_textbox, chunk_overlap_textbox], outputs=[all_collections_state], show_progress=True).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
214 |
delete_collections_button.click(delete_collection, inputs=[all_collections_state, collections_viewer], outputs=[all_collections_state, collections_viewer]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
215 |
delete_all_collections_button.click(delete_all_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
216 |
get_all_collection_names_button.click(list_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
217 |
-
|
218 |
# Whenever chain parameters change, destroy the agent.
|
219 |
input_list = [openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox]
|
220 |
output_list = [agent_state]
|
|
|
1 |
# chat-pykg/app.py
|
2 |
import datetime
|
3 |
import os
|
4 |
+
import random
|
5 |
+
import shutil
|
6 |
+
import string
|
7 |
+
|
8 |
import chromadb
|
9 |
+
import gradio as gr
|
10 |
from chromadb.config import Settings
|
|
|
|
|
|
|
11 |
from langchain.docstore.document import Document
|
12 |
+
from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
|
13 |
+
from langchain.vectorstores import Chroma
|
14 |
+
|
15 |
from chain import get_new_chain1
|
16 |
from ingest import ingest_docs
|
17 |
|
18 |
+
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
19 |
+
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
20 |
+
|
21 |
def randomword(length):
|
22 |
letters = string.ascii_lowercase
|
23 |
return ''.join(random.choice(letters) for i in range(length))
|
|
|
28 |
def merge_collections(collection_load_names, vs_state):
|
29 |
merged_documents = []
|
30 |
merged_embeddings = []
|
|
|
|
|
|
|
|
|
|
|
31 |
for collection_name in collection_load_names:
|
32 |
+
chroma_obj_get = chromadb.Client(Settings(
|
33 |
+
chroma_db_impl="duckdb+parquet",
|
34 |
+
persist_directory=".persisted_data",
|
35 |
+
anonymized_telemetry = True
|
36 |
+
))
|
37 |
if collection_name == '':
|
38 |
continue
|
39 |
+
collection_obj = chroma_obj_get.get_collection(collection_name, embedding_function=HuggingFaceEmbeddings())
|
40 |
+
collection = collection_obj.get(include=["metadatas", "documents", "embeddings"])
|
41 |
for i in range(len(collection['documents'])):
|
42 |
merged_documents.append(Document(page_content=collection['documents'][i], metadata = collection['metadatas'][i]))
|
43 |
merged_embeddings.append(collection['embeddings'][i])
|
44 |
+
merged_vectorstore = Chroma(collection_name="temp", embedding_function=HuggingFaceEmbeddings())
|
45 |
+
merged_vectorstore.add_documents(documents=merged_documents, embeddings=merged_embeddings)
|
46 |
return merged_vectorstore
|
47 |
|
48 |
def set_chain_up(openai_api_key, model_selector, k_textbox, max_tokens_textbox, vectorstore, agent):
|
|
|
70 |
persist_directory=".persisted_data" # Optional, defaults to .chromadb/ in the current directory
|
71 |
))
|
72 |
for collection in collections_viewer:
|
73 |
+
try:
|
74 |
+
client.delete_collection(collection)
|
75 |
+
all_collections_state.remove(collection)
|
76 |
+
collections_viewer.remove(collection)
|
77 |
+
except:
|
78 |
+
continue
|
79 |
return all_collections_state, collections_viewer
|
80 |
|
81 |
def delete_all_collections(all_collections_state):
|
|
|
98 |
agent = None
|
99 |
return agent
|
100 |
|
101 |
+
def clear_chat(chatbot, history):
|
102 |
+
return [], []
|
103 |
+
|
104 |
def chat(inp, history, agent):
|
105 |
history = history or []
|
106 |
if type(agent) == str:
|
|
|
154 |
show_label=True,
|
155 |
lines=1,
|
156 |
)
|
157 |
+
max_tokens_textbox.value="1000"
|
158 |
chatbot = gr.Chatbot()
|
159 |
with gr.Row():
|
160 |
+
clear_btn = gr.Button("Clear Chat", variant="secondary").style(full_width=False)
|
161 |
message = gr.Textbox(
|
162 |
label="What's your question?",
|
163 |
placeholder="What is this code?",
|
164 |
lines=1,
|
165 |
)
|
166 |
+
submit = gr.Button(value="Send").style(full_width=False)
|
167 |
gr.Examples(
|
168 |
examples=[
|
169 |
"What does this code do?",
|
|
|
218 |
chat_state = gr.State()
|
219 |
|
220 |
submit.click(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
|
221 |
+
message.submit(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
|
222 |
|
223 |
load_collections_button.click(merge_collections, inputs=[collections_viewer, vs_state], outputs=[vs_state])#.then(change_tab, None, tabs) #.then(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state])
|
224 |
make_collections_button.click(ingest_docs, inputs=[all_collections_state, all_collections_to_get, chunk_size_textbox, chunk_overlap_textbox], outputs=[all_collections_state], show_progress=True).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
225 |
delete_collections_button.click(delete_collection, inputs=[all_collections_state, collections_viewer], outputs=[all_collections_state, collections_viewer]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
226 |
delete_all_collections_button.click(delete_all_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
227 |
get_all_collection_names_button.click(list_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
|
228 |
+
clear_btn.click(clear_chat, inputs = [chatbot, history_state], outputs = [chatbot, history_state])
|
229 |
# Whenever chain parameters change, destroy the agent.
|
230 |
input_list = [openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox]
|
231 |
output_list = [agent_state]
|
chain.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
# chat-pykg/chain.py
|
2 |
|
3 |
from langchain.chains.base import Chain
|
4 |
-
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
5 |
-
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
6 |
from langchain import HuggingFaceHub
|
7 |
from langchain.chains.question_answering import load_qa_chain
|
8 |
from langchain.chat_models import ChatOpenAI
|
@@ -13,8 +11,22 @@ from langchain.callbacks.base import CallbackManager
|
|
13 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
14 |
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT
|
15 |
|
|
|
|
|
|
|
16 |
def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -> Chain:
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
if model_selector in ['gpt-4', 'gpt-3.5-turbo']:
|
19 |
llm = ChatOpenAI(client = None, temperature=0.7, model_name=model_selector)
|
20 |
doc_chain_llm = ChatOpenAI(client = None, streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0.7, model_name=model_selector, max_tokens=int(max_tokens_textbox))
|
@@ -26,7 +38,7 @@ def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -
|
|
26 |
|
27 |
# memory = ConversationKGMemory(llm=llm, input_key="question", output_key="answer")
|
28 |
memory = ConversationBufferWindowMemory(input_key="question", output_key="answer", k=5)
|
29 |
-
retriever = vectorstore.as_retriever()
|
30 |
if len(k_textbox) != 0:
|
31 |
retriever.search_kwargs = {"k": int(k_textbox)}
|
32 |
else:
|
|
|
1 |
# chat-pykg/chain.py
|
2 |
|
3 |
from langchain.chains.base import Chain
|
|
|
|
|
4 |
from langchain import HuggingFaceHub
|
5 |
from langchain.chains.question_answering import load_qa_chain
|
6 |
from langchain.chat_models import ChatOpenAI
|
|
|
11 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
12 |
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT
|
13 |
|
14 |
+
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
15 |
+
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
16 |
+
|
17 |
def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -> Chain:
|
18 |
|
19 |
+
template = """You are called chat-pykg and are an AI assistant coded in python using langchain and gradio. You are very helpful for answering questions about various open source libraries.
|
20 |
+
You are given the following extracted parts of code and a question. Provide a conversational answer to the question.
|
21 |
+
Do NOT make up any hyperlinks that are not in the code.
|
22 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
23 |
+
If the question is not about the package documentation, politely inform them that you are tuned to only answer questions about the package documentations.
|
24 |
+
Question: {question}
|
25 |
+
=========
|
26 |
+
{context}
|
27 |
+
=========
|
28 |
+
Answer in Markdown:"""
|
29 |
+
QA_PROMPT.template = template
|
30 |
if model_selector in ['gpt-4', 'gpt-3.5-turbo']:
|
31 |
llm = ChatOpenAI(client = None, temperature=0.7, model_name=model_selector)
|
32 |
doc_chain_llm = ChatOpenAI(client = None, streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0.7, model_name=model_selector, max_tokens=int(max_tokens_textbox))
|
|
|
38 |
|
39 |
# memory = ConversationKGMemory(llm=llm, input_key="question", output_key="answer")
|
40 |
memory = ConversationBufferWindowMemory(input_key="question", output_key="answer", k=5)
|
41 |
+
retriever = vectorstore.as_retriever(search_type="similarity")
|
42 |
if len(k_textbox) != 0:
|
43 |
retriever.search_kwargs = {"k": int(k_textbox)}
|
44 |
else:
|
ingest.py
CHANGED
@@ -9,6 +9,9 @@ from langchain.vectorstores import Chroma
|
|
9 |
import shutil
|
10 |
from pathlib import Path
|
11 |
import subprocess
|
|
|
|
|
|
|
12 |
|
13 |
# class CachedChroma(Chroma, ABC):
|
14 |
# """
|
@@ -72,7 +75,6 @@ def get_text(content):
|
|
72 |
def ingest_docs(all_collections_state, urls, chunk_size, chunk_overlap):
|
73 |
"""Get documents from web pages."""
|
74 |
all_docs = []
|
75 |
-
|
76 |
folders=[]
|
77 |
documents = []
|
78 |
shutil.rmtree('downloaded/', ignore_errors=True)
|
@@ -157,13 +159,12 @@ def ingest_docs(all_collections_state, urls, chunk_size, chunk_overlap):
|
|
157 |
# else:
|
158 |
# documents += text_splitter.split_documents(docs_by_ext[ext]
|
159 |
all_docs += documents
|
160 |
-
embeddings = HuggingFaceEmbeddings()
|
161 |
if 'downloaded/' in folder:
|
162 |
folder = '-'.join(folder.split('/')[1:])
|
163 |
if folder == '.':
|
164 |
folder = 'chat-pykg'
|
165 |
-
|
166 |
-
|
167 |
all_collections_state.append(folder)
|
168 |
return all_collections_state
|
169 |
# embeddings = HuggingFaceEmbeddings()
|
|
|
9 |
import shutil
|
10 |
from pathlib import Path
|
11 |
import subprocess
|
12 |
+
import chromadb
|
13 |
+
from chromadb.config import Settings
|
14 |
+
import chromadb.utils.embedding_functions as ef
|
15 |
|
16 |
# class CachedChroma(Chroma, ABC):
|
17 |
# """
|
|
|
75 |
def ingest_docs(all_collections_state, urls, chunk_size, chunk_overlap):
|
76 |
"""Get documents from web pages."""
|
77 |
all_docs = []
|
|
|
78 |
folders=[]
|
79 |
documents = []
|
80 |
shutil.rmtree('downloaded/', ignore_errors=True)
|
|
|
159 |
# else:
|
160 |
# documents += text_splitter.split_documents(docs_by_ext[ext]
|
161 |
all_docs += documents
|
|
|
162 |
if 'downloaded/' in folder:
|
163 |
folder = '-'.join(folder.split('/')[1:])
|
164 |
if folder == '.':
|
165 |
folder = 'chat-pykg'
|
166 |
+
collection = Chroma.from_documents(documents=documents, collection_name=folder, embedding=HuggingFaceEmbeddings(), persist_directory=".persisted_data")
|
167 |
+
collection.persist()
|
168 |
all_collections_state.append(folder)
|
169 |
return all_collections_state
|
170 |
# embeddings = HuggingFaceEmbeddings()
|