Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
import numpy as np | |
import pandas as pd | |
import time | |
from tqdm import tqdm | |
import warnings | |
from langchain.chains import RetrievalQA | |
from langchain.callbacks import StdOutCallbackHandler | |
import chainlit as cl # importing chainlit for our app | |
from chainlit.prompt import Prompt, PromptMessage | |
from chainlit.playground.providers.openai import ChatOpenAI # importing ChatOpenAI tools | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import Chroma, DeepLake | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.document_loaders.dataframe import DataFrameLoader | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains import RetrievalQA | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.llms.openai import OpenAIChat | |
from langchain.agents.agent_toolkits import create_retriever_tool | |
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent | |
from langchain.utilities import SerpAPIWrapper | |
from langchain.agents import load_tools | |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory | |
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent | |
from langchain.schema.messages import SystemMessage | |
from langchain.prompts import MessagesPlaceholder | |
from langchain.agents import AgentExecutor | |
warnings.filterwarnings("ignore") | |
os.environ["LANGCHAIN_WANDB_TRACING"] = "true" | |
# review_df = pd.read_csv("./data/justice.csv") | |
# data = review_df | |
# text_splitter = RecursiveCharacterTextSplitter( | |
# chunk_size = 7000, # the character length of the chunk | |
# chunk_overlap = 700, # the character length of the overlap between chunks | |
# length_function = len, # the length function - in this case, character length (aka the python len() fn.) | |
# ) | |
# loader = DataFrameLoader(review_df, page_content_column="facts") | |
# base_docs = loader.load() | |
# docs = text_splitter.split_documents(base_docs) | |
embedder = OpenAIEmbeddings() | |
# This is needed for both the memory and the prompt | |
memory_key = "history" | |
# Embed and persist db | |
persist_directory = "./data/chroma" | |
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embedder) | |
# vectorstore = DeepLake(dataset_path="./legalminds/", embedding=embedder, overwrite=True) | |
# vectorstore.add_documents(docs) | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
primary_qa_llm = ChatOpenAI( | |
model="gpt-3.5-turbo-16k", | |
temperature=0, | |
) | |
retriever = vectorstore.as_retriever() | |
CUSTOM_TOOL_N_DOCS = 3 # number of retrieved docs from deep lake to consider | |
CUSTOM_TOOL_DOCS_SEPARATOR ="\n\n" # how to join together the retrieved docs to form a single string | |
def retrieve_n_docs_tool(query: str) -> str: | |
""" Searches for relevant documents that may contain the answer to the query.""" | |
docs = retriever.get_relevant_documents(query)[:CUSTOM_TOOL_N_DOCS] | |
texts = [doc.page_content for doc in docs] | |
texts_merged = CUSTOM_TOOL_DOCS_SEPARATOR.join(texts) | |
return texts_merged | |
serp_tool = load_tools(["serpapi"]) | |
# print("Serp Tool:",serp_tool[0]) | |
data_tool = create_retriever_tool( | |
retriever, | |
"retrieve_n_docs_tool", | |
"Searches and returns documents regarding the query asked." | |
) | |
tools = [data_tool, serp_tool[0]] | |
# llm = OpenAIChat(model="gpt-3.5-turbo", temperature=0) | |
llm = ChatOpenAI(temperature = 0) | |
memory = AgentTokenBufferMemory(memory_key=memory_key, llm=llm) | |
system_message = SystemMessage( | |
content=( | |
"Do your best to answer the questions. " | |
"Feel free to use any tools available to look up " | |
"relevant information, only if necessary" | |
) | |
) | |
prompt = OpenAIFunctionsAgent.create_prompt( | |
system_message=system_message, | |
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)] | |
) | |
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) | |
handler = StdOutCallbackHandler() | |
# marks a function that will be executed at the start of a user session | |
async def start_chat(): | |
agent_executor = AgentExecutor(agent=agent, tools=tools, memory=memory, verbose=True, | |
return_intermediate_steps=True) | |
# agent_executor = create_conversational_retrieval_agent(llm, tools, verbose=True) | |
# qa_with_sources_chain = RetrievalQA.from_chain_type( | |
# llm=llm, | |
# retriever=retriever, | |
# callbacks=[handler], | |
# return_source_documents=True | |
# ) | |
cl.user_session.set("agent", agent_executor) | |
# marks a function that should be run each time the chatbot receives a message from a user | |
async def main(message: str): | |
agent_executor = cl.user_session.get("agent") | |
# prompt = Prompt( | |
# provider=ChatOpenAI.id, | |
# messages=[ | |
# PromptMessage( | |
# role="system", | |
# # template=RAQA_PROMPT_TEMPLATE, | |
# # formatted=RAQA_PROMPT_TEMPLATE, | |
# ), | |
# PromptMessage( | |
# role="user", | |
# # template=user_template, | |
# # formatted=user_template.format(input=message), | |
# ), | |
# ], | |
# inputs={"input": message}, | |
# # settings=settings, | |
# ) | |
# result = await qa_with_sources_chain.acall({"query" : message}) #, callbacks=[cl.AsyncLangchainCallbackHandler()]) | |
result = agent_executor({"input": message}) | |
# print("result Dict:",result) | |
msg = cl.Message(content=result["output"]) | |
print("message:",msg) | |
print("output message:",msg.content) | |
# Update the prompt object with the completion | |
# msg.content = result["output"] | |
# prompt.completion = msg.content | |
# msg.prompt = prompt | |
# print("message_content: ",msg.content) | |
await msg.send() | |