Spaces:
Runtime error
Runtime error
| import os | |
| import streamlit as st | |
| from streamlit_chat import message | |
| from langchain_openai import OpenAIEmbeddings | |
| from pinecone import Pinecone | |
| import time | |
| from langchain_pinecone.vectorstores import Pinecone as PineconeVectorStore | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnableParallel, RunnablePassthrough | |
| from langchain_openai import ChatOpenAI | |
| from langchain_groq import ChatGroq | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain_core.runnables import RunnableLambda | |
| from operator import itemgetter | |
| # Streamlit App Configuration | |
| st.set_page_config(page_title="Docu-Help") | |
| # Dropdown for namespace selection | |
| namespace_name = st.sidebar.selectbox("Select Website:", ('crawlee', ''), key='namespace_name') | |
| # Read API keys from environment variables | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| PINE_API_KEY = os.getenv("PINE_API_KEY") | |
| LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY") | |
| LANGCHAIN_TRACING_V2 = 'true' | |
| LANGCHAIN_ENDPOINT = "https://api.smith.langchain.com" | |
| LANGCHAIN_PROJECT = "docu-help" | |
| # Sidebar for model selection and Pinecone index name input | |
| st.sidebar.title("Sidebar") | |
| model_name = st.sidebar.radio("Choose a model:", ("gpt-3.5-turbo-1106", "gpt-4-0125-preview", "Claude-Sonnet", "mixtral-groq")) | |
| openai_api_key2 = st.sidebar.text_input("Enter OpenAI Key: ") | |
| groq_api_key = st.sidebar.text_input("Groq API Key: ") | |
| anthropic_api_key = st.sidebar.text_input("Claude API Key: ") | |
| pinecone_index_name = os.getenv("pinecone_index_name") | |
| namespace_name = "crawlee" | |
| # Initialize session state variables if they don't exist | |
| if 'generated' not in st.session_state: | |
| st.session_state['generated'] = [] | |
| if 'past' not in st.session_state: | |
| st.session_state['past'] = [] | |
| if 'messages' not in st.session_state: | |
| st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}] | |
| if 'total_cost' not in st.session_state: | |
| st.session_state['total_cost'] = 0.0 | |
| def refresh_text(): | |
| with response_container: | |
| for i in range(len(st.session_state['past'])): | |
| try: | |
| user_message_content = st.session_state["past"][i] | |
| message = st.chat_message("user") | |
| message.write(user_message_content) | |
| except: | |
| print("Past error") | |
| try: | |
| ai_message_content = st.session_state["generated"][i] | |
| message = st.chat_message("assistant") | |
| message.write(ai_message_content) | |
| except: | |
| print("Generated Error") | |
| # Function to generate a response using App 2's functionality | |
| def generate_response(prompt): | |
| st.session_state['messages'].append({"role": "user", "content": prompt}) | |
| embed = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=OPENAI_API_KEY) | |
| pc = Pinecone(api_key=PINE_API_KEY) | |
| index = pc.Index(pinecone_index_name) | |
| time.sleep(1) # Ensure index is ready | |
| index.describe_index_stats() | |
| vectorstore = PineconeVectorStore(index, embed, "text", namespace=namespace_name) | |
| retriever = vectorstore.as_retriever() | |
| template = """You are an expert software developer who specializes in APIs. Answer the user's question based only on the following context: | |
| {context} | |
| Chat History: | |
| {chat_history} | |
| Question: {question} | |
| """ | |
| prompt_template = ChatPromptTemplate.from_template(template) | |
| if model_name == "Claude-Sonnet": | |
| chat_model = ChatAnthropic(temperature=0, model="claude-3-sonnet-20240229", anthropic_api_key=anthropic_api_key) | |
| elif model_name == "mixtral-groq": | |
| chat_model = ChatGroq(temperature=0, groq_api_key=groq_api_key, model_name="mixtral-8x7b-32768") | |
| else: | |
| chat_model = ChatOpenAI(temperature=0, model=model_name, openai_api_key=openai_api_key2) | |
| memory = ConversationBufferMemory( | |
| return_messages=True, output_key="answer", input_key="question" | |
| ) | |
| # Loading the previous chat messages into memory | |
| for i in range(len(st.session_state['generated'])): | |
| # Replaced "Answer: " with "" to stop the model from learning to add "Answer: " to the beginning by itself | |
| memory.save_context({"question": st.session_state["past"][i]}, {"answer": st.session_state["generated"][i].replace("Answer: ", "")}) | |
| # Prints the memory that the model will be using | |
| print(f"Memory: {memory.load_memory_variables({})}") | |
| rag_chain = ( | |
| RunnablePassthrough.assign(context=(lambda x: x["context"]), chat_history=lambda x: get_buffer_string(x["chat_history"])) | |
| | prompt_template | |
| | chat_model | |
| | StrOutputParser() | |
| ) | |
| rag_chain_with_source = RunnableParallel( | |
| {"context": retriever, "question": RunnablePassthrough(), "chat_history": RunnableLambda(memory.load_memory_variables) | itemgetter("history")} | |
| ).assign(answer=rag_chain) | |
| # Function that extracts the individual tokens from the output of the model | |
| def make_stream(): | |
| sources = [] | |
| st.session_state['generated'].append("Answer: ") | |
| yield st.session_state['generated'][-1] | |
| for chunk in rag_chain_with_source.stream(prompt): | |
| if list(chunk.keys())[0] == 'answer': | |
| st.session_state['generated'][-1] += chunk['answer'] | |
| yield chunk['answer'] | |
| elif list(chunk.keys())[0] == 'context': | |
| # sources = chunk['context'] | |
| sources = [doc.metadata['source'] for doc in chunk['context']] | |
| sources_txt = "\n\nSources:\n" + "\n".join(sources) | |
| st.session_state['generated'][-1] += sources_txt | |
| yield sources_txt | |
| # Sending the message as a stream using the function above | |
| print("Running the response streamer...") | |
| with response_container: | |
| message = st.chat_message("assistant") | |
| my_generator = make_stream() | |
| message.write_stream(my_generator) | |
| formatted_response = st.session_state['generated'][-1] | |
| #response = rag_chain_with_source.invoke(prompt) | |
| #sources = [doc.metadata['source'] for doc in response['context']] | |
| #answer = response['answer'] # Extracting the 'answer' part | |
| #formatted_response = f"Answer: {answer}\n\nSources:\n" + "\n".join(sources) | |
| st.session_state['messages'].append({"role": "assistant", "content": formatted_response}) | |
| return formatted_response | |
| # Container for chat history and text box | |
| response_container = st.container() | |
| container = st.container() | |
| # Implementing chat input as opposed to a form because chat_input stays locked at the bottom | |
| if prompt := st.chat_input("Ask a question..."): | |
| # I moved reponse here because, for some reason, I get an error if I only have an if statement for user_input later... | |
| st.session_state['past'].append(prompt) | |
| refresh_text() | |
| response = generate_response(prompt) |