File size: 3,803 Bytes
f0fc5f8
 
 
 
 
cc2ce8c
f0fc5f8
 
 
6e28a81
787d3cb
f0fc5f8
 
6e28a81
f0fc5f8
6e28a81
f0fc5f8
 
6e28a81
f0fc5f8
 
 
 
cc2ce8c
f0fc5f8
 
 
 
 
 
 
6e28a81
f0fc5f8
6e28a81
 
 
f0fc5f8
 
6e28a81
 
 
 
 
f0fc5f8
 
 
3d561c7
6e28a81
 
 
 
 
3d561c7
 
6e28a81
3d561c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e28a81
 
 
3d561c7
 
 
 
 
 
6e28a81
 
3d561c7
6e28a81
3d561c7
 
 
6e28a81
3d561c7
f0fc5f8
 
 
 
787d3cb
6e28a81
f0fc5f8
6e28a81
 
787d3cb
f0fc5f8
 
 
 
6e28a81
787d3cb
6e28a81
f0fc5f8
 
6e28a81
 
 
 
 
f0fc5f8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# https://python.langchain.com/docs/modules/chains/how_to/custom_chain
# Including reformulation of the question in the chain
import json

from langchain import PromptTemplate, LLMChain
from langchain.chains import QAWithSourcesChain
from langchain.chains import TransformChain, SequentialChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain

from climateqa.prompts import answer_prompt, reformulation_prompt, audience_prompts
from climateqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain


def load_reformulation_chain(llm):
    prompt = PromptTemplate(
        template=reformulation_prompt,
        input_variables=["query"],
    )
    reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json")

    # Parse the output
    def parse_output(output):
        query = output["query"]
        print("output", output)
        json_output = json.loads(output["json"])
        question = json_output.get("question", query)
        language = json_output.get("language", "English")
        return {
            "question": question,
            "language": language,
        }

    transform_chain = TransformChain(
        input_variables=["json"],
        output_variables=["question", "language"],
        transform=parse_output,
    )

    reformulation_chain = SequentialChain(
        chains=[reformulation_chain, transform_chain],
        input_variables=["query"],
        output_variables=["question", "language"],
    )
    return reformulation_chain


def load_combine_documents_chain(llm):
    prompt = PromptTemplate(
        template=answer_prompt,
        input_variables=["summaries", "question", "audience", "language"],
    )
    qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt)
    return qa_chain


def load_qa_chain_with_docs(llm):
    """Load a QA chain with documents.
    Useful when you already have retrieved docs

    To be called with this input

    ```
    output = chain({
        "question":query,
        "audience":"experts climate scientists",
        "docs":docs,
        "language":"English",
    })
    ```
    """

    qa_chain = load_combine_documents_chain(llm)
    chain = QAWithSourcesChain(
        input_docs_key="docs",
        combine_documents_chain=qa_chain,
        return_source_documents=True,
    )
    return chain


def load_qa_chain_with_text(llm):
    prompt = PromptTemplate(
        template=answer_prompt,
        input_variables=["question", "audience", "language", "summaries"],
    )
    qa_chain = LLMChain(llm=llm, prompt=prompt)
    return qa_chain


def load_qa_chain_with_retriever(retriever, llm):
    qa_chain = load_combine_documents_chain(llm)

    # This could be improved by providing a document prompt to avoid modifying page_content in the docs
    # See here https://github.com/langchain-ai/langchain/issues/3523

    answer_chain = CustomRetrievalQAWithSourcesChain(
        combine_documents_chain=qa_chain,
        retriever=retriever,
        return_source_documents=True,
        verbose=True,
        fallback_answer="**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**",
    )
    return answer_chain


def load_climateqa_chain(retriever, llm_reformulation, llm_answer):
    reformulation_chain = load_reformulation_chain(llm_reformulation)
    answer_chain = load_qa_chain_with_retriever(retriever, llm_answer)

    climateqa_chain = SequentialChain(
        chains=[reformulation_chain, answer_chain],
        input_variables=["query", "audience"],
        output_variables=["answer", "question", "language", "source_documents"],
        return_all=True,
        verbose=True,
    )
    return climateqa_chain