fsal's picture
first commit
c8ebe28
raw history blame
No virus
2.24 kB
from langchain.chains.base import Chain
from langchain.chains.summarize import load_summarize_chain
from langchain.prompts import PromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import RunnableSequence, RunnablePassthrough
prompt_template = """Write a concise summary of the following text, based on the user input.
User input: {query}
Text:
```
{text}
```
CONCISE SUMMARY:"""
refine_template = (
"You are iteratively crafting a summary of the text below based on the user input\n"
"User input: {query}\n"
"We have provided an existing summary up to a certain point: {existing_answer}\n"
"We have the opportunity to refine the existing summary"
"(only if needed) with some more context below.\n"
"------------\n"
"{text}\n"
"------------\n"
"Given the new context, refine the original summary.\n"
"If the context isn't useful, return the original summary.\n"
"If the context is useful, refine the summary to include the new context.\n"
"Your contribution is helping to build a comprehensive summary of a large body of knowledge.\n"
"You do not have the complete context, so do not discard pieces of the original summary."
)
def get_summarization_chain(
llm: BaseLanguageModel,
prompt: str,
) -> Chain:
_prompt = PromptTemplate.from_template(
prompt_template,
partial_variables={"query": prompt},
)
refine_prompt = PromptTemplate.from_template(
refine_template,
partial_variables={"query": prompt},
)
return load_summarize_chain(
llm=llm,
chain_type="refine",
question_prompt=_prompt,
refine_prompt=refine_prompt,
return_intermediate_steps=False,
input_key="input_documents",
output_key="output_text",
)
def get_rag_summarization_chain(
prompt: str,
retriever: BaseRetriever,
llm: BaseLanguageModel,
input_key: str = "prompt",
) -> RunnableSequence:
return (
{"input_documents": retriever, input_key: RunnablePassthrough()}
| get_summarization_chain(llm, prompt)
| (lambda output: output["output_text"])
)