File size: 2,405 Bytes
f0fc5f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
# https://python.langchain.com/docs/modules/chains/how_to/custom_chain
# Including reformulation of the question in the chain
import json
from langchain import PromptTemplate, LLMChain
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains import TransformChain, SequentialChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from climateqa.prompts import answer_prompt, reformulation_prompt,audience_prompts
def load_reformulation_chain(llm):
prompt = PromptTemplate(
template = reformulation_prompt,
input_variables=["query"],
)
reformulation_chain = LLMChain(llm = llm,prompt = prompt,output_key="json")
# Parse the output
def parse_output(output):
query = output["query"]
json_output = json.loads(output["json"])
question = json_output.get("question", query)
language = json_output.get("language", "English")
return {
"question": question,
"language": language,
}
transform_chain = TransformChain(
input_variables=["json"], output_variables=["question","language"], transform=parse_output
)
reformulation_chain = SequentialChain(chains = [reformulation_chain,transform_chain],input_variables=["query"],output_variables=["question","language"])
return reformulation_chain
def load_answer_chain(retriever,llm):
prompt = PromptTemplate(template=answer_prompt, input_variables=["summaries", "question","audience","language"])
qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff",prompt = prompt)
# This could be improved by providing a document prompt to avoid modifying page_content in the docs
# See here https://github.com/langchain-ai/langchain/issues/3523
answer_chain = RetrievalQAWithSourcesChain(
combine_documents_chain = qa_chain,
retriever=retriever,
return_source_documents = True,
)
return answer_chain
def load_climateqa_chain(retriever,llm):
reformulation_chain = load_reformulation_chain(llm)
answer_chain = load_answer_chain(retriever,llm)
climateqa_chain = SequentialChain(
chains = [reformulation_chain,answer_chain],
input_variables=["query","audience"],
output_variables=["answer","question","language","source_documents"],
return_all = True,
)
return climateqa_chain
|