|
|
|
|
|
import json |
|
|
|
from langchain import PromptTemplate, LLMChain |
|
from langchain.chains import QAWithSourcesChain |
|
from langchain.chains import TransformChain, SequentialChain |
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain |
|
|
|
from anyqa.prompts import answer_prompt, reformulation_prompt |
|
from anyqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain |
|
|
|
|
|
def load_qa_chain_with_docs(llm): |
|
"""Load a QA chain with documents. |
|
Useful when you already have retrieved docs |
|
|
|
To be called with this input |
|
|
|
``` |
|
output = chain({ |
|
"question":query, |
|
"audience":"experts scientists", |
|
"docs":docs, |
|
"language":"English", |
|
}) |
|
``` |
|
""" |
|
|
|
qa_chain = load_combine_documents_chain(llm) |
|
chain = QAWithSourcesChain( |
|
input_docs_key="docs", |
|
combine_documents_chain=qa_chain, |
|
return_source_documents=True, |
|
) |
|
return chain |
|
|
|
|
|
def load_combine_documents_chain(llm): |
|
prompt = PromptTemplate( |
|
template=answer_prompt, |
|
input_variables=["summaries", "question", "audience", "language"], |
|
) |
|
qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt) |
|
return qa_chain |
|
|
|
|
|
def load_qa_chain_with_text(llm): |
|
prompt = PromptTemplate( |
|
template=answer_prompt, |
|
input_variables=["question", "audience", "language", "summaries"], |
|
) |
|
qa_chain = LLMChain(llm=llm, prompt=prompt) |
|
return qa_chain |
|
|
|
|
|
def load_qa_chain(retriever, llm_reformulation, llm_answer): |
|
reformulation_chain = load_reformulation_chain(llm_reformulation) |
|
answer_chain = load_qa_chain_with_retriever(retriever, llm_answer) |
|
|
|
qa_chain = SequentialChain( |
|
chains=[reformulation_chain, answer_chain], |
|
input_variables=["query", "audience"], |
|
output_variables=["answer", "question", "language", "source_documents"], |
|
return_all=True, |
|
verbose=True, |
|
) |
|
return qa_chain |
|
|
|
|
|
def load_reformulation_chain(llm): |
|
prompt = PromptTemplate( |
|
template=reformulation_prompt, |
|
input_variables=["query"], |
|
) |
|
reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json") |
|
|
|
|
|
def parse_output(output): |
|
query = output["query"] |
|
print("output", output) |
|
json_output = json.loads(output["json"]) |
|
question = json_output.get("question", query) |
|
language = json_output.get("language", "English") |
|
return { |
|
"question": question, |
|
"language": language, |
|
} |
|
|
|
transform_chain = TransformChain( |
|
input_variables=["json"], |
|
output_variables=["question", "language"], |
|
transform=parse_output, |
|
) |
|
|
|
reformulation_chain = SequentialChain( |
|
chains=[reformulation_chain, transform_chain], |
|
input_variables=["query"], |
|
output_variables=["question", "language"], |
|
) |
|
return reformulation_chain |
|
|
|
|
|
def load_qa_chain_with_retriever(retriever, llm): |
|
qa_chain = load_combine_documents_chain(llm) |
|
|
|
|
|
|
|
|
|
answer_chain = CustomRetrievalQAWithSourcesChain( |
|
combine_documents_chain=qa_chain, |
|
retriever=retriever, |
|
return_source_documents=True, |
|
verbose=True, |
|
fallback_answer="**⚠️ No relevant passages found in the sources, you may want to ask a more specific question.**", |
|
) |
|
return answer_chain |
|
|