Spaces:
Running
Running
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain.chains.history_aware_retriever import ( | |
create_history_aware_retriever, | |
) | |
from langchain.chains.retrieval import create_retrieval_chain | |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_chroma import Chroma | |
from langchain_community.chat_message_histories import ChatMessageHistory | |
from langchain_core.chat_history import BaseChatMessageHistory | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
class ConversationalQA: | |
""" | |
A class that handles conversational question-answering using a | |
retrieval-augmented generation approach with session history and | |
document retrieval capabilities. | |
""" | |
def __init__( | |
self, | |
docs: list, | |
chunk_size: int = 1000, | |
chunk_overlap: int = 200, | |
): | |
""" | |
Initialize the ConversationalQA class with API key, documents, and | |
text splitting configurations. | |
:param openai_api_key: OpenAI API key to access LLM | |
:param docs: List of documents to be used for retrieval and answering | |
:param chunk_size: Maximum size of each text chunk for processing | |
:param chunk_overlap: Number of characters to overlap between chunks | |
""" | |
self.text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
) | |
self.splits = self.text_splitter.split_documents(docs) | |
self.llm = ChatOpenAI() | |
self.vectorstore = Chroma.from_documents( | |
documents=self.splits, | |
embedding=OpenAIEmbeddings(), | |
collection_name="youtube", | |
) | |
self.retriever = self.vectorstore.as_retriever() | |
self.qa_system_prompt = """You are an assistant for question-answering | |
tasks. Use the following pieces of retrieved context to answer the | |
question. If you don't know the answer, just say that you don't know. | |
Use three sentences maximum and keep the answer | |
concise.\n\n{context}""" | |
self.qa_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", self.qa_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
self.contextualize_q_system_prompt = """Given a chat history and the | |
latest user question which might reference context in the chat | |
history, formulate a standalone question which can be understood | |
without the chat history. Do NOT answer the question, just | |
reformulate it if needed and otherwise return it as is.""" | |
self.contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", self.contextualize_q_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
self.question_answer_chain = create_stuff_documents_chain( | |
self.llm, self.qa_prompt | |
) | |
self.history_aware_chain = create_history_aware_retriever( | |
self.llm, self.retriever, self.contextualize_q_prompt | |
) | |
self.rag_chain = create_retrieval_chain( | |
self.history_aware_chain, self.question_answer_chain | |
) | |
self.store = {} | |
def get_session_history(self, session_id: str) -> BaseChatMessageHistory: | |
""" | |
Retrieve or create a chat history for a given session ID. | |
:param session_id: Unique session identifier | |
:return: ChatMessageHistory object for the session | |
""" | |
if session_id not in self.store: | |
self.store[session_id] = ChatMessageHistory() | |
return self.store[session_id] | |
def invoke_chain(self, session_id: str, user_input: str) -> str: | |
""" | |
Invoke the conversational question-answering chain with user input | |
and session history. | |
:param session_id: Unique session identifier | |
:param user_input: User's question input | |
:return: Answer generated by the system | |
""" | |
conversational_rag_chain = RunnableWithMessageHistory( | |
self.rag_chain, | |
self.get_session_history, | |
input_messages_key="input", | |
history_messages_key="chat_history", | |
output_messages_key="answer", | |
) | |
return conversational_rag_chain.invoke( | |
{"input": user_input}, | |
config={"configurable": {"session_id": session_id}}, | |
)["answer"] | |