AlanOC commited on
Commit
2eabc61
1 Parent(s): 43d8dea

Update app.py

Browse files

Attempt to add Claude Sonnet 3.5 option

Files changed (1) hide show
  1. app.py +61 -3
app.py CHANGED
@@ -31,6 +31,7 @@ from langchain.prompts import HumanMessagePromptTemplate
31
  from langchain.prompts import ChatMessagePromptTemplate
32
  from langchain.prompts import ChatPromptTemplate
33
  from wordcloud import WordCloud
 
34
 
35
 
36
  # Function to get base64 encoding of an image
@@ -85,6 +86,21 @@ worksheet = sheet.get_worksheet(0)
85
  # Retrieve the API key from the environment variables
86
  api_key = os.getenv("OPENAI_API_KEY")
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  # Check if the API key is available, if not, raise an error
89
  if api_key is None:
90
  raise ValueError("API key not found. Ensure that the OPENAI_API_KEY environment variable is set.")
@@ -121,6 +137,9 @@ def create_copy_button(text_to_copy):
121
  """
122
  return copy_js
123
 
 
 
 
124
 
125
  # Create a Chroma database instance using the selected directory
126
  def create_chroma_instance(directory):
@@ -188,6 +207,43 @@ def ask_alans_ai(query, vectordb, chat_history, aoc_qa):
188
 
189
  return answer
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  def clear_input_box():
192
  st.session_state["new_item"] = ""
193
 
@@ -287,6 +343,7 @@ def main():
287
  'gpt-3.5-turbo-16k',
288
  'gpt-3.5-turbo-1106',
289
  'gpt-4',
 
290
 
291
  # Other custom or fine-tuned models can be added here
292
  ]
@@ -389,14 +446,15 @@ def main():
389
  if st.button("Select 12"):
390
  st.session_state.user_selected_avatar = avatar_12
391
 
392
- ############ Set up the LangChain Conversational Retrieval Chain ################
 
393
  aoc_qa = ConversationalRetrievalChain.from_llm(
394
- ChatOpenAI(temperature=ai_temp, model_name=selected_model),
395
  retriever=vectordb.as_retriever(search_kwargs={'k': k_value}, search_type=selected_search_type),
396
  chain_type='stuff',
397
  return_source_documents=True,
398
  verbose=False,
399
- combine_docs_chain_kwargs={"prompt": qa_prompt})
400
 
401
  # HTML for social media links with base64-encoded images
402
  social_media_html = f"""
 
31
  from langchain.prompts import ChatMessagePromptTemplate
32
  from langchain.prompts import ChatPromptTemplate
33
  from wordcloud import WordCloud
34
+ from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT # New import for Anthropic
35
 
36
 
37
  # Function to get base64 encoding of an image
 
86
  # Retrieve the API key from the environment variables
87
  api_key = os.getenv("OPENAI_API_KEY")
88
 
89
+ # Function to get Claude Sonnet model
90
+ def get_claude_sonnet():
91
+ anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
92
+ if not anthropic_api_key:
93
+ raise ValueError("Anthropic API key not found. Set the ANTHROPIC_API_KEY environment variable.")
94
+ return Anthropic(api_key=anthropic_api_key)
95
+
96
+ # Function to get the appropriate LLM based on the selected model
97
+ def get_llm(model_name, temperature):
98
+ if model_name == 'claude-3-sonnet-20240229':
99
+ return get_claude_sonnet()
100
+ else:
101
+ return ChatOpenAI(temperature=temperature, model_name=model_name)
102
+
103
+
104
  # Check if the API key is available, if not, raise an error
105
  if api_key is None:
106
  raise ValueError("API key not found. Ensure that the OPENAI_API_KEY environment variable is set.")
 
137
  """
138
  return copy_js
139
 
140
+
141
+
142
+
143
 
144
  # Create a Chroma database instance using the selected directory
145
  def create_chroma_instance(directory):
 
207
 
208
  return answer
209
 
210
+
211
+ # Update the ask_alans_ai function to handle Claude Sonnet
212
+ def ask_alans_ai(query, vectordb, chat_history, aoc_qa):
213
+ filtered_chat_history = [(q, a) for q, a in chat_history if a is not None]
214
+
215
+ try:
216
+ if isinstance(aoc_qa.llm, Anthropic):
217
+ # Handle Claude Sonnet
218
+ context = aoc_qa.retriever.get_relevant_documents(query)
219
+ context_str = "\n".join([doc.page_content for doc in context])
220
+
221
+ prompt = f"{HUMAN_PROMPT} Here's some context:\n{context_str}\n\nNow answer this question: {query}{AI_PROMPT}"
222
+
223
+ response = aoc_qa.llm.complete(
224
+ prompt=prompt,
225
+ max_tokens_to_sample=1000,
226
+ model="claude-3-sonnet-20240229",
227
+ temperature=ai_temp
228
+ )
229
+ answer = response.completion
230
+ else:
231
+ # Handle other models
232
+ result = aoc_qa.invoke({"question": query, "chat_history": filtered_chat_history, "vectordb": vectordb})
233
+ answer = result["answer"]
234
+
235
+ chat_history.append((query, answer))
236
+ return answer
237
+ except Exception as e:
238
+ st.error(f"An error occurred: {str(e)}")
239
+ return "I'm sorry, but I encountered an error while processing your request. Please try again later."
240
+
241
+
242
+
243
+
244
+
245
+
246
+
247
  def clear_input_box():
248
  st.session_state["new_item"] = ""
249
 
 
343
  'gpt-3.5-turbo-16k',
344
  'gpt-3.5-turbo-1106',
345
  'gpt-4',
346
+ 'claude-3-sonnet-20240229'
347
 
348
  # Other custom or fine-tuned models can be added here
349
  ]
 
446
  if st.button("Select 12"):
447
  st.session_state.user_selected_avatar = avatar_12
448
 
449
+ ############ Set up the LangChain Conversational Retrieval Chain ################
450
+
451
  aoc_qa = ConversationalRetrievalChain.from_llm(
452
+ get_llm(selected_model, ai_temp),
453
  retriever=vectordb.as_retriever(search_kwargs={'k': k_value}, search_type=selected_search_type),
454
  chain_type='stuff',
455
  return_source_documents=True,
456
  verbose=False,
457
+ combine_docs_chain_kwargs={"prompt": qa_prompt})
458
 
459
  # HTML for social media links with base64-encoded images
460
  social_media_html = f"""