Spaces:
Sleeping
Sleeping
""" | |
Python package defining the different chains that will be tested in the thesis | |
""" | |
from langchain_openai import ChatOpenAI | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_chroma import Chroma | |
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_core.retrievers import BaseRetriever | |
from langchain_core.language_models.llms import BaseLLM | |
from langchain_core.output_parsers.base import BaseOutputParser | |
from langchain_core.runnables import RunnablePassthrough, RunnableParallel | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.documents import Document | |
# define prompt | |
BASE_SYSTEM_PROMPT: str = ( | |
"You are an assistant for question-answering tasks over books. " | |
"Only use the following book extracts to answer the question. " | |
"If you don't know the answer, say that you don't know. " | |
"\n\n" | |
"{context}" | |
) | |
BASE_QA_PROMPT: ChatPromptTemplate = ChatPromptTemplate.from_messages( | |
[ | |
("system", | |
"You are an assistant for question-answering tasks over books. " | |
"Only use the following book extracts to answer the question. " | |
"If you don't know the answer, say that you don't know. " | |
"\n\n" | |
"{context}"), | |
("human", | |
"{question}"), | |
] | |
) | |
# # define prompt | |
# BASE_SYSTEM_PROMPT: str = ( | |
# "You are an assistant for question-answering tasks over books. " | |
# "Only use the following book extracts to answer the question. " | |
# "If you don't know the answer, say that you don't know. " | |
# "\n\n" | |
# "{context}" | |
# ) | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
def build_naive_rag_chain( | |
retriever: BaseRetriever, | |
llm: BaseLLM, | |
retrieval_prompt: ChatPromptTemplate = BASE_QA_PROMPT, | |
output_parser: BaseOutputParser = StrOutputParser() | |
): | |
rag_chain_from_docs = ( | |
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) | |
| retrieval_prompt | |
| llm | |
| output_parser | |
) | |
rag_chain_with_source = RunnableParallel( | |
{"context": retriever, "question": RunnablePassthrough()} | |
).assign(answer=rag_chain_from_docs) | |
return rag_chain_with_source | |
class RATChain: | |
THOUGHTS_PROMPT_TEMPLATE: str = """ | |
IMPORTANT: | |
Answer this question with step-by-step thoughts. | |
Split your different thoughts with \n\n to structure your answer into several paragraphs. | |
Only reply to the question directly. | |
DO NOT add additional explanations or information that is not related to the question unless you are asked to. | |
""" | |
THOUGHTS_PROMPT = ChatPromptTemplate.from_messages([ | |
("system", THOUGHTS_PROMPT_TEMPLATE), | |
("user", "{question}") | |
]) | |
GENERATE_QUERY_PROMPT = ChatPromptTemplate.from_messages([ | |
("user", | |
"I want to verify the content correctness of the given question. " | |
"Summarize the main points of the content and provide a query that I can use " | |
"to retrive information from a textbook." | |
"Make the query as relevant as possible to the last content." | |
"**IMPORTANT**" | |
"Just output the query directly. DO NOT add additional explanations or introducement " | |
"in the answer unless you are asked to." | |
"CONTENT: {content}" | |
) | |
]) | |
REVISE_ANSWER_PROMPT = ChatPromptTemplate.from_messages( | |
[ | |
("user", | |
"Verify the answer accoridng ot the retrieved information, " | |
"while keeping the initial question in mind. " | |
"If you find any mistakes, correct them." | |
"If you find any missing information, add them." | |
"If you find any irrelevant information, remove them." | |
"If you find the answer is correct and does not need improvement, output the original answer." | |
"**IMPORTANT**" | |
"Try to keep the structure (multiple paragraphs with its subtitles) in the revised answer and make it more structual for understanding." | |
"Add more details from retrieved text to the answer." | |
"Split the paragraphs with \n\n characters." | |
"Just output the revised answer directly. DO NOT add additional explanations or annoucement in the revised answer unless you are asked to." | |
"INITIAL QUESTION:{question}" | |
"ANSWER:{answer}" | |
"\n\n" | |
"retrieved information={retrieved_info}" | |
) | |
]) | |
def split_thoughts(thoughts: str) -> list[str]: | |
return thoughts.split("\n\n") | |
def get_page_content(docs: list[list[Document]]): | |
docs_page_content = [] | |
for doc_list in docs: | |
docs_page_content.append([doc.page_content for doc in doc_list]) | |
return docs_page_content | |
def __init__(self, retriever: BaseRetriever, llm: BaseLLM): | |
self.retriever = retriever | |
self.llm = llm | |
def get_initial_thought_chain(self): | |
return self.THOUGHTS_PROMPT | self.llm | StrOutputParser() | |
def get_revise_answer_chain(self): | |
return self.REVISE_ANSWER_PROMPT | self.llm | StrOutputParser() | |
def get_generate_query_chain(self): | |
return self.GENERATE_QUERY_PROMPT | self.llm | StrOutputParser() | |
def iteratively_improve_thoughts(self, question: str, thoughts: str): | |
splited_thoughts = self.split_thoughts(thoughts) | |
# initial_thought_chain = self.get_initial_thought_chain() | |
generate_query_chain = self.get_generate_query_chain() | |
revise_answer_chain = self.get_revise_answer_chain() | |
responses = [] | |
queries = [] | |
contexts = [] | |
answer = thoughts | |
for i, content in enumerate(splited_thoughts): | |
query = generate_query_chain.invoke(content) | |
queries.append(query) | |
retrieved_info = self.retriever.invoke(query) | |
contexts.append(retrieved_info) | |
answer = revise_answer_chain.invoke({ | |
"question":question, | |
"answer":answer, | |
"retrieved_info":retrieved_info | |
}) | |
responses.append(answer) | |
output = { | |
"question": question, | |
"splited_thoughts":splited_thoughts, | |
"queries": queries, | |
"context": contexts, | |
"responses": responses, | |
"answer": answer, | |
} | |
return output | |
def invoke(self, question: str): | |
initial_thought_chain = self.get_initial_thought_chain() | |
thoughts = initial_thought_chain.invoke(question) | |
response = self.iteratively_improve_thoughts(question, thoughts) | |
return response | |
def retrival_augmented_thoughts(self, question: str): | |
"does exactly the same as invoke" | |
initial_thought_chain = self.get_initial_thought_chain() | |
thoughts = initial_thought_chain.invoke(question) | |
response = self.iteratively_improve_thoughts(question, thoughts) | |
return response | |
# retrival_augmented_regneration( | |
# question: str, | |
# subquestion_chain: LLMChain, | |
# loop_function:Callable, | |
# retriever: BaseRetriever, | |
# ) -> dict: | |
# response = subquestion_chain.invoke(question) | |
# return loop_function(question, response, retriever) | |