from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.chains.history_aware_retriever import ( create_history_aware_retriever, ) from langchain.chains.retrieval import create_retrieval_chain from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_chroma import Chroma from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter class ConversationalQA: """ A class that handles conversational question-answering using a retrieval-augmented generation approach with session history and document retrieval capabilities. """ def __init__( self, docs: list, chunk_size: int = 1000, chunk_overlap: int = 200, ): """ Initialize the ConversationalQA class with API key, documents, and text splitting configurations. :param openai_api_key: OpenAI API key to access LLM :param docs: List of documents to be used for retrieval and answering :param chunk_size: Maximum size of each text chunk for processing :param chunk_overlap: Number of characters to overlap between chunks """ self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap ) self.splits = self.text_splitter.split_documents(docs) self.llm = ChatOpenAI() self.vectorstore = Chroma.from_documents( documents=self.splits, embedding=OpenAIEmbeddings(), collection_name="youtube", ) self.retriever = self.vectorstore.as_retriever() self.qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\n\n{context}""" self.qa_prompt = ChatPromptTemplate.from_messages( [ ("system", self.qa_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) self.contextualize_q_system_prompt = """Given a chat history and the latest user question which might reference context in the chat history, formulate a standalone question which can be understood without the chat history. Do NOT answer the question, just reformulate it if needed and otherwise return it as is.""" self.contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", self.contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) self.question_answer_chain = create_stuff_documents_chain( self.llm, self.qa_prompt ) self.history_aware_chain = create_history_aware_retriever( self.llm, self.retriever, self.contextualize_q_prompt ) self.rag_chain = create_retrieval_chain( self.history_aware_chain, self.question_answer_chain ) self.store = {} def get_session_history(self, session_id: str) -> BaseChatMessageHistory: """ Retrieve or create a chat history for a given session ID. :param session_id: Unique session identifier :return: ChatMessageHistory object for the session """ if session_id not in self.store: self.store[session_id] = ChatMessageHistory() return self.store[session_id] def invoke_chain(self, session_id: str, user_input: str) -> str: """ Invoke the conversational question-answering chain with user input and session history. :param session_id: Unique session identifier :param user_input: User's question input :return: Answer generated by the system """ conversational_rag_chain = RunnableWithMessageHistory( self.rag_chain, self.get_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) return conversational_rag_chain.invoke( {"input": user_input}, config={"configurable": {"session_id": session_id}}, )["answer"]