CongresoRAG / RAG_public.py
Ulaşcan Akbulut
Add Rag file
05caa09
# Import
import os
#from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from pymilvus import connections, utility
from langchain_openai import OpenAIEmbeddings
from langchain_milvus.vectorstores import Milvus
from langchain.chains import create_retrieval_chain
from langchain.chains import create_history_aware_retriever
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains.combine_documents import create_stuff_documents_chain
# Environment Settings
#load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
cloud_api_key = os.getenv("CLOUD_API_KEY")
cloud_uri = os.getenv("URI")
# Database Connection
class DatabaseManagement:
"""
Connects Milvus database
"""
def __init__(self):
"""
Connects to Milvus server and calls initiliaze_database function
"""
# Connects to Milvus server
connections.connect(alias="default", uri=cloud_uri, token=cloud_api_key, timeout=120)
print("Connected to the Milvus Server")
# Manages vectorstore
class VectorStoreManagement:
"""
Creates vectorstore from Milvus if vectorstore is not defined or defined as None
Methods
------
create_vectorstore()
Checks whether vectorstore is defined or not defined. If is defined, splits the data into
smaller chunks and creates vectorstore from Milvus
"""
def __init__(self, document):
"""
Initialize document, embedding and vectorstore and calls create_vectorstore function
Parameters
----------
document: list
Document from langchain_core.documents inside a list
embedding:
Openai embeddings
"""
self.document = document
self.vectorstore = None
self.create_vectorstore()
def create_vectorstore(self):
"""
create_vectorstore()
Checks whether vectorstore is defined or not defined. If it is defined, splits the data into
smaller chunks and creates vectorstore from Milvus
"""
# Define collection name
collection_name = "RAG_Milvus"
# Creates collection under ChatRAG database
if collection_name not in utility.list_collections():
print("RAG_Milvus collection does not exist under the ChatRAG database")
# Split the string data into smaller chunks
textsplitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, length_function=len)
chunks_data = textsplitter.split_documents(documents=self.document)
# Create vectorstore from Milvus
self.vectorstore = Milvus.from_documents(documents=chunks_data,
embedding=OpenAIEmbeddings(openai_api_key=openai_api_key),
collection_name=collection_name,
connection_args={"uri":cloud_uri,
"token":cloud_api_key})
print("RAG_Milvus collection is created under ChatRAG database")
else:
print("RAG_Milvus collection already exist")
self.vectorstore = Milvus(embedding_function=OpenAIEmbeddings(openai_api_key=openai_api_key),
collection_name=collection_name,
connection_args={"uri":cloud_uri,
"token":cloud_api_key})
# RAG class to retrieve ai response for a given user query
class RAG:
"""
ChatRAG that uses Retrieval Augmented Generation model for large language model
with the langchain
Methods
-------
model():
Creates llm from openai. Uses the model gpt-3.5-turbo-0125 with temperature=0
Creates retriever from vectorstore
Defines contextualize_q_prompt to use it in history_aware_retriever where llm, retriever and contextualize_q_prompt is combined
Defines qa_prompt (question/answer) to use it in create_stuff_documents_chain where llm and qa_prompt is combined for question_answer_chain
Defines rag chain by combining history_aware_retriever and question_answer_chain
get_session_history(session_id):
Stores chat history and session_id in a dictionary
conversational_rag_chain(input):
Creates conversational rag chain and invokes the ai response
"""
def __init__(self, document):
"""
Initilization of document and store to store the chat history
Parameters
----------
document: list
Document from langchain.schema inside a list
"""
self.document = document
self.database_manager = DatabaseManagement()
self.vectorstore_manager = VectorStoreManagement(self.document)
self.store = {}
# RAG model
def model(self):
"""
Creates llm from openai. Uses the model gpt-3.5-turbo-0125 with temperature=0
Creates retriever from vectorstore
Defines contextualize_q_prompt to use it in history_aware_retriever where llm, retriever and contextualize_q_prompt is combined
Defines qa_prompt (question/answer) to use it in create_stuff_documents_chain where llm and qa_prompt is combined for question_answer_chain
Defines rag chain by combining history_aware_retriever and question_answer_chain
"""
# Create llm from chatopenai
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
# Create retriever. Its function is to return relevant documents from documents with respect to similarity search and user input.
retriever = self.vectorstore_manager.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 6})
# System prompt that tells the language model on how to handle the latest user query in the context of the entire conversation history
# It tells the model to take the chat history and the latest user question and rephrase the question so it can be understood independently
# of the history
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
# Create customized Chat Prompt Template with a customized system prompt
contextualize_q_prompt = ChatPromptTemplate.from_messages([
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),])
# Create history aware retriever. It combines current user query with the chat history so that
# ai response is relevant to the previous question/answer
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
# Create custom question/answer prompt
qa_system_prompt = """You are an assistant for question-answering tasks. \
Use the following pieces of retrieved context to answer the question. \
If you don't know the answer, just say that you don't know. \
Use three sentences maximum and keep the answer concise. \
{context}"""
# Create custom question answer Chat Prompt
qa_prompt = ChatPromptTemplate.from_messages([
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),])
# Create question/answer chain. It combines llm and qa_prompt.
# It uses llm and retrieved context to asnwer question.
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
# RAG chain that combines the history aware retriever and question/answer chain
# It makes sure that that retrieved documents are related to the chat history and user query
self.rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
# Method/function to store chat history
def get_session_history(self, session_id: str) -> BaseChatMessageHistory:
"""
Stores chat history and session_id in a dictionary
Parameters
----------
session_id: str
session_id in string format
Returns
-------
store: dict
Dictionary that has key: session_id and value: chat history
"""
if session_id not in self.store:
self.store[session_id] = ChatMessageHistory()
return self.store[session_id]
#Create conversational RAG chain
def conversational_rag_chain(self, input):
"""
Creates conversational rag chain and invokes it
Parameters
----------
input: str
User's query
Returns
-------
str
AI response
"""
conversational_rag_chain = RunnableWithMessageHistory(
self.rag_chain,
self.get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer")
result = conversational_rag_chain.invoke({"input": str(input)},
config={"configurable": {"session_id": "6161"}})
l = []
for doc in result["context"]:
l.append(doc.metadata["pdf_url"])
return result["answer"], l