Zulelee's picture
Upload 254 files
5e9cd1d verified
from typing import List, Optional
from langchain.schema.language_model import BaseLanguageModel
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from configs import (logger)
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser
from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain
import sys
import asyncio
class SummaryAdapter:
_OVERLAP_SIZE: int
token_max: int
_separator: str = "\n\n"
chain: MapReduceDocumentsChain
def __init__(self, overlap_size: int, token_max: int,
chain: MapReduceDocumentsChain):
self._OVERLAP_SIZE = overlap_size
self.chain = chain
self.token_max = token_max
@classmethod
def form_summary(cls,
llm: BaseLanguageModel,
reduce_llm: BaseLanguageModel,
overlap_size: int,
token_max: int = 1300):
"""
获取实例
:param reduce_llm: 用于合并摘要的llm
:param llm: 用于生成摘要的llm
:param overlap_size: 重叠部分大小
:param token_max: 最大的chunk数量,每个chunk长度小于token_max长度,第一次生成摘要时,大于token_max长度的摘要会报错
:return:
"""
# This controls how each document will be formatted. Specifically,
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
# The prompt here should take as an input variable the
# `document_variable_name`
prompt_template = (
"根据文本执行任务。以下任务信息"
"{task_briefing}"
"文本内容如下: "
"\r\n"
"{context}"
)
prompt = PromptTemplate(
template=prompt_template,
input_variables=["task_briefing", "context"]
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
# We now define how to combine these summaries
reduce_prompt = PromptTemplate.from_template(
"Combine these summaries: {context}"
)
reduce_llm_chain = LLMChain(llm=reduce_llm, prompt=reduce_prompt)
document_variable_name = "context"
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
reduce_documents_chain = ReduceDocumentsChain(
token_max=token_max,
combine_documents_chain=combine_documents_chain,
)
chain = MapReduceDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
reduce_documents_chain=reduce_documents_chain,
# 返回中间步骤
return_intermediate_steps=True
)
return cls(overlap_size=overlap_size,
chain=chain,
token_max=token_max)
def summarize(self,
file_description: str,
docs: List[DocumentWithVSId] = []
) -> List[Document]:
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 同步调用协程代码
return loop.run_until_complete(self.asummarize(file_description=file_description,
docs=docs))
async def asummarize(self,
file_description: str,
docs: List[DocumentWithVSId] = []) -> List[Document]:
logger.info("start summary")
"""
这个过程分成两个部分:
1. 对每个文档进行处理,得到每个文档的摘要
map_results = self.llm_chain.apply(
# FYI - this is parallelized and so it is fast.
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
callbacks=callbacks,
)
2. 对每个文档的摘要进行合并,得到最终的摘要,return_intermediate_steps=True,返回中间步骤
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
result_docs, token_max=token_max, callbacks=callbacks, **kwargs
)
"""
summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs,
task_briefing="描述不同方法之间的接近度和相似性,"
"以帮助读者理解它们之间的关系。")
print(summary_combine)
print(summary_intermediate_steps)
# if len(summary_combine) == 0:
# # 为空重新生成,数量减半
# result_docs = [
# Document(page_content=question_result_key, metadata=docs[i].metadata)
# # This uses metadata from the docs, and the textual results from `results`
# for i, question_result_key in enumerate(
# summary_intermediate_steps["intermediate_steps"][
# :len(summary_intermediate_steps["intermediate_steps"]) // 2
# ])
# ]
# summary_combine, summary_intermediate_steps = self.chain.reduce_documents_chain.combine_docs(
# result_docs, token_max=self.token_max
# )
logger.info("end summary")
doc_ids = ",".join([doc.id for doc in docs])
_metadata = {
"file_description": file_description,
"summary_intermediate_steps": summary_intermediate_steps,
"doc_ids": doc_ids
}
summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata)
return [summary_combine_doc]
def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]:
"""
# 将文档中page_content句子叠加的部分去掉
:param docs:
:param separator:
:return:
"""
merge_docs = []
pre_doc = None
for doc in docs:
# 第一个文档直接添加
if len(merge_docs) == 0:
pre_doc = doc.page_content
merge_docs.append(doc.page_content)
continue
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
# 查询重叠部分,直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator)
for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1):
# 每次迭代删除前面的字符
pre_doc = pre_doc[1:]
if doc.page_content[:len(pre_doc)] == pre_doc:
# 删除下一个开头重叠的部分
merge_docs.append(doc.page_content[len(pre_doc):])
break
pre_doc = doc.page_content
return merge_docs
def _join_docs(self, docs: List[str]) -> Optional[str]:
text = self._separator.join(docs)
text = text.strip()
if text == "":
return None
else:
return text
if __name__ == '__main__':
docs = [
'梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的',
'梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象',
'使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各',
'值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全'
]
_OVERLAP_SIZE = 1
separator: str = "\n\n"
merge_docs = []
# 将文档中page_content句子叠加的部分去掉,
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
pre_doc = None
for doc in docs:
# 第一个文档直接添加
if len(merge_docs) == 0:
pre_doc = doc
merge_docs.append(doc)
continue
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
# 查询重叠部分,直到pre_doc的长度小于 _OVERLAP_SIZE-2len(separator)
for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1):
# 每次迭代删除前面的字符
pre_doc = pre_doc[1:]
if doc[:len(pre_doc)] == pre_doc:
# 删除下一个开头重叠的部分
page_content = doc[len(pre_doc):]
merge_docs.append(page_content)
pre_doc = doc
break
# 将merge_docs中的句子合并成一个文档
text = separator.join(merge_docs)
text = text.strip()
print(text)