Legal_MINDS / legalminds.py
Megatron17's picture
Upload 3 files
4ff740e
import os
from dotenv import load_dotenv
load_dotenv()
import numpy as np
import pandas as pd
import time
from tqdm import tqdm
import warnings
from langchain.chains import RetrievalQA
from langchain.callbacks import StdOutCallbackHandler
import chainlit as cl # importing chainlit for our app
from chainlit.prompt import Prompt, PromptMessage
from chainlit.playground.providers.openai import ChatOpenAI # importing ChatOpenAI tools
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma, DeepLake
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders.dataframe import DataFrameLoader
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.llms.openai import OpenAIChat
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
from langchain.utilities import SerpAPIWrapper
from langchain.agents import load_tools
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.schema.messages import SystemMessage
from langchain.prompts import MessagesPlaceholder
from langchain.agents import AgentExecutor
warnings.filterwarnings("ignore")
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
# review_df = pd.read_csv("./data/justice.csv")
# data = review_df
# text_splitter = RecursiveCharacterTextSplitter(
# chunk_size = 7000, # the character length of the chunk
# chunk_overlap = 700, # the character length of the overlap between chunks
# length_function = len, # the length function - in this case, character length (aka the python len() fn.)
# )
# loader = DataFrameLoader(review_df, page_content_column="facts")
# base_docs = loader.load()
# docs = text_splitter.split_documents(base_docs)
embedder = OpenAIEmbeddings()
# This is needed for both the memory and the prompt
memory_key = "history"
# Embed and persist db
persist_directory = "./data/chroma"
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embedder)
# vectorstore = DeepLake(dataset_path="./legalminds/", embedding=embedder, overwrite=True)
# vectorstore.add_documents(docs)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
primary_qa_llm = ChatOpenAI(
model="gpt-3.5-turbo-16k",
temperature=0,
)
retriever = vectorstore.as_retriever()
CUSTOM_TOOL_N_DOCS = 3 # number of retrieved docs from deep lake to consider
CUSTOM_TOOL_DOCS_SEPARATOR ="\n\n" # how to join together the retrieved docs to form a single string
def retrieve_n_docs_tool(query: str) -> str:
""" Searches for relevant documents that may contain the answer to the query."""
docs = retriever.get_relevant_documents(query)[:CUSTOM_TOOL_N_DOCS]
texts = [doc.page_content for doc in docs]
texts_merged = CUSTOM_TOOL_DOCS_SEPARATOR.join(texts)
return texts_merged
serp_tool = load_tools(["serpapi"])
# print("Serp Tool:",serp_tool[0])
data_tool = create_retriever_tool(
retriever,
"retrieve_n_docs_tool",
"Searches and returns documents regarding the query asked."
)
tools = [data_tool, serp_tool[0]]
# llm = OpenAIChat(model="gpt-3.5-turbo", temperature=0)
llm = ChatOpenAI(temperature = 0)
memory = AgentTokenBufferMemory(memory_key=memory_key, llm=llm)
system_message = SystemMessage(
content=(
"Do your best to answer the questions. "
"Feel free to use any tools available to look up "
"relevant information, only if necessary"
)
)
prompt = OpenAIFunctionsAgent.create_prompt(
system_message=system_message,
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)]
)
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
handler = StdOutCallbackHandler()
@cl.on_chat_start # marks a function that will be executed at the start of a user session
async def start_chat():
agent_executor = AgentExecutor(agent=agent, tools=tools, memory=memory, verbose=True,
return_intermediate_steps=True)
# agent_executor = create_conversational_retrieval_agent(llm, tools, verbose=True)
# qa_with_sources_chain = RetrievalQA.from_chain_type(
# llm=llm,
# retriever=retriever,
# callbacks=[handler],
# return_source_documents=True
# )
cl.user_session.set("agent", agent_executor)
@cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
async def main(message: str):
agent_executor = cl.user_session.get("agent")
# prompt = Prompt(
# provider=ChatOpenAI.id,
# messages=[
# PromptMessage(
# role="system",
# # template=RAQA_PROMPT_TEMPLATE,
# # formatted=RAQA_PROMPT_TEMPLATE,
# ),
# PromptMessage(
# role="user",
# # template=user_template,
# # formatted=user_template.format(input=message),
# ),
# ],
# inputs={"input": message},
# # settings=settings,
# )
# result = await qa_with_sources_chain.acall({"query" : message}) #, callbacks=[cl.AsyncLangchainCallbackHandler()])
result = agent_executor({"input": message})
# print("result Dict:",result)
msg = cl.Message(content=result["output"])
print("message:",msg)
print("output message:",msg.content)
# Update the prompt object with the completion
# msg.content = result["output"]
# prompt.completion = msg.content
# msg.prompt = prompt
# print("message_content: ",msg.content)
await msg.send()