File size: 3,670 Bytes
f0fc5f8
 
 
 
 
cc2ce8c
f0fc5f8
 
 
93decd4
 
f0fc5f8
 
3d561c7
 
 
 
 
 
 
 
 
93decd4
3d561c7
 
 
 
 
 
 
 
6e28a81
 
 
3d561c7
 
 
 
cde6d5c
 
 
 
 
 
 
 
 
3d561c7
 
6e28a81
 
3d561c7
6e28a81
3d561c7
 
 
93decd4
cde6d5c
 
 
93decd4
cde6d5c
 
 
 
 
 
93decd4
cde6d5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e28a81
3d561c7
f0fc5f8
 
 
 
787d3cb
6e28a81
f0fc5f8
6e28a81
 
93decd4
f0fc5f8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# https://python.langchain.com/docs/modules/chains/how_to/custom_chain
# Including reformulation of the question in the chain
import json

from langchain import PromptTemplate, LLMChain
from langchain.chains import QAWithSourcesChain
from langchain.chains import TransformChain, SequentialChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain

from anyqa.prompts import answer_prompt, reformulation_prompt
from anyqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain


def load_qa_chain_with_docs(llm):
    """Load a QA chain with documents.
    Useful when you already have retrieved docs

    To be called with this input

    ```
    output = chain({
        "question":query,
        "audience":"experts scientists",
        "docs":docs,
        "language":"English",
    })
    ```
    """

    qa_chain = load_combine_documents_chain(llm)
    chain = QAWithSourcesChain(
        input_docs_key="docs",
        combine_documents_chain=qa_chain,
        return_source_documents=True,
    )
    return chain


def load_combine_documents_chain(llm):
    prompt = PromptTemplate(
        template=answer_prompt,
        input_variables=["summaries", "question", "audience", "language"],
    )
    qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt)
    return qa_chain


def load_qa_chain_with_text(llm):
    prompt = PromptTemplate(
        template=answer_prompt,
        input_variables=["question", "audience", "language", "summaries"],
    )
    qa_chain = LLMChain(llm=llm, prompt=prompt)
    return qa_chain


def load_qa_chain(retriever, llm_reformulation, llm_answer):
    reformulation_chain = load_reformulation_chain(llm_reformulation)
    answer_chain = load_qa_chain_with_retriever(retriever, llm_answer)

    qa_chain = SequentialChain(
        chains=[reformulation_chain, answer_chain],
        input_variables=["query", "audience"],
        output_variables=["answer", "question", "language", "source_documents"],
        return_all=True,
        verbose=True,
    )
    return qa_chain


def load_reformulation_chain(llm):
    prompt = PromptTemplate(
        template=reformulation_prompt,
        input_variables=["query"],
    )
    reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json")

    # Parse the output
    def parse_output(output):
        query = output["query"]
        print("output", output)
        json_output = json.loads(output["json"])
        question = json_output.get("question", query)
        language = json_output.get("language", "English")
        return {
            "question": question,
            "language": language,
        }

    transform_chain = TransformChain(
        input_variables=["json"],
        output_variables=["question", "language"],
        transform=parse_output,
    )

    reformulation_chain = SequentialChain(
        chains=[reformulation_chain, transform_chain],
        input_variables=["query"],
        output_variables=["question", "language"],
    )
    return reformulation_chain


def load_qa_chain_with_retriever(retriever, llm):
    qa_chain = load_combine_documents_chain(llm)

    # This could be improved by providing a document prompt to avoid modifying page_content in the docs
    # See here https://github.com/langchain-ai/langchain/issues/3523

    answer_chain = CustomRetrievalQAWithSourcesChain(
        combine_documents_chain=qa_chain,
        retriever=retriever,
        return_source_documents=True,
        verbose=True,
        fallback_answer="**⚠️ No relevant passages found in the sources, you may want to ask a more specific question.**",
    )
    return answer_chain