Annual_Report_Summarization_Demo / mapReduceSummarizer.py
RMWeerasinghe's picture
Initial Commit
99e744f
raw
history blame
No virus
1.84 kB
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import MapReduceDocumentsChain, ReduceDocumentsChain, LLMChain, StuffDocumentsChain
from langchain.prompts import PromptTemplate
def get_map_reduce_chain(pipeline_or_llm,model_type)-> LLMChain:
if model_type == "openai":
llm = pipeline_or_llm
map_template = """The following is a set of documents
{docs}
Based on this list of docs, please identify the main themes.
Helpful Answer:"""
map_prompt = PromptTemplate.from_template(map_template)
reduce_template = """The following is set of summaries:
{docs}
Take these and distill into a final, consolidated summary of the main themes.
Helpful Answer:"""
reduce_prompt = PromptTemplate.from_template(reduce_template)
else:
map_prompt = PromptTemplate.from_template(template="{docs}")
reduce_prompt = PromptTemplate.from_template(template="{docs}")
llm = HuggingFacePipeline(pipeline=pipeline_or_llm)
map_chain = LLMChain(llm = llm, prompt=map_prompt)
reduce_chain = LLMChain(llm = llm, prompt = reduce_prompt,verbose = True)
combine_documents_chain = StuffDocumentsChain(llm_chain=reduce_chain, document_variable_name="docs")
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=combine_documents_chain,
token_max=16384,
verbose = True,
)
map_reduce_chain = MapReduceDocumentsChain(
llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name="docs",
return_intermediate_steps=False,
verbose = True,
)
return map_reduce_chain