File size: 3,512 Bytes
139fefe 38ed905 139fefe 4b4bf28 37b1e7a 38ed905 139fefe 4b4bf28 139fefe 4b4bf28 139fefe 8edfef8 139fefe 37b1e7a 139fefe 37b1e7a 139fefe 4b4bf28 37b1e7a 8edfef8 37b1e7a 4b4bf28 139fefe 8edfef8 37b1e7a 4b4bf28 8edfef8 4b4bf28 8edfef8 4b4bf28 8edfef8 139fefe 38ed905 139fefe 4b4bf28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from operator import itemgetter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.base import format_document
from climateqa.engine.reformulation import make_reformulation_chain
from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
from climateqa.engine.utils import pass_values, flatten_dict,prepare_chain,rename_chain
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
):
doc_strings = []
for i,doc in enumerate(docs):
# chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
chunk_type = "Doc"
doc_string = f"{chunk_type} {i+1}: " + format_document(doc, document_prompt)
doc_string = doc_string.replace("\n"," ")
doc_strings.append(doc_string)
return sep.join(doc_strings)
def get_text_docs(x):
return [doc for doc in x if doc.metadata["chunk_type"] == "text"]
def get_image_docs(x):
return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
def make_rag_chain(retriever,llm):
# Construct the prompt
prompt = ChatPromptTemplate.from_template(answer_prompt_template)
prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
# ------- CHAIN 0 - Reformulation
reformulation = make_reformulation_chain(llm)
reformulation = prepare_chain(reformulation,"reformulation")
# ------- CHAIN 1
# Retrieved documents
find_documents = {"docs": itemgetter("question") | retriever} | RunnablePassthrough()
find_documents = prepare_chain(find_documents,"find_documents")
# ------- CHAIN 2
# Construct inputs for the llm
input_documents = {
"context":lambda x : _combine_documents(x["docs"]),
**pass_values(["question","audience","language"])
}
# ------- CHAIN 3
# Bot answer
llm_final = rename_chain(llm,"answer")
answer_with_docs = {
"answer": input_documents | prompt | llm_final | StrOutputParser(),
**pass_values(["question","audience","language","query","docs"]),
}
answer_without_docs = {
"answer": prompt_without_docs | llm_final | StrOutputParser(),
**pass_values(["question","audience","language","query","docs"]),
}
# def has_images(x):
# image_docs = [doc for doc in x["docs"] if doc.metadata["chunk_type"]=="image"]
# return len(image_docs) > 0
def has_docs(x):
return len(x["docs"]) > 0
answer = RunnableBranch(
(lambda x: has_docs(x), answer_with_docs),
answer_without_docs,
)
# ------- FINAL CHAIN
# Build the final chain
rag_chain = reformulation | find_documents | answer
return rag_chain
def make_illustration_chain(llm):
prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
input_description_images = {
"images":lambda x : _combine_documents(get_image_docs(x["docs"])),
**pass_values(["question","audience","language","answer"]),
}
illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
return illustration_chain |