# https://python.langchain.com/docs/modules/chains/how_to/custom_chain # Including reformulation of the question in the chain import json from langchain import PromptTemplate, LLMChain from langchain.chains import QAWithSourcesChain from langchain.chains import TransformChain, SequentialChain from langchain.chains.qa_with_sources import load_qa_with_sources_chain from anyqa.prompts import answer_prompt, reformulation_prompt from anyqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain def load_qa_chain_with_docs(llm): """Load a QA chain with documents. Useful when you already have retrieved docs To be called with this input ``` output = chain({ "question":query, "audience":"experts scientists", "docs":docs, "language":"English", }) ``` """ qa_chain = load_combine_documents_chain(llm) chain = QAWithSourcesChain( input_docs_key="docs", combine_documents_chain=qa_chain, return_source_documents=True, ) return chain def load_combine_documents_chain(llm): prompt = PromptTemplate( template=answer_prompt, input_variables=["summaries", "question", "audience", "language"], ) qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt) return qa_chain def load_qa_chain_with_text(llm): prompt = PromptTemplate( template=answer_prompt, input_variables=["question", "audience", "language", "summaries"], ) qa_chain = LLMChain(llm=llm, prompt=prompt) return qa_chain def load_qa_chain(retriever, llm_reformulation, llm_answer): reformulation_chain = load_reformulation_chain(llm_reformulation) answer_chain = load_qa_chain_with_retriever(retriever, llm_answer) qa_chain = SequentialChain( chains=[reformulation_chain, answer_chain], input_variables=["query", "audience"], output_variables=["answer", "question", "language", "source_documents"], return_all=True, verbose=True, ) return qa_chain def load_reformulation_chain(llm): prompt = PromptTemplate( template=reformulation_prompt, input_variables=["query"], ) reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json") # Parse the output def parse_output(output): query = output["query"] print("output", output) json_output = json.loads(output["json"]) question = json_output.get("question", query) language = json_output.get("language", "English") return { "question": question, "language": language, } transform_chain = TransformChain( input_variables=["json"], output_variables=["question", "language"], transform=parse_output, ) reformulation_chain = SequentialChain( chains=[reformulation_chain, transform_chain], input_variables=["query"], output_variables=["question", "language"], ) return reformulation_chain def load_qa_chain_with_retriever(retriever, llm): qa_chain = load_combine_documents_chain(llm) # This could be improved by providing a document prompt to avoid modifying page_content in the docs # See here https://github.com/langchain-ai/langchain/issues/3523 answer_chain = CustomRetrievalQAWithSourcesChain( combine_documents_chain=qa_chain, retriever=retriever, return_source_documents=True, verbose=True, fallback_answer="**⚠️ No relevant passages found in the sources, you may want to ask a more specific question.**", ) return answer_chain