xangma commited on
Commit
572a6c9
1 Parent(s): df62f91
Files changed (3) hide show
  1. app.py +65 -41
  2. chain.py +62 -10
  3. requirements.txt +2 -1
app.py CHANGED
@@ -8,14 +8,12 @@ import string
8
  import sys
9
  from pathlib import Path
10
  import numpy as np
11
-
12
  import chromadb
13
  import gradio as gr
14
  from chromadb.config import Settings
15
  from langchain.docstore.document import Document
16
  from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
17
  from langchain.vectorstores import Chroma
18
- from langchain.retrievers import SVMRetriever
19
  from chain import get_new_chain1
20
  from ingest import embedding_chooser, ingest_docs
21
  logging.basicConfig(stream=sys.stdout, level=logging.INFO)
@@ -97,14 +95,18 @@ def merge_collections(collection_load_names, vs_state, k_textbox, search_type_se
97
  # merged_vectorstore.append(f.readlines())
98
  return merged_vectorstore
99
 
100
- def set_chain_up(openai_api_key, model_selector, k_textbox, search_type_selector, max_tokens_textbox, vectorstore_radio, vectorstore, agent):
101
  if not agent or type(agent) == str:
102
  if vectorstore != None:
103
  if model_selector in ["gpt-3.5-turbo", "gpt-4"]:
104
  if openai_api_key:
105
  os.environ["OPENAI_API_KEY"] = openai_api_key
 
 
106
  qa_chain = get_new_chain1(vectorstore, vectorstore_radio, model_selector, k_textbox, search_type_selector, max_tokens_textbox)
107
  os.environ["OPENAI_API_KEY"] = ""
 
 
108
  return qa_chain
109
  else:
110
  return 'no_open_aikey'
@@ -197,39 +199,7 @@ with block:
197
  with gr.Tabs() as tabs:
198
  with gr.TabItem("Chat", id=0):
199
  with gr.Row():
200
- openai_api_key_textbox = gr.Textbox(
201
- placeholder="Paste your OpenAI API key (sk-...)",
202
- show_label=False,
203
- lines=1,
204
- type="password",
205
- )
206
- model_selector = gr.Dropdown(
207
- choices=["gpt-3.5-turbo", "gpt-4", "other"],
208
- label="Model",
209
- show_label=True,
210
- value = "gpt-3.5-turbo"
211
- )
212
- k_textbox = gr.Textbox(
213
- placeholder="k: Number of search results to consider",
214
- label="Search Results k:",
215
- show_label=True,
216
- lines=1,
217
- value="20",
218
- )
219
- search_type_selector = gr.Dropdown(
220
- choices=["similarity", "mmr", "svm"],
221
- label="Search Type",
222
- show_label=True,
223
- value = "similarity"
224
- )
225
- max_tokens_textbox = gr.Textbox(
226
- placeholder="max_tokens: Maximum number of tokens to generate",
227
- label="max_tokens",
228
- show_label=True,
229
- lines=1,
230
- value="1000",
231
- )
232
- chatbot = gr.Chatbot()
233
  with gr.Row():
234
  clear_btn = gr.Button("Clear Chat", variant="secondary").style(full_width=False)
235
  message = gr.Textbox(
@@ -240,12 +210,66 @@ with block:
240
  submit = gr.Button(value="Send").style(full_width=False)
241
  gr.Examples(
242
  examples=[
243
- "What does this code do?",
244
  "I want to change the chat-pykg app to have a log viewer, where the user can see what python is doing in the background. How could I do that?",
245
- "Hello, I want to allow chat-pykg to search the internet before answering, can you help me change the code to do that? Thanks.",
 
246
  ],
247
  inputs=message,
248
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  gr.HTML(
251
  """
@@ -318,8 +342,8 @@ with block:
318
  debug_state.value = False
319
  radio_state = gr.State()
320
 
321
- submit.click(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, search_type_selector, max_tokens_textbox, select_vectorstore_radio, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
322
- message.submit(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, search_type_selector, max_tokens_textbox, select_vectorstore_radio, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
323
 
324
  load_collections_button.click(merge_collections, inputs=[collections_viewer, vs_state, k_textbox, search_type_selector, select_vectorstore_radio, select_embedding_radio], outputs=[vs_state])#.then(change_tab, None, tabs) #.then(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state])
325
  make_collections_button.click(ingest_docs, inputs=[all_collections_state, all_collections_to_get, chunk_size_textbox, chunk_overlap_textbox, select_vectorstore_radio, select_embedding_radio, debug_state], outputs=[all_collections_state, all_collections_to_get], show_progress=True).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
@@ -334,7 +358,7 @@ with block:
334
  select_vectorstore_radio.change(update_radio, inputs = select_vectorstore_radio, outputs = make_vectorstore_radio)
335
 
336
  # Whenever chain parameters change, destroy the agent.
337
- input_list = [openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, select_vectorstore_radio, make_embedding_radio]
338
  output_list = [agent_state]
339
  for input_item in input_list:
340
  input_item.change(
 
8
  import sys
9
  from pathlib import Path
10
  import numpy as np
 
11
  import chromadb
12
  import gradio as gr
13
  from chromadb.config import Settings
14
  from langchain.docstore.document import Document
15
  from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
16
  from langchain.vectorstores import Chroma
 
17
  from chain import get_new_chain1
18
  from ingest import embedding_chooser, ingest_docs
19
  logging.basicConfig(stream=sys.stdout, level=logging.INFO)
 
95
  # merged_vectorstore.append(f.readlines())
96
  return merged_vectorstore
97
 
98
+ def set_chain_up(openai_api_key, google_api_key, google_cse_id, model_selector, k_textbox, search_type_selector, max_tokens_textbox, vectorstore_radio, vectorstore, agent):
99
  if not agent or type(agent) == str:
100
  if vectorstore != None:
101
  if model_selector in ["gpt-3.5-turbo", "gpt-4"]:
102
  if openai_api_key:
103
  os.environ["OPENAI_API_KEY"] = openai_api_key
104
+ os.environ["GOOGLE_API_KEY"] = google_api_key
105
+ os.environ["GOOGLE_CSE_ID"] = google_cse_id
106
  qa_chain = get_new_chain1(vectorstore, vectorstore_radio, model_selector, k_textbox, search_type_selector, max_tokens_textbox)
107
  os.environ["OPENAI_API_KEY"] = ""
108
+ os.environ["GOOGLE_API_KEY"] = ""
109
+ os.environ["GOOGLE_CSE_ID"] = ""
110
  return qa_chain
111
  else:
112
  return 'no_open_aikey'
 
199
  with gr.Tabs() as tabs:
200
  with gr.TabItem("Chat", id=0):
201
  with gr.Row():
202
+ chatbot = gr.Chatbot()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  with gr.Row():
204
  clear_btn = gr.Button("Clear Chat", variant="secondary").style(full_width=False)
205
  message = gr.Textbox(
 
210
  submit = gr.Button(value="Send").style(full_width=False)
211
  gr.Examples(
212
  examples=[
 
213
  "I want to change the chat-pykg app to have a log viewer, where the user can see what python is doing in the background. How could I do that?",
214
+ "Hello, I want to allow chat-pykg to search google before answering. In the langchain docs it says you can use a tool to do this: from langchain.agents import load_tools\ntools = load_tools([“google-search”]). How would I need to change get_new_chain1 function to use tools when it needs to as well as searching the vectorstore? Thanks!",
215
+ "Great, thanks. What if I want to add other tools in the future? Can you please change get_new_chain1 function to do that?"
216
  ],
217
  inputs=message,
218
  )
219
+ with gr.Row():
220
+ with gr.Column(scale=1):
221
+ model_selector = gr.Dropdown(
222
+ choices=["gpt-3.5-turbo", "gpt-4", "other"],
223
+ label="Model",
224
+ show_label=True,
225
+ value = "gpt-4"
226
+ )
227
+ k_textbox = gr.Textbox(
228
+ placeholder="k: Number of search results to consider",
229
+ label="Search Results k:",
230
+ show_label=True,
231
+ lines=1,
232
+ value="20",
233
+ )
234
+ search_type_selector = gr.Dropdown(
235
+ choices=["similarity", "mmr", "svm"],
236
+ label="Search Type",
237
+ show_label=True,
238
+ value = "similarity"
239
+ )
240
+ max_tokens_textbox = gr.Textbox(
241
+ placeholder="max_tokens: Maximum number of tokens to generate",
242
+ label="max_tokens",
243
+ show_label=True,
244
+ lines=1,
245
+ value="500",
246
+ )
247
+ with gr.Column(scale=1):
248
+ gr.HTML("")
249
+ with gr.Column(scale=1):
250
+ gr.HTML("")
251
+ with gr.Column(scale=1):
252
+ openai_api_key_textbox = gr.Textbox(
253
+ placeholder="Paste your OpenAI API key (sk-...)",
254
+ show_label=True,
255
+ lines=1,
256
+ type="password",
257
+ label="OpenAI API Key",
258
+ )
259
+ google_api_key_textbox = gr.Textbox(
260
+ placeholder="Paste your Google API key (AIza...)",
261
+ show_label=True,
262
+ lines=1,
263
+ type="password",
264
+ label="Google API Key",
265
+ )
266
+ google_cse_id_textbox = gr.Textbox(
267
+ placeholder="Paste your Google CSE ID (0123...)",
268
+ show_label=True,
269
+ lines=1,
270
+ type="password",
271
+ label="Google CSE ID",
272
+ )
273
 
274
  gr.HTML(
275
  """
 
342
  debug_state.value = False
343
  radio_state = gr.State()
344
 
345
+ submit.click(set_chain_up, inputs=[openai_api_key_textbox, google_api_key_textbox, google_cse_id_textbox, model_selector, k_textbox, search_type_selector, max_tokens_textbox, select_vectorstore_radio, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
346
+ message.submit(set_chain_up, inputs=[openai_api_key_textbox, google_api_key_textbox, google_cse_id_textbox, model_selector, k_textbox, search_type_selector, max_tokens_textbox, select_vectorstore_radio, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
347
 
348
  load_collections_button.click(merge_collections, inputs=[collections_viewer, vs_state, k_textbox, search_type_selector, select_vectorstore_radio, select_embedding_radio], outputs=[vs_state])#.then(change_tab, None, tabs) #.then(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state])
349
  make_collections_button.click(ingest_docs, inputs=[all_collections_state, all_collections_to_get, chunk_size_textbox, chunk_overlap_textbox, select_vectorstore_radio, select_embedding_radio, debug_state], outputs=[all_collections_state, all_collections_to_get], show_progress=True).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
 
358
  select_vectorstore_radio.change(update_radio, inputs = select_vectorstore_radio, outputs = make_vectorstore_radio)
359
 
360
  # Whenever chain parameters change, destroy the agent.
361
+ input_list = [openai_api_key_textbox, model_selector, k_textbox, search_type_selector, max_tokens_textbox, select_vectorstore_radio, make_embedding_radio]
362
  output_list = [agent_state]
363
  for input_item in input_list:
364
  input_item.change(
chain.py CHANGED
@@ -15,28 +15,70 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
15
  from langchain.chains.llm import LLMChain
16
  from langchain.schema import BaseLanguageModel, BaseRetriever, Document
17
  from langchain.prompts.prompt import PromptTemplate
 
 
 
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def get_new_chain1(vectorstore, vectorstore_radio, model_selector, k_textbox, search_type_selector, max_tokens_textbox) -> Chain:
21
  retriever = None
22
  if vectorstore_radio == 'Chroma':
23
  retriever = vectorstore.as_retriever(search_type=search_type_selector)
24
  retriever.search_kwargs = {"k":int(k_textbox)}
 
 
25
  if vectorstore_radio == 'raw':
26
  if search_type_selector == 'svm':
27
  retriever = SVMRetriever.from_texts(merged_vectorstore, embedding_function)
28
  retriever.k = int(k_textbox)
29
 
30
- template = """You are called chat-pykg and are an AI assistant coded in python using langchain and gradio. You are very helpful for answering questions about various open source libraries.
31
- You are given the following extracted parts of code and a question. Provide a conversational answer to the question.
32
- Do NOT make up any hyperlinks that are not in the code.
 
 
 
33
  If you don't know the answer, just say that you don't know, don't try to make up an answer.
34
- Question: {question}
35
  =========
36
- {context}
37
  =========
38
- Answer in Markdown:"""
39
- QA_PROMPT.template = template
 
 
 
 
 
 
 
 
 
 
 
 
40
  if model_selector in ['gpt-4', 'gpt-3.5-turbo']:
41
  llm = ChatOpenAI(client = None, temperature=0.7, model_name=model_selector)
42
  doc_chain_llm = ChatOpenAI(client = None, streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0.7, model_name=model_selector, max_tokens=int(max_tokens_textbox))
@@ -49,8 +91,18 @@ def get_new_chain1(vectorstore, vectorstore_radio, model_selector, k_textbox, se
49
  # memory = ConversationKGMemory(llm=llm, input_key="question", output_key="answer")
50
  memory = ConversationBufferWindowMemory(input_key="question", output_key="answer", k=5)
51
 
52
- qa = ConversationalRetrievalChain(
53
- retriever=retriever, memory=memory, combine_docs_chain=doc_chain, question_generator=question_generator, verbose=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
54
- # qa._get_docs = _get_docs.__get__(qa, ConversationalRetrievalChain)
55
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  return qa
 
15
  from langchain.chains.llm import LLMChain
16
  from langchain.schema import BaseLanguageModel, BaseRetriever, Document
17
  from langchain.prompts.prompt import PromptTemplate
18
+ from langchain.utilities.google_serper import GoogleSerperAPIWrapper
19
+ from langchain.utilities.google_search import GoogleSearchAPIWrapper
20
+ from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
21
+ from langchain.agents.self_ask_with_search.prompt import PROMPT
22
 
23
+ class ConversationalRetrievalChainWithGoogleSearch(ConversationalRetrievalChain):
24
+ google_search_tool: GoogleSearchAPIWrapper
25
+
26
+ def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
27
+ # Get documents from the retriever
28
+ docs_from_retriever = self.retriever.get_relevant_documents(question)
29
+
30
+ # Get search results from Google Search
31
+ search_results = self.google_search_tool.results(question, num_results=self.google_search_tool.k)
32
+
33
+ # Create documents from the search results
34
+ docs_from_search = []
35
+ for result in search_results:
36
+ content = result.get("snippet", "")
37
+ metadata = {"title": result["title"], "link": result["link"]}
38
+ docs_from_search.append(Document(page_content=content, metadata=metadata))
39
+
40
+ # Combine both lists of documents
41
+ docs = docs_from_retriever + docs_from_search
42
+
43
+ return self._reduce_tokens_below_limit(docs)
44
 
45
  def get_new_chain1(vectorstore, vectorstore_radio, model_selector, k_textbox, search_type_selector, max_tokens_textbox) -> Chain:
46
  retriever = None
47
  if vectorstore_radio == 'Chroma':
48
  retriever = vectorstore.as_retriever(search_type=search_type_selector)
49
  retriever.search_kwargs = {"k":int(k_textbox)}
50
+ if search_type_selector == 'mmr':
51
+ retriever.search_kwargs = {"k":int(k_textbox), "fetch_k":4*int(k_textbox)}
52
  if vectorstore_radio == 'raw':
53
  if search_type_selector == 'svm':
54
  retriever = SVMRetriever.from_texts(merged_vectorstore, embedding_function)
55
  retriever.k = int(k_textbox)
56
 
57
+ qa_template = """You are called chat-pykg and are an AI assistant coded in python using langchain and gradio. You are very helpful for answering questions about programming with various open source packages and libraries.
58
+ You are given snippets of code and information in the Context below, as well as a Question to give a Helpful answer to.
59
+ Due to data size limitations, the snippets of code in the Context have been specifically filtered/selected for their relevance from a document store containing code from one or many packages and libraries.
60
+ Each of the code snippets is marked with '# source: package/filename' so you can attempt to establish where they are located in their package structure and gain more understanding of the code.
61
+ Please provide a helpful answer in markdown to the Question.
62
+ Do not make up any hyperlinks that are not in the Context.
63
  If you don't know the answer, just say that you don't know, don't try to make up an answer.
64
+
65
  =========
66
+ Context:{context}
67
  =========
68
+ Question: {question}
69
+ Helpful answer:"""
70
+ QA_PROMPT.template = qa_template
71
+
72
+ condense_question_template = """Given the following conversation and a Follow Up Input, rephrase the Follow Up Input to be a Standalone question.
73
+ The Standalone question will be used for retrieving relevant source code and information from a document store, where each document is marked with '# source: package/filename'.
74
+ Therefore, in your Standalone question you must try to include references to related code or sources that have been mentioned in the Follow Up Input or Chat History.
75
+ =========
76
+ Chat History:
77
+ {chat_history}
78
+ =========
79
+ Follow Up Input: {question}
80
+ Standalone question in markdown:"""
81
+ CONDENSE_QUESTION_PROMPT.template = condense_question_template
82
  if model_selector in ['gpt-4', 'gpt-3.5-turbo']:
83
  llm = ChatOpenAI(client = None, temperature=0.7, model_name=model_selector)
84
  doc_chain_llm = ChatOpenAI(client = None, streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0.7, model_name=model_selector, max_tokens=int(max_tokens_textbox))
 
91
  # memory = ConversationKGMemory(llm=llm, input_key="question", output_key="answer")
92
  memory = ConversationBufferWindowMemory(input_key="question", output_key="answer", k=5)
93
 
94
+ google_search_tool = GoogleSearchAPIWrapper(search_engine = "google", k = int(int(k_textbox)/2))
 
 
95
 
96
+ qa_orig = ConversationalRetrievalChain(
97
+ retriever=retriever, memory=memory, combine_docs_chain=doc_chain, question_generator=question_generator, verbose=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
98
+ qa_with_google_search = ConversationalRetrievalChainWithGoogleSearch(
99
+ retriever=retriever,
100
+ memory=memory,
101
+ combine_docs_chain=doc_chain,
102
+ question_generator=question_generator,
103
+ google_search_tool=google_search_tool,
104
+ verbose=True,
105
+ callback_manager=CallbackManager([StreamingStdOutCallbackHandler()])
106
+ )
107
+ qa = qa_orig
108
  return qa
requirements.txt CHANGED
@@ -7,4 +7,5 @@ transformers
7
  gradio
8
  chromadb
9
  sentence_transformers
10
- python-magic
 
 
7
  gradio
8
  chromadb
9
  sentence_transformers
10
+ python-magic
11
+ google-api-python-client