xangma commited on
Commit
1433ce9
1 Parent(s): bc3a2c2
Files changed (3) hide show
  1. app.py +39 -32
  2. chain.py +2 -2
  3. ingest.py +1 -2
app.py CHANGED
@@ -41,21 +41,24 @@ def merge_collections(collection_load_names, vs_state):
41
  merged_vectorstore = Chroma.from_documents(documents=merged_documents, embeddings=merged_embeddings, collection_name=merged_collection_name)
42
  return merged_vectorstore
43
 
44
- def set_chain_up(openai_api_key, model_selector, k_textbox, vectorstore, agent):
45
- if vectorstore == None:
46
- return 'no_vectorstore'
47
- if vectorstore != None:
48
- if model_selector in ["gpt-3.5-turbo", "gpt-4"]:
49
- if openai_api_key:
50
- os.environ["OPENAI_API_KEY"] = openai_api_key
51
- qa_chain = get_new_chain1(vectorstore, model_selector, k_textbox)
52
- os.environ["OPENAI_API_KEY"] = ""
53
- return qa_chain
 
54
  else:
55
- return 'no_open_aikey'
 
56
  else:
57
- qa_chain = get_new_chain1(vectorstore, model_selector, k_textbox)
58
- return qa_chain
 
59
 
60
  def delete_vs(all_collections_state, collections_viewer):
61
  client = chromadb.Client(Settings(
@@ -83,6 +86,10 @@ def update_checkboxgroup(all_collections_state):
83
  new_options = [i for i in all_collections_state]
84
  return gr.CheckboxGroup.update(choices=new_options)
85
 
 
 
 
 
86
  def chat(inp, history, agent):
87
  history = history or []
88
  if type(agent) == str:
@@ -129,7 +136,14 @@ with block:
129
  show_label=True,
130
  lines=1,
131
  )
132
- k_textbox.value = "10"
 
 
 
 
 
 
 
133
  chatbot = gr.Chatbot()
134
  with gr.Row():
135
  message = gr.Textbox(
@@ -174,31 +188,24 @@ with block:
174
  all_collections_state = gr.State()
175
  chat_state = gr.State()
176
 
177
- submit.click(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
178
  message.submit(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
179
 
180
- get_vs_button.click(merge_collections, inputs=[collections_viewer, vs_state], outputs=[vs_state]).then(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, vs_state, agent_state], outputs=[agent_state, tabs])
181
  make_vs_button.click(ingest_docs, inputs=[all_collections_state, all_collections_to_get], outputs=[all_collections_state], show_progress=True).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
182
  delete_vs_button.click(delete_vs, inputs=[all_collections_state, collections_viewer], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
183
  delete_all_vs_button.click(delete_all_vs, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
184
  get_all_vs_names_button.click(list_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
185
 
186
- #I need to also parse this code in the docstore so I can ask it to fix silly things like this below:
187
- openai_api_key_textbox.change(
188
- set_chain_up,
189
- inputs=[openai_api_key_textbox, model_selector, k_textbox, vs_state, agent_state],
190
- outputs=[agent_state],
191
- )
192
- model_selector.change(
193
- set_chain_up,
194
- inputs=[openai_api_key_textbox, model_selector, k_textbox, vs_state, agent_state],
195
- outputs=[agent_state],
196
- )
197
- k_textbox.change(
198
- set_chain_up,
199
- inputs=[openai_api_key_textbox, model_selector, k_textbox, vs_state, agent_state],
200
- outputs=[agent_state],
201
- )
202
  all_collections_state.value = list_collections(all_collections_state)
203
  block.load(update_checkboxgroup, inputs = all_collections_state, outputs = collections_viewer)
204
  block.launch(debug=True)
 
41
  merged_vectorstore = Chroma.from_documents(documents=merged_documents, embeddings=merged_embeddings, collection_name=merged_collection_name)
42
  return merged_vectorstore
43
 
44
+ def set_chain_up(openai_api_key, model_selector, k_textbox, max_tokens_textbox, vectorstore, agent):
45
+ if agent == None or type(agent) == str:
46
+ if vectorstore != None:
47
+ if model_selector in ["gpt-3.5-turbo", "gpt-4"]:
48
+ if openai_api_key:
49
+ os.environ["OPENAI_API_KEY"] = openai_api_key
50
+ qa_chain = get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox)
51
+ os.environ["OPENAI_API_KEY"] = ""
52
+ return qa_chain
53
+ else:
54
+ return 'no_open_aikey'
55
  else:
56
+ qa_chain = get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox)
57
+ return qa_chain
58
  else:
59
+ return 'no_vectorstore'
60
+ else:
61
+ return agent
62
 
63
  def delete_vs(all_collections_state, collections_viewer):
64
  client = chromadb.Client(Settings(
 
86
  new_options = [i for i in all_collections_state]
87
  return gr.CheckboxGroup.update(choices=new_options)
88
 
89
+ def destroy_agent(agent):
90
+ agent = None
91
+ return agent
92
+
93
  def chat(inp, history, agent):
94
  history = history or []
95
  if type(agent) == str:
 
136
  show_label=True,
137
  lines=1,
138
  )
139
+ k_textbox.value = "20"
140
+ max_tokens_textbox = gr.Textbox(
141
+ placeholder="max_tokens: Maximum number of tokens to generate",
142
+ label="max_tokens",
143
+ show_label=True,
144
+ lines=1,
145
+ )
146
+ max_tokens_textbox.value="2000"
147
  chatbot = gr.Chatbot()
148
  with gr.Row():
149
  message = gr.Textbox(
 
188
  all_collections_state = gr.State()
189
  chat_state = gr.State()
190
 
191
+ submit.click(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state]).then(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
192
  message.submit(chat, inputs=[message, history_state, agent_state], outputs=[chatbot, history_state])
193
 
194
+ get_vs_button.click(merge_collections, inputs=[collections_viewer, vs_state], outputs=[vs_state])#.then(set_chain_up, inputs=[openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox, vs_state, agent_state], outputs=[agent_state])
195
  make_vs_button.click(ingest_docs, inputs=[all_collections_state, all_collections_to_get], outputs=[all_collections_state], show_progress=True).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
196
  delete_vs_button.click(delete_vs, inputs=[all_collections_state, collections_viewer], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
197
  delete_all_vs_button.click(delete_all_vs, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
198
  get_all_vs_names_button.click(list_collections, inputs=[all_collections_state], outputs=[all_collections_state]).then(update_checkboxgroup, inputs = [all_collections_state], outputs = [collections_viewer])
199
 
200
+ # Whenever chain parameters change, destroy the agent.
201
+ input_list = [openai_api_key_textbox, model_selector, k_textbox, max_tokens_textbox]
202
+ output_list = [agent_state]
203
+ for input_item in input_list:
204
+ input_item.change(
205
+ destroy_agent,
206
+ inputs=output_list,
207
+ outputs=output_list,
208
+ )
 
 
 
 
 
 
 
209
  all_collections_state.value = list_collections(all_collections_state)
210
  block.load(update_checkboxgroup, inputs = all_collections_state, outputs = collections_viewer)
211
  block.launch(debug=True)
chain.py CHANGED
@@ -13,12 +13,12 @@ from langchain.callbacks.base import CallbackManager
13
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
14
  from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT
15
 
16
- def get_new_chain1(vectorstore, model_selector, k_textbox) -> Chain:
17
  max_tokens_dict = {'gpt-4': 2000, 'gpt-3.5-turbo': 1000}
18
 
19
  if model_selector in ['gpt-4', 'gpt-3.5-turbo']:
20
  llm = ChatOpenAI(client = None, temperature=0.7, model_name=model_selector)
21
- doc_chain_llm = ChatOpenAI(client = None, streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0.7, model_name=model_selector, max_tokens=max_tokens_dict[model_selector])
22
  if model_selector == 'other':
23
  llm = HuggingFaceHub(repo_id="chavinlo/gpt4-x-alpaca")#, model_kwargs={"temperature":0, "max_length":64})
24
  doc_chain_llm = HuggingFaceHub(repo_id="chavinlo/gpt4-x-alpaca")
 
13
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
14
  from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT
15
 
16
+ def get_new_chain1(vectorstore, model_selector, k_textbox, max_tokens_textbox) -> Chain:
17
  max_tokens_dict = {'gpt-4': 2000, 'gpt-3.5-turbo': 1000}
18
 
19
  if model_selector in ['gpt-4', 'gpt-3.5-turbo']:
20
  llm = ChatOpenAI(client = None, temperature=0.7, model_name=model_selector)
21
+ doc_chain_llm = ChatOpenAI(client = None, streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0.7, model_name=model_selector, max_tokens=int(max_tokens_textbox))
22
  if model_selector == 'other':
23
  llm = HuggingFaceHub(repo_id="chavinlo/gpt4-x-alpaca")#, model_kwargs={"temperature":0, "max_length":64})
24
  doc_chain_llm = HuggingFaceHub(repo_id="chavinlo/gpt4-x-alpaca")
ingest.py CHANGED
@@ -72,7 +72,7 @@ def get_text(content):
72
  def ingest_docs(all_collections_state, urls):
73
  """Get documents from web pages."""
74
  all_docs = []
75
- local = False
76
  folders=[]
77
  documents = []
78
  shutil.rmtree('downloaded/', ignore_errors=True)
@@ -90,7 +90,6 @@ def ingest_docs(all_collections_state, urls):
90
  if url == '':
91
  continue
92
  if "." in url:
93
- local = True
94
  if len(url) > 1:
95
  folder = url.split('.')[1]
96
  else:
 
72
  def ingest_docs(all_collections_state, urls):
73
  """Get documents from web pages."""
74
  all_docs = []
75
+
76
  folders=[]
77
  documents = []
78
  shutil.rmtree('downloaded/', ignore_errors=True)
 
90
  if url == '':
91
  continue
92
  if "." in url:
 
93
  if len(url) > 1:
94
  folder = url.split('.')[1]
95
  else: