File size: 3,447 Bytes
139fefe
 
38ed905
 
 
 
139fefe
088e816
 
09457a7
c3b815e
 
38ed905
139fefe
 
 
 
 
4b4bf28
 
 
 
 
 
caf1faa
 
 
 
 
4b4bf28
 
 
139fefe
 
 
4b4bf28
 
 
 
 
 
088e816
139fefe
088e816
 
09457a7
088e816
 
 
 
 
4b4bf28
088e816
 
 
 
8edfef8
088e816
139fefe
088e816
 
 
 
139fefe
088e816
49acaf1
09457a7
49acaf1
088e816
09457a7
 
 
 
 
 
 
088e816
4b4bf28
088e816
caf1faa
 
 
 
c3b815e
caf1faa
c3b815e
 
 
 
 
caf1faa
c3b815e
 
caf1faa
c3b815e
caf1faa
 
4b4bf28
 
 
 
c3b815e
481f3b1
c3b815e
481f3b1
c3b815e
 
 
 
481f3b1
c3b815e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.base import format_document

from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
from climateqa.engine.chains.prompts import papers_prompt_template
import time
from ..utils import rename_chain, pass_values


DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")

def _combine_documents(
    docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
):

    doc_strings =  []

    for i,doc in enumerate(docs):
        # chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
        chunk_type = "Doc"
        if isinstance(doc,str):
            doc_formatted = doc
        else:
            doc_formatted = format_document(doc, document_prompt)
        doc_string = f"{chunk_type} {i+1}: " + doc_formatted
        doc_string = doc_string.replace("\n"," ") 
        doc_strings.append(doc_string)

    return sep.join(doc_strings)


def get_text_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "text"]

def get_image_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "image"]

def make_rag_chain(llm):
    prompt = ChatPromptTemplate.from_template(answer_prompt_template)
    chain = ({
        "context":lambda x : _combine_documents(x["documents"]),
        "context_length":lambda x : print("CONTEXT LENGTH : " , len(_combine_documents(x["documents"]))),
        "query":itemgetter("query"),
        "language":itemgetter("language"),
        "audience":itemgetter("audience"),
    } | prompt | llm | StrOutputParser())
    return chain

def make_rag_chain_without_docs(llm):
    prompt = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
    chain = prompt | llm | StrOutputParser()
    return chain

def make_rag_node(llm,with_docs = True):

    if with_docs:
        rag_chain = make_rag_chain(llm)
    else:
        rag_chain = make_rag_chain_without_docs(llm)

    async def answer_rag(state,config):
        print("---- Answer RAG ----")
        start_time = time.time()

        answer = await rag_chain.ainvoke(state,config)
    
        end_time = time.time()
        elapsed_time = end_time - start_time
        print("RAG elapsed time: ", elapsed_time)
        print("Answer size : ", len(answer))
        # print(f"\n\nAnswer:\n{answer}")
        
        return {"answer":answer}

    return answer_rag




def make_rag_papers_chain(llm):

    prompt = ChatPromptTemplate.from_template(papers_prompt_template)
    input_documents = {
        "context":lambda x : _combine_documents(x["docs"]),
        **pass_values(["question","language"])
    }

    chain = input_documents | prompt | llm | StrOutputParser()
    chain = rename_chain(chain,"answer")

    return chain






def make_illustration_chain(llm):

    prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)

    input_description_images = {
        "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
        **pass_values(["question","audience","language","answer"]),
    }

    illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
    return illustration_chain