File size: 3,512 Bytes
139fefe
 
38ed905
 
 
 
 
139fefe
 
4b4bf28
37b1e7a
38ed905
139fefe
 
 
 
 
4b4bf28
 
 
 
 
 
 
 
 
 
139fefe
 
 
4b4bf28
 
 
 
 
 
 
139fefe
 
 
 
 
8edfef8
139fefe
 
37b1e7a
 
139fefe
 
 
37b1e7a
 
139fefe
 
 
 
 
 
 
 
4b4bf28
 
37b1e7a
8edfef8
 
37b1e7a
4b4bf28
139fefe
 
8edfef8
37b1e7a
4b4bf28
8edfef8
 
4b4bf28
 
 
 
 
 
 
8edfef8
4b4bf28
8edfef8
 
 
 
139fefe
 
38ed905
139fefe
 
 
4b4bf28
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.base import format_document

from climateqa.engine.reformulation import make_reformulation_chain
from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
from climateqa.engine.utils import pass_values, flatten_dict,prepare_chain,rename_chain

DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")

def _combine_documents(
    docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
):

    doc_strings =  []

    for i,doc in enumerate(docs):
        # chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
        chunk_type = "Doc"
        doc_string = f"{chunk_type} {i+1}: " + format_document(doc, document_prompt)
        doc_string = doc_string.replace("\n"," ") 
        doc_strings.append(doc_string)

    return sep.join(doc_strings)


def get_text_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "text"]

def get_image_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "image"]


def make_rag_chain(retriever,llm):


    # Construct the prompt
    prompt = ChatPromptTemplate.from_template(answer_prompt_template)
    prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)

    # ------- CHAIN 0 - Reformulation
    reformulation = make_reformulation_chain(llm)
    reformulation = prepare_chain(reformulation,"reformulation")

    # ------- CHAIN 1
    # Retrieved documents
    find_documents = {"docs": itemgetter("question") | retriever} | RunnablePassthrough()
    find_documents = prepare_chain(find_documents,"find_documents")

    # ------- CHAIN 2
    # Construct inputs for the llm
    input_documents = {
        "context":lambda x : _combine_documents(x["docs"]),
        **pass_values(["question","audience","language"])
    }

    # ------- CHAIN 3
    # Bot answer
    llm_final = rename_chain(llm,"answer")

    answer_with_docs = {
        "answer": input_documents | prompt | llm_final | StrOutputParser(),
        **pass_values(["question","audience","language","query","docs"]),
    }

    answer_without_docs = {
        "answer":  prompt_without_docs | llm_final | StrOutputParser(),
        **pass_values(["question","audience","language","query","docs"]),
    }

    # def has_images(x):
    #     image_docs = [doc for doc in x["docs"] if doc.metadata["chunk_type"]=="image"]
    #     return len(image_docs) > 0
    
    def has_docs(x):
        return len(x["docs"]) > 0

    answer = RunnableBranch(
        (lambda x: has_docs(x), answer_with_docs),
        answer_without_docs,
    )


    # ------- FINAL CHAIN
    # Build the final chain
    rag_chain = reformulation | find_documents | answer

    return rag_chain



def make_illustration_chain(llm):

    prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)

    input_description_images = {
        "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
        **pass_values(["question","audience","language","answer"]),
    }

    illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
    return illustration_chain