Spaces:
Sleeping
Sleeping
Update app.py
Browse filesAttempt to add Claude Sonnet 3.5 option
app.py
CHANGED
@@ -31,6 +31,7 @@ from langchain.prompts import HumanMessagePromptTemplate
|
|
31 |
from langchain.prompts import ChatMessagePromptTemplate
|
32 |
from langchain.prompts import ChatPromptTemplate
|
33 |
from wordcloud import WordCloud
|
|
|
34 |
|
35 |
|
36 |
# Function to get base64 encoding of an image
|
@@ -85,6 +86,21 @@ worksheet = sheet.get_worksheet(0)
|
|
85 |
# Retrieve the API key from the environment variables
|
86 |
api_key = os.getenv("OPENAI_API_KEY")
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
# Check if the API key is available, if not, raise an error
|
89 |
if api_key is None:
|
90 |
raise ValueError("API key not found. Ensure that the OPENAI_API_KEY environment variable is set.")
|
@@ -121,6 +137,9 @@ def create_copy_button(text_to_copy):
|
|
121 |
"""
|
122 |
return copy_js
|
123 |
|
|
|
|
|
|
|
124 |
|
125 |
# Create a Chroma database instance using the selected directory
|
126 |
def create_chroma_instance(directory):
|
@@ -188,6 +207,43 @@ def ask_alans_ai(query, vectordb, chat_history, aoc_qa):
|
|
188 |
|
189 |
return answer
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
def clear_input_box():
|
192 |
st.session_state["new_item"] = ""
|
193 |
|
@@ -287,6 +343,7 @@ def main():
|
|
287 |
'gpt-3.5-turbo-16k',
|
288 |
'gpt-3.5-turbo-1106',
|
289 |
'gpt-4',
|
|
|
290 |
|
291 |
# Other custom or fine-tuned models can be added here
|
292 |
]
|
@@ -389,14 +446,15 @@ def main():
|
|
389 |
if st.button("Select 12"):
|
390 |
st.session_state.user_selected_avatar = avatar_12
|
391 |
|
392 |
-
############ Set up the LangChain Conversational Retrieval Chain ################
|
|
|
393 |
aoc_qa = ConversationalRetrievalChain.from_llm(
|
394 |
-
|
395 |
retriever=vectordb.as_retriever(search_kwargs={'k': k_value}, search_type=selected_search_type),
|
396 |
chain_type='stuff',
|
397 |
return_source_documents=True,
|
398 |
verbose=False,
|
399 |
-
combine_docs_chain_kwargs={"prompt": qa_prompt})
|
400 |
|
401 |
# HTML for social media links with base64-encoded images
|
402 |
social_media_html = f"""
|
|
|
31 |
from langchain.prompts import ChatMessagePromptTemplate
|
32 |
from langchain.prompts import ChatPromptTemplate
|
33 |
from wordcloud import WordCloud
|
34 |
+
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT # New import for Anthropic
|
35 |
|
36 |
|
37 |
# Function to get base64 encoding of an image
|
|
|
86 |
# Retrieve the API key from the environment variables
|
87 |
api_key = os.getenv("OPENAI_API_KEY")
|
88 |
|
89 |
+
# Function to get Claude Sonnet model
|
90 |
+
def get_claude_sonnet():
|
91 |
+
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
92 |
+
if not anthropic_api_key:
|
93 |
+
raise ValueError("Anthropic API key not found. Set the ANTHROPIC_API_KEY environment variable.")
|
94 |
+
return Anthropic(api_key=anthropic_api_key)
|
95 |
+
|
96 |
+
# Function to get the appropriate LLM based on the selected model
|
97 |
+
def get_llm(model_name, temperature):
|
98 |
+
if model_name == 'claude-3-sonnet-20240229':
|
99 |
+
return get_claude_sonnet()
|
100 |
+
else:
|
101 |
+
return ChatOpenAI(temperature=temperature, model_name=model_name)
|
102 |
+
|
103 |
+
|
104 |
# Check if the API key is available, if not, raise an error
|
105 |
if api_key is None:
|
106 |
raise ValueError("API key not found. Ensure that the OPENAI_API_KEY environment variable is set.")
|
|
|
137 |
"""
|
138 |
return copy_js
|
139 |
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
|
144 |
# Create a Chroma database instance using the selected directory
|
145 |
def create_chroma_instance(directory):
|
|
|
207 |
|
208 |
return answer
|
209 |
|
210 |
+
|
211 |
+
# Update the ask_alans_ai function to handle Claude Sonnet
|
212 |
+
def ask_alans_ai(query, vectordb, chat_history, aoc_qa):
|
213 |
+
filtered_chat_history = [(q, a) for q, a in chat_history if a is not None]
|
214 |
+
|
215 |
+
try:
|
216 |
+
if isinstance(aoc_qa.llm, Anthropic):
|
217 |
+
# Handle Claude Sonnet
|
218 |
+
context = aoc_qa.retriever.get_relevant_documents(query)
|
219 |
+
context_str = "\n".join([doc.page_content for doc in context])
|
220 |
+
|
221 |
+
prompt = f"{HUMAN_PROMPT} Here's some context:\n{context_str}\n\nNow answer this question: {query}{AI_PROMPT}"
|
222 |
+
|
223 |
+
response = aoc_qa.llm.complete(
|
224 |
+
prompt=prompt,
|
225 |
+
max_tokens_to_sample=1000,
|
226 |
+
model="claude-3-sonnet-20240229",
|
227 |
+
temperature=ai_temp
|
228 |
+
)
|
229 |
+
answer = response.completion
|
230 |
+
else:
|
231 |
+
# Handle other models
|
232 |
+
result = aoc_qa.invoke({"question": query, "chat_history": filtered_chat_history, "vectordb": vectordb})
|
233 |
+
answer = result["answer"]
|
234 |
+
|
235 |
+
chat_history.append((query, answer))
|
236 |
+
return answer
|
237 |
+
except Exception as e:
|
238 |
+
st.error(f"An error occurred: {str(e)}")
|
239 |
+
return "I'm sorry, but I encountered an error while processing your request. Please try again later."
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
|
246 |
+
|
247 |
def clear_input_box():
|
248 |
st.session_state["new_item"] = ""
|
249 |
|
|
|
343 |
'gpt-3.5-turbo-16k',
|
344 |
'gpt-3.5-turbo-1106',
|
345 |
'gpt-4',
|
346 |
+
'claude-3-sonnet-20240229'
|
347 |
|
348 |
# Other custom or fine-tuned models can be added here
|
349 |
]
|
|
|
446 |
if st.button("Select 12"):
|
447 |
st.session_state.user_selected_avatar = avatar_12
|
448 |
|
449 |
+
############ Set up the LangChain Conversational Retrieval Chain ################
|
450 |
+
|
451 |
aoc_qa = ConversationalRetrievalChain.from_llm(
|
452 |
+
get_llm(selected_model, ai_temp),
|
453 |
retriever=vectordb.as_retriever(search_kwargs={'k': k_value}, search_type=selected_search_type),
|
454 |
chain_type='stuff',
|
455 |
return_source_documents=True,
|
456 |
verbose=False,
|
457 |
+
combine_docs_chain_kwargs={"prompt": qa_prompt})
|
458 |
|
459 |
# HTML for social media links with base64-encoded images
|
460 |
social_media_html = f"""
|